In [1]:
import os
import pandas as pd
import numpy as np
import time
import torch
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.utils.data.sampler import WeightedRandomSampler
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import Trainer, seed_everything
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.feature_selection import mutual_info_classif, chi2
from sklearn.linear_model import LassoCV
import matplotlib.pyplot as plt
from pytorch_lightning.loggers import TensorBoardLogger
import seaborn as sns
import os
import sys
from sklearn.metrics import f1_score
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
from torch_explain.models.explainer import Explainer
from torch_explain.logic.metrics import formula_consistency
# from load_datasets import load_mimic
from imblearn.under_sampling import RandomUnderSampler
from imblearn.over_sampling import SMOTEN
from imblearn.combine import SMOTEENN
from torch.nn.functional import one_hot
from func_timeout import func_set_timeout, func_timeout, FunctionTimedOut
import datetime
import time

seed_everything(42)
base_dir = f'./runs'

Global seed set to 42


In [2]:
files = os.listdir("./categorisedData/")


datasets = {file : pd.read_csv("./categorisedData/" + file) for file in files}


print(files)

results_dict = {}

['breastCancer.csv', 'clusteredData.csv', 'clusteredDataSepsis.csv', 'expertLabelledData.csv', 'metricExtractedData.csv', 'staticData.csv']


In [3]:
@func_set_timeout(90)
def explain_with_timeout(model, val_data, train_data, test_data, topk_expl, concepts):

    return model.explain_class(val_dataloaders=val_data, train_dataloaders=train_data, test_dataloaders=test_data, topk_explanations=topk_expl, concept_names=concepts, max_minterm_complexity=5)

In [4]:
# Nodes in each hidden layer, learning rate

hiddenLayers = {
    'breastCancer.csv' : [[20], 0.01],
    'clusteredData.csv' : [[10], 0.01], 
    'clusteredDataSepsis.csv' : [[20, 40, 20], 0.0001],
    'expertLabelledData.csv' : [[20], 0.01],
    'metricExtractedData.csv' : [[20, 20], 0.01],
    'staticData.csv': [[20], 0.01]
}

In [5]:

np.set_printoptions(threshold=sys.maxsize)

# print(os.listdir("."))


for file in files[:1]:

    file = "clusteredDataSepsis.csv"

    if file in hiddenLayers:
        layers = hiddenLayers[file]
    else:
        print("Set layers for " + file)
        layers = [[20], 0.01]

    print(f"Training {file}\n")

    data = datasets[file]

    if "PatientID" in data.columns:
        data = data.drop(columns=["PatientID"])


    targetName = "Mortality14Days"

    targetSeries = data[targetName]
    print(data[targetName].value_counts())
    data = data.drop(columns=[targetName])
    
    dataTensor = torch.FloatTensor(data.to_numpy())
    targetTensor = one_hot(torch.tensor(targetSeries.values).to(torch.long)).to(torch.float)

    n_concepts = dataTensor.shape[1]
    print("There are " + str(n_concepts) + " concepts")
    n_classes = 2
    # print("feature names: ", concept_names)
    # print("features:", n_concepts)
    # print(n_classes)

    

    

    n_splits = 5
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
    
    results_list = []
    feature_selection = []

    splitResults_list = []


    x = dataTensor
    y = targetTensor

    for split, (trainval_index, test_index) in enumerate(skf.split(x.cpu().detach().numpy(),
                                                                y.argmax(dim=1).cpu().detach().numpy())):
        print(f'Split [{split + 1}/{n_splits}]')

        # print(x.shape)

        x_trainval, x_test = torch.FloatTensor(x[trainval_index]), torch.FloatTensor(x[test_index])
        y_trainval, y_test = torch.FloatTensor(y[trainval_index]), torch.FloatTensor(y[test_index])
        x_train, x_val, y_train, y_val = train_test_split(x_trainval, y_trainval, test_size=0.2, random_state=42)
        print(f'{len(y_train)}/{len(y_val)}/{len(y_test)}')

        print(pd.Series(np.argmax(y_train.numpy(), axis=1)).value_counts().values)

        # For oversampling... 
        clf = SMOTEN(random_state=0)

        x_train, y_train = clf.fit_resample(x_train.numpy(), np.argmax(y_train.numpy(), axis=1))

        x_train = torch.FloatTensor(x_train)
        y_train = one_hot(torch.tensor(y_train).to(torch.long)).to(torch.float)

        print(pd.Series(np.argmax(y_train.numpy(), axis=1)).value_counts().values)

        batch_size = 128

        train_data = TensorDataset(x_train, y_train)
        train_loader = DataLoader(train_data, batch_size = batch_size, shuffle=True)



        # For random sampling...
        # class_count = pd.Series(targetSeries).value_counts()
        # print(class_count)
        # weights = 1. / torch.FloatTensor(class_count.values)
        # print(weights)
        # train_weights = np.array([weights[t] for t in torch.argmax(y_train, axis=1).numpy()]).astype(np.float64)
        # sampler = WeightedRandomSampler(train_weights, train_size)
        # train_data = TensorDataset(x_train, y_train)
        # train_loader = DataLoader(train_data, batch_size=train_size, sampler=sampler)

        


        val_data = TensorDataset(x_val, y_val)
        test_data = TensorDataset(x_test, y_test)
        val_loader = DataLoader(val_data, batch_size = batch_size)
        test_loader = DataLoader(test_data, batch_size = batch_size)

        checkpoint_callback = ModelCheckpoint(dirpath=base_dir, monitor='val_loss', mode='min', save_top_k=1)
        early_stopping_callback = EarlyStopping(monitor='val_loss', patience=20, verbose=True, mode='min')

        logger = TensorBoardLogger("./runs/splits/", name=file)

        trainer = Trainer(max_epochs=200, gpus=1, auto_lr_find=True, deterministic=True,
                        check_val_every_n_epoch=1, default_root_dir=base_dir,
                        weights_save_path=base_dir, callbacks=[checkpoint_callback, early_stopping_callback],
                        logger=logger, enable_progress_bar=False, gradient_clip_val=0.5)

        model = Explainer(n_concepts=n_concepts, n_classes=n_classes, l1=1e-3, lr=layers[1],
                        explainer_hidden=layers[0], temperature=0.7)

        trainer.fit(model, train_loader, val_loader)
        # print(f"Gamma: {model.model[0].concept_mask}")
        model.freeze()

        # Precision, Recall, F1
        y_pred = torch.argmax(model(x_test), axis=1)
        y_test_argmax = torch.argmax(y_test, axis=1)

        scores = [f1_score(y_test_argmax.numpy(), y_pred.numpy(), average='macro'), 
                recall_score(y_test_argmax.numpy(), y_pred.numpy(), average='macro'), 
                precision_score(y_test_argmax.numpy(), y_pred.numpy(), average='macro')]

        print(f"Before loading best: {scores}")

        # scores_list.append(scores)
    
        model = model.load_from_checkpoint(checkpoint_callback.best_model_path)

        

        # Precision, Recall, F1

        scores = [f1_score(y_test_argmax.numpy(), y_pred.numpy(), average='macro'), 
                recall_score(y_test_argmax.numpy(), y_pred.numpy(), average='macro'), 
                precision_score(y_test_argmax.numpy(), y_pred.numpy(), average='macro')]

        print(f"{file} split {split+1} scores: {scores}")

        print("\nTesting...\n")
        # test_loader is giving a new batch of testing values, hence why the output here is different than above.
        model_results = trainer.test(model, dataloaders=test_loader)


        print("\nExplaining\n")

        start = time.time()

        try:

            results, f = explain_with_timeout(model, val_data=val_loader, train_data=train_loader, test_data=test_loader,
                                        topk_expl=3,
                                        concepts=data.columns)

        except FunctionTimedOut:
            print("Explanation timed out, skipping...")
            # explanations_list.append(None)
            # results_list.append(None)
            continue

        end = time.time()
        # explanations_list.append(f)

        print(f"Explaining time: {end - start}")
        results['model_accuracy'] = model_results[0]['test_acc_epoch']
        results['extraction_time'] = end - start

        results_list.append(results)
        extracted_concepts = []
        all_concepts = model.model[0].concept_mask[0] > 0.5
        common_concepts = model.model[0].concept_mask[0] > 0.5
        for j in range(n_classes):
            # print(f[j]['explanation'])
            n_used_concepts = sum(model.model[0].concept_mask[j] > 0.5)
            print(f"Number of features that impact on target {j}: {n_used_concepts}")
            print(f"Explanation for target {j}: {f[j]['explanation']}")
            print(f"Explanation accuracy: {f[j]['explanation_accuracy']}")
            extracted_concepts.append(n_used_concepts)
            all_concepts += model.model[0].concept_mask[j] > 0.5
            common_concepts *= model.model[0].concept_mask[j] > 0.5


        results['extracted_concepts'] = np.mean(extracted_concepts)
        results['common_concepts_ratio'] = sum(common_concepts) / sum(all_concepts)



        # prec_rec = precision_recall(y_pred, y_test_argmax, num_classes = n_classes)

        # print(prec_rec)

        # compare against standard feature selection
        i_mutual_info = mutual_info_classif(x_trainval, y_trainval[:, 1])
        i_chi2 = chi2(x_trainval, y_trainval[:, 1])[0]
        i_chi2[np.isnan(i_chi2)] = 0
        lasso = LassoCV(cv=5, random_state=0).fit(x_trainval, y_trainval[:, 1])
        i_lasso = np.abs(lasso.coef_)
        i_mu = model.model[0].concept_mask[1]
        # print(model.model[0].concept_mask)
        df = pd.DataFrame(np.hstack([
            i_mu.numpy(),
            # i_mutual_info / np.max(i_mutual_info),
            # i_chi2 / np.max(i_chi2),
            # i_lasso / np.max(i_lasso),
        ]).T, columns=['feature importance'])
        df['method'] = 'explainer'
        # df.iloc[90:, 1] = 'MI'
        # df.iloc[180:, 1] = 'CHI2'
        # df.iloc[270:, 1] = 'Lasso'
        df['feature'] = np.hstack([np.arange(0, n_concepts)])
        feature_selection.append(df)

        splitResults = [results['model_accuracy'], results['extraction_time'], *scores, f]

        splitResults_list.append(splitResults)
        break


    results_dict[file] = splitResults_list


Training clusteredDataSepsis.csv

0    31606
1     2422
Name: Mortality14Days, dtype: int64
There are 72 concepts
Split [1/5]
21777/5445/6806
[20225  1552]


  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


[14768 13050]



  | Name  | Type             | Params
-------------------------------------------
0 | loss  | CrossEntropyLoss | 0     
1 | model | Sequential       | 4.6 K 
-------------------------------------------
4.6 K     Trainable params
0         Non-trainable params
4.6 K     Total params
0.018     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(
Metric val_loss improved. New best score: 0.669
Metric val_loss improved by 0.051 >= min_delta = 0.0. New best score: 0.618
Metric val_loss improved by 0.029 >= min_delta = 0.0. New best score: 0.589
Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 0.589
Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 0.588
Metric val_loss improved by 0.015 >= min_delta = 0.0. New best score: 0.573
Monitored metric val_loss did not improve in the last 10 records. Best score: 0.573. Signaling Trainer to stop.
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Before loading best: [0.5270557005378721, 0.6531393062661937, 0.5489363681944676]
clusteredDataSepsis.csv split 1 scores: [0.5270557005378721, 0.6531393062661937, 0.5489363681944676]

Testing...

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      f1_test_epoch         0.5330842137336731
     test_acc_epoch         0.7358213067054749
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

Explaining

Explaining time: 5.772283554077148
Number of features that impact on target 0: 58
Explanation for target 0: ~O2Sat_StdDev_high & ~Potassium_Mean_low & ~Temp_StdDev_high
Explanation accuracy: 0.5589558231892459
Number of features that impact on target 1: 50
Explanation for target 1: Temp_St

In [6]:
y_pred = torch.argmax(model(x_train), axis=1)

y = torch.argmax(y_train, axis=1)

print("train f1:" , f1_score(y, y_pred, average='macro'))

train f1: 0.7089564472233105


In [7]:
y_pred = torch.argmax(model(x_test), axis=1)

y = torch.argmax(y_test, axis=1)

print("test f1:", f1_score(y, y_pred, average='macro'))

test f1: 0.5351116208836337


In [8]:
display(pd.Series(y.numpy()).value_counts().values)
display(pd.Series(y_pred.numpy()).value_counts().values)

array([6322,  484], dtype=int64)

array([4956, 1850], dtype=int64)

In [9]:
def removeNoneExplanations(explanations):

    toRemove = []
    for idx, expl in enumerate(explanations):
        if expl['explanation'] == None:
            toRemove.append(idx)
    for i in sorted(toRemove, reverse=True):
        # print(class0Explanations[i])
        del explanations[i]

    return explanations

In [10]:
kFoldMeans = []


for x in results_dict:

    cols = ['file', 'model_accuracy', 'extraction_time', 'f1', 'recall', 'precision']


    rows = []

    class0Explanations = []
    class1Explanations = []

    for split in results_dict[x]:
        row = [x]
        
        row.extend(split[:5])

        # print(row)
        rows.append(row)

        class0Explanations.append(split[5][0])
        class1Explanations.append(split[5][1])


    class0Explanations = removeNoneExplanations(class0Explanations)

    class1Explanations = removeNoneExplanations(class1Explanations)

    

    class0DF = pd.DataFrame(class0Explanations)
    class1DF = pd.DataFrame(class1Explanations)

    average0 = class0DF.mean().values
    average1 = class1DF.mean().values

    if len(class0Explanations) == 0:
        average0 = [0]*4

    if len(class1Explanations) == 0:
        average1 = [0]*4

    df = pd.DataFrame(columns=cols, data=rows)

    df = df.set_index('file')

    combinedCols = list(df.describe().columns)

    # print(combinedCols)

    row = [x]
    row.extend(np.round(df.describe().loc['mean'].values, 2))

    row.extend(list(average0)[1:])
    row.extend(list(average1)[1:])

    # print(row)

    kFoldMeans.append(row)

# print(kFoldMeans)



kFoldMeansCols = list(df.describe().columns)

combinedCols.insert(0, "file")


# print(kFoldMeansCols)
for idx, d in enumerate(results_dict[list(results_dict.keys())[0]][0][5]):
    combinedCols.extend([str(x) + "_" + str(idx) for x in list(d)[2:]])

# print(combinedCols)

totalMeans = pd.DataFrame(columns=combinedCols, data=kFoldMeans)

totalMeans = totalMeans.set_index('file')

cols = totalMeans.columns

cols = [c.replace("explanation", "expl").replace("accuracy", "acc").replace("complexity", "comp") for c in cols]

totalMeans.columns = cols

totalMeans = totalMeans.round(2)

totalMeans = totalMeans.drop("extraction_time", axis=1)

display(totalMeans)



timeNow = datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S")
totalMeans.to_csv(f"./processingCache/totalMeans{timeNow}.csv")

  average0 = class0DF.mean().values
  average1 = class1DF.mean().values


Unnamed: 0_level_0,model_acc,f1,recall,precision,expl_acc_0,expl_fidelity_0,expl_comp_0,expl_acc_1,expl_fidelity_1,expl_comp_1
file,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
clusteredDataSepsis.csv,0.74,0.53,0.65,0.55,0.56,0.86,3.0,0.57,0.85,2.0


In [11]:
results_dict

{'clusteredDataSepsis.csv': [[0.7358213067054749,
   5.772283554077148,
   0.5270557005378721,
   0.6531393062661937,
   0.5489363681944676,
   [{'target_class': 0,
     'explanation': '~O2Sat_StdDev_high & ~Potassium_Mean_low & ~Temp_StdDev_high',
     'explanation_accuracy': 0.5589558231892459,
     'explanation_fidelity': 0.8565971201880693,
     'explanation_complexity': 3},
    {'target_class': 1,
     'explanation': 'Temp_StdDev_high | ICULOS_high',
     'explanation_accuracy': 0.5714128762909251,
     'explanation_fidelity': 0.8521892447840141,
     'explanation_complexity': 2}]]]}