In [1]:
# Libraries
import numpy as np
import pandas as pd
pd.set_option('display.max_columns', 40)
pd.set_option('display.width', 2000)
from tqdm import tqdm
import math
import time

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedKFold

from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import average_precision_score

import gc

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Statistical tests (ran after results generated)

In [2]:
set_results = [0.816038,
0.829856,
0.838451,
0.834127,
0.790073,
0.825867,
0.772416,
0.818897,
0.817904,
0.821161]
random_set_results = [0.801023,
0.809279,
0.852223,
0.828857,
0.747952,
0.773643,
0.829476,
0.776294,
0.799354,
0.824476]
lr_results = [0.789008,
0.790284,
0.819746,
0.791620,
0.760105,
0.814236,
0.776551,
0.779840,
0.777884,
0.788120]
charlson_results =[0.656873,
0.663736,
0.645588,
0.660310,
0.625806,
0.668568,
0.652461,
0.639701,
0.628784,
0.654787	
]
 

set_results_8 = [0.967391,
0.911765,
0.957143,
0.910256,
0.577381,
0.854430,
0.794118]
random_set_results_8 = [0.989130,
0.823529,
0.185714,
0.076923,
0.684524,
0.588608,
0.544118]
lr_results_8 = [0.967391,
0.963235,
0.985714,
0.974359,
0.708333,
0.568354,
0.441176]
charlson_results_8 = [0.654787,
0.492647,
0.492857,
0.500000,
0.500000,
0.493671,
0.500000]

In [3]:
from scipy.stats import shapiro 
from scipy.stats import kstest
from scipy.stats import ttest_rel
from scipy.stats import wilcoxon
from scipy.stats import normaltest

In [39]:
shapiro(set_results)
kstest(set_results, 'norm')

ShapiroResult(statistic=0.8530228137969971, pvalue=0.06309852004051208)

KstestResult(statistic=0.7800659597749264, pvalue=6.097117198531377e-07)

In [31]:
shapiro(random_set_results)
kstest(random_set_results, 'norm')

ShapiroResult(statistic=0.9691386222839355, pvalue=0.8827458620071411)

KstestResult(statistic=0.7727554447028735, pvalue=8.696277437320612e-07)

In [32]:
shapiro(lr_results)
kstest(lr_results, 'norm')

ShapiroResult(statistic=0.9283631443977356, pvalue=0.4319867789745331)

KstestResult(statistic=0.7764040879128568, pvalue=7.29211618720243e-07)

In [4]:
shapiro(charlson_results)
kstest(charlson_results, 'norm')
normaltest(charlson_results)

ShapiroResult(statistic=0.9371631145477295, pvalue=0.5219002366065979)

KstestResult(statistic=0.7342789020541408, pvalue=4.915848430562385e-06)



NormaltestResult(statistic=1.005837393893007, pvalue=0.6047629614707875)

In [37]:
shapiro(set_results_8)
kstest(set_results_8, 'norm')

ShapiroResult(statistic=0.8202360272407532, pvalue=0.06455416977405548)

KstestResult(statistic=0.7181589467957652, pvalue=0.0003549931229125905)

In [38]:
shapiro(random_set_results_8)
kstest(random_set_results_8, 'norm')

ShapiroResult(statistic=0.9490331411361694, pvalue=0.7208895087242126)

KstestResult(statistic=0.5306575997994031, pvalue=0.02292443342962655)

In [40]:
shapiro(lr_results_8)
kstest(lr_results_8, 'norm')

ShapiroResult(statistic=0.8019464015960693, pvalue=0.04281892254948616)

KstestResult(statistic=0.6704572067007551, pvalue=0.001245833330910244)

In [5]:
shapiro(charlson_results_8)
kstest(charlson_results_8, 'norm')
normaltest(charlson_results_8)

ShapiroResult(statistic=0.5029751062393188, pvalue=1.8005403035203926e-05)

KstestResult(statistic=0.6888689837361116, pvalue=0.0007837190773760501)

ValueError: skewtest is not valid with less than 8 samples; 7 samples were given.

In [6]:
ttest_rel(set_results, random_set_results) 
ttest_rel(set_results, lr_results) 
ttest_rel(lr_results, random_set_results) 
ttest_rel(random_set_results, lr_results) 

ttest_rel(charlson_results, set_results) 
ttest_rel(charlson_results, random_set_results) 
ttest_rel(charlson_results, lr_results) 

Ttest_relResult(statistic=1.2022451343608076, pvalue=0.2599446970742992)

Ttest_relResult(statistic=5.844200133068764, pvalue=0.00024552765660242667)

Ttest_relResult(statistic=-1.768451442569985, pvalue=0.11077672866130538)

Ttest_relResult(statistic=1.768451442569985, pvalue=0.11077672866130538)

Ttest_relResult(statistic=-25.91095431545953, pvalue=9.159242655591445e-10)

Ttest_relResult(statistic=-16.49170413084455, pvalue=4.938874103199237e-08)

Ttest_relResult(statistic=-30.09847575145, pvalue=2.412157527877996e-10)

In [7]:
wilcoxon(set_results, random_set_results) 
wilcoxon(set_results, lr_results) 
wilcoxon(lr_results, random_set_results) 

wilcoxon(charlson_results, set_results) 
wilcoxon(charlson_results, random_set_results) 
wilcoxon(charlson_results, lr_results) 

WilcoxonResult(statistic=14.0, pvalue=0.193359375)

WilcoxonResult(statistic=1.0, pvalue=0.00390625)

WilcoxonResult(statistic=13.0, pvalue=0.16015625)

WilcoxonResult(statistic=0.0, pvalue=0.001953125)

WilcoxonResult(statistic=0.0, pvalue=0.001953125)

WilcoxonResult(statistic=0.0, pvalue=0.001953125)

In [8]:
ttest_rel(set_results_8, random_set_results_8) 
ttest_rel(set_results_8, lr_results_8) 
ttest_rel(lr_results_8, random_set_results_8) 

ttest_rel(charlson_results_8, set_results_8) 
ttest_rel(charlson_results_8, random_set_results_8) 
ttest_rel(charlson_results_8, lr_results_8) 

Ttest_relResult(statistic=2.1208077690526217, pvalue=0.07819611078960588)

Ttest_relResult(statistic=0.7313937553608119, pvalue=0.49210321625783504)

Ttest_relResult(statistic=1.5458160651212287, pvalue=0.173106274064707)

Ttest_relResult(statistic=-6.893102535820686, pvalue=0.0004603031040224992)

Ttest_relResult(statistic=-0.3288271130612035, pvalue=0.7534604483729107)

Ttest_relResult(statistic=-3.4393097652361853, pvalue=0.01381262106563483)

In [10]:
wilcoxon(set_results_8, random_set_results_8) 
wilcoxon(set_results_8, lr_results_8) 
wilcoxon(lr_results_8, random_set_results_8) 

wilcoxon(charlson_results_8, set_results_8) 
wilcoxon(charlson_results_8, random_set_results_8) 
wilcoxon(charlson_results_8, lr_results_8) 

WilcoxonResult(statistic=4.0, pvalue=0.109375)

WilcoxonResult(statistic=10.0, pvalue=0.916511907863894)

WilcoxonResult(statistic=7.0, pvalue=0.296875)

WilcoxonResult(statistic=0.0, pvalue=0.015625)

WilcoxonResult(statistic=11.0, pvalue=0.6875)

WilcoxonResult(statistic=1.0, pvalue=0.03125)

# LR

In [3]:
# Import
path = r'data/slim_problem_dummies_death.csv'
slim_problem_dummies_death = pd.read_csv(path, index_col=0)
slim_problem_dummies_death.reset_index(inplace=True, drop=True)
# Work out count for problems
problem_sum = pd.DataFrame(slim_problem_dummies_death.iloc[:,2:].sum(axis=0), columns=['Count']).sort_values(by=['Count'], ascending=False)

# Define what is a rare disease
cut_off_list = [45, 8]

In [None]:
# Get rare 10 fold cv results for one hot

results_df = pd.DataFrame()
x_data = slim_problem_dummies_death.iloc[:,2:]
y_data = slim_problem_dummies_death.death_year_label
# Get CV folds
cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=2)
fold_n = 0
for train_index, test_index in cv.split(x_data, y_data):
    fold_n += 1
    x_train  = x_data[x_data.index.isin(list(train_index))]
    y_train  = y_data[y_data.index.isin(list(train_index))]
    x_test  = x_data[x_data.index.isin(list(test_index))]
    y_test  = y_data[y_data.index.isin(list(test_index))]

    # Fit
    LR = LogisticRegression(class_weight='balanced')
    LR.fit(x_train, y_train)

    # Get results for rare diseases 
    for n in cut_off_list:
        # Get filter list 
        filter_list = problem_sum[problem_sum['Count'] > n].index.tolist()
        x_test2 = x_test.copy()
        for code in filter_list:
            x_test2 = x_test2.loc[x_test2[code] != 1]
        x_test_list = x_test2.index.tolist()
        y_test2 = y_test.loc[x_test_list]

        # AUC
        try:
            aucroc = roc_auc_score(y_test2, LR.predict(x_test2))
        except:
            aucroc = np.nan
        # Accuracy
        accuracy = accuracy_score(y_test2, LR.predict(x_test2))
        # Precision
        precision = precision_score(y_test2, LR.predict(x_test2))
        # Recall
        recall = precision_score(y_test2, LR.predict(x_test2))
        # AUPRC
        try:
            auprc = average_precision_score(y_test2, LR.predict(x_test2))
        except:
            auprc = np.nan
        # F1
        f1 = f1_score(y_test2, LR.predict(x_test2))
        # TPR and FPR
        cm = confusion_matrix(y_test2, LR.predict(x_test2))
        try:
            _tp = cm[0, 0]
        except:
            _tp = np.nan
        try:
            _fn = cm[0, 1]
        except:
            _fn = np.nan
        try:
            _fp = cm[1, 0]
        except:
            _fp = np.nan
        try:
            _tn = cm[1, 1]
        except:
            _tn = np.nan
        tpr = _tp / (_tp + _fn)
        fpr = _fp / (_tn + _fp)

        new_row = {'data': 'dummies', 'fold': fold_n, 'subset': n, 'AUROC': aucroc, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'F1': f1, 'AUPRC': auprc, 'TPR': tpr, 'FPR': fpr}
        results_df = results_df.append(new_row, ignore_index=True)

In [8]:
results_df

Unnamed: 0,data,fold,subset,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
0,dummies,1,45,0.868043,0.901996,0.086207,0.086207,0.15625,0.073654,0.902752,0.166667
1,dummies,1,8,,0.946237,0.0,0.0,0.0,-0.0,0.946237,
2,dummies,2,45,0.939479,0.880295,0.084507,0.084507,0.155844,0.084507,0.878957,0.0
3,dummies,2,8,,0.92,0.0,0.0,0.0,-0.0,0.92,
4,dummies,3,45,0.769104,0.867961,0.084507,0.084507,0.15,0.062163,0.871542,0.333333
5,dummies,3,8,0.967391,0.935484,0.142857,0.142857,0.25,0.142857,0.934783,0.0
6,dummies,4,45,0.843648,0.885542,0.126984,0.126984,0.219178,0.105603,0.887295,0.2
7,dummies,4,8,0.963235,0.929577,0.375,0.375,0.545455,0.375,0.926471,0.0
8,dummies,5,45,0.785374,0.8998,0.113208,0.113208,0.193548,0.081484,0.904082,0.333333
9,dummies,5,8,0.985714,0.971831,0.333333,0.333333,0.5,0.333333,0.971429,0.0


In [9]:
results_df.groupby(['data', 'subset']).mean()
results_df.groupby(['data', 'subset']).std()

Unnamed: 0_level_0,Unnamed: 1_level_0,fold,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
data,subset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
dummies,8,5.5,0.801223,0.91796,0.134286,0.134286,0.201061,0.123484,0.925349,0.328571
dummies,45,5.5,0.785632,0.879998,0.092882,0.092882,0.162536,0.070018,0.883976,0.312713


Unnamed: 0_level_0,Unnamed: 1_level_0,fold,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
data,subset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
dummies,8,3.02765,0.227436,0.036319,0.139531,0.139531,0.20815,0.138454,0.030417,0.434796
dummies,45,3.02765,0.079854,0.014437,0.01606,0.01606,0.025554,0.016671,0.013523,0.157587


In [None]:
# Get 10 fold cv results for one hot
# Import
path = r'data/slim_problem_dummies_death.csv'
slim_problem_dummies_death = pd.read_csv(path, index_col=0)
slim_problem_dummies_death.reset_index(inplace=True, drop=True)

results_df = pd.DataFrame()
x_data = slim_problem_dummies_death.iloc[:,2:]
y_data = slim_problem_dummies_death.death_year_label
# Get CV folds
cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=2)
fold_n = 0
for train_index, test_index in cv.split(x_data, y_data):
    fold_n += 1
    x_train  = x_data[x_data.index.isin(list(train_index))]
    y_train  = y_data[y_data.index.isin(list(train_index))]
    x_test  = x_data[x_data.index.isin(list(test_index))]
    y_test  = y_data[y_data.index.isin(list(test_index))]

    # Fit
    LR = LogisticRegression(class_weight='balanced')
    LR.fit(x_train, y_train)

    # AUC
    aucroc = roc_auc_score(y_test, LR.predict(x_test))
    # Accuracy
    accuracy = accuracy_score(y_test, LR.predict(x_test))
    # Precision
    precision = precision_score(y_test, LR.predict(x_test))
    # Recall
    recall = recall_score(y_test, LR.predict(x_test))
    # AUPRC
    auprc = average_precision_score(y_test, LR.predict(x_test))
    # F1
    f1 = f1_score(y_test, LR.predict(x_test))
    # TPR and FPR
    #fpr, tpr, thresholds = roc_curve(y_test, LR.predict(x_test))
    cm = confusion_matrix(y_test, LR.predict(x_test))
    _tp = cm[0, 0]
    _fn = cm[0, 1]
    _fp = cm[1, 0]
    _tn = cm[1, 1]
    tpr = _tp / (_tp + _fn)
    fpr = _fp / (_tn + _fp)

    new_row = {'data': 'dummies', 'fold': fold_n, 'subset':'overall', 'AUROC': aucroc, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'F1': f1, 'AUPRC': auprc, 'TPR': tpr, 'FPR': fpr}
    results_df = results_df.append(new_row, ignore_index=True)

In [5]:
results_df

Unnamed: 0,data,fold,subset,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
0,dummies,1,overall,0.789008,0.849878,0.113621,0.724891,0.19645,0.089327,0.853125,0.275109
1,dummies,2,overall,0.790284,0.84822,0.112991,0.729258,0.195665,0.089253,0.85131,0.270742
2,dummies,3,overall,0.819746,0.851758,0.122283,0.786026,0.21164,0.101534,0.853465,0.213974
3,dummies,4,overall,0.79162,0.857948,0.119683,0.721739,0.205318,0.093455,0.861502,0.278261
4,dummies,5,overall,0.760105,0.84743,0.10501,0.668122,0.181495,0.078562,0.852087,0.331878
5,dummies,6,overall,0.814236,0.857601,0.124734,0.768559,0.214634,0.101725,0.859914,0.231441
6,dummies,7,overall,0.776551,0.846324,0.108564,0.703057,0.188084,0.083844,0.850045,0.296943
7,dummies,8,overall,0.77984,0.844444,0.108378,0.71179,0.188113,0.084439,0.84789,0.28821
8,dummies,9,overall,0.777884,0.853068,0.112676,0.69869,0.194057,0.086354,0.857078,0.30131
9,dummies,10,overall,0.78812,0.852294,0.114823,0.720524,0.198079,0.089808,0.855717,0.279476


In [6]:
results_df.groupby(['data', 'subset']).mean()
results_df.groupby(['data', 'subset']).std()

Unnamed: 0_level_0,Unnamed: 1_level_0,fold,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
data,subset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
dummies,overall,5.5,0.788739,0.850897,0.114276,0.723266,0.197354,0.08983,0.854213,0.276734


Unnamed: 0_level_0,Unnamed: 1_level_0,fold,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
data,subset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
dummies,overall,3.02765,0.017601,0.004532,0.006333,0.033704,0.010562,0.007425,0.004322,0.033704


# Set transformer

In [9]:
# Import
path = r'data/slim_problem_dummies_death.csv'
slim_problem_dummies_death = pd.read_csv(path, index_col=0)
# Work out count for problems
problem_sum = pd.DataFrame(slim_problem_dummies_death.iloc[:,2:].sum(axis=0), columns=['Count']).sort_values(by=['Count'], ascending=False)
# Define what is a rare disease
cut_off_list = [45, 8]
# Get indexes for rare diseases 
filter_list_45 = problem_sum[problem_sum['Count'] > 45].index.tolist()
filter_list_8 = problem_sum[problem_sum['Count'] > 8].index.tolist()

problem_dummies_45 = slim_problem_dummies_death.iloc[:,2:].reset_index().copy()
del slim_problem_dummies_death
print('Working on 45...')
for code in filter_list_45:
    problem_dummies_45 = problem_dummies_45.loc[problem_dummies_45[code] != 1]
    index_list_45 = problem_dummies_45.index.tolist()
del problem_dummies_45

path = r'data/slim_problem_dummies_death.csv'
slim_problem_dummies_death = pd.read_csv(path, index_col=0)
problem_dummies_8 = slim_problem_dummies_death.iloc[:,2:].reset_index().copy()
del slim_problem_dummies_death
print('Working on 8...')
for code in filter_list_8:
    problem_dummies_8 = problem_dummies_8.loc[problem_dummies_8[code] != 1]
    index_list_8 = problem_dummies_8.index.tolist()
del problem_dummies_8

Working on 45...
Working on 8...


In [10]:
# Set transformer initial setup

# Import
path = r'data/slim_problem_dummies_death.csv'
slim_problem_dummies_death = pd.read_csv(path, index_col=0)
# Import
path = r'data/final_trimmed_snomed_embedding_128d.csv'
snomed_embedding = pd.read_csv(path)

random_bool = True
#random_bool = False
if random_bool == True:
    # Get random embeddings for each disease
    snomed_embedding.set_index(['snomed_code'], inplace=True)
    np.random.seed(0)
    random_embedding = pd.DataFrame(np.random.default_rng(seed=0).uniform(low=snomed_embedding.min().min(),high=snomed_embedding.max().max(),size=[len(snomed_embedding),len(snomed_embedding.columns)]))
    random_embedding.index = snomed_embedding.index
    random_embedding.columns = snomed_embedding.columns
    random_embedding.reset_index(inplace=True)
    snomed_embedding = random_embedding

# Create dfs
patient_mortality = slim_problem_dummies_death.iloc[:,:2]
patient_df = slim_problem_dummies_death.iloc[:,2:]
# Str
patient_df.columns = patient_df.columns.astype(str)
snomed_embedding['snomed_code'] = snomed_embedding['snomed_code'].astype(str)
# Filter
snomed_embedding = snomed_embedding[snomed_embedding['snomed_code'].isin(patient_df.columns.tolist())]
snomed_embedding.set_index('snomed_code', inplace=True)

# Get lengths of each patients co-morbidities
comorbidity_len = np.array(patient_df.sum(axis=1))

# Add padding embedding 
padding_df = pd.DataFrame(np.random.choice([0], size=len(snomed_embedding.columns)))
padding_df = padding_df.T
padding_df.index = ['9999999999']
padding_df.columns = snomed_embedding.columns
snomed_embedding2 = pd.concat([snomed_embedding, padding_df])
snomed_embedding2.index = snomed_embedding2.index.astype(str)

# Get max number of co-morbidities
max_len = int(patient_df.sum(axis=1).max())

# Format patients embeddings into set and pad / create array
feature_array = np.zeros(shape=(len(patient_df), max_len , 128))
n = -1
for index, row in patient_df.iterrows():
    n += 1
    n2 = -1
    code_list = row[row ==1].index.tolist()
    while len(code_list) < max_len:
        code_list.append('9999999999')
    for code in code_list:
        n2 += 1
        feature_array[n, n2] = np.array(snomed_embedding2.loc[code])

# Get array for death
mortality_array = np.array(patient_mortality['death_year_label'])
mortality_array = mortality_array.squeeze()

# Create mask tensor based on lengths
comorbidity_len2 = torch.as_tensor(comorbidity_len, dtype=torch.long)
mask = torch.arange(max_len)[None, :] < comorbidity_len2[:, None]

# del so more memory
del slim_problem_dummies_death
del snomed_embedding
del patient_mortality
del snomed_embedding2
del padding_df
del row
del comorbidity_len
del comorbidity_len2
gc.collect()

57187

In [3]:
# Check if any have no co-morbidities
for n in range(mask.shape[0]):
    if torch.all(mask[n] == False):
        print(n)

In [11]:
# Custom dataset class
class DiseaseDataset(Dataset):
    def __init__(self, disease_embeddings, mortality_labels, padding_mask):
        self.disease_embeddings = disease_embeddings
        self.mortality_labels = mortality_labels
        self.padding_mask = padding_mask

    def __len__(self):
        return len(self.mortality_labels)

    def __getitem__(self, idx):
        return self.disease_embeddings[idx], self.mortality_labels[idx], self.padding_mask[idx] 

In [12]:
class SetTransformer(nn.Module):
    def __init__(self, dim_input, num_outputs, dim_output,
            num_inds=36, dim_hidden=160, num_heads=4, ln=False):
        super(SetTransformer, self).__init__()
        self.enc = nn.Sequential(
                ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln),
                ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
        self.isab = ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln)
        self.pma = PMA(dim_hidden, num_heads, num_outputs, ln=ln)
        self.dec = nn.Sequential(
                #SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
                #SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
                nn.Linear(dim_hidden, dim_output))

    def forward(self, X, batch_mask):
        x = self.isab(X, batch_mask)
        x = self.pma(x, batch_mask)
        return self.dec(x), x

class MAB0(nn.Module):
    def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
        super(MAB0, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)
        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)
        self.fc_o = nn.Linear(dim_V, dim_V)

    def forward(self, Q, K, mask):
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)

        dim_split = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), 0)
        K_ = torch.cat(K.split(dim_split, 2), 0)
        V_ = torch.cat(V.split(dim_split, 2), 0)

        # Create new variable for softmax
        WB_ = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V)

        # Exspand mask dimensions to align
        mask = mask.unsqueeze(1).repeat(self.num_heads, Q.shape[1], 1)
        # Mask for softmax
        WB_[~mask] = float('-inf')

        A = torch.softmax(WB_, 2)
        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
        O = O + F.relu(self.fc_o(O))
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        return O

class MAB(nn.Module):
    def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
        super(MAB, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)
        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)
        self.fc_o = nn.Linear(dim_V, dim_V)

    def forward(self, Q, K):
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)

        dim_split = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), 0)
        K_ = torch.cat(K.split(dim_split, 2), 0)
        V_ = torch.cat(V.split(dim_split, 2), 0)

        A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
        O = O + F.relu(self.fc_o(O))
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        return O

class SAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, ln=False):
        super(SAB, self).__init__()
        self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)

    def forward(self, X):
        return self.mab(X, X)

class ISAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
        super(ISAB, self).__init__()
        self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
        nn.init.xavier_uniform_(self.I)
        self.mab0 = MAB0(dim_out, dim_in, dim_out, num_heads, ln=ln)
        self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)

    def forward(self, X, mask):
        H = self.mab0(self.I.repeat(X.size(0), 1, 1), X, mask)
        return self.mab1(X, H)

class PMA(nn.Module):
    def __init__(self, dim, num_heads, num_seeds, ln=False):
        super(PMA, self).__init__()
        self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
        nn.init.xavier_uniform_(self.S)
        self.mab = MAB0(dim, dim, dim, num_heads, ln=ln)

    def forward(self, X, mask):
        return self.mab(self.S.repeat(X.size(0), 1, 1), X, mask)

In [13]:
# Define how long an epoch takes
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

# Train function 
def train(model, dataloader, optimizer, criterion):
    model.train()
    epoch_loss = 0
    batch_prediction_list = []
    batch_label_list = []
    
    for i, (batch_embeddings, batch_labels, batch_mask) in enumerate(tqdm(dataloader, 0)):
        batch_labels = batch_labels.to(device)
        batch_embeddings = batch_embeddings.to(device)
        batch_mask = batch_mask.to(device)

        optimizer.zero_grad()
        logits, pma = model(batch_embeddings, batch_mask)
        logits = logits.squeeze(-1) # squeeze to remove extra dimensions
        batch_labels = batch_labels.unsqueeze(1)
        
        loss = criterion(logits, batch_labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

        sig = torch.nn.Sigmoid()
        output = sig(logits)      
        np_predictions = output.cpu().detach().numpy()
        np_labels = batch_labels.cpu().detach().numpy()
        np_pma = pma.cpu().detach().numpy()

        np_predictions = np_predictions.squeeze()
        np_labels = np_labels.squeeze()
        np_pma = np_pma.squeeze()

        np_predictions = np_predictions.flatten()
        np_labels = np_labels.flatten()
        
        # Create list
        for x in np_predictions:
            batch_prediction_list.append(x)
        for x in np_labels:
            batch_label_list.append(x)
        if i == 0:
            np_pma_final = np_pma
        else:
            np_pma_final = np.vstack([np_pma_final, np_pma])

    final_predictions = np.array(batch_prediction_list)

    final_labels = np.array(batch_label_list)

    try:
        auroc = roc_auc_score(final_labels, final_predictions)
    except:
        auroc = np.nan
    
    try:
        final_loss = epoch_loss / len(dataloader)
    except:
        final_loss = np.nan

    return final_loss, auroc, final_predictions, final_labels, np_pma_final

# Eval function
def evaluate(model, dataloader, criterion):

    # Set the model to evaluation mode
    model.eval()
    epoch_loss = 0
    batch_prediction_list = []
    batch_label_list = []

    # use the with torch.no_grad() block to ensure no gradients are calculated within the bloc
    with torch.no_grad():
        for i, (batch_embeddings, batch_labels, batch_mask) in enumerate(tqdm(dataloader, 0)):
            batch_labels = batch_labels.to(device)
            batch_embeddings = batch_embeddings.to(device)
            batch_mask = batch_mask.to(device)

            logits, pma = model(batch_embeddings, batch_mask)
            logits = logits.squeeze() # squeeze to remove extra dimensions

            if len(logits.size()) == 0: # Need to add this so tensor with just one element does not cause error
                logits = logits.unsqueeze(dim=0)

            loss = criterion(logits, batch_labels)

            epoch_loss += loss.item()

            sig = torch.nn.Sigmoid()
            output = sig(logits)      
            np_predictions = output.cpu().detach().numpy()
            np_labels = batch_labels.cpu().detach().numpy()
            np_pma = pma.cpu().detach().numpy()

            np_predictions = np_predictions.squeeze()
            np_labels = np_labels.squeeze()

            np_predictions = np_predictions.flatten()
            np_labels = np_labels.flatten()
            np_pma = np_pma.squeeze()
            
            # Create list
            for x in np_predictions:
                batch_prediction_list.append(x)
            for x in np_labels:
                batch_label_list.append(x)
            if i == 0:
                np_pma_final = np_pma
            else:
                np_pma_final = np.vstack([np_pma_final, np_pma])

        final_predictions = np.array(batch_prediction_list)

        final_labels = np.array(batch_label_list)

        try:
            auroc = roc_auc_score(final_labels, final_predictions)
        except:
            auroc = np.nan
        
        try:
            final_loss = epoch_loss / len(dataloader)
        except:
            final_loss = np.nan

    return final_loss, auroc, final_predictions, final_labels, np_pma_final

In [6]:
# CV for set transformer

# Get CV folds and create dictionarys 
cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=2)
fold_n = 0
train_index_dict = {}
test_index_dict = {}
for train_index, test_index in cv.split(feature_array, mortality_array):
    fold_n += 1
    train_index_dict[fold_n] = train_index
    test_index_dict[fold_n] = test_index

In [7]:
# del for memmory 
del test_index
del train_index
del patient_df
del index
del fold_n
del code
del code_list
gc.collect()

190

In [8]:
# Main run
best_test_auroc = 0
results_df = pd.DataFrame()
for fold in range(10):
    fold += 1
    if fold < 10:
        # Get train, validation and test sets
        train_embeddings = feature_array[np.setdiff1d(train_index_dict[fold], test_index_dict[fold+1])]
        val_embeddings = feature_array[test_index_dict[fold+1]]
        test_embeddings = feature_array[test_index_dict[fold]]

        train_labels = mortality_array[np.setdiff1d(train_index_dict[fold], test_index_dict[fold+1])]
        val_labels = mortality_array[test_index_dict[fold+1]]
        test_labels = mortality_array[test_index_dict[fold]]

        # Split masks
        train_padding_mask = mask[np.setdiff1d(train_index_dict[fold], test_index_dict[fold+1])]
        val_padding_mask = mask[test_index_dict[fold+1]]
        test_padding_mask = mask[test_index_dict[fold]]
    else: # Change here to fist cv split for final fold
        # Get train, validation and test sets
        train_embeddings = feature_array[np.setdiff1d(train_index_dict[fold], test_index_dict[1])]
        val_embeddings = feature_array[test_index_dict[1]]
        test_embeddings = feature_array[test_index_dict[fold]]

        train_labels = mortality_array[np.setdiff1d(train_index_dict[fold], test_index_dict[1])]
        val_labels = mortality_array[test_index_dict[1]]
        test_labels = mortality_array[test_index_dict[fold]]

        # Split masks
        train_padding_mask = mask[np.setdiff1d(train_index_dict[fold], test_index_dict[1])]
        val_padding_mask = mask[test_index_dict[1]]
        test_padding_mask = mask[test_index_dict[fold]]
    
    # Create datasets
    train_dataset = DiseaseDataset(train_embeddings.astype(np.float32), train_labels.astype(np.float32), train_padding_mask)
    val_dataset = DiseaseDataset(val_embeddings.astype(np.float32), val_labels.astype(np.float32), val_padding_mask)
    test_dataset = DiseaseDataset(test_embeddings.astype(np.float32), test_labels.astype(np.float32), test_padding_mask)

    # Define batch size
    batch_size = 512

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    # Create weight for loss
    unique, counts = np.unique(train_labels, return_counts=True)
    pos_weight = torch.Tensor([(counts[0] / counts[1])])

    # Define model
    model = SetTransformer(dim_input=128, num_outputs=1, dim_output=1, num_inds=32, dim_hidden=160, num_heads=4, ln=False)

    model = model.to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # Run
    best_valid_loss = float('inf')
    best_valid_auroc = 0
    num_epochs = 10

    for epoch in range(num_epochs):

        start_time = time.time()

        train_loss, train_auroc, train_predictions, train_labels_out, train_pma_final = train(model, train_loader, optimizer, criterion)
        valid_loss, valid_auroc, valid_predictions, valid_labels_out, valid_pma_final = evaluate(model, val_loader, criterion)
        
        if epoch % 1 == 0:
            print('Epoch:', epoch)
            print(f'train loss: {train_loss:.3f}')
            print(f'train AUROC: {train_auroc:.3f}')
            print(f'valid loss: {valid_loss:.3f}')
            print(f'valid AUROC: {valid_auroc:.3f}')
        
        end_time = time.time()

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss  
            print('BEST VALID LOSS:', best_valid_loss)

        if valid_auroc > best_valid_auroc:
            best_valid_auroc = valid_auroc
            print('BEST VALID AUROC:', best_valid_auroc)
            print('UPDATED BEST INTERMEDIATE MODEL')
            torch.save(model.state_dict(), f'intermediate_set_transformer_mortality.pt')

    # -----------------------------
    # Evaluate best model on test set
    # -----------------------------

    model.load_state_dict(torch.load(f'intermediate_set_transformer_mortality.pt'))

    test_loss, test_auroc, test_predictions, test_labels_out, test_pma_final = evaluate(model, test_loader, criterion)

    # del for memmory 
    del train_embeddings
    del val_embeddings
    del test_embeddings
    del train_labels
    del val_labels
    del test_labels
    del train_padding_mask
    del val_padding_mask
    del test_padding_mask
    del train_loader
    del val_loader
    del test_loader
    del train_predictions
    del train_labels_out
    del train_pma_final
    del valid_predictions
    del valid_labels_out
    del valid_pma_final
    gc.collect()

    print(f'test predictions: {test_predictions.mean().item():.3f}') 
    print(f'test loss: {test_loss:.3f}')
    print(f'test AUROC: {test_auroc:.3f}')

    if test_auroc > best_test_auroc:
        best_test_auroc = test_auroc
        print('BEST TEST AUROC:', best_test_auroc)
        print('UPDATED BEST MODEL')
        torch.save(model.state_dict(), f'set_transformer_mortality.pt')

    # Get results
    # AUC
    aucroc = roc_auc_score(test_labels_out, test_predictions)
    # Accuracy
    accuracy = accuracy_score(test_labels_out, test_predictions.round())
    # Precision
    precision = precision_score(test_labels_out, test_predictions.round())
    # Recall
    recall = recall_score(test_labels_out, test_predictions.round())
    # AUPRC
    auprc = average_precision_score(test_labels_out, test_predictions)
    # F1
    f1 = f1_score(test_labels_out, test_predictions.round())
    # TPR and FPR
    cm = confusion_matrix(test_labels_out, test_predictions.round())
    _tp = cm[0, 0]
    _fn = cm[0, 1]
    _fp = cm[1, 0]
    _tn = cm[1, 1]
    tpr = _tp / (_tp + _fn)
    fpr = _fp / (_tn + _fp)

    new_row = {'data': 'set', 'fold': fold, 'subset':'overall', 'AUROC': aucroc, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'F1': f1, 'AUPRC': auprc, 'TPR': tpr, 'FPR': fpr}
    results_df = results_df.append(new_row, ignore_index=True)


100%|██████████| 142/142 [00:58<00:00,  2.43it/s]
100%|██████████| 18/18 [00:02<00:00,  6.05it/s]
100%|██████████| 142/142 [00:55<00:00,  2.55it/s]
100%|██████████| 18/18 [00:03<00:00,  5.82it/s]
100%|██████████| 142/142 [00:55<00:00,  2.58it/s]
100%|██████████| 18/18 [00:02<00:00,  6.01it/s]
100%|██████████| 142/142 [00:54<00:00,  2.62it/s]
100%|██████████| 18/18 [00:02<00:00,  6.08it/s]
100%|██████████| 142/142 [00:54<00:00,  2.60it/s]
100%|██████████| 18/18 [00:03<00:00,  5.97it/s]
100%|██████████| 142/142 [00:54<00:00,  2.60it/s]
100%|██████████| 18/18 [00:03<00:00,  5.88it/s]
100%|██████████| 142/142 [00:54<00:00,  2.63it/s]
100%|██████████| 18/18 [00:02<00:00,  6.02it/s]
100%|██████████| 142/142 [00:53<00:00,  2.65it/s]
100%|██████████| 18/18 [00:02<00:00,  6.05it/s]
100%|██████████| 142/142 [00:52<00:00,  2.70it/s]
100%|██████████| 18/18 [00:02<00:00,  6.06it/s]
100%|██████████| 142/142 [00:53<00:00,  2.65it/s]
100%|██████████| 18/18 [00:02<00:00,  6.07it/s]


Epoch: 0
train loss: 1.142
train AUROC: 0.771
valid loss: 1.247
valid AUROC: 0.781
BEST VALID LOSS: 1.2471927205721538
BEST VALID AUROC: 0.7809249004379689
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.035
train AUROC: 0.813
valid loss: 1.003
valid AUROC: 0.840
BEST VALID LOSS: 1.003496958149804
BEST VALID AUROC: 0.8396168477628321
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 0.962
train AUROC: 0.838
valid loss: 1.251
valid AUROC: 0.801
Epoch: 3
train loss: 1.148
train AUROC: 0.748
valid loss: 1.020
valid AUROC: 0.826
Epoch: 4
train loss: 0.977
train AUROC: 0.827
valid loss: 0.948
valid AUROC: 0.843
BEST VALID LOSS: 0.9478144513236152
BEST VALID AUROC: 0.8434519856192856
UPDATED BEST INTERMEDIATE MODEL
Epoch: 5
train loss: 0.944
train AUROC: 0.841
valid loss: 1.130
valid AUROC: 0.812
Epoch: 6
train loss: 1.002
train AUROC: 0.823
valid loss: 0.957
valid AUROC: 0.847
BEST VALID AUROC: 0.8472447777294061
UPDATED BEST INTERMEDIATE MODEL
Epoch: 7
train loss: 0.918
train AUR

<All keys matched successfully>

100%|██████████| 18/18 [00:02<00:00,  6.14it/s]


0

test predictions: 0.254
test loss: 1.114
test AUROC: 0.816
BEST TEST AUROC: 0.8160384390416886
UPDATED BEST MODEL
Epoch: 0
train loss: 1.153
train AUROC: 0.762
valid loss: 1.107
valid AUROC: 0.818
BEST VALID LOSS: 1.1071911685996585
BEST VALID AUROC: 0.81758764950401
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.039
train AUROC: 0.807
valid loss: 1.169
valid AUROC: 0.825
BEST VALID AUROC: 0.825494169906983
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.038
train AUROC: 0.814
valid loss: 0.984
valid AUROC: 0.853
BEST VALID LOSS: 0.9835855745606952
BEST VALID AUROC: 0.853033763179804
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 0.994
train AUROC: 0.826
valid loss: 0.916
valid AUROC: 0.859
BEST VALID LOSS: 0.9163484970728556
BEST VALID AUROC: 0.8593772055076215
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 0.962
train AUROC: 0.839
valid loss: 1.040
valid AUROC: 0.847
Epoch: 5
train loss: 1.059
train AUROC: 0.798
valid loss: 0.973
valid AUROC: 0.845
Epoch: 6

  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [00:58<00:00,  2.43it/s]
100%|██████████| 18/18 [00:02<00:00,  6.17it/s]
100%|██████████| 142/142 [00:55<00:00,  2.57it/s]
100%|██████████| 18/18 [00:02<00:00,  6.04it/s]
100%|██████████| 142/142 [00:54<00:00,  2.61it/s]
100%|██████████| 18/18 [00:03<00:00,  5.90it/s]
100%|██████████| 142/142 [00:53<00:00,  2.64it/s]
100%|██████████| 18/18 [00:02<00:00,  6.18it/s]
100%|██████████| 142/142 [00:55<00:00,  2.54it/s]
100%|██████████| 18/18 [00:03<00:00,  5.97it/s]
100%|██████████| 142/142 [00:54<00:00,  2.62it/s]
100%|██████████| 18/18 [00:03<00:00,  5.86it/s]
100%|██████████| 142/142 [00:55<00:00,  2.56it/s]
100%|██████████| 18/18 [00:02<00:00,  6.10it/s]
100%|██████████| 142/142 [00:54<00:00,  2.61it/s]
100%|██████████| 18/18 [00:03<00:00,  5.96it/s]
100%|██████████| 142/142 [00:53<00:00,  2.66it/s]
100%|██████████| 18/18 [00:02<00:00,  6.10it/s]
100%|██████████| 142/142 [00:54<00:00,  2.60it/s]
100%|██

<All keys matched successfully>

100%|██████████| 18/18 [00:03<00:00,  5.95it/s]


0

test predictions: 0.222
test loss: 1.218
test AUROC: 0.830
BEST TEST AUROC: 0.8298562770511315
UPDATED BEST MODEL
Epoch: 0
train loss: 1.139
train AUROC: 0.772
valid loss: 1.018
valid AUROC: 0.821
BEST VALID LOSS: 1.0177795324060652
BEST VALID AUROC: 0.8212299277992583
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 0.983
train AUROC: 0.832
valid loss: 0.997
valid AUROC: 0.835
BEST VALID LOSS: 0.9966390596495734
BEST VALID AUROC: 0.834582626449933
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 0.960
train AUROC: 0.839
valid loss: 1.098
valid AUROC: 0.835
BEST VALID AUROC: 0.8354678746153239
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 0.968
train AUROC: 0.837
valid loss: 0.910
valid AUROC: 0.857
BEST VALID LOSS: 0.9099229309293959
BEST VALID AUROC: 0.8568082734948316
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 0.927
train AUROC: 0.850
valid loss: 1.132
valid AUROC: 0.835
Epoch: 5
train loss: 0.977
train AUROC: 0.836
valid loss: 0.918
valid AUROC: 0.856
Epoch

  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [00:54<00:00,  2.61it/s]
100%|██████████| 18/18 [00:03<00:00,  5.94it/s]
100%|██████████| 142/142 [00:56<00:00,  2.53it/s]
100%|██████████| 18/18 [00:03<00:00,  6.00it/s]
100%|██████████| 142/142 [00:55<00:00,  2.57it/s]
100%|██████████| 18/18 [00:03<00:00,  5.88it/s]
100%|██████████| 142/142 [00:54<00:00,  2.60it/s]
100%|██████████| 18/18 [00:03<00:00,  5.95it/s]
100%|██████████| 142/142 [00:55<00:00,  2.56it/s]
100%|██████████| 18/18 [00:03<00:00,  5.95it/s]
100%|██████████| 142/142 [00:54<00:00,  2.61it/s]
100%|██████████| 18/18 [00:03<00:00,  5.78it/s]
100%|██████████| 142/142 [00:58<00:00,  2.42it/s]
100%|██████████| 18/18 [00:02<00:00,  6.10it/s]
100%|██████████| 142/142 [00:54<00:00,  2.60it/s]
100%|██████████| 18/18 [00:03<00:00,  5.94it/s]
100%|██████████| 142/142 [00:54<00:00,  2.62it/s]
100%|██████████| 18/18 [00:03<00:00,  5.99it/s]
100%|██████████| 142/142 [00:54<00:00,  2.60it/s]
100%|██

<All keys matched successfully>

100%|██████████| 18/18 [00:02<00:00,  6.05it/s]


0

test predictions: 0.279
test loss: 1.020
test AUROC: 0.838
BEST TEST AUROC: 0.8384507301050521
UPDATED BEST MODEL
Epoch: 0
train loss: 1.155
train AUROC: 0.765
valid loss: 1.040
valid AUROC: 0.806
BEST VALID LOSS: 1.0399916966756184
BEST VALID AUROC: 0.8055938884441943
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 0.965
train AUROC: 0.837
valid loss: 1.028
valid AUROC: 0.815
BEST VALID LOSS: 1.0279488166173298
BEST VALID AUROC: 0.8146113358799799
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 0.972
train AUROC: 0.831
valid loss: 1.299
valid AUROC: 0.791
Epoch: 3
train loss: 0.983
train AUROC: 0.831
valid loss: 1.002
valid AUROC: 0.822
BEST VALID LOSS: 1.0019027921888564
BEST VALID AUROC: 0.8220541353949548
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 0.930
train AUROC: 0.849
valid loss: 1.176
valid AUROC: 0.804
Epoch: 5
train loss: 0.948
train AUROC: 0.846
valid loss: 0.987
valid AUROC: 0.829
BEST VALID LOSS: 0.9868938028812408
BEST VALID AUROC: 0.8286949987715864


  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [00:58<00:00,  2.42it/s]
100%|██████████| 18/18 [00:02<00:00,  6.13it/s]
100%|██████████| 142/142 [00:53<00:00,  2.66it/s]
100%|██████████| 18/18 [00:02<00:00,  6.18it/s]
100%|██████████| 142/142 [00:54<00:00,  2.58it/s]
100%|██████████| 18/18 [00:02<00:00,  6.01it/s]
100%|██████████| 142/142 [00:53<00:00,  2.64it/s]
100%|██████████| 18/18 [00:02<00:00,  6.04it/s]
100%|██████████| 142/142 [00:54<00:00,  2.62it/s]
100%|██████████| 18/18 [00:02<00:00,  6.12it/s]
100%|██████████| 142/142 [00:54<00:00,  2.59it/s]
100%|██████████| 18/18 [00:03<00:00,  5.91it/s]
100%|██████████| 142/142 [00:54<00:00,  2.59it/s]
100%|██████████| 18/18 [00:02<00:00,  6.00it/s]
100%|██████████| 142/142 [00:54<00:00,  2.62it/s]
100%|██████████| 18/18 [00:02<00:00,  6.14it/s]
100%|██████████| 142/142 [00:54<00:00,  2.61it/s]
100%|██████████| 18/18 [00:02<00:00,  6.03it/s]
100%|██████████| 142/142 [00:55<00:00,  2.58it/s]
100%|██

<All keys matched successfully>

100%|██████████| 18/18 [00:02<00:00,  6.13it/s]


0

test predictions: 0.138
test loss: 1.417
test AUROC: 0.834
Epoch: 0
train loss: 1.137
train AUROC: 0.781
valid loss: 1.299
valid AUROC: 0.814
BEST VALID LOSS: 1.2992880940437317
BEST VALID AUROC: 0.8135728805902724
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.085
train AUROC: 0.788
valid loss: 1.064
valid AUROC: 0.837
BEST VALID LOSS: 1.064277172088623
BEST VALID AUROC: 0.8371093842874013
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 0.957
train AUROC: 0.841
valid loss: 0.936
valid AUROC: 0.855
BEST VALID LOSS: 0.9362406763765547
BEST VALID AUROC: 0.8547128484137614
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 0.932
train AUROC: 0.849
valid loss: 1.222
valid AUROC: 0.819
Epoch: 4
train loss: 0.986
train AUROC: 0.835
valid loss: 0.938
valid AUROC: 0.857
BEST VALID AUROC: 0.8572115308411067
UPDATED BEST INTERMEDIATE MODEL
Epoch: 5
train loss: 0.922
train AUROC: 0.853
valid loss: 0.951
valid AUROC: 0.854
Epoch: 6
train loss: 0.935
train AUROC: 0.847
valid loss: 1.

  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [01:01<00:00,  2.30it/s]
100%|██████████| 18/18 [00:02<00:00,  6.13it/s]
100%|██████████| 142/142 [00:54<00:00,  2.59it/s]
100%|██████████| 18/18 [00:02<00:00,  6.04it/s]
100%|██████████| 142/142 [00:53<00:00,  2.67it/s]
100%|██████████| 18/18 [00:02<00:00,  6.09it/s]
100%|██████████| 142/142 [00:53<00:00,  2.65it/s]
100%|██████████| 18/18 [00:02<00:00,  6.07it/s]
100%|██████████| 142/142 [00:54<00:00,  2.61it/s]
100%|██████████| 18/18 [00:02<00:00,  6.03it/s]
100%|██████████| 142/142 [00:53<00:00,  2.66it/s]
100%|██████████| 18/18 [00:02<00:00,  6.03it/s]
100%|██████████| 142/142 [00:54<00:00,  2.63it/s]
100%|██████████| 18/18 [00:02<00:00,  6.23it/s]
100%|██████████| 142/142 [00:52<00:00,  2.68it/s]
100%|██████████| 18/18 [00:02<00:00,  6.01it/s]
100%|██████████| 142/142 [01:11<00:00,  1.98it/s]
100%|██████████| 18/18 [00:04<00:00,  4.43it/s]
100%|██████████| 142/142 [00:55<00:00,  2.58it/s]
100%|██

<All keys matched successfully>

100%|██████████| 18/18 [00:02<00:00,  6.30it/s]


0

test predictions: 0.499
test loss: 1.173
test AUROC: 0.790
Epoch: 0
train loss: 1.154
train AUROC: 0.769
valid loss: 1.142
valid AUROC: 0.805
BEST VALID LOSS: 1.1415328118536208
BEST VALID AUROC: 0.8045011947312943
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.024
train AUROC: 0.812
valid loss: 0.987
valid AUROC: 0.831
BEST VALID LOSS: 0.9870838423569998
BEST VALID AUROC: 0.8309086198971303
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 0.998
train AUROC: 0.822
valid loss: 1.283
valid AUROC: 0.792
Epoch: 3
train loss: 1.070
train AUROC: 0.810
valid loss: 0.973
valid AUROC: 0.838
BEST VALID LOSS: 0.972743401924769
BEST VALID AUROC: 0.8380245028887533
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 0.949
train AUROC: 0.842
valid loss: 1.007
valid AUROC: 0.838
Epoch: 5
train loss: 0.939
train AUROC: 0.845
valid loss: 0.937
valid AUROC: 0.851
BEST VALID LOSS: 0.9374432298872206
BEST VALID AUROC: 0.8513111829226734
UPDATED BEST INTERMEDIATE MODEL
Epoch: 6
train loss: 0.9

  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [00:57<00:00,  2.49it/s]
100%|██████████| 18/18 [00:02<00:00,  6.06it/s]
100%|██████████| 142/142 [00:54<00:00,  2.62it/s]
100%|██████████| 18/18 [00:02<00:00,  6.22it/s]
100%|██████████| 142/142 [00:51<00:00,  2.74it/s]
100%|██████████| 18/18 [00:02<00:00,  6.25it/s]
100%|██████████| 142/142 [00:52<00:00,  2.73it/s]
100%|██████████| 18/18 [00:02<00:00,  6.28it/s]
100%|██████████| 142/142 [00:52<00:00,  2.70it/s]
100%|██████████| 18/18 [00:02<00:00,  6.22it/s]
100%|██████████| 142/142 [00:52<00:00,  2.70it/s]
100%|██████████| 18/18 [00:02<00:00,  6.18it/s]
100%|██████████| 142/142 [00:56<00:00,  2.54it/s]
100%|██████████| 18/18 [00:02<00:00,  6.04it/s]
100%|██████████| 142/142 [00:52<00:00,  2.69it/s]
100%|██████████| 18/18 [00:03<00:00,  5.99it/s]
100%|██████████| 142/142 [00:52<00:00,  2.70it/s]
100%|██████████| 18/18 [00:02<00:00,  6.28it/s]
100%|██████████| 142/142 [00:53<00:00,  2.64it/s]
100%|██

<All keys matched successfully>

100%|██████████| 18/18 [00:02<00:00,  6.12it/s]


0

test predictions: 0.132
test loss: 1.637
test AUROC: 0.826
Epoch: 0
train loss: 1.148
train AUROC: 0.770
valid loss: 1.049
valid AUROC: 0.808
BEST VALID LOSS: 1.0493533710638683
BEST VALID AUROC: 0.8084727351619524
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.025
train AUROC: 0.807
valid loss: 0.998
valid AUROC: 0.828
BEST VALID LOSS: 0.9983502560191684
BEST VALID AUROC: 0.8283888860269933
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 0.951
train AUROC: 0.843
valid loss: 1.016
valid AUROC: 0.825
Epoch: 3
train loss: 0.949
train AUROC: 0.841
valid loss: 1.098
valid AUROC: 0.813
Epoch: 4
train loss: 0.958
train AUROC: 0.842
valid loss: 1.030
valid AUROC: 0.830
BEST VALID AUROC: 0.8295593462462058
UPDATED BEST INTERMEDIATE MODEL
Epoch: 5
train loss: 0.934
train AUROC: 0.848
valid loss: 1.057
valid AUROC: 0.825
Epoch: 6
train loss: 0.972
train AUROC: 0.831
valid loss: 1.002
valid AUROC: 0.827
Epoch: 7
train loss: 0.957
train AUROC: 0.837
valid loss: 0.974
valid AUROC: 0.83

  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [00:55<00:00,  2.54it/s]
100%|██████████| 18/18 [00:02<00:00,  6.01it/s]
100%|██████████| 142/142 [00:52<00:00,  2.73it/s]
100%|██████████| 18/18 [00:02<00:00,  6.21it/s]
100%|██████████| 142/142 [00:52<00:00,  2.72it/s]
100%|██████████| 18/18 [00:02<00:00,  6.03it/s]
100%|██████████| 142/142 [00:52<00:00,  2.71it/s]
100%|██████████| 18/18 [00:02<00:00,  6.08it/s]
100%|██████████| 142/142 [00:53<00:00,  2.68it/s]
100%|██████████| 18/18 [00:02<00:00,  6.10it/s]
100%|██████████| 142/142 [00:51<00:00,  2.74it/s]
100%|██████████| 18/18 [00:02<00:00,  6.15it/s]
100%|██████████| 142/142 [00:54<00:00,  2.61it/s]
100%|██████████| 18/18 [00:02<00:00,  6.07it/s]
100%|██████████| 142/142 [00:52<00:00,  2.71it/s]
100%|██████████| 18/18 [00:02<00:00,  6.29it/s]
100%|██████████| 142/142 [00:52<00:00,  2.70it/s]
100%|██████████| 18/18 [00:04<00:00,  4.24it/s]
100%|██████████| 142/142 [01:12<00:00,  1.97it/s]
100%|██

<All keys matched successfully>

100%|██████████| 18/18 [00:02<00:00,  6.02it/s]


0

test predictions: 0.677
test loss: 2.177
test AUROC: 0.772
Epoch: 0
train loss: 1.147
train AUROC: 0.769
valid loss: 1.799
valid AUROC: 0.728
BEST VALID LOSS: 1.7990168001916673
BEST VALID AUROC: 0.728135723852622
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.124
train AUROC: 0.791
valid loss: 1.056
valid AUROC: 0.808
BEST VALID LOSS: 1.0564256873395708
BEST VALID AUROC: 0.8083461788411701
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 0.963
train AUROC: 0.836
valid loss: 1.071
valid AUROC: 0.817
BEST VALID AUROC: 0.8169331861878759
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 0.934
train AUROC: 0.847
valid loss: 0.994
valid AUROC: 0.833
BEST VALID LOSS: 0.9941513968838586
BEST VALID AUROC: 0.8326118054509862
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 0.911
train AUROC: 0.856
valid loss: 1.070
valid AUROC: 0.825
Epoch: 5
train loss: 0.911
train AUROC: 0.857
valid loss: 1.177
valid AUROC: 0.821
Epoch: 6
train loss: 0.936
train AUROC: 0.848
valid loss: 1.

  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [00:58<00:00,  2.42it/s]
100%|██████████| 18/18 [00:02<00:00,  6.23it/s]
100%|██████████| 142/142 [00:53<00:00,  2.67it/s]
100%|██████████| 18/18 [00:02<00:00,  6.11it/s]
100%|██████████| 142/142 [00:53<00:00,  2.64it/s]
100%|██████████| 18/18 [00:03<00:00,  5.90it/s]
100%|██████████| 142/142 [00:52<00:00,  2.69it/s]
100%|██████████| 18/18 [00:02<00:00,  6.13it/s]
100%|██████████| 142/142 [00:53<00:00,  2.64it/s]
100%|██████████| 18/18 [00:02<00:00,  6.16it/s]
100%|██████████| 142/142 [00:53<00:00,  2.66it/s]
100%|██████████| 18/18 [00:02<00:00,  6.06it/s]
100%|██████████| 142/142 [00:59<00:00,  2.40it/s]
100%|██████████| 18/18 [00:02<00:00,  6.07it/s]
100%|██████████| 142/142 [00:53<00:00,  2.65it/s]
100%|██████████| 18/18 [00:03<00:00,  5.99it/s]
100%|██████████| 142/142 [00:52<00:00,  2.71it/s]
100%|██████████| 18/18 [00:03<00:00,  5.97it/s]
100%|██████████| 142/142 [00:53<00:00,  2.66it/s]
100%|██

<All keys matched successfully>

100%|██████████| 18/18 [00:02<00:00,  6.09it/s]


0

test predictions: 0.500
test loss: 1.455
test AUROC: 0.819
Epoch: 0
train loss: 1.136
train AUROC: 0.776
valid loss: 1.340
valid AUROC: 0.818
BEST VALID LOSS: 1.3398859699567158
BEST VALID AUROC: 0.818195034435207
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.034
train AUROC: 0.811
valid loss: 1.140
valid AUROC: 0.808
BEST VALID LOSS: 1.1399602856900957
Epoch: 2
train loss: 0.991
train AUROC: 0.829
valid loss: 1.116
valid AUROC: 0.825
BEST VALID LOSS: 1.116351111067666
BEST VALID AUROC: 0.8254322728029228
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 0.984
train AUROC: 0.834
valid loss: 1.075
valid AUROC: 0.824
BEST VALID LOSS: 1.0753612849447463
Epoch: 4
train loss: 1.039
train AUROC: 0.810
valid loss: 0.979
valid AUROC: 0.836
BEST VALID LOSS: 0.9787584841251373
BEST VALID AUROC: 0.83620565823156
UPDATED BEST INTERMEDIATE MODEL
Epoch: 5
train loss: 0.935
train AUROC: 0.850
valid loss: 1.530
valid AUROC: 0.779
Epoch: 6
train loss: 1.185
train AUROC: 0.757
valid loss: 0.

  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [00:55<00:00,  2.54it/s]
100%|██████████| 18/18 [00:02<00:00,  6.01it/s]
100%|██████████| 142/142 [00:53<00:00,  2.65it/s]
100%|██████████| 18/18 [00:02<00:00,  6.02it/s]
100%|██████████| 142/142 [00:53<00:00,  2.66it/s]
100%|██████████| 18/18 [00:02<00:00,  6.19it/s]
100%|██████████| 142/142 [00:52<00:00,  2.71it/s]
100%|██████████| 18/18 [00:02<00:00,  6.17it/s]
100%|██████████| 142/142 [00:53<00:00,  2.63it/s]
100%|██████████| 18/18 [00:02<00:00,  6.12it/s]
100%|██████████| 142/142 [00:52<00:00,  2.70it/s]
100%|██████████| 18/18 [00:02<00:00,  6.23it/s]
100%|██████████| 142/142 [00:53<00:00,  2.63it/s]
100%|██████████| 18/18 [00:02<00:00,  6.26it/s]
100%|██████████| 142/142 [00:53<00:00,  2.65it/s]
100%|██████████| 18/18 [00:02<00:00,  6.13it/s]
100%|██████████| 142/142 [00:52<00:00,  2.70it/s]
100%|██████████| 18/18 [00:02<00:00,  6.27it/s]
100%|██████████| 142/142 [00:53<00:00,  2.64it/s]
100%|██

<All keys matched successfully>

100%|██████████| 18/18 [00:03<00:00,  5.92it/s]


0

test predictions: 0.136
test loss: 1.531
test AUROC: 0.818
Epoch: 0
train loss: 1.143
train AUROC: 0.779
valid loss: 1.050
valid AUROC: 0.807
BEST VALID LOSS: 1.0500128037399716
BEST VALID AUROC: 0.8067867601938098
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 0.994
train AUROC: 0.827
valid loss: 0.999
valid AUROC: 0.822
BEST VALID LOSS: 0.9990914828247495
BEST VALID AUROC: 0.8222362714347482
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 0.965
train AUROC: 0.835
valid loss: 1.004
valid AUROC: 0.827
BEST VALID AUROC: 0.8269908815492897
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 0.953
train AUROC: 0.840
valid loss: 1.010
valid AUROC: 0.826
Epoch: 4
train loss: 0.927
train AUROC: 0.850
valid loss: 0.971
valid AUROC: 0.831
BEST VALID LOSS: 0.9706840051545037
BEST VALID AUROC: 0.8307173072265616
UPDATED BEST INTERMEDIATE MODEL
Epoch: 5
train loss: 0.914
train AUROC: 0.854
valid loss: 1.036
valid AUROC: 0.835
BEST VALID AUROC: 0.8352005083470647
UPDATED BEST INTERMEDI

  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [00:56<00:00,  2.54it/s]
100%|██████████| 18/18 [00:02<00:00,  6.14it/s]
100%|██████████| 142/142 [00:54<00:00,  2.62it/s]
100%|██████████| 18/18 [00:03<00:00,  5.89it/s]
100%|██████████| 142/142 [00:54<00:00,  2.61it/s]
100%|██████████| 18/18 [00:02<00:00,  6.09it/s]
100%|██████████| 142/142 [00:54<00:00,  2.62it/s]
100%|██████████| 18/18 [00:02<00:00,  6.03it/s]
100%|██████████| 142/142 [00:54<00:00,  2.59it/s]
100%|██████████| 18/18 [00:02<00:00,  6.07it/s]
100%|██████████| 142/142 [00:53<00:00,  2.65it/s]
100%|██████████| 18/18 [00:02<00:00,  6.02it/s]
100%|██████████| 142/142 [01:02<00:00,  2.28it/s]
100%|██████████| 18/18 [00:04<00:00,  4.28it/s]
100%|██████████| 142/142 [01:06<00:00,  2.14it/s]
100%|██████████| 18/18 [00:02<00:00,  6.08it/s]
100%|██████████| 142/142 [00:53<00:00,  2.67it/s]
100%|██████████| 18/18 [00:03<00:00,  5.75it/s]
100%|██████████| 142/142 [00:53<00:00,  2.64it/s]
100%|██

<All keys matched successfully>

100%|██████████| 18/18 [00:03<00:00,  5.98it/s]


0

test predictions: 0.366
test loss: 1.184
test AUROC: 0.821


  results_df = results_df.append(new_row, ignore_index=True)


In [9]:
results_df

Unnamed: 0,data,fold,subset,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
0,set,1,overall,0.816038,0.778466,0.075562,0.689956,0.136207,0.083026,0.780764,0.310044
1,set,2,overall,0.829856,0.806765,0.082921,0.659389,0.147317,0.120749,0.810593,0.340611
2,set,3,overall,0.838451,0.765974,0.07554,0.733624,0.136975,0.135031,0.766814,0.266376
3,set,4,overall,0.834127,0.906367,0.130539,0.473913,0.204695,0.135992,0.91765,0.526087
4,set,5,overall,0.790073,0.482255,0.044581,0.951965,0.085173,0.082219,0.470054,0.048035
5,set,6,overall,0.825867,0.905583,0.13192,0.489083,0.207792,0.136805,0.916402,0.510917
6,set,7,overall,0.772416,0.307131,0.034392,0.973799,0.066438,0.098275,0.289814,0.026201
7,set,8,overall,0.818897,0.505473,0.04561,0.930131,0.086957,0.092536,0.494442,0.069869
8,set,9,overall,0.817904,0.907573,0.112388,0.384279,0.173913,0.10125,0.921166,0.615721
9,set,10,overall,0.821161,0.64953,0.05801,0.842795,0.108549,0.113101,0.64451,0.157205


In [11]:
# Save
#results_df.to_csv('set_transformer_mortality_10_fold_cv_results')

In [13]:
results_df.groupby(['data', 'subset']).mean()
results_df.groupby(['data', 'subset']).std()

Unnamed: 0_level_0,Unnamed: 1_level_0,fold,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
data,subset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
set,overall,5.5,0.816479,0.701512,0.079146,0.712893,0.135402,0.109898,0.701221,0.287107


Unnamed: 0_level_0,Unnamed: 1_level_0,fold,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
data,subset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
set,overall,3.02765,0.020358,0.208692,0.035497,0.212995,0.04946,0.02151,0.219287,0.212995


In [9]:
# CV for set transformer - only testing on those with rare co-morbidities

# Get CV folds and create dictionarys 
cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=2)
fold_n = 0
train_index_dict = {}
test_index_dict = {}
test_index_dict_8 = {}
test_index_dict_45 = {}

for train_index, test_index in cv.split(feature_array, mortality_array):
    fold_n += 1
    train_index_dict[fold_n] = train_index
    test_index_dict[fold_n] = test_index
    # Filter for rare diseases in test set
    test_index_8 = np.array(list(set(test_index.tolist()) & set(index_list_8)))
    test_index_45 = np.array(list(set(test_index.tolist()) & set(index_list_45)))
    test_index_dict_8[fold_n] = test_index_8
    test_index_dict_45[fold_n] = test_index_45

In [10]:
# del for memmory 
del test_index
del train_index
del patient_df
del index
del fold_n
del code
del code_list
del path
del n 
del n2
del max_len
gc.collect()

471

In [12]:
# Run
best_test_auroc = 0
results_df = pd.DataFrame()
for fold in range(10):
    fold += 1
    if fold < 10:
        # Get train, validation and test sets
        train_embeddings = feature_array[np.setdiff1d(train_index_dict[fold], test_index_dict[fold+1])]
        val_embeddings = feature_array[test_index_dict[fold+1]]
        test_embeddings_8 = feature_array[test_index_dict_8[fold]]
        test_embeddings_45 = feature_array[test_index_dict_45[fold]]

        train_labels = mortality_array[np.setdiff1d(train_index_dict[fold], test_index_dict[fold+1])]
        val_labels = mortality_array[test_index_dict[fold+1]]
        test_labels_8 = mortality_array[test_index_dict_8[fold]]
        test_labels_45 = mortality_array[test_index_dict_45[fold]]    

        # Split masks
        train_padding_mask = mask[np.setdiff1d(train_index_dict[fold], test_index_dict[fold+1])]
        val_padding_mask = mask[test_index_dict[fold+1]]
        test_padding_mask_8 = mask[test_index_dict_8[fold]]
        test_padding_mask_45 = mask[test_index_dict_45[fold]]

    else: # Change here to fist cv split for final fold
        # Get train, validation and test sets
        train_embeddings = feature_array[np.setdiff1d(train_index_dict[fold], test_index_dict[1])]
        val_embeddings = feature_array[test_index_dict[1]]
        test_embeddings_8 = feature_array[test_index_dict_8[fold]]
        test_embeddings_45 = feature_array[test_index_dict_45[fold]]

        train_labels = mortality_array[np.setdiff1d(train_index_dict[fold], test_index_dict[1])]
        val_labels = mortality_array[test_index_dict[1]]
        test_labels_8 = mortality_array[test_index_dict_8[fold]]
        test_labels_45 = mortality_array[test_index_dict_45[fold]]

        # Split masks
        train_padding_mask = mask[np.setdiff1d(train_index_dict[fold], test_index_dict[1])]
        val_padding_mask = mask[test_index_dict[1]]
        test_padding_mask_8 = mask[test_index_dict_8[fold]]
        test_padding_mask_45 = mask[test_index_dict_45[fold]]
    
    # Create datasets
    train_dataset = DiseaseDataset(train_embeddings.astype(np.float32), train_labels.astype(np.float32), train_padding_mask)
    val_dataset = DiseaseDataset(val_embeddings.astype(np.float32), val_labels.astype(np.float32), val_padding_mask)
    test_dataset_8 = DiseaseDataset(test_embeddings_8.astype(np.float32), test_labels_8.astype(np.float32), test_padding_mask_8)
    test_dataset_45 = DiseaseDataset(test_embeddings_45.astype(np.float32), test_labels_45.astype(np.float32), test_padding_mask_45)

    # Define batch size
    batch_size = 512

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader_8 = DataLoader(test_dataset_8, batch_size=batch_size)
    test_loader_45 = DataLoader(test_dataset_45, batch_size=batch_size)

    # Create weight for loss
    unique, counts = np.unique(train_labels, return_counts=True)
    pos_weight = torch.Tensor([(counts[0] / counts[1])])

    # Define model
    model = SetTransformer(dim_input=128, num_outputs=1, dim_output=1, num_inds=32, dim_hidden=160, num_heads=4, ln=False)

    model = model.to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # Run
    best_valid_loss = float('inf')
    best_valid_auroc = 0
    num_epochs = 10

    for epoch in range(num_epochs):

        start_time = time.time()

        train_loss, train_auroc, train_predictions, train_labels_out, train_pma_final = train(model, train_loader, optimizer, criterion)
        valid_loss, valid_auroc, valid_predictions, valid_labels_out, valid_pma_final = evaluate(model, val_loader, criterion)
        
        if epoch % 1 == 0:
            print('Epoch:', epoch)
            print(f'train loss: {train_loss:.3f}')
            print(f'train AUROC: {train_auroc:.3f}')
            print(f'valid loss: {valid_loss:.3f}')
            print(f'valid AUROC: {valid_auroc:.3f}')
        
        end_time = time.time()

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss  
            print('BEST VALID LOSS:', best_valid_loss)

        if valid_auroc > best_valid_auroc:
            best_valid_auroc = valid_auroc
            print('BEST VALID AUROC:', best_valid_auroc)
            print('UPDATED BEST INTERMEDIATE MODEL')
            torch.save(model.state_dict(), f'intermediate_set_transformer_mortality.pt')

    # -----------------------------
    # Evaluate best model on test set
    # -----------------------------

    model.load_state_dict(torch.load(f'intermediate_set_transformer_mortality.pt'))

    test_loss_8, test_auroc_8, test_predictions_8, test_labels_out_8, test_pma_final_8 = evaluate(model, test_loader_8, criterion)
    test_loss_45, test_auroc_45, test_predictions_45, test_labels_out_45, test_pma_final_45 = evaluate(model, test_loader_45, criterion)

    # del for memmory 
    del train_embeddings
    del val_embeddings
    del train_labels
    del val_labels
    del train_padding_mask
    del val_padding_mask
    del train_loader
    del val_loader
    del train_predictions
    del train_labels_out
    del train_pma_final
    del valid_predictions
    del valid_labels_out
    del valid_pma_final
    gc.collect()

    print(f'test predictions 8: {test_predictions_8.mean().item():.3f}') 
    print(f'test loss 8: {test_loss_8:.3f}')
    print(f'test AUROC 8: {test_auroc_8:.3f}')

    print(f'test predictions 45: {test_predictions_45.mean().item():.3f}') 
    print(f'test loss 45: {test_loss_45:.3f}')
    print(f'test AUROC 45: {test_auroc_45:.3f}')

    #if test_auroc > best_test_auroc:
    #    best_test_auroc = test_auroc
    #    print('BEST TEST AUROC:', best_test_auroc)
        #print('UPDATED BEST MODEL')
        #torch.save(model.state_dict(), f'set_transformer_mortality.pt')

    # Get results
    # AUC
    try:
        aucroc = roc_auc_score(test_labels_out_8, test_predictions_8)
    except:
        aucroc = np.nan
    # Accuracy
    accuracy = accuracy_score(test_labels_out_8, test_predictions_8.round())
    # Precision
    precision = precision_score(test_labels_out_8, test_predictions_8.round())
    # Recall
    recall = recall_score(test_labels_out_8, test_predictions_8.round())
    # AUPRC
    try:
        auprc = average_precision_score(test_labels_out_8, test_predictions_8)
    except:
        auprc = np.nan
    # F1
    f1 = f1_score(test_labels_out_8, test_predictions_8.round())
    # TPR and FPR
    cm = confusion_matrix(test_labels_out_8, test_predictions_8.round())
    try:
        _tp = cm[0, 0]
    except:
        _tp = np.nan
    try:
        _fn = cm[0, 1]
    except:
        _fn = np.nan
    try:
        _fp = cm[1, 0]
    except:
        _fp = np.nan
    try:
        _tn = cm[1, 1]
    except:
        _tn = np.nan
    tpr = _tp / (_tp + _fn)
    fpr = _fp / (_tn + _fp)

    new_row = {'data': 'set', 'fold': fold, 'subset': 8, 'AUROC': aucroc, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'F1': f1, 'AUPRC': auprc, 'TPR': tpr, 'FPR': fpr}
    results_df = results_df.append(new_row, ignore_index=True)

    # Get results
    # AUC
    try:
        aucroc = roc_auc_score(test_labels_out_45, test_predictions_45)
    except:
        aucroc = np.nan
    # Accuracy
    accuracy = accuracy_score(test_labels_out_45, test_predictions_45.round())
    # Precision
    precision = precision_score(test_labels_out_45, test_predictions_45.round())
    # Recall
    recall = recall_score(test_labels_out_45, test_predictions_45.round())
    # AUPRC
    try:
        auprc = average_precision_score(test_labels_out_45, test_predictions_45)
    except:
        auprc = np.nan
    # F1
    f1 = f1_score(test_labels_out_45, test_predictions_45.round())
    # TPR and FPR
    cm = confusion_matrix(test_labels_out_45, test_predictions_45.round())
    try:
        _tp = cm[0, 0]
    except:
        _tp = np.nan
    try:
        _fn = cm[0, 1]
    except:
        _fn = np.nan
    try:
        _fp = cm[1, 0]
    except:
        _fp = np.nan
    try:
        _tn = cm[1, 1]
    except:
        _tn = np.nan
    tpr = _tp / (_tp + _fn)
    fpr = _fp / (_tn + _fp)

    new_row = {'data': 'set', 'fold': fold, 'subset': 45, 'AUROC': aucroc, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'F1': f1, 'AUPRC': auprc, 'TPR': tpr, 'FPR': fpr}
    results_df = results_df.append(new_row, ignore_index=True)


100%|██████████| 142/142 [01:10<00:00,  2.02it/s]
100%|██████████| 18/18 [00:03<00:00,  5.83it/s]
100%|██████████| 142/142 [00:55<00:00,  2.56it/s]
100%|██████████| 18/18 [00:02<00:00,  6.02it/s]
100%|██████████| 142/142 [00:54<00:00,  2.61it/s]
100%|██████████| 18/18 [00:02<00:00,  6.07it/s]
100%|██████████| 142/142 [00:54<00:00,  2.60it/s]
100%|██████████| 18/18 [00:03<00:00,  5.97it/s]
100%|██████████| 142/142 [00:56<00:00,  2.54it/s]
100%|██████████| 18/18 [00:03<00:00,  5.93it/s]
100%|██████████| 142/142 [01:00<00:00,  2.35it/s]
100%|██████████| 18/18 [00:03<00:00,  5.94it/s]
100%|██████████| 142/142 [00:54<00:00,  2.61it/s]
100%|██████████| 18/18 [00:02<00:00,  6.01it/s]
100%|██████████| 142/142 [00:54<00:00,  2.60it/s]
100%|██████████| 18/18 [00:02<00:00,  6.06it/s]
100%|██████████| 142/142 [00:54<00:00,  2.60it/s]
100%|██████████| 18/18 [00:03<00:00,  5.92it/s]
100%|██████████| 142/142 [00:54<00:00,  2.59it/s]
100%|██████████| 18/18 [00:03<00:00,  5.85it/s]


Epoch: 0
train loss: 1.167
train AUROC: 0.757
valid loss: 1.356
valid AUROC: 0.771
BEST VALID LOSS: 1.3558742437097762
BEST VALID AUROC: 0.7708357168292892
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.066
train AUROC: 0.806
valid loss: 1.003
valid AUROC: 0.831
BEST VALID LOSS: 1.0033313499556646
BEST VALID AUROC: 0.8308270099495169
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 0.962
train AUROC: 0.837
valid loss: 1.026
valid AUROC: 0.832
BEST VALID AUROC: 0.832040426072499
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 0.973
train AUROC: 0.834
valid loss: 0.954
valid AUROC: 0.841
BEST VALID LOSS: 0.9538111024432712
BEST VALID AUROC: 0.8407101604532332
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 0.985
train AUROC: 0.826
valid loss: 0.976
valid AUROC: 0.851
BEST VALID AUROC: 0.850820640753051
UPDATED BEST INTERMEDIATE MODEL
Epoch: 5
train loss: 0.950
train AUROC: 0.843
valid loss: 1.028
valid AUROC: 0.844
Epoch: 6
train loss: 0.959
train AUROC: 0.839
valid

<All keys matched successfully>

100%|██████████| 1/1 [00:00<00:00, 26.90it/s]
100%|██████████| 2/2 [00:00<00:00, 10.60it/s]


0

test predictions 8: 0.276
test loss 8: 0.509
test AUROC 8: nan
test predictions 45: 0.247
test loss 45: 1.241
test AUROC 45: 0.735
Epoch: 0
train loss: 1.163
train AUROC: 0.758
valid loss: 1.938
valid AUROC: 0.726
BEST VALID LOSS: 1.9375099622541003
BEST VALID AUROC: 0.7258078255929765
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.293
train AUROC: 0.747
valid loss: 0.945
valid AUROC: 0.851
BEST VALID LOSS: 0.9449769821431901
BEST VALID AUROC: 0.8510402938349051
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 0.980
train AUROC: 0.835
valid loss: 0.923
valid AUROC: 0.865
BEST VALID LOSS: 0.9230816496743096
BEST VALID AUROC: 0.8646097529930519
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 0.966
train AUROC: 0.837
valid loss: 0.863
valid AUROC: 0.874
BEST VALID LOSS: 0.8627264599005381
BEST VALID AUROC: 0.8737554436571272
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 0.930
train AUROC: 0.849
valid loss: 1.037
valid AUROC: 0.864
Epoch: 5
train loss: 0.955
train A

  _warn_prf(average, modifier, msg_start, len(result))
  fpr = _fp / (_tn + _fp)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [01:18<00:00,  1.81it/s]
100%|██████████| 18/18 [00:03<00:00,  5.92it/s]
100%|██████████| 142/142 [00:57<00:00,  2.47it/s]
100%|██████████| 18/18 [00:03<00:00,  5.82it/s]
100%|██████████| 142/142 [00:55<00:00,  2.57it/s]
100%|██████████| 18/18 [00:03<00:00,  5.83it/s]
100%|██████████| 142/142 [00:55<00:00,  2.54it/s]
100%|██████████| 18/18 [00:03<00:00,  6.00it/s]
100%|██████████| 142/142 [00:55<00:00,  2.58it/s]
100%|██████████| 18/18 [00:02<00:00,  6.09it/s]
100%|██████████| 142/142 [00:54<00:00,  2.59it/s]
100%|██████████| 18/18 [00:03<00:00,  5.94it/s]
100%|██████████| 142/142 [01:00<00:00,  2.34it/s]
100%|██████████| 18/18 [00:03<00:00,  5.96it/s]
100%|██████████| 142/142 [00:55<00:00,  2.55it/s]
100%|██████████| 18/18 [00:02<00:00,  6.12it/s]
100%|████████

<All keys matched successfully>

100%|██████████| 1/1 [00:00<00:00, 21.08it/s]
100%|██████████| 2/2 [00:00<00:00, 10.72it/s]


0

  _warn_prf(average, modifier, msg_start, len(result))
  fpr = _fp / (_tn + _fp)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [01:13<00:00,  1.94it/s]
100%|██████████| 18/18 [00:03<00:00,  5.98it/s]
100%|██████████| 142/142 [00:56<00:00,  2.53it/s]
100%|██████████| 18/18 [00:03<00:00,  5.90it/s]
100%|██████████| 142/142 [00:56<00:00,  2.52it/s]
100%|██████████| 18/18 [00:03<00:00,  5.99it/s]
100%|██████████| 142/142 [00:59<00:00,  2.37it/s]
100%|██████████| 18/18 [00:03<00:00,  5.43it/s]
100%|██████████| 142/142 [01:00<00:00,  2.36it/s]
100%|██████████| 18/18 [00:03<00:00,  5.45it/s]
100%|██████████| 142/142 [00:59<00:00,  2.38it/s]
100%|██████████| 18/18 [00:03<00:00,  5.59it/s]
100%|██████████| 142/142 [01:07<00:00,  2.11it/s]
100%|██████████| 18/18 [00:03<00:00,  5.43it/s]
100%|██████████| 142/142 [01:00<00:00,  2.36it/s]
100%|██████████| 18/18 [00:03<00:00,  5.50it/s]
100%|████████

test predictions 8: 0.341
test loss 8: 0.574
test AUROC 8: nan
test predictions 45: 0.267
test loss 45: 0.479
test AUROC 45: 0.869
Epoch: 0
train loss: 1.162
train AUROC: 0.757
valid loss: 1.153
valid AUROC: 0.796
BEST VALID LOSS: 1.1525114046202765
BEST VALID AUROC: 0.7957540144401483
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.029
train AUROC: 0.809
valid loss: 0.969
valid AUROC: 0.841
BEST VALID LOSS: 0.9692818754249148
BEST VALID AUROC: 0.840856298824272
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 0.960
train AUROC: 0.837
valid loss: 1.084
valid AUROC: 0.824
Epoch: 3
train loss: 0.997
train AUROC: 0.826
valid loss: 0.972
valid AUROC: 0.848
BEST VALID AUROC: 0.8481828986822378
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 0.938
train AUROC: 0.847
valid loss: 1.055
valid AUROC: 0.843
Epoch: 5
train loss: 0.967
train AUROC: 0.838
valid loss: 0.967
valid AUROC: 0.852
BEST VALID LOSS: 0.9667700926462809
BEST VALID AUROC: 0.8517658111733606
UPDATED BEST INTERME

<All keys matched successfully>

100%|██████████| 1/1 [00:00<00:00, 18.66it/s]
100%|██████████| 2/2 [00:00<00:00, 10.13it/s]


0

test predictions 8: 0.217
test loss 8: 0.393
test AUROC 8: 0.967
test predictions 45: 0.217
test loss 45: 0.803
test AUROC 45: 0.923
Epoch: 0
train loss: 1.140
train AUROC: 0.775
valid loss: 1.476
valid AUROC: 0.763
BEST VALID LOSS: 1.4763859974013434
BEST VALID AUROC: 0.7632047032390492
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.111
train AUROC: 0.776
valid loss: 1.077
valid AUROC: 0.801
BEST VALID LOSS: 1.077010389831331
BEST VALID AUROC: 0.8005232150357825
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 0.989
train AUROC: 0.827
valid loss: 1.175
valid AUROC: 0.789
Epoch: 3
train loss: 1.048
train AUROC: 0.802
valid loss: 1.041
valid AUROC: 0.805
BEST VALID LOSS: 1.0408413608868916
BEST VALID AUROC: 0.8052263550194564
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 0.980
train AUROC: 0.830
valid loss: 1.178
valid AUROC: 0.791
Epoch: 5
train loss: 0.976
train AUROC: 0.834
valid loss: 1.003
valid AUROC: 0.820
BEST VALID LOSS: 1.0028717087374792
BEST VALID AUROC: 0

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [01:15<00:00,  1.88it/s]
100%|██████████| 18/18 [00:03<00:00,  5.20it/s]
100%|██████████| 142/142 [01:04<00:00,  2.20it/s]
100%|██████████| 18/18 [00:03<00:00,  5.34it/s]
100%|██████████| 142/142 [00:58<00:00,  2.43it/s]
100%|██████████| 18/18 [00:03<00:00,  5.77it/s]
100%|██████████| 142/142 [00:55<00:00,  2.58it/s]
100%|██████████| 18/18 [00:03<00:00,  5.92it/s]
100%|██████████| 142/142 [00:55<00:00,  2.57it/s]
100%|██████████| 18/18 [00:03<00:00,  5.80it/s]
100%|██████████| 142/142 [00:59<00:00,  2.38it/s]
100%|██████████| 18/18 [00:03<00:00,  5.29it/s]
100%|██████████| 142/142 [01:02<00:00,  2.27it/s]
100%|██████████| 18/18 [00:03<00:00,  5.90it/s]
100%|██████████| 142/142 [00:55<00:00,  2.57it/s]
100%|██████████| 18/18 [00:03<00:00,  5.97it/s]
100%|██████████| 142/142 [00:55<00:00,  2.57it/s]
100%|██████████| 18/18 [00:03<00:00,  5.81it

<All keys matched successfully>

100%|██████████| 1/1 [00:00<00:00, 29.36it/s]
100%|██████████| 1/1 [00:00<00:00,  5.78it/s]


0

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [01:17<00:00,  1.84it/s]
100%|██████████| 18/18 [00:03<00:00,  5.67it/s]
100%|██████████| 142/142 [01:00<00:00,  2.33it/s]
100%|██████████| 18/18 [00:03<00:00,  5.63it/s]
100%|██████████| 142/142 [00:56<00:00,  2.51it/s]
100%|██████████| 18/18 [00:03<00:00,  5.65it/s]
100%|██████████| 142/142 [00:57<00:00,  2.45it/s]
100%|██████████| 18/18 [00:03<00:00,  5.71it/s]
100%|██████████| 142/142 [00:56<00:00,  2.49it/s]
100%|██████████| 18/18 [00:03<00:00,  5.56it/s]
100%|██████████| 142/142 [00:56<00:00,  2.52it/s]
100%|██████████| 18/18 [00:03<00:00,  5.73it/s]
100%|██████████| 142/142 [01:01<00:00,  2.30it/s]
100%|██████████| 18/18 [00:03<00:00,  5.79it/s]
100%|██████████| 142/142 [00:56<00:00,  2.53it/s]
100%|██████████| 18/18 [00:03<00:00,  5.81it/s]
100%|██████████| 142/142 [00:55<00:00,  2.56it/s]
100%|██████████| 18/18 [00:03<00:00,  5.81it

test predictions 8: 0.417
test loss 8: 1.264
test AUROC 8: 0.912
test predictions 45: 0.377
test loss 45: 0.867
test AUROC 45: 0.893
Epoch: 0
train loss: 1.138
train AUROC: 0.778
valid loss: 1.288
valid AUROC: 0.815
BEST VALID LOSS: 1.2881134284867182
BEST VALID AUROC: 0.8150935377519238
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.091
train AUROC: 0.788
valid loss: 1.030
valid AUROC: 0.845
BEST VALID LOSS: 1.029807448387146
BEST VALID AUROC: 0.8447681468390145
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.011
train AUROC: 0.815
valid loss: 1.004
valid AUROC: 0.839
BEST VALID LOSS: 1.0039128462473552
Epoch: 3
train loss: 0.960
train AUROC: 0.839
valid loss: 0.974
valid AUROC: 0.850
BEST VALID LOSS: 0.973584363857905
BEST VALID AUROC: 0.8498482314806742
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 0.932
train AUROC: 0.852
valid loss: 1.101
valid AUROC: 0.847
Epoch: 5
train loss: 1.065
train AUROC: 0.793
valid loss: 0.978
valid AUROC: 0.854
BEST VALID AUROC: 0.

<All keys matched successfully>

100%|██████████| 1/1 [00:00<00:00, 28.34it/s]
100%|██████████| 1/1 [00:00<00:00,  6.06it/s]


0

test predictions 8: 0.351
test loss 8: 0.654
test AUROC 8: 0.957
test predictions 45: 0.352
test loss 45: 0.924
test AUROC 45: 0.778
Epoch: 0
train loss: 1.133
train AUROC: 0.774
valid loss: 1.214
valid AUROC: 0.806
BEST VALID LOSS: 1.2135267588827345
BEST VALID AUROC: 0.806418609673559
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.029
train AUROC: 0.818
valid loss: 1.003
valid AUROC: 0.835
BEST VALID LOSS: 1.0031103955374823
BEST VALID AUROC: 0.8345651316780129
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 0.989
train AUROC: 0.828
valid loss: 0.948
valid AUROC: 0.846
BEST VALID LOSS: 0.9480303393469917
BEST VALID AUROC: 0.8460577334580239
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 0.943
train AUROC: 0.845
valid loss: 0.933
valid AUROC: 0.854
BEST VALID LOSS: 0.9330031606886122
BEST VALID AUROC: 0.854269282130941
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 0.935
train AUROC: 0.848
valid loss: 1.324
valid AUROC: 0.802
Epoch: 5
train loss: 1.022
train A

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [01:06<00:00,  2.14it/s]
100%|██████████| 18/18 [00:03<00:00,  5.79it/s]
100%|██████████| 142/142 [00:57<00:00,  2.45it/s]
100%|██████████| 18/18 [00:03<00:00,  5.74it/s]
100%|██████████| 142/142 [00:56<00:00,  2.50it/s]
100%|██████████| 18/18 [00:03<00:00,  5.91it/s]
100%|██████████| 142/142 [00:55<00:00,  2.57it/s]
100%|██████████| 18/18 [00:03<00:00,  5.89it/s]
100%|██████████| 142/142 [00:55<00:00,  2.55it/s]
100%|██████████| 18/18 [00:03<00:00,  5.78it/s]
100%|██████████| 142/142 [00:56<00:00,  2.53it/s]
100%|██████████| 18/18 [00:03<00:00,  5.85it/s]
100%|██████████| 142/142 [01:03<00:00,  2.23it/s]
100%|██████████| 18/18 [00:03<00:00,  5.78it/s]
100%|██████████| 142/142 [00:56<00:00,  2.50it/s]
100%|██████████| 18/18 [00:03<00:00,  5.81it/s]
100%|██████████| 142/142 [00:55<00:00,  2.57it/s]
100%|██████████| 18/18 [00:03<00:00,  5.86it

<All keys matched successfully>

  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00, 24.24it/s]
  0%|          | 0/2 [00:00<?, ?it/s] 50%|█████     | 1/2 [00:00<00:00,  5.35it/s]100%|██████████| 2/2 [00:00<00:00, 10.24it/s]


0

test predictions 8: 0.276
test loss 8: 0.602
test AUROC 8: 0.910
test predictions 45: 0.235
test loss 45: 0.636
test AUROC 45: 0.871
Epoch: 0
train loss: 1.140
train AUROC: 0.773
valid loss: 1.013
valid AUROC: 0.826
BEST VALID LOSS: 1.0127150515715282
BEST VALID AUROC: 0.8256893480690131
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 0.976
train AUROC: 0.836
valid loss: 1.250
valid AUROC: 0.769
Epoch: 2
train loss: 1.149
train AUROC: 0.753
valid loss: 1.515
valid AUROC: 0.756
Epoch: 3
train loss: 1.191
train AUROC: 0.755
valid loss: 1.041
valid AUROC: 0.816
Epoch: 4
train loss: 0.955
train AUROC: 0.839
valid loss: 0.987
valid AUROC: 0.836
BEST VALID LOSS: 0.9870909485552046
BEST VALID AUROC: 0.835846050055873
UPDATED BEST INTERMEDIATE MODEL
Epoch: 5
train loss: 0.924
train AUROC: 0.850
valid loss: 1.044
valid AUROC: 0.829
Epoch: 6
train loss: 0.930
train AUROC: 0.850
valid loss: 1.045
valid AUROC: 0.829
Epoch: 7
train loss: 0.971
train AUROC: 0.831
valid loss: 0.998
valid AUROC: 

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [01:19<00:00,  1.78it/s]
100%|██████████| 18/18 [00:03<00:00,  5.72it/s]
100%|██████████| 142/142 [00:57<00:00,  2.49it/s]
100%|██████████| 18/18 [00:03<00:00,  5.78it/s]
100%|██████████| 142/142 [00:54<00:00,  2.58it/s]
100%|██████████| 18/18 [00:02<00:00,  6.01it/s]
100%|██████████| 142/142 [00:54<00:00,  2.61it/s]
100%|██████████| 18/18 [00:03<00:00,  5.92it/s]
100%|██████████| 142/142 [00:54<00:00,  2.59it/s]
100%|██████████| 18/18 [00:03<00:00,  5.92it/s]
100%|██████████| 142/142 [00:54<00:00,  2.59it/s]
100%|██████████| 18/18 [00:03<00:00,  5.73it/s]
100%|██████████| 142/142 [01:02<00:00,  2.28it/s]
100%|██████████| 18/18 [00:03<00:00,  5.76it/s]
100%|██████████| 142/142 [00:54<00:00,  2.59it/s]
100%|██████████| 18/18 [00:03<00:00,  5.81it/s]
100%|██████████| 142/142 [00:54<00:00,  2.60it/s]
100%|██████████| 18/18 [00:03<00:00,  5.78it

<All keys matched successfully>

  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00, 22.94it/s]
  0%|          | 0/2 [00:00<?, ?it/s] 50%|█████     | 1/2 [00:00<00:00,  5.84it/s]100%|██████████| 2/2 [00:00<00:00, 11.04it/s]


0

test predictions 8: 0.214
test loss 8: 2.464
test AUROC 8: 0.577
test predictions 45: 0.190
test loss 45: 0.797
test AUROC 45: 0.713
Epoch: 0
train loss: 1.137
train AUROC: 0.776
valid loss: 1.096
valid AUROC: 0.797
BEST VALID LOSS: 1.096046722597546
BEST VALID AUROC: 0.7967441095586428
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 0.973
train AUROC: 0.834
valid loss: 1.030
valid AUROC: 0.820
BEST VALID LOSS: 1.0297419130802155
BEST VALID AUROC: 0.8202612954612098
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 0.953
train AUROC: 0.840
valid loss: 1.103
valid AUROC: 0.806
Epoch: 3
train loss: 0.949
train AUROC: 0.839
valid loss: 1.534
valid AUROC: 0.787
Epoch: 4
train loss: 1.034
train AUROC: 0.825
valid loss: 0.996
valid AUROC: 0.831
BEST VALID LOSS: 0.9956837197144827
BEST VALID AUROC: 0.8306661568089777
UPDATED BEST INTERMEDIATE MODEL
Epoch: 5
train loss: 0.916
train AUROC: 0.854
valid loss: 1.025
valid AUROC: 0.829
Epoch: 6
train loss: 0.908
train AUROC: 0.858
valid los

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [01:12<00:00,  1.96it/s]
100%|██████████| 18/18 [00:03<00:00,  5.48it/s]
100%|██████████| 142/142 [00:54<00:00,  2.58it/s]
100%|██████████| 18/18 [00:03<00:00,  5.76it/s]
100%|██████████| 142/142 [00:54<00:00,  2.58it/s]
100%|██████████| 18/18 [00:03<00:00,  5.89it/s]
100%|██████████| 142/142 [00:56<00:00,  2.52it/s]
100%|██████████| 18/18 [00:03<00:00,  5.30it/s]
100%|██████████| 142/142 [00:59<00:00,  2.38it/s]
100%|██████████| 18/18 [00:03<00:00,  5.29it/s]
100%|██████████| 142/142 [01:00<00:00,  2.33it/s]
100%|██████████| 18/18 [00:03<00:00,  5.33it/s]
100%|██████████| 142/142 [01:08<00:00,  2.08it/s]
100%|██████████| 18/18 [00:03<00:00,  5.25it/s]
100%|██████████| 142/142 [01:00<00:00,  2.33it/s]
100%|██████████| 18/18 [00:03<00:00,  5.62it/s]
100%|██████████| 142/142 [01:01<00:00,  2.32it/s]
100%|██████████| 18/18 [00:03<00:00,  5.12it

<All keys matched successfully>

  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00, 23.76it/s]
  0%|          | 0/2 [00:00<?, ?it/s] 50%|█████     | 1/2 [00:00<00:00,  5.20it/s]100%|██████████| 2/2 [00:00<00:00,  9.84it/s]


0

test predictions 8: 0.282
test loss 8: 1.622
test AUROC 8: 0.854
test predictions 45: 0.234
test loss 45: 1.109
test AUROC 45: 0.778
Epoch: 0
train loss: 1.147
train AUROC: 0.770
valid loss: 1.280
valid AUROC: 0.790
BEST VALID LOSS: 1.2802615099483066
BEST VALID AUROC: 0.7900663442411178
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.038
train AUROC: 0.815
valid loss: 1.181
valid AUROC: 0.803
BEST VALID LOSS: 1.1805301308631897
BEST VALID AUROC: 0.8032717904722655
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.041
train AUROC: 0.808
valid loss: 0.966
valid AUROC: 0.834
BEST VALID LOSS: 0.965732991695404
BEST VALID AUROC: 0.8344962810768828
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 0.961
train AUROC: 0.836
valid loss: 1.054
valid AUROC: 0.816
Epoch: 4
train loss: 1.007
train AUROC: 0.818
valid loss: 1.009
valid AUROC: 0.831
Epoch: 5
train loss: 1.035
train AUROC: 0.817
valid loss: 0.977
valid AUROC: 0.835
BEST VALID AUROC: 0.8350475812139897
UPDATED BEST INTER

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [01:24<00:00,  1.68it/s]
100%|██████████| 18/18 [00:03<00:00,  5.30it/s]
100%|██████████| 142/142 [00:59<00:00,  2.37it/s]
100%|██████████| 18/18 [00:03<00:00,  5.83it/s]
100%|██████████| 142/142 [00:54<00:00,  2.58it/s]
100%|██████████| 18/18 [00:03<00:00,  6.00it/s]
100%|██████████| 142/142 [00:55<00:00,  2.57it/s]
100%|██████████| 18/18 [00:03<00:00,  5.83it/s]
100%|██████████| 142/142 [00:54<00:00,  2.61it/s]
100%|██████████| 18/18 [00:03<00:00,  5.91it/s]
100%|██████████| 142/142 [00:54<00:00,  2.60it/s]
100%|██████████| 18/18 [00:02<00:00,  6.05it/s]
100%|██████████| 142/142 [01:05<00:00,  2.17it/s]
100%|██████████| 18/18 [00:03<00:00,  5.51it/s]
100%|██████████| 142/142 [00:58<00:00,  2.44it/s]
100%|██████████| 18/18 [00:03<00:00,  5.94it/s]
100%|██████████| 142/142 [00:54<00:00,  2.60it/s]
100%|██████████| 18/18 [00:03<00:00,  5.95it

<All keys matched successfully>

  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00, 33.76it/s]
  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00,  5.84it/s]100%|██████████| 1/1 [00:00<00:00,  5.80it/s]


0

test predictions 8: 0.286
test loss 8: 1.088
test AUROC 8: 0.794
test predictions 45: 0.255
test loss 45: 1.998
test AUROC 45: 0.664
Epoch: 0
train loss: 1.159
train AUROC: 0.772
valid loss: 1.299
valid AUROC: 0.764
BEST VALID LOSS: 1.2987295985221863
BEST VALID AUROC: 0.7639103300343273
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.078
train AUROC: 0.797
valid loss: 0.984
valid AUROC: 0.830
BEST VALID LOSS: 0.9837486942609152
BEST VALID AUROC: 0.82957322916775
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 0.983
train AUROC: 0.831
valid loss: 1.044
valid AUROC: 0.805
Epoch: 3
train loss: 0.978
train AUROC: 0.827
valid loss: 1.107
valid AUROC: 0.818
Epoch: 4
train loss: 0.957
train AUROC: 0.842
valid loss: 1.003
valid AUROC: 0.827
Epoch: 5
train loss: 0.984
train AUROC: 0.825
valid loss: 0.976
valid AUROC: 0.830
BEST VALID LOSS: 0.9755645659234788
BEST VALID AUROC: 0.8300219454973099
UPDATED BEST INTERMEDIATE MODEL
Epoch: 6
train loss: 0.933
train AUROC: 0.848
valid loss

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [01:22<00:00,  1.73it/s]
100%|██████████| 18/18 [00:03<00:00,  5.77it/s]
100%|██████████| 142/142 [00:55<00:00,  2.57it/s]
100%|██████████| 18/18 [00:03<00:00,  5.90it/s]
100%|██████████| 142/142 [00:55<00:00,  2.58it/s]
100%|██████████| 18/18 [00:03<00:00,  5.83it/s]
100%|██████████| 142/142 [00:55<00:00,  2.58it/s]
100%|██████████| 18/18 [00:03<00:00,  5.99it/s]
100%|██████████| 142/142 [00:56<00:00,  2.51it/s]
100%|██████████| 18/18 [00:03<00:00,  5.88it/s]
100%|██████████| 142/142 [00:55<00:00,  2.58it/s]
100%|██████████| 18/18 [00:03<00:00,  5.69it/s]
100%|██████████| 142/142 [01:01<00:00,  2.30it/s]
100%|██████████| 18/18 [00:03<00:00,  5.96it/s]
100%|██████████| 142/142 [00:54<00:00,  2.59it/s]
100%|██████████| 18/18 [00:03<00:00,  5.88it/s]
100%|██████████| 142/142 [00:54<00:00,  2.59it/s]
100%|██████████| 18/18 [00:03<00:00,  5.86it

<All keys matched successfully>

  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00, 23.15it/s]
  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00,  5.80it/s]100%|██████████| 1/1 [00:00<00:00,  5.76it/s]


0

test predictions 8: 0.290
test loss 8: 0.464
test AUROC 8: nan
test predictions 45: 0.285
test loss 45: 1.092
test AUROC 45: 0.734


  _warn_prf(average, modifier, msg_start, len(result))
  fpr = _fp / (_tn + _fp)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)


In [13]:
results_df

Unnamed: 0,data,fold,subset,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
0,set,1,8,,0.752688,0.0,0.0,0.0,-0.0,0.752688,
1,set,1,45,0.735168,0.791289,0.017699,0.333333,0.033613,0.110905,0.79633,0.666667
2,set,2,8,,0.706667,0.0,0.0,0.0,-0.0,0.706667,
3,set,2,45,0.86887,0.769797,0.03876,0.833333,0.074074,0.156789,0.769088,0.166667
4,set,3,8,0.967391,0.870968,0.076923,1.0,0.142857,0.25,0.869565,0.0
5,set,3,45,0.922596,0.84466,0.082353,0.777778,0.148936,0.214279,0.84585,0.222222
6,set,4,8,0.911765,0.647887,0.107143,1.0,0.193548,0.268687,0.632353,0.0
7,set,4,45,0.893443,0.7249,0.055944,0.8,0.104575,0.226468,0.723361,0.2
8,set,5,8,0.957143,0.647887,0.038462,1.0,0.074074,0.25,0.642857,0.0
9,set,5,45,0.778345,0.673347,0.041667,0.777778,0.079096,0.055274,0.671429,0.222222


In [14]:
results_df.groupby(['data', 'subset']).mean()
results_df.groupby(['data', 'subset']).std()

Unnamed: 0_level_0,Unnamed: 1_level_0,fold,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
data,subset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
set,8,5.5,0.853212,0.74568,0.054376,0.59,0.096823,0.153954,0.745516,0.157143
set,45,5.5,0.795891,0.783094,0.055151,0.659466,0.100948,0.141016,0.785284,0.340534


Unnamed: 0_level_0,Unnamed: 1_level_0,fold,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
data,subset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
set,8,3.02765,0.135461,0.068243,0.044508,0.462961,0.077117,0.123463,0.073729,0.269921
set,45,3.02765,0.087543,0.059741,0.023401,0.182321,0.040909,0.055891,0.061222,0.182321


In [14]:
# CV for set transformer - random embedding results - on whole dataset and those with rare co-morbidities

# Get CV folds and create dictionarys 
cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=2)
fold_n = 0
train_index_dict = {}
test_index_dict = {}
test_index_dict_8 = {}
test_index_dict_45 = {}

for train_index, test_index in cv.split(feature_array, mortality_array):
    fold_n += 1
    train_index_dict[fold_n] = train_index
    test_index_dict[fold_n] = test_index
    # Filter for rare diseases in test set
    test_index_8 = np.array(list(set(test_index.tolist()) & set(index_list_8)))
    test_index_45 = np.array(list(set(test_index.tolist()) & set(index_list_45)))
    test_index_dict_8[fold_n] = test_index_8
    test_index_dict_45[fold_n] = test_index_45

In [15]:
# del for memmory 
del test_index
del train_index
del patient_df
del index
del fold_n
del code
del code_list
del path
del n 
del n2
del max_len
gc.collect()

0

In [16]:
# Run
best_test_auroc = 0
results_df = pd.DataFrame()
for fold in range(10):
    fold += 1
    if fold < 10:
        # Get train, validation and test sets
        train_embeddings = feature_array[np.setdiff1d(train_index_dict[fold], test_index_dict[fold+1])]
        val_embeddings = feature_array[test_index_dict[fold+1]]
        test_embeddings = feature_array[test_index_dict[fold]]
        test_embeddings_8 = feature_array[test_index_dict_8[fold]]
        test_embeddings_45 = feature_array[test_index_dict_45[fold]]

        train_labels = mortality_array[np.setdiff1d(train_index_dict[fold], test_index_dict[fold+1])]
        val_labels = mortality_array[test_index_dict[fold+1]]
        test_labels = mortality_array[test_index_dict[fold]]
        test_labels_8 = mortality_array[test_index_dict_8[fold]]
        test_labels_45 = mortality_array[test_index_dict_45[fold]]    

        # Split masks
        train_padding_mask = mask[np.setdiff1d(train_index_dict[fold], test_index_dict[fold+1])]
        val_padding_mask = mask[test_index_dict[fold+1]]
        test_padding_mask = mask[test_index_dict[fold]]
        test_padding_mask_8 = mask[test_index_dict_8[fold]]
        test_padding_mask_45 = mask[test_index_dict_45[fold]]

    else: # Change here to fist cv split for final fold
        # Get train, validation and test sets
        train_embeddings = feature_array[np.setdiff1d(train_index_dict[fold], test_index_dict[1])]
        val_embeddings = feature_array[test_index_dict[1]]
        test_embeddings = feature_array[test_index_dict[fold]]
        test_embeddings_8 = feature_array[test_index_dict_8[fold]]
        test_embeddings_45 = feature_array[test_index_dict_45[fold]]

        train_labels = mortality_array[np.setdiff1d(train_index_dict[fold], test_index_dict[1])]
        val_labels = mortality_array[test_index_dict[1]]
        test_labels = mortality_array[test_index_dict[fold]]
        test_labels_8 = mortality_array[test_index_dict_8[fold]]
        test_labels_45 = mortality_array[test_index_dict_45[fold]]

        # Split masks
        train_padding_mask = mask[np.setdiff1d(train_index_dict[fold], test_index_dict[1])]
        val_padding_mask = mask[test_index_dict[1]]
        test_padding_mask = mask[test_index_dict[fold]]
        test_padding_mask_8 = mask[test_index_dict_8[fold]]
        test_padding_mask_45 = mask[test_index_dict_45[fold]]
    
    # Create datasets
    train_dataset = DiseaseDataset(train_embeddings.astype(np.float32), train_labels.astype(np.float32), train_padding_mask)
    val_dataset = DiseaseDataset(val_embeddings.astype(np.float32), val_labels.astype(np.float32), val_padding_mask)
    test_dataset = DiseaseDataset(test_embeddings.astype(np.float32), test_labels.astype(np.float32), test_padding_mask)
    test_dataset_8 = DiseaseDataset(test_embeddings_8.astype(np.float32), test_labels_8.astype(np.float32), test_padding_mask_8)
    test_dataset_45 = DiseaseDataset(test_embeddings_45.astype(np.float32), test_labels_45.astype(np.float32), test_padding_mask_45)

    # Define batch size
    batch_size = 512

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    test_loader_8 = DataLoader(test_dataset_8, batch_size=batch_size)
    test_loader_45 = DataLoader(test_dataset_45, batch_size=batch_size)

    # Create weight for loss
    unique, counts = np.unique(train_labels, return_counts=True)
    pos_weight = torch.Tensor([(counts[0] / counts[1])])

    # Define model
    model = SetTransformer(dim_input=128, num_outputs=1, dim_output=1, num_inds=32, dim_hidden=160, num_heads=4, ln=False)

    model = model.to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # Run
    best_valid_loss = float('inf')
    best_valid_auroc = 0
    num_epochs = 10

    for epoch in range(num_epochs):

        start_time = time.time()

        train_loss, train_auroc, train_predictions, train_labels_out, train_pma_final = train(model, train_loader, optimizer, criterion)
        valid_loss, valid_auroc, valid_predictions, valid_labels_out, valid_pma_final = evaluate(model, val_loader, criterion)
        
        if epoch % 1 == 0:
            print('Epoch:', epoch)
            print(f'train loss: {train_loss:.3f}')
            print(f'train AUROC: {train_auroc:.3f}')
            print(f'valid loss: {valid_loss:.3f}')
            print(f'valid AUROC: {valid_auroc:.3f}')
        
        end_time = time.time()

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss  
            print('BEST VALID LOSS:', best_valid_loss)

        if valid_auroc > best_valid_auroc:
            best_valid_auroc = valid_auroc
            print('BEST VALID AUROC:', best_valid_auroc)
            print('UPDATED BEST INTERMEDIATE MODEL')
            torch.save(model.state_dict(), f'intermediate_set_transformer_mortality.pt')

    # -----------------------------
    # Evaluate best model on test set
    # -----------------------------

    model.load_state_dict(torch.load(f'intermediate_set_transformer_mortality.pt'))

    test_loss, test_auroc, test_predictions, test_labels_out, test_pma_final = evaluate(model, test_loader, criterion)
    test_loss_8, test_auroc_8, test_predictions_8, test_labels_out_8, test_pma_final_8 = evaluate(model, test_loader_8, criterion)
    test_loss_45, test_auroc_45, test_predictions_45, test_labels_out_45, test_pma_final_45 = evaluate(model, test_loader_45, criterion)

    # del for memmory 
    del train_embeddings
    del val_embeddings
    del train_labels
    del val_labels
    del train_padding_mask
    del val_padding_mask
    del train_loader
    del val_loader
    del train_predictions
    del train_labels_out
    del train_pma_final
    del valid_predictions
    del valid_labels_out
    del valid_pma_final
    gc.collect()

    print(f'test predictions: {test_predictions.mean().item():.3f}') 
    print(f'test loss: {test_loss:.3f}')
    print(f'test AUROC: {test_auroc:.3f}')

    print(f'test predictions 8: {test_predictions_8.mean().item():.3f}') 
    print(f'test loss 8: {test_loss_8:.3f}')
    print(f'test AUROC 8: {test_auroc_8:.3f}')

    print(f'test predictions 45: {test_predictions_45.mean().item():.3f}') 
    print(f'test loss 45: {test_loss_45:.3f}')
    print(f'test AUROC 45: {test_auroc_45:.3f}')

    if test_auroc > best_test_auroc:
        best_test_auroc = test_auroc
        print('BEST TEST AUROC:', best_test_auroc)
        print('UPDATED BEST MODEL')
        torch.save(model.state_dict(), f'set_transformer_random_mortality.pt')

    # Get results
    # AUC
    aucroc = roc_auc_score(test_labels_out, test_predictions)
    # Accuracy
    accuracy = accuracy_score(test_labels_out, test_predictions.round())
    # Precision
    precision = precision_score(test_labels_out, test_predictions.round())
    # Recall
    recall = recall_score(test_labels_out, test_predictions.round())
    # AUPRC
    auprc = average_precision_score(test_labels_out, test_predictions)
    # F1
    f1 = f1_score(test_labels_out, test_predictions.round())
    # TPR and FPR
    cm = confusion_matrix(test_labels_out, test_predictions.round())
    _tp = cm[0, 0]
    _fn = cm[0, 1]
    _fp = cm[1, 0]
    _tn = cm[1, 1]
    tpr = _tp / (_tp + _fn)
    fpr = _fp / (_tn + _fp)

    new_row = {'data': 'set', 'fold': fold, 'subset':'overall', 'AUROC': aucroc, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'F1': f1, 'AUPRC': auprc, 'TPR': tpr, 'FPR': fpr}
    results_df = results_df.append(new_row, ignore_index=True)

    # Get results - rare
    # AUC
    try:
        aucroc = roc_auc_score(test_labels_out_8, test_predictions_8)
    except:
        aucroc = np.nan
    # Accuracy
    accuracy = accuracy_score(test_labels_out_8, test_predictions_8.round())
    # Precision
    precision = precision_score(test_labels_out_8, test_predictions_8.round())
    # Recall
    recall = recall_score(test_labels_out_8, test_predictions_8.round())
    # AUPRC
    try:
        auprc = average_precision_score(test_labels_out_8, test_predictions_8)
    except:
        auprc = np.nan
    # F1
    f1 = f1_score(test_labels_out_8, test_predictions_8.round())
    # TPR and FPR
    cm = confusion_matrix(test_labels_out_8, test_predictions_8.round())
    try:
        _tp = cm[0, 0]
    except:
        _tp = np.nan
    try:
        _fn = cm[0, 1]
    except:
        _fn = np.nan
    try:
        _fp = cm[1, 0]
    except:
        _fp = np.nan
    try:
        _tn = cm[1, 1]
    except:
        _tn = np.nan
    tpr = _tp / (_tp + _fn)
    fpr = _fp / (_tn + _fp)

    new_row = {'data': 'set', 'fold': fold, 'subset': 8, 'AUROC': aucroc, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'F1': f1, 'AUPRC': auprc, 'TPR': tpr, 'FPR': fpr}
    results_df = results_df.append(new_row, ignore_index=True)

    # Get results
    # AUC
    try:
        aucroc = roc_auc_score(test_labels_out_45, test_predictions_45)
    except:
        aucroc = np.nan
    # Accuracy
    accuracy = accuracy_score(test_labels_out_45, test_predictions_45.round())
    # Precision
    precision = precision_score(test_labels_out_45, test_predictions_45.round())
    # Recall
    recall = recall_score(test_labels_out_45, test_predictions_45.round())
    # AUPRC
    try:
        auprc = average_precision_score(test_labels_out_45, test_predictions_45)
    except:
        auprc = np.nan
    # F1
    f1 = f1_score(test_labels_out_45, test_predictions_45.round())
    # TPR and FPR
    cm = confusion_matrix(test_labels_out_45, test_predictions_45.round())
    try:
        _tp = cm[0, 0]
    except:
        _tp = np.nan
    try:
        _fn = cm[0, 1]
    except:
        _fn = np.nan
    try:
        _fp = cm[1, 0]
    except:
        _fp = np.nan
    try:
        _tn = cm[1, 1]
    except:
        _tn = np.nan
    tpr = _tp / (_tp + _fn)
    fpr = _fp / (_tn + _fp)

    new_row = {'data': 'set', 'fold': fold, 'subset': 45, 'AUROC': aucroc, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'F1': f1, 'AUPRC': auprc, 'TPR': tpr, 'FPR': fpr}
    results_df = results_df.append(new_row, ignore_index=True)


100%|██████████| 142/142 [01:39<00:00,  1.43it/s]
100%|██████████| 18/18 [00:04<00:00,  3.74it/s]
100%|██████████| 142/142 [01:35<00:00,  1.49it/s]
100%|██████████| 18/18 [00:04<00:00,  3.79it/s]
100%|██████████| 142/142 [01:35<00:00,  1.49it/s]
100%|██████████| 18/18 [00:04<00:00,  3.79it/s]
100%|██████████| 142/142 [01:37<00:00,  1.46it/s]
100%|██████████| 18/18 [00:04<00:00,  3.75it/s]
100%|██████████| 142/142 [01:35<00:00,  1.49it/s]
100%|██████████| 18/18 [00:04<00:00,  3.76it/s]
100%|██████████| 142/142 [01:34<00:00,  1.50it/s]
100%|██████████| 18/18 [00:04<00:00,  3.72it/s]
100%|██████████| 142/142 [01:34<00:00,  1.50it/s]
100%|██████████| 18/18 [00:04<00:00,  3.78it/s]
100%|██████████| 142/142 [01:47<00:00,  1.32it/s]
100%|██████████| 18/18 [00:04<00:00,  3.74it/s]
100%|██████████| 142/142 [01:34<00:00,  1.50it/s]
100%|██████████| 18/18 [00:04<00:00,  3.69it/s]
100%|██████████| 142/142 [01:34<00:00,  1.50it/s]
100%|██████████| 18/18 [00:04<00:00,  3.74it/s]


Epoch: 0
train loss: 1.177
train AUROC: 0.737
valid loss: 1.320
valid AUROC: 0.759
BEST VALID LOSS: 1.320160620742374
BEST VALID AUROC: 0.7593043014858653
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.186
train AUROC: 0.748
valid loss: 1.120
valid AUROC: 0.778
BEST VALID LOSS: 1.1204609904024336
BEST VALID AUROC: 0.7779186000842953
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.061
train AUROC: 0.797
valid loss: 1.071
valid AUROC: 0.800
BEST VALID LOSS: 1.0706020891666412
BEST VALID AUROC: 0.8004143444606068
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 1.029
train AUROC: 0.814
valid loss: 1.052
valid AUROC: 0.814
BEST VALID LOSS: 1.0522677169905768
BEST VALID AUROC: 0.8143879454784896
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 1.040
train AUROC: 0.808
valid loss: 1.064
valid AUROC: 0.814
Epoch: 5
train loss: 0.999
train AUROC: 0.827
valid loss: 1.059
valid AUROC: 0.821
BEST VALID AUROC: 0.8207281685390421
UPDATED BEST INTERMEDIATE MODEL
Epoch: 6
train

<All keys matched successfully>

100%|██████████| 18/18 [00:04<00:00,  3.62it/s]
100%|██████████| 1/1 [00:00<00:00, 18.29it/s]
100%|██████████| 2/2 [00:00<00:00,  5.92it/s]


0

test predictions: 0.131
test loss: 1.622
test AUROC: 0.801
test predictions 8: 0.148
test loss 8: 0.225
test AUROC 8: nan
test predictions 45: 0.122
test loss 45: 2.035
test AUROC 45: 0.615
BEST TEST AUROC: 0.801023281245589
UPDATED BEST MODEL
Epoch: 0
train loss: 1.171
train AUROC: 0.742
valid loss: 1.134
valid AUROC: 0.778
BEST VALID LOSS: 1.1337205900086298
BEST VALID AUROC: 0.7780094824755472
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.074
train AUROC: 0.795
valid loss: 1.075
valid AUROC: 0.813
BEST VALID LOSS: 1.0748276511828105
BEST VALID AUROC: 0.8130326339599018
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.080
train AUROC: 0.789
valid loss: 1.128
valid AUROC: 0.796
Epoch: 3
train loss: 1.044
train AUROC: 0.808
valid loss: 1.028
valid AUROC: 0.829
BEST VALID LOSS: 1.0280612177318997
BEST VALID AUROC: 0.829378092044299
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 1.022
train AUROC: 0.817
valid loss: 1.088
valid AUROC: 0.801
Epoch: 5
train loss: 1.035


  results_df = results_df.append(new_row, ignore_index=True)
  _warn_prf(average, modifier, msg_start, len(result))
  fpr = _fp / (_tn + _fp)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [01:42<00:00,  1.38it/s]
100%|██████████| 18/18 [00:04<00:00,  3.69it/s]
100%|██████████| 142/142 [01:35<00:00,  1.48it/s]
100%|██████████| 18/18 [00:04<00:00,  3.66it/s]
100%|██████████| 142/142 [01:36<00:00,  1.48it/s]
100%|██████████| 18/18 [00:04<00:00,  3.76it/s]
100%|██████████| 142/142 [01:35<00:00,  1.48it/s]
100%|██████████| 18/18 [00:04<00:00,  3.71it/s]
100%|██████████| 142/142 [01:34<00:00,  1.50it/s]
100%|██████████| 18/18 [00:04<00:00,  3.74it/s]
100%|██████████| 142/142 [01:35<00:00,  1.49it/s]
100%|██████████| 18/18 [00:04<00:00,  3.72it/s]
100%|██████████| 142/142 [01:35<00:00,  1.49it/s]
100%|██████████| 18/18 [00:04<00:00,  3.63it/s]
100%|██████████| 142/142 [01:34<00:00,  1.50it/s]


<All keys matched successfully>

100%|██████████| 18/18 [00:05<00:00,  3.59it/s]
100%|██████████| 1/1 [00:00<00:00, 20.87it/s]
100%|██████████| 2/2 [00:00<00:00,  6.70it/s]


0

test predictions: 0.274
test loss: 1.153
test AUROC: 0.809
test predictions 8: 0.315
test loss 8: 0.615
test AUROC 8: nan
test predictions 45: 0.282
test loss 45: 0.582
test AUROC 45: 0.838
BEST TEST AUROC: 0.8092792159647921
UPDATED BEST MODEL
Epoch: 0
train loss: 1.173
train AUROC: 0.739
valid loss: 1.130
valid AUROC: 0.764
BEST VALID LOSS: 1.1301849153306749
BEST VALID AUROC: 0.763811104316263
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.091
train AUROC: 0.783
valid loss: 1.095
valid AUROC: 0.791
BEST VALID LOSS: 1.0948341588179271
BEST VALID AUROC: 0.7914362719955811
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.039
train AUROC: 0.808
valid loss: 1.033
valid AUROC: 0.817
BEST VALID LOSS: 1.0332855979601543
BEST VALID AUROC: 0.8170088475499092
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 0.989
train AUROC: 0.828
valid loss: 1.026
valid AUROC: 0.829
BEST VALID LOSS: 1.0256448222531214
BEST VALID AUROC: 0.8291340349562062
UPDATED BEST INTERMEDIATE MODEL
Epoc

  results_df = results_df.append(new_row, ignore_index=True)
  _warn_prf(average, modifier, msg_start, len(result))
  fpr = _fp / (_tn + _fp)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [01:53<00:00,  1.25it/s]
100%|██████████| 18/18 [00:05<00:00,  3.13it/s]
100%|██████████| 142/142 [01:37<00:00,  1.45it/s]
100%|██████████| 18/18 [00:04<00:00,  3.72it/s]
100%|██████████| 142/142 [01:36<00:00,  1.48it/s]
100%|██████████| 18/18 [00:04<00:00,  3.69it/s]
100%|██████████| 142/142 [01:36<00:00,  1.47it/s]
100%|██████████| 18/18 [00:04<00:00,  3.75it/s]
100%|██████████| 142/142 [01:34<00:00,  1.50it/s]
100%|██████████| 18/18 [00:04<00:00,  3.71it/s]
100%|██████████| 142/142 [01:35<00:00,  1.49it/s]
100%|██████████| 18/18 [00:04<00:00,  3.70it/s]
100%|██████████| 142/142 [01:36<00:00,  1.47it/s]
100%|██████████| 18/18 [00:04<00:00,  3.78it/s]
100%|██████████| 142/142 [01:35<00:00,  1.48it/s]


<All keys matched successfully>

100%|██████████| 18/18 [00:04<00:00,  3.73it/s]
100%|██████████| 1/1 [00:00<00:00, 18.36it/s]
100%|██████████| 2/2 [00:00<00:00,  5.41it/s]


0

test predictions: 0.240
test loss: 0.985
test AUROC: 0.852
test predictions 8: 0.235
test loss 8: 0.373
test AUROC 8: 0.989
test predictions 45: 0.168
test loss 45: 0.686
test AUROC 45: 0.727
BEST TEST AUROC: 0.8522225078290104
UPDATED BEST MODEL
Epoch: 0
train loss: 1.190
train AUROC: 0.723
valid loss: 1.152
valid AUROC: 0.755
BEST VALID LOSS: 1.1524707012706332
BEST VALID AUROC: 0.7554889779598823
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.093
train AUROC: 0.786
valid loss: 1.144
valid AUROC: 0.758
BEST VALID LOSS: 1.1444272067811754
BEST VALID AUROC: 0.7583930864089904
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.057
train AUROC: 0.802
valid loss: 1.177
valid AUROC: 0.747
Epoch: 3
train loss: 1.062
train AUROC: 0.800
valid loss: 1.150
valid AUROC: 0.776
BEST VALID AUROC: 0.7763581400233002
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 1.008
train AUROC: 0.824
valid loss: 1.109
valid AUROC: 0.789
BEST VALID LOSS: 1.1085581448343065
BEST VALID AUROC: 0.789

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [01:41<00:00,  1.39it/s]
100%|██████████| 18/18 [00:04<00:00,  3.73it/s]
100%|██████████| 142/142 [01:36<00:00,  1.47it/s]
100%|██████████| 18/18 [00:04<00:00,  3.71it/s]
100%|██████████| 142/142 [01:34<00:00,  1.50it/s]
100%|██████████| 18/18 [00:04<00:00,  3.82it/s]
100%|██████████| 142/142 [01:35<00:00,  1.49it/s]
100%|██████████| 18/18 [00:04<00:00,  3.79it/s]
100%|██████████| 142/142 [01:34<00:00,  1.50it/s]
100%|██████████| 18/18 [00:04<00:00,  3.71it/s]
100%|██████████| 142/142 [01:35<00:00,  1.49it/s]
100%|██████████| 18/18 [00:04<00:00,  3.78it/s]
100%|██████████| 142/142 [01:34<00:00,  1.50it/s]
100%|██████████| 18/18 [00:04<00:00,  3.78it/s]
100%|██████████| 142/142 [01:35<00:00,  1.48it/s]
100%|██████████| 18/18 [00:04<00:00,  3.82it/s]
100%|██████████| 142/142 [01:36<0

<All keys matched successfully>

100%|██████████| 18/18 [00:04<00:00,  3.75it/s]
100%|██████████| 1/1 [00:00<00:00, 24.27it/s]
100%|██████████| 1/1 [00:00<00:00,  3.79it/s]


0

test predictions: 0.168
test loss: 1.256
test AUROC: 0.829
test predictions 8: 0.132
test loss 8: 2.185
test AUROC 8: 0.824
test predictions 45: 0.105
test loss 45: 1.916
test AUROC 45: 0.662
Epoch: 0
train loss: 1.190
train AUROC: 0.726
valid loss: 1.107
valid AUROC: 0.777
BEST VALID LOSS: 1.1072234610716503
BEST VALID AUROC: 0.7771112863471735
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.076
train AUROC: 0.790
valid loss: 1.308
valid AUROC: 0.775
Epoch: 2
train loss: 1.072
train AUROC: 0.800
valid loss: 1.016
valid AUROC: 0.820
BEST VALID LOSS: 1.0161157217290666
BEST VALID AUROC: 0.8199165471274935
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 1.053
train AUROC: 0.808
valid loss: 1.047
valid AUROC: 0.805
Epoch: 4
train loss: 1.016
train AUROC: 0.821
valid loss: 1.036
valid AUROC: 0.810
Epoch: 5
train loss: 0.992
train AUROC: 0.828
valid loss: 0.991
valid AUROC: 0.830
BEST VALID LOSS: 0.9908427364296384
BEST VALID AUROC: 0.8298434168918758
UPDATED BEST INTERMEDIATE M

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [01:39<00:00,  1.43it/s]
100%|██████████| 18/18 [00:04<00:00,  3.73it/s]
100%|██████████| 142/142 [01:35<00:00,  1.49it/s]
100%|██████████| 18/18 [00:04<00:00,  3.75it/s]
100%|██████████| 142/142 [01:35<00:00,  1.49it/s]
100%|██████████| 18/18 [00:04<00:00,  3.72it/s]
100%|██████████| 142/142 [01:35<00:00,  1.49it/s]
100%|██████████| 18/18 [00:04<00:00,  3.68it/s]
100%|██████████| 142/142 [01:35<00:00,  1.49it/s]
100%|██████████| 18/18 [00:04<00:00,  3.77it/s]
100%|██████████| 142/142 [01:34<00:00,  1.50it/s]
100%|██████████| 18/18 [00:04<00:00,  3.67it/s]
100%|██████████| 142/142 [01:35<00:00,  1.49it/s]
100%|██████████| 18/18 [00:04<00:00,  3.67it/s]
100%|██████████| 142/142 [01:35<00:00,  1.49it/s]
100%|██████████| 18/18 [00:04<00:00,  3.80it/s]
100%|██████████| 142/142 [01:34<0

<All keys matched successfully>

100%|██████████| 18/18 [00:04<00:00,  3.77it/s]
100%|██████████| 1/1 [00:00<00:00, 22.47it/s]
100%|██████████| 1/1 [00:00<00:00,  3.59it/s]


0

test predictions: 0.334
test loss: 1.280
test AUROC: 0.748
test predictions 8: 0.284
test loss 8: 2.694
test AUROC 8: 0.186
test predictions 45: 0.298
test loss 45: 2.338
test AUROC 45: 0.444
Epoch: 0
train loss: 1.183
train AUROC: 0.732
valid loss: 1.182
valid AUROC: 0.766
BEST VALID LOSS: 1.1819382674164243
BEST VALID AUROC: 0.7656001097646993
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.107
train AUROC: 0.781
valid loss: 1.188
valid AUROC: 0.766
BEST VALID AUROC: 0.765790563405955
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.095
train AUROC: 0.793
valid loss: 1.016
valid AUROC: 0.823
BEST VALID LOSS: 1.0158656338850658
BEST VALID AUROC: 0.8225137998399099
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 0.999
train AUROC: 0.827
valid loss: 1.014
valid AUROC: 0.835
BEST VALID LOSS: 1.0139889849556818
BEST VALID AUROC: 0.835475544662741
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 0.993
train AUROC: 0.830
valid loss: 0.998
valid AUROC: 0.834
BEST VALID 

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [01:36<00:00,  1.47it/s]
100%|██████████| 18/18 [00:04<00:00,  3.73it/s]
100%|██████████| 142/142 [01:34<00:00,  1.49it/s]
100%|██████████| 18/18 [00:04<00:00,  3.70it/s]
100%|██████████| 142/142 [01:34<00:00,  1.51it/s]
100%|██████████| 18/18 [00:04<00:00,  3.75it/s]
100%|██████████| 142/142 [01:34<00:00,  1.50it/s]
100%|██████████| 18/18 [00:04<00:00,  3.76it/s]
100%|██████████| 142/142 [01:35<00:00,  1.48it/s]
100%|██████████| 18/18 [00:05<00:00,  3.36it/s]
100%|██████████| 142/142 [01:39<00:00,  1.42it/s]
100%|██████████| 18/18 [00:05<00:00,  3.31it/s]
100%|██████████| 142/142 [01:41<00:00,  1.40it/s]
100%|██████████| 18/18 [00:04<00:00,  3.74it/s]
100%|██████████| 142/142 [01:35<00:00,  1.49it/s]
100%|██████████| 18/18 [00:04<00:00,  3.77it/s]
100%|██████████| 142/142 [01:34<0

<All keys matched successfully>

100%|██████████| 18/18 [00:04<00:00,  3.81it/s]
100%|██████████| 1/1 [00:00<00:00, 22.05it/s]
100%|██████████| 2/2 [00:00<00:00,  7.11it/s]


0

test predictions: 0.204
test loss: 1.373
test AUROC: 0.774
test predictions 8: 0.168
test loss 8: 2.646
test AUROC 8: 0.077
test predictions 45: 0.195
test loss 45: 1.778
test AUROC 45: 0.662
Epoch: 0
train loss: 1.162
train AUROC: 0.748
valid loss: 1.227
valid AUROC: 0.753
BEST VALID LOSS: 1.227254592710071
BEST VALID AUROC: 0.7531371603040125
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.100
train AUROC: 0.780
valid loss: 1.132
valid AUROC: 0.781
BEST VALID LOSS: 1.1323311295774248
BEST VALID AUROC: 0.7811199763827579
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.037
train AUROC: 0.809
valid loss: 1.116
valid AUROC: 0.790
BEST VALID LOSS: 1.1157179971536
BEST VALID AUROC: 0.7900921013005333
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 0.977
train AUROC: 0.834
valid loss: 1.239
valid AUROC: 0.783
Epoch: 4
train loss: 0.993
train AUROC: 0.829
valid loss: 1.077
valid AUROC: 0.800
BEST VALID LOSS: 1.0773493746916454
BEST VALID AUROC: 0.8003181492165892
UPDATED B

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [01:37<00:00,  1.46it/s]
100%|██████████| 18/18 [00:04<00:00,  3.72it/s]
100%|██████████| 142/142 [01:34<00:00,  1.50it/s]
100%|██████████| 18/18 [00:04<00:00,  3.75it/s]
100%|██████████| 142/142 [01:36<00:00,  1.47it/s]
100%|██████████| 18/18 [00:04<00:00,  3.72it/s]
100%|██████████| 142/142 [01:35<00:00,  1.48it/s]
100%|██████████| 18/18 [00:04<00:00,  3.76it/s]
100%|██████████| 142/142 [01:36<00:00,  1.47it/s]
100%|██████████| 18/18 [00:04<00:00,  3.73it/s]
100%|██████████| 142/142 [01:34<00:00,  1.50it/s]
100%|██████████| 18/18 [00:04<00:00,  3.77it/s]
100%|██████████| 142/142 [01:35<00:00,  1.49it/s]
100%|██████████| 18/18 [00:04<00:00,  3.68it/s]
100%|██████████| 142/142 [01:35<00:00,  1.49it/s]
100%|██████████| 18/18 [00:04<00:00,  3.73it/s]
100%|██████████| 142/142 [01:35<0

<All keys matched successfully>

100%|██████████| 18/18 [00:04<00:00,  3.72it/s]
100%|██████████| 1/1 [00:00<00:00, 21.79it/s]
100%|██████████| 2/2 [00:00<00:00,  6.78it/s]


0

test predictions: 0.393
test loss: 1.116
test AUROC: 0.829
test predictions 8: 0.397
test loss 8: 1.481
test AUROC 8: 0.685
test predictions 45: 0.377
test loss 45: 0.977
test AUROC 45: 0.571
Epoch: 0
train loss: 1.181
train AUROC: 0.735
valid loss: 1.128
valid AUROC: 0.777
BEST VALID LOSS: 1.1277635163731046
BEST VALID AUROC: 0.7773678662851979
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.094
train AUROC: 0.780
valid loss: 1.112
valid AUROC: 0.792
BEST VALID LOSS: 1.1120229628351
BEST VALID AUROC: 0.792026605061064
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.025
train AUROC: 0.815
valid loss: 1.056
valid AUROC: 0.807
BEST VALID LOSS: 1.0560295581817627
BEST VALID AUROC: 0.8069869986289319
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 0.979
train AUROC: 0.834
valid loss: 1.060
valid AUROC: 0.807
BEST VALID AUROC: 0.8071559054993303
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 1.015
train AUROC: 0.821
valid loss: 1.157
valid AUROC: 0.779
Epoch: 5
trai

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [01:41<00:00,  1.40it/s]
100%|██████████| 18/18 [00:04<00:00,  3.73it/s]
100%|██████████| 142/142 [01:37<00:00,  1.45it/s]
100%|██████████| 18/18 [00:04<00:00,  3.72it/s]
100%|██████████| 142/142 [01:45<00:00,  1.35it/s]
100%|██████████| 18/18 [00:05<00:00,  3.29it/s]
100%|██████████| 142/142 [01:46<00:00,  1.33it/s]
100%|██████████| 18/18 [00:05<00:00,  3.27it/s]
100%|██████████| 142/142 [01:48<00:00,  1.31it/s]
100%|██████████| 18/18 [00:05<00:00,  3.08it/s]
100%|██████████| 142/142 [01:46<00:00,  1.33it/s]
100%|██████████| 18/18 [00:05<00:00,  3.42it/s]
100%|██████████| 142/142 [01:48<00:00,  1.30it/s]
100%|██████████| 18/18 [00:05<00:00,  3.46it/s]
100%|██████████| 142/142 [01:44<00:00,  1.36it/s]
100%|██████████| 18/18 [00:05<00:00,  3.38it/s]
100%|██████████| 142/142 [01:45<0

<All keys matched successfully>

100%|██████████| 18/18 [00:05<00:00,  3.01it/s]
100%|██████████| 1/1 [00:00<00:00, 20.51it/s]
100%|██████████| 2/2 [00:00<00:00,  6.32it/s]


0

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [01:52<00:00,  1.26it/s]
100%|██████████| 18/18 [00:05<00:00,  3.49it/s]
100%|██████████| 142/142 [01:44<00:00,  1.36it/s]
100%|██████████| 18/18 [00:05<00:00,  3.53it/s]
100%|██████████| 142/142 [01:43<00:00,  1.37it/s]
100%|██████████| 18/18 [00:05<00:00,  3.39it/s]
100%|██████████| 142/142 [01:42<00:00,  1.38it/s]
100%|██████████| 18/18 [00:05<00:00,  3.43it/s]
100%|██████████| 142/142 [01:43<00:00,  1.37it/s]
100%|██████████| 18/18 [00:05<00:00,  3.44it/s]
100%|██████████| 142/142 [01:42<00:00,  1.38it/s]
100%|██████████| 18/18 [00:05<00:00,  3.53it/s]
100%|██████████| 142/142 [01:42<00:00,  1.39it/s]
100%|██████████| 18/18 [00:05<00:00,  3.52it/s]
100%|██████████| 142/142 [01:43<00:00,  1.37it/s]
100%|██████████| 18/18 [00:05<00:00,  3.40it/s]
100%|██████████| 142/142 [01:45<0

test predictions: 0.579
test loss: 1.470
test AUROC: 0.776
test predictions 8: 0.709
test loss 8: 2.169
test AUROC 8: 0.589
test predictions 45: 0.643
test loss 45: 2.114
test AUROC 45: 0.624
Epoch: 0
train loss: 1.178
train AUROC: 0.739
valid loss: 1.101
valid AUROC: 0.791
BEST VALID LOSS: 1.1008850965234969
BEST VALID AUROC: 0.7907820932960319
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.090
train AUROC: 0.787
valid loss: 1.081
valid AUROC: 0.798
BEST VALID LOSS: 1.0812691781255934
BEST VALID AUROC: 0.7982645190562614
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.111
train AUROC: 0.773
valid loss: 1.246
valid AUROC: 0.764
Epoch: 3
train loss: 1.136
train AUROC: 0.773
valid loss: 1.309
valid AUROC: 0.769
Epoch: 4
train loss: 1.119
train AUROC: 0.787
valid loss: 1.090
valid AUROC: 0.809
BEST VALID AUROC: 0.8090339418603729
UPDATED BEST INTERMEDIATE MODEL
Epoch: 5
train loss: 1.018
train AUROC: 0.818
valid loss: 1.025
valid AUROC: 0.830
BEST VALID LOSS: 1.024587879578

<All keys matched successfully>

100%|██████████| 18/18 [00:05<00:00,  3.46it/s]
100%|██████████| 1/1 [00:00<00:00, 25.92it/s]
100%|██████████| 1/1 [00:00<00:00,  3.44it/s]


0

test predictions: 0.205
test loss: 1.281
test AUROC: 0.799
test predictions 8: 0.191
test loss 8: 2.536
test AUROC 8: 0.544
test predictions 45: 0.174
test loss 45: 2.402
test AUROC 45: 0.534
Epoch: 0
train loss: 1.191
train AUROC: 0.724
valid loss: 1.114
valid AUROC: 0.794
BEST VALID LOSS: 1.1136738856633503
BEST VALID AUROC: 0.7936803802499439
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.089
train AUROC: 0.786
valid loss: 1.119
valid AUROC: 0.787
Epoch: 2
train loss: 1.075
train AUROC: 0.793
valid loss: 1.065
valid AUROC: 0.810
BEST VALID LOSS: 1.0649199055300818
BEST VALID AUROC: 0.8097772118471016
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 0.991
train AUROC: 0.829
valid loss: 1.040
valid AUROC: 0.822
BEST VALID LOSS: 1.0399571690294478
BEST VALID AUROC: 0.8222749026419288
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 0.969
train AUROC: 0.837
valid loss: 0.981
valid AUROC: 0.834
BEST VALID LOSS: 0.9814841515488095
BEST VALID AUROC: 0.8339809013254961
UPDAT

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 142/142 [01:52<00:00,  1.27it/s]
100%|██████████| 18/18 [00:05<00:00,  3.51it/s]
100%|██████████| 142/142 [01:43<00:00,  1.37it/s]
100%|██████████| 18/18 [00:05<00:00,  3.51it/s]
100%|██████████| 142/142 [01:42<00:00,  1.38it/s]
100%|██████████| 18/18 [00:05<00:00,  3.53it/s]
100%|██████████| 142/142 [01:41<00:00,  1.40it/s]
100%|██████████| 18/18 [00:05<00:00,  3.35it/s]
100%|██████████| 142/142 [01:42<00:00,  1.38it/s]
100%|██████████| 18/18 [00:05<00:00,  3.44it/s]
100%|██████████| 142/142 [01:42<00:00,  1.38it/s]
100%|██████████| 18/18 [00:05<00:00,  3.43it/s]
100%|██████████| 142/142 [01:44<00:00,  1.36it/s]
100%|██████████| 18/18 [00:05<00:00,  3.38it/s]
100%|██████████| 142/142 [01:43<00:00,  1.38it/s]
100%|██████████| 18/18 [00:05<00:00,  3.48it/s]
100%|██████████| 142/142 [01:42<0

<All keys matched successfully>

100%|██████████| 18/18 [00:05<00:00,  3.48it/s]
100%|██████████| 1/1 [00:00<00:00, 20.56it/s]
100%|██████████| 1/1 [00:00<00:00,  3.65it/s]


0

test predictions: 0.304
test loss: 1.024
test AUROC: 0.824
test predictions 8: 0.311
test loss 8: 0.513
test AUROC 8: nan
test predictions 45: 0.301
test loss 45: 1.235
test AUROC 45: 0.669


  results_df = results_df.append(new_row, ignore_index=True)
  _warn_prf(average, modifier, msg_start, len(result))
  fpr = _fp / (_tn + _fp)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)


In [17]:
results_df

Unnamed: 0,data,fold,subset,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
0,set,1,overall,0.801023,0.934336,0.122153,0.257642,0.16573,0.100178,0.951911,0.742358
1,set,1,8,,0.892473,0.0,0.0,0.0,-0.0,0.892473,
2,set,1,45,0.615291,0.934664,0.0,0.0,0.0,0.018356,0.944954,1.0
3,set,2,overall,0.809279,0.793832,0.083078,0.71179,0.148791,0.115004,0.795962,0.28821
4,set,2,8,,0.746667,0.0,0.0,0.0,-0.0,0.746667,
5,set,2,45,0.838454,0.766114,0.031008,0.666667,0.059259,0.041074,0.767225,0.333333
6,set,3,overall,0.852223,0.839155,0.10143,0.681223,0.17657,0.127085,0.843257,0.318777
7,set,3,8,0.98913,0.817204,0.055556,1.0,0.105263,0.5,0.815217,0.0
8,set,3,45,0.726943,0.904854,0.045455,0.222222,0.075472,0.093181,0.916996,0.777778
9,set,4,overall,0.828857,0.90681,0.135553,0.495652,0.212885,0.141549,0.917536,0.504348


In [18]:
results_df.groupby(['data', 'subset']).mean()
results_df.groupby(['data', 'subset']).std()

Unnamed: 0_level_0,Unnamed: 1_level_0,fold,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
data,subset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
set,8,5.5,0.556078,0.760136,0.041033,0.283333,0.058679,0.159266,0.76306,0.595238
set,45,5.5,0.634672,0.78602,0.038258,0.333016,0.060523,0.051705,0.793498,0.666984
set,overall,5.5,0.804258,0.772035,0.084724,0.613321,0.141723,0.108583,0.776162,0.386679


Unnamed: 0_level_0,Unnamed: 1_level_0,fold,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
data,subset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
set,8,3.02765,0.327496,0.20508,0.078267,0.416111,0.094579,0.22292,0.223687,0.449868
set,45,3.02765,0.107444,0.187253,0.040618,0.281479,0.052552,0.025499,0.195832,0.281479
set,overall,3.02765,0.031446,0.165882,0.030799,0.204433,0.041329,0.025596,0.174637,0.204433
