In [1]:
# Libraries
import numpy as np
import pandas as pd
pd.set_option('display.max_columns', 40)
pd.set_option('display.width', 2000)
from tqdm import tqdm
import math
import time

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import StratifiedKFold

from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import average_precision_score

import gc

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Import
path = r'data/los.csv'
los_labels = pd.read_csv(path)

# Group by and take average
los_labels = los_labels.groupby(['new_subject', 'SUBJECT', 'PROBLEM_DT_TM'])['los_days', 'los_hours', 'los_long'].mean()
los_labels = los_labels.reset_index()[['new_subject', 'los_days', 'los_hours', 'los_long']]

# Import
path = r'data/trimmed_patient_embedding_128d.csv'
patients_embeddings = pd.read_csv(path, index_col=0)

# Fill in nan
los_labels.los_hours.fillna(0, inplace=True)
los_labels.los_long.fillna(0, inplace=True)
# Merge
patients_embeddings = pd.merge(los_labels[['new_subject', 'los_hours', 'los_long']], patients_embeddings.reset_index()).set_index('new_subject', inplace=True)

# Import
path = r'data/final_problem_dummies.csv'
problem_dummies = pd.read_csv(path)

# Drop columns
problem_dummies.drop(columns=['SUBJECT', 'PROBLEM_DT_TM'], inplace=True)
# Remove prefix
problem_dummies.columns = problem_dummies.columns.str.strip('PROBLEM_')
# Merge
problem_dummies = pd.merge(los_labels[['new_subject', 'los_hours', 'los_long']], problem_dummies)
# Set index
problem_dummies.set_index('new_subject', inplace=True)

# Import
path = r'data/final_trimmed_snomed_embedding_128d.csv'
snomed_embeddings = pd.read_csv(path, index_col=0)

# Get list 
snomed_embeddings.index = snomed_embeddings.index.astype(str)
snomed_list = snomed_embeddings.index.tolist()
problem_dummies_list = problem_dummies.columns.tolist()
overlap_list = list(set(snomed_list) & set(problem_dummies_list))
overlap_list = ['los_hours', 'los_long'] + overlap_list

# Filter
problem_dummies = problem_dummies[overlap_list]


In [3]:
del los_labels
del patients_embeddings
del snomed_embeddings
del snomed_list
del problem_dummies_list
del overlap_list
gc.collect()

7

Statistical tests (ran after results generated)

In [2]:
set_results = [0.747882,
0.757391,
0.761086,
0.720488,
0.749033,
0.766235,
0.757610,
0.774033,
0.746641,
0.748128]
random_set_results = [0.745206,
0.755118,
0.705928,
0.722723,
0.705892,
0.748889,
0.742211,
0.748850,
0.742797,
0.740542]
lr_results = [0.475000,
0.730232,
0.694128,
0.725744,
0.723272,
0.738767,
0.678671,
0.736754,
0.723696,
0.720582]
charlson_results = [0.589538,
0.604296,
0.596816,
0.599770,
0.583316,
0.602164,
0.599521,
0.603749,
0.614322,
0.607869]

set_results_8 = [0.476526,
0.919048,
0.301136,
0.463675,
0.831276,
0.469136,
0.764706,
0.577778,
0.554430,
0.768382]
random_set_results_8 = [0.204225,
0.776190,
0.384470,
0.480769,
0.308642,
0.382716,
0.460784,
0.944444,
0.468354,
0.757353]
lr_results_8 = [0.955556,
0.786164,
0.636667,
0.989130,
0.500000,
0.488372,
0.487179,
0.500000,
0.456522,
0.475000]
charlson_results_8 = [0.607869,
0.500000,
0.484848,
0.487179,
0.493827,
0.493827,
0.500000,
0.488889,
0.587342,
0.492647]

In [4]:
from scipy.stats import shapiro 
from scipy.stats import kstest
from scipy.stats import normaltest
from scipy.stats import ttest_rel
from scipy.stats import wilcoxon

In [10]:
shapiro(set_results)
kstest(set_results, 'norm')
normaltest(set_results)

ShapiroResult(statistic=0.91277015209198, pvalue=0.3005659282207489)

KstestResult(statistic=0.7643877072403327, pvalue=1.291738076015592e-06)



NormaltestResult(statistic=4.611049210161704, pvalue=0.09970648040487014)

In [11]:
shapiro(random_set_results)
kstest(random_set_results, 'norm')
normaltest(random_set_results)

ShapiroResult(statistic=0.8209559321403503, pvalue=0.026025492697954178)

KstestResult(statistic=0.7598723485150173, pvalue=1.5918509884571782e-06)

NormaltestResult(statistic=2.335046864196942, pvalue=0.3111365386566636)

In [12]:
shapiro(lr_results)
kstest(lr_results, 'norm')
normaltest(lr_results)

ShapiroResult(statistic=0.5631038546562195, pvalue=2.015241261688061e-05)

KstestResult(statistic=0.6826065133587409, pvalue=3.751089511108325e-05)

NormaltestResult(statistic=24.531976093259708, pvalue=4.709221045991683e-06)

In [5]:
shapiro(charlson_results)
kstest(charlson_results, 'norm')
normaltest(charlson_results)

ShapiroResult(statistic=0.9653530120849609, pvalue=0.8447386026382446)

KstestResult(statistic=0.7201597032392104, pvalue=8.822979890802452e-06)



NormaltestResult(statistic=0.9814405359297835, pvalue=0.6121852978910722)

In [15]:
shapiro(set_results_8)
kstest(set_results_8, 'norm')
normaltest(set_results_8)

ShapiroResult(statistic=0.9421727657318115, pvalue=0.5774344801902771)

KstestResult(statistic=0.6183446048356221, pvalue=0.0003261227478991447)

NormaltestResult(statistic=0.7324600778071833, pvalue=0.6933432869336758)

In [16]:
shapiro(random_set_results_8)
kstest(random_set_results_8, 'norm')
normaltest(random_set_results_8)

ShapiroResult(statistic=0.9161323308944702, pvalue=0.3258223533630371)

KstestResult(statistic=0.580911162069865, pvalue=0.0009861795674738682)

NormaltestResult(statistic=1.1380770259852806, pvalue=0.5660694455470265)

In [17]:
shapiro(lr_results_8)
kstest(lr_results_8, 'norm')
normaltest(lr_results_8)

ShapiroResult(statistic=0.7687857151031494, pvalue=0.0060443882830441)

KstestResult(statistic=0.6759926728028078, pvalue=4.768463566879488e-05)

NormaltestResult(statistic=2.526126778016009, pvalue=0.2827864134695235)

In [6]:
shapiro(charlson_results_8)
kstest(charlson_results_8, 'norm')
normaltest(charlson_results_8)

ShapiroResult(statistic=0.6320642232894897, pvalue=0.00013367459177970886)

KstestResult(statistic=0.6861079145284064, pvalue=3.297805441868364e-05)



NormaltestResult(statistic=7.78817716890301, pvalue=0.020361924168909893)

In [7]:
ttest_rel(set_results, random_set_results) 
ttest_rel(set_results, lr_results) 
ttest_rel(lr_results, random_set_results) 
ttest_rel(random_set_results, lr_results) 

ttest_rel(charlson_results, set_results) 
ttest_rel(charlson_results, random_set_results) 
ttest_rel(charlson_results, lr_results) 

Ttest_relResult(statistic=2.829592962153314, pvalue=0.01973551454323084)

Ttest_relResult(statistic=2.327709444684237, pvalue=0.04491527408950964)

Ttest_relResult(statistic=-1.5644104361008246, pvalue=0.1521601324968958)

Ttest_relResult(statistic=1.5644104361008246, pvalue=0.1521601324968958)

Ttest_relResult(statistic=-29.598950578378286, pvalue=2.800428635087982e-10)

Ttest_relResult(statistic=-28.860055928422543, pvalue=3.5083266455280124e-10)

Ttest_relResult(statistic=-3.94128045770264, pvalue=0.0033993534478292046)

In [8]:
wilcoxon(set_results, random_set_results) 
wilcoxon(set_results, lr_results) 
wilcoxon(lr_results, random_set_results)

wilcoxon(charlson_results, set_results) 
wilcoxon(charlson_results, random_set_results) 
wilcoxon(charlson_results, lr_results) 

WilcoxonResult(statistic=1.0, pvalue=0.00390625)

WilcoxonResult(statistic=1.0, pvalue=0.00390625)

WilcoxonResult(statistic=6.0, pvalue=0.02734375)

WilcoxonResult(statistic=0.0, pvalue=0.001953125)

WilcoxonResult(statistic=0.0, pvalue=0.001953125)

WilcoxonResult(statistic=5.0, pvalue=0.01953125)

In [9]:
ttest_rel(set_results_8, random_set_results_8) 
ttest_rel(set_results_8, lr_results_8) 
ttest_rel(lr_results_8, random_set_results_8) 

ttest_rel(charlson_results_8, set_results_8) 
ttest_rel(charlson_results_8, random_set_results_8) 
ttest_rel(charlson_results_8, lr_results_8) 

Ttest_relResult(statistic=1.2545714022215622, pvalue=0.24123516463731395)

Ttest_relResult(statistic=-0.14667884025148195, pvalue=0.8866185689670323)

Ttest_relResult(statistic=1.005560171919287, pvalue=0.34089546463234344)

Ttest_relResult(statistic=-1.4867739148193868, pvalue=0.17124386193426686)

Ttest_relResult(statistic=-0.03895618664411773, pvalue=0.9697757731803747)

Ttest_relResult(statistic=-1.7847326104121752, pvalue=0.10796609718098465)

In [10]:
wilcoxon(set_results_8, random_set_results_8) 
wilcoxon(set_results_8, lr_results_8) 
wilcoxon(lr_results_8, random_set_results_8) 

wilcoxon(charlson_results_8, set_results_8) 
wilcoxon(charlson_results_8, random_set_results_8) 
wilcoxon(charlson_results_8, lr_results_8) 

WilcoxonResult(statistic=14.0, pvalue=0.193359375)

WilcoxonResult(statistic=27.0, pvalue=1.0)

WilcoxonResult(statistic=17.0, pvalue=0.322265625)

WilcoxonResult(statistic=17.0, pvalue=0.322265625)

WilcoxonResult(statistic=25.0, pvalue=0.845703125)

WilcoxonResult(statistic=16.0, pvalue=0.275390625)

# LR

In [4]:
# Drop those with no co-morbidities
drop_index_list = problem_dummies.iloc[:,2:][(problem_dummies.iloc[:,2:] == 0).all(axis=1)].index.to_list()
problem_dummies = problem_dummies[~problem_dummies.index.isin(drop_index_list)]

In [5]:
# Work out count for problems
problem_sum = pd.DataFrame(problem_dummies.iloc[:,2:].sum(axis=0), columns=['Count']).sort_values(by=['Count'], ascending=False)
# Define what is a rare disease
cut_off_list = [45, 8]

In [6]:
# Reset index
problem_dummies.reset_index(inplace=True, drop=True)

In [7]:
problem_dummies.los_long.round().value_counts()

0.0    53402
1.0     3293
Name: los_long, dtype: int64

In [None]:
# Get rare 10 fold cv results for one hot

results_df = pd.DataFrame()
x_data = problem_dummies.iloc[:,2:]
y_data = problem_dummies.los_long.round()
# Get CV folds
cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=2)
fold_n = 0
for train_index, test_index in cv.split(x_data, y_data):
    fold_n += 1
    x_train  = x_data[x_data.index.isin(list(train_index))]
    y_train  = y_data[y_data.index.isin(list(train_index))]
    x_test  = x_data[x_data.index.isin(list(test_index))]
    y_test  = y_data[y_data.index.isin(list(test_index))]

    # Fit
    LR = LogisticRegression(class_weight='balanced')
    LR.fit(x_train, y_train)

    # Get results for rare diseases 
    for n in cut_off_list:
        # Get filter list 
        filter_list = problem_sum[problem_sum['Count'] > n].index.tolist()
        x_test2 = x_test.copy()
        for code in filter_list:
            x_test2 = x_test2.loc[x_test2[code] != 1]
        x_test_list = x_test2.index.tolist()
        y_test2 = y_test.loc[x_test_list]

        # AUC
        try:
            aucroc = roc_auc_score(y_test2, LR.predict(x_test2))
        except:
            aucroc = np.nan
        # Accuracy
        accuracy = accuracy_score(y_test2, LR.predict(x_test2))
        # Precision
        precision = precision_score(y_test2, LR.predict(x_test2))
        # Recall
        recall = precision_score(y_test2, LR.predict(x_test2))
        # AUPRC
        try:
            auprc = average_precision_score(y_test2, LR.predict(x_test2))
        except:
            auprc = np.nan
        # F1
        f1 = f1_score(y_test2, LR.predict(x_test2))
        # TPR and FPR
        cm = confusion_matrix(y_test2, LR.predict(x_test2))
        try:
            _tp = cm[0, 0]
        except:
            _tp = np.nan
        try:
            _fn = cm[0, 1]
        except:
            _fn = np.nan
        try:
            _fp = cm[1, 0]
        except:
            _fp = np.nan
        try:
            _tn = cm[1, 1]
        except:
            _tn = np.nan
        tpr = _tp / (_tp + _fn)
        fpr = _fp / (_tn + _fp)

        new_row = {'data': 'dummies', 'fold': fold_n, 'subset': n, 'AUROC': aucroc, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'F1': f1, 'AUPRC': auprc, 'TPR': tpr, 'FPR': fpr}
        results_df = results_df.append(new_row, ignore_index=True)

In [11]:
results_df.groupby(['data', 'subset']).mean()
results_df.groupby(['data', 'subset']).std()

Unnamed: 0_level_0,Unnamed: 1_level_0,fold,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
data,subset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
dummies,8,5.5,0.550889,0.881187,0.070742,0.070742,0.101667,0.060143,0.910112,0.808333
dummies,45,5.5,0.647394,0.822181,0.102456,0.102456,0.165464,0.069421,0.836882,0.542094


Unnamed: 0_level_0,Unnamed: 1_level_0,fold,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
data,subset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
dummies,8,3.02765,0.108442,0.023184,0.074888,0.074888,0.108994,0.029641,0.026671,0.22923
dummies,45,3.02765,0.043717,0.021077,0.036339,0.036339,0.050315,0.025553,0.02102,0.088371


In [None]:
# Get 10 fold cv results for one hot

results_df = pd.DataFrame()
x_data = problem_dummies.iloc[:,2:]
y_data = problem_dummies.los_long.round()
# Get CV folds
cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=2)
fold_n = 0
for train_index, test_index in cv.split(x_data, y_data):
    fold_n += 1
    x_train  = x_data[x_data.index.isin(list(train_index))]
    y_train  = y_data[y_data.index.isin(list(train_index))]
    x_test  = x_data[x_data.index.isin(list(test_index))]
    y_test  = y_data[y_data.index.isin(list(test_index))]
    # Split data

    # Fit
    LR = LogisticRegression(class_weight='balanced')
    LR.fit(x_train, y_train)

    # AUC
    aucroc = roc_auc_score(y_test, LR.predict(x_test))
    # Accuracy
    accuracy = accuracy_score(y_test, LR.predict(x_test))
    # Precision
    precision = precision_score(y_test, LR.predict(x_test))
    # Recall
    recall = recall_score(y_test, LR.predict(x_test))
    # AUPRC
    auprc = average_precision_score(y_test, LR.predict(x_test))
    # F1
    f1 = f1_score(y_test, LR.predict(x_test))
    # TPR and FPR
    cm = confusion_matrix(y_test, LR.predict(x_test))
    _tp = cm[0, 0]
    _fn = cm[0, 1]
    _fp = cm[1, 0]
    _tn = cm[1, 1]
    tpr = _tp / (_tp + _fn)
    fpr = _fp / (_tn + _fp)

    new_row = {'data': 'dummies', 'fold': fold_n, 'subset':'overall', 'AUROC': aucroc, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'F1': f1, 'AUPRC': auprc, 'TPR': tpr, 'FPR': fpr}
    results_df = results_df.append(new_row, ignore_index=True)

In [13]:
results_df.groupby(['data', 'subset']).mean()
results_df.groupby(['data', 'subset']).std()

Unnamed: 0_level_0,Unnamed: 1_level_0,fold,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
data,subset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
dummies,overall,5.5,0.716109,0.730629,0.138937,0.699679,0.231824,0.114717,0.732538,0.300321


Unnamed: 0_level_0,Unnamed: 1_level_0,fold,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
data,subset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
dummies,overall,3.02765,0.01094,0.0063,0.004233,0.021951,0.006829,0.004451,0.006775,0.021951


# Set transformer

In [4]:
# Drop those with no co-morbidities
drop_index_list = problem_dummies.iloc[:,2:][(problem_dummies.iloc[:,2:] == 0).all(axis=1)].index.to_list()
problem_dummies = problem_dummies[~problem_dummies.index.isin(drop_index_list)]

In [5]:
# Work out count for problems
problem_sum = pd.DataFrame(problem_dummies.iloc[:,2:].sum(axis=0), columns=['Count']).sort_values(by=['Count'], ascending=False)
# Define what is a rare disease
cut_off_list = [45, 8]

In [6]:
# Get indexes for rare diseases 
filter_list_45 = problem_sum[problem_sum['Count'] > 45].index.tolist()
filter_list_8 = problem_sum[problem_sum['Count'] > 8].index.tolist()

problem_dummies_45 = problem_dummies.iloc[:,2:].reset_index().copy()
problem_dummies_8 = problem_dummies.iloc[:,2:].reset_index().copy()
print('Working on 45...')
for code in filter_list_45:
    problem_dummies_45 = problem_dummies_45.loc[problem_dummies_45[code] != 1]
    index_list_45 = problem_dummies_45.index.tolist()
print('Working on 8...')
for code in filter_list_8:
    problem_dummies_8 = problem_dummies_8.loc[problem_dummies_8[code] != 1]
    index_list_8 = problem_dummies_8.index.tolist()

Working on 45...
Working on 8...


In [7]:
# Set transformer initial setup

# Import
path = r'data/final_trimmed_snomed_embedding_128d.csv'
snomed_embedding = pd.read_csv(path)

#random_bool = True
random_bool = False
if random_bool == True:
    # Get random embeddings for each disease
    snomed_embedding.set_index(['snomed_code'], inplace=True)
    np.random.seed(0)
    random_embedding = pd.DataFrame(np.random.default_rng(seed=0).uniform(low=snomed_embedding.min().min(),high=snomed_embedding.max().max(),size=[len(snomed_embedding),len(snomed_embedding.columns)]))
    random_embedding.index = snomed_embedding.index
    random_embedding.columns = snomed_embedding.columns
    random_embedding.reset_index(inplace=True)
    snomed_embedding = random_embedding

# Create dfs
patient_los = problem_dummies.iloc[:,:2]
patient_df = problem_dummies.iloc[:,2:]
# Str
patient_df.columns = patient_df.columns.astype(str)
snomed_embedding['snomed_code'] = snomed_embedding['snomed_code'].astype(str)
# Filter
snomed_embedding = snomed_embedding[snomed_embedding['snomed_code'].isin(patient_df.columns.tolist())]
snomed_embedding.set_index('snomed_code', inplace=True)

# Get lengths of each patients co-morbidities
comorbidity_len = np.array(patient_df.sum(axis=1))

# Add padding embedding 
padding_df = pd.DataFrame(np.random.choice([0], size=len(snomed_embedding.columns))) # Changed to 0
padding_df = padding_df.T
padding_df.index = ['9999999999']
padding_df.columns = snomed_embedding.columns
snomed_embedding2 = pd.concat([snomed_embedding, padding_df])
snomed_embedding2.index = snomed_embedding2.index.astype(str)

# Get max number of co-morbidities
max_len = int(patient_df.sum(axis=1).max())

# Format patients embeddings into set and pad / create array
feature_array = np.zeros(shape=(len(patient_df), max_len , 128))
n = -1
for index, row in patient_df.iterrows():
    n += 1
    n2 = -1
    code_list = row[row == 1].index.tolist()
    while len(code_list) < max_len:
        code_list.append('9999999999')
    for code in code_list:
        n2 += 1
        feature_array[n, n2] = np.array(snomed_embedding2.loc[code])

# Get array for death
los_array = np.array(patient_los['los_long'].round())
los_array = los_array.squeeze()

# Create mask tensor based on lengths
comorbidity_len2 = torch.as_tensor(comorbidity_len, dtype=torch.long)
mask = torch.arange(max_len)[None, :] < comorbidity_len2[:, None]

# del so more memory
del problem_dummies
del snomed_embedding
del patient_los
del snomed_embedding2
del padding_df
del row
del comorbidity_len
del comorbidity_len2
gc.collect()

0

In [8]:
# Custom dataset class
class DiseaseDataset(Dataset):
    def __init__(self, disease_embeddings, mortality_labels, padding_mask):
        self.disease_embeddings = disease_embeddings
        self.mortality_labels = mortality_labels
        self.padding_mask = padding_mask

    def __len__(self):
        return len(self.mortality_labels)

    def __getitem__(self, idx):
        return self.disease_embeddings[idx], self.mortality_labels[idx], self.padding_mask[idx] 

In [9]:
class SetTransformer(nn.Module):
    def __init__(self, dim_input, num_outputs, dim_output,
            num_inds=36, dim_hidden=160, num_heads=4, ln=False):
        super(SetTransformer, self).__init__()
        self.enc = nn.Sequential(
                ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln),
                ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
        self.isab = ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln)
        self.pma = PMA(dim_hidden, num_heads, num_outputs, ln=ln)
        self.dec = nn.Sequential(
                #SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
                #SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
                nn.Linear(dim_hidden, dim_output))

    def forward(self, X, batch_mask):
        x = self.isab(X, batch_mask)
        x = self.pma(x, batch_mask)
        return self.dec(x), x

class MAB0(nn.Module):
    def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
        super(MAB0, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)
        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)
        self.fc_o = nn.Linear(dim_V, dim_V)

    def forward(self, Q, K, mask):
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)

        dim_split = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), 0)
        K_ = torch.cat(K.split(dim_split, 2), 0)
        V_ = torch.cat(V.split(dim_split, 2), 0)

        # Create new variable for softmax
        WB_ = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V)

        # Exspand mask dimensions to align
        mask = mask.unsqueeze(1).repeat(self.num_heads, Q.shape[1], 1)
        # Mask for softmax
        WB_[~mask] = float('-inf')

        A = torch.softmax(WB_, 2)
        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
        O = O + F.relu(self.fc_o(O))
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        return O

class MAB(nn.Module):
    def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
        super(MAB, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)
        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)
        self.fc_o = nn.Linear(dim_V, dim_V)

    def forward(self, Q, K):
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)

        dim_split = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), 0)
        K_ = torch.cat(K.split(dim_split, 2), 0)
        V_ = torch.cat(V.split(dim_split, 2), 0)

        A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
        O = O + F.relu(self.fc_o(O))
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        #print('MAB', O)
        return O

class SAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, ln=False):
        super(SAB, self).__init__()
        self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)

    def forward(self, X):
        return self.mab(X, X)

class ISAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
        super(ISAB, self).__init__()
        self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
        nn.init.xavier_uniform_(self.I)
        self.mab0 = MAB0(dim_out, dim_in, dim_out, num_heads, ln=ln)
        self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)

    def forward(self, X, mask):
        H = self.mab0(self.I.repeat(X.size(0), 1, 1), X, mask)
        return self.mab1(X, H)

class PMA(nn.Module):
    def __init__(self, dim, num_heads, num_seeds, ln=False):
        super(PMA, self).__init__()
        self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
        nn.init.xavier_uniform_(self.S)
        self.mab = MAB0(dim, dim, dim, num_heads, ln=ln)

    def forward(self, X, mask):
        return self.mab(self.S.repeat(X.size(0), 1, 1), X, mask)

In [10]:
# Define how long an epoch takes
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

# Train function 
def train(model, dataloader, optimizer, criterion):
    model.train()
    epoch_loss = 0
    batch_prediction_list = []
    batch_label_list = []
    
    for i, (batch_embeddings, batch_labels, batch_mask) in enumerate(tqdm(dataloader, 0)):
        batch_labels = batch_labels.to(device)
        batch_embeddings = batch_embeddings.to(device)
        batch_mask = batch_mask.to(device)

        optimizer.zero_grad()
        logits, pma = model(batch_embeddings, batch_mask)
        logits = logits.squeeze(-1) # squeeze to remove extra dimensions
        batch_labels = batch_labels.unsqueeze(1)
        
        loss = criterion(logits, batch_labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

        sig = torch.nn.Sigmoid()
        output = sig(logits)      
        np_predictions = output.cpu().detach().numpy()
        np_labels = batch_labels.cpu().detach().numpy()
        np_pma = pma.cpu().detach().numpy()

        np_predictions = np_predictions.squeeze()
        np_labels = np_labels.squeeze()
        np_pma = np_pma.squeeze()

        np_predictions = np_predictions.flatten()
        np_labels = np_labels.flatten()
        
        # Create list
        for x in np_predictions:
            batch_prediction_list.append(x)
        for x in np_labels:
            batch_label_list.append(x)
        if i == 0:
            np_pma_final = np_pma
        else:
            np_pma_final = np.vstack([np_pma_final, np_pma])

    final_predictions = np.array(batch_prediction_list)

    final_labels = np.array(batch_label_list)

    try:
        auroc = roc_auc_score(final_labels, final_predictions)
    except:
        auroc = np.nan
    
    try:
        final_loss = epoch_loss / len(dataloader)
    except:
        final_loss = np.nan

    return final_loss, auroc, final_predictions, final_labels, np_pma_final

# Eval function
def evaluate(model, dataloader, criterion):

    # Set the model to evaluation mode
    model.eval()
    epoch_loss = 0
    batch_prediction_list = []
    batch_label_list = []

    with torch.no_grad():
        for i, (batch_embeddings, batch_labels, batch_mask) in enumerate(tqdm(dataloader, 0)):
            batch_labels = batch_labels.to(device)
            batch_embeddings = batch_embeddings.to(device)
            batch_mask = batch_mask.to(device)

            logits, pma = model(batch_embeddings, batch_mask)
            logits = logits.squeeze() # squeeze to remove extra dimensions

            if len(logits.size()) == 0: # Need to add this so tensor with just one element does not cause error
                logits = logits.unsqueeze(dim=0)

            loss = criterion(logits, batch_labels)

            epoch_loss += loss.item()

            sig = torch.nn.Sigmoid()
            output = sig(logits)      
            np_predictions = output.cpu().detach().numpy()
            np_labels = batch_labels.cpu().detach().numpy()
            np_pma = pma.cpu().detach().numpy()

            np_predictions = np_predictions.squeeze()
            np_labels = np_labels.squeeze()

            np_predictions = np_predictions.flatten()
            np_labels = np_labels.flatten()
            np_pma = np_pma.squeeze()
            
            # Create list
            for x in np_predictions:
                batch_prediction_list.append(x)
            for x in np_labels:
                batch_label_list.append(x)
            if i == 0:
                np_pma_final = np_pma
            else:
                np_pma_final = np.vstack([np_pma_final, np_pma])

        final_predictions = np.array(batch_prediction_list)

        final_labels = np.array(batch_label_list)

        try:
            auroc = roc_auc_score(final_labels, final_predictions)
        except:
            auroc = np.nan
        
        try:
            final_loss = epoch_loss / len(dataloader)
        except:
            final_loss = np.nan

    return final_loss, auroc, final_predictions, final_labels, np_pma_final

In [11]:
# CV for set transformer

# Get CV folds and create dictionarys 
cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=2)
fold_n = 0
train_index_dict = {}
test_index_dict = {}
for train_index, test_index in cv.split(feature_array, los_array):
    fold_n += 1
    train_index_dict[fold_n] = train_index
    test_index_dict[fold_n] = test_index

In [12]:
# del for memmory 
del test_index
del train_index
del patient_df
del index
del fold_n
del code
del code_list
del path
del n 
del n2
del max_len
gc.collect()

0

In [30]:
np.max(feature_array)
np.min(feature_array)

1.2494696

-1.530874

In [15]:
# Check if any have no co-morbidities - this was causing nan outputs
for n in range(mask.shape[0]):
    if torch.all(mask[n] == False):
        print(n)

In [12]:
np.isnan(feature_array).any()
np.isinf(feature_array).any()

False

False

In [13]:
# CV for set transformer

# Get CV folds and create dictionarys 
cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=2)
fold_n = 0
train_index_dict = {}
test_index_dict = {}
test_index_dict_8 = {}
test_index_dict_45 = {}

for train_index, test_index in cv.split(feature_array, los_array):
    fold_n += 1
    train_index_dict[fold_n] = train_index
    test_index_dict[fold_n] = test_index
    # Filter for rare diseases in test set
    test_index_8 = np.array(list(set(test_index.tolist()) & set(index_list_8)))
    test_index_45 = np.array(list(set(test_index.tolist()) & set(index_list_45)))
    test_index_dict_8[fold_n] = test_index_8
    test_index_dict_45[fold_n] = test_index_45

In [15]:
# Normal (not random) - all and rare

# Run
best_test_auroc = 0
results_df = pd.DataFrame()
for fold in range(10):
    fold += 1
    if fold < 10:
        # Get train, validation and test sets
        train_embeddings = feature_array[np.setdiff1d(train_index_dict[fold], test_index_dict[fold+1])]
        val_embeddings = feature_array[test_index_dict[fold+1]]
        test_embeddings = feature_array[test_index_dict[fold]]
        test_embeddings_8 = feature_array[test_index_dict_8[fold]]
        test_embeddings_45 = feature_array[test_index_dict_45[fold]]

        train_labels = los_array[np.setdiff1d(train_index_dict[fold], test_index_dict[fold+1])]
        val_labels = los_array[test_index_dict[fold+1]]
        test_labels = los_array[test_index_dict[fold]]
        test_labels_8 = los_array[test_index_dict_8[fold]]
        test_labels_45 = los_array[test_index_dict_45[fold]]    

        # Split masks
        train_padding_mask = mask[np.setdiff1d(train_index_dict[fold], test_index_dict[fold+1])]
        val_padding_mask = mask[test_index_dict[fold+1]]
        test_padding_mask = mask[test_index_dict[fold]]
        test_padding_mask_8 = mask[test_index_dict_8[fold]]
        test_padding_mask_45 = mask[test_index_dict_45[fold]]

    else: # Change here to fist cv split for final fold
        # Get train, validation and test sets
        train_embeddings = feature_array[np.setdiff1d(train_index_dict[fold], test_index_dict[1])]
        val_embeddings = feature_array[test_index_dict[1]]
        test_embeddings = feature_array[test_index_dict[fold]]
        test_embeddings_8 = feature_array[test_index_dict_8[fold]]
        test_embeddings_45 = feature_array[test_index_dict_45[fold]]

        train_labels = los_array[np.setdiff1d(train_index_dict[fold], test_index_dict[1])]
        val_labels = los_array[test_index_dict[1]]
        test_labels = los_array[test_index_dict[fold]]
        test_labels_8 = los_array[test_index_dict_8[fold]]
        test_labels_45 = los_array[test_index_dict_45[fold]]

        # Split masks
        train_padding_mask = mask[np.setdiff1d(train_index_dict[fold], test_index_dict[1])]
        val_padding_mask = mask[test_index_dict[1]]
        test_padding_mask = mask[test_index_dict[fold]]
        test_padding_mask_8 = mask[test_index_dict_8[fold]]
        test_padding_mask_45 = mask[test_index_dict_45[fold]]
    
    # Create datasets
    train_dataset = DiseaseDataset(train_embeddings.astype(np.float32), train_labels.astype(np.float32), train_padding_mask)
    val_dataset = DiseaseDataset(val_embeddings.astype(np.float32), val_labels.astype(np.float32), val_padding_mask)
    test_dataset = DiseaseDataset(test_embeddings.astype(np.float32), test_labels.astype(np.float32), test_padding_mask)
    test_dataset_8 = DiseaseDataset(test_embeddings_8.astype(np.float32), test_labels_8.astype(np.float32), test_padding_mask_8)
    test_dataset_45 = DiseaseDataset(test_embeddings_45.astype(np.float32), test_labels_45.astype(np.float32), test_padding_mask_45)

    # Define batch size
    batch_size = 512

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    test_loader_8 = DataLoader(test_dataset_8, batch_size=batch_size)
    test_loader_45 = DataLoader(test_dataset_45, batch_size=batch_size)

    # Create weight for loss
    unique, counts = np.unique(train_labels, return_counts=True)
    pos_weight = torch.Tensor([(counts[0] / counts[1])])

    # Define model
    model = SetTransformer(dim_input=128, num_outputs=1, dim_output=1, num_inds=32, dim_hidden=160, num_heads=4, ln=False)

    model = model.to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # Run
    best_valid_loss = float('inf')
    best_valid_auroc = 0
    num_epochs = 10

    for epoch in range(num_epochs):

        start_time = time.time()

        train_loss, train_auroc, train_predictions, train_labels_out, train_pma_final = train(model, train_loader, optimizer, criterion)
        valid_loss, valid_auroc, valid_predictions, valid_labels_out, valid_pma_final = evaluate(model, val_loader, criterion)
        
        if epoch % 1 == 0:
            print('Epoch:', epoch)
            print(f'train loss: {train_loss:.3f}')
            print(f'train AUROC: {train_auroc:.3f}')
            print(f'valid loss: {valid_loss:.3f}')
            print(f'valid AUROC: {valid_auroc:.3f}')
        
        end_time = time.time()

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss  
            print('BEST VALID LOSS:', best_valid_loss)

        if valid_auroc > best_valid_auroc:
            best_valid_auroc = valid_auroc
            print('BEST VALID AUROC:', best_valid_auroc)
            print('UPDATED BEST INTERMEDIATE MODEL')
            torch.save(model.state_dict(), f'intermediate_set_transformer_long_los.pt')

    # -----------------------------
    # Evaluate best model on test set
    # -----------------------------

    model.load_state_dict(torch.load(f'intermediate_set_transformer_long_los.pt'))

    test_loss, test_auroc, test_predictions, test_labels_out, test_pma_final = evaluate(model, test_loader, criterion)
    test_loss_8, test_auroc_8, test_predictions_8, test_labels_out_8, test_pma_final_8 = evaluate(model, test_loader_8, criterion)
    test_loss_45, test_auroc_45, test_predictions_45, test_labels_out_45, test_pma_final_45 = evaluate(model, test_loader_45, criterion)

    # del for memmory 
    del train_embeddings
    del val_embeddings
    del train_labels
    del val_labels
    del train_padding_mask
    del val_padding_mask
    del train_loader
    del val_loader
    del train_predictions
    del train_labels_out
    del train_pma_final
    del valid_predictions
    del valid_labels_out
    del valid_pma_final
    gc.collect()

    print(f'test predictions: {test_predictions.mean().item():.3f}') 
    print(f'test loss: {test_loss:.3f}')
    print(f'test AUROC: {test_auroc:.3f}')

    print(f'test predictions 8: {test_predictions_8.mean().item():.3f}') 
    print(f'test loss 8: {test_loss_8:.3f}')
    print(f'test AUROC 8: {test_auroc_8:.3f}')

    print(f'test predictions 45: {test_predictions_45.mean().item():.3f}') 
    print(f'test loss 45: {test_loss_45:.3f}')
    print(f'test AUROC 45: {test_auroc_45:.3f}')

    if test_auroc > best_test_auroc:
        best_test_auroc = test_auroc
        print('BEST TEST AUROC:', best_test_auroc)
        print('UPDATED BEST MODEL')
        torch.save(model.state_dict(), f'set_transformer_random_long_los.pt')

    # Get results
    # AUC
    aucroc = roc_auc_score(test_labels_out, test_predictions)
    # Accuracy
    accuracy = accuracy_score(test_labels_out, test_predictions.round())
    # Precision
    precision = precision_score(test_labels_out, test_predictions.round())
    # Recall
    recall = recall_score(test_labels_out, test_predictions.round())
    # AUPRC
    auprc = average_precision_score(test_labels_out, test_predictions)
    # F1
    f1 = f1_score(test_labels_out, test_predictions.round())
    # TPR and FPR
    cm = confusion_matrix(test_labels_out, test_predictions.round())
    _tp = cm[0, 0]
    _fn = cm[0, 1]
    _fp = cm[1, 0]
    _tn = cm[1, 1]
    tpr = _tp / (_tp + _fn)
    fpr = _fp / (_tn + _fp)

    new_row = {'data': 'set', 'fold': fold, 'subset':'overall', 'AUROC': aucroc, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'F1': f1, 'AUPRC': auprc, 'TPR': tpr, 'FPR': fpr}
    results_df = results_df.append(new_row, ignore_index=True)

    # Get results - rare
    # AUC
    try:
        aucroc = roc_auc_score(test_labels_out_8, test_predictions_8)
    except:
        aucroc = np.nan
    # Accuracy
    accuracy = accuracy_score(test_labels_out_8, test_predictions_8.round())
    # Precision
    precision = precision_score(test_labels_out_8, test_predictions_8.round())
    # Recall
    recall = recall_score(test_labels_out_8, test_predictions_8.round())
    # AUPRC
    try:
        auprc = average_precision_score(test_labels_out_8, test_predictions_8)
    except:
        auprc = np.nan
    # F1
    f1 = f1_score(test_labels_out_8, test_predictions_8.round())
    # TPR and FPR
    cm = confusion_matrix(test_labels_out_8, test_predictions_8.round())
    try:
        _tp = cm[0, 0]
    except:
        _tp = np.nan
    try:
        _fn = cm[0, 1]
    except:
        _fn = np.nan
    try:
        _fp = cm[1, 0]
    except:
        _fp = np.nan
    try:
        _tn = cm[1, 1]
    except:
        _tn = np.nan
    tpr = _tp / (_tp + _fn)
    fpr = _fp / (_tn + _fp)

    new_row = {'data': 'set', 'fold': fold, 'subset': 8, 'AUROC': aucroc, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'F1': f1, 'AUPRC': auprc, 'TPR': tpr, 'FPR': fpr}
    results_df = results_df.append(new_row, ignore_index=True)

    # Get results
    # AUC
    try:
        aucroc = roc_auc_score(test_labels_out_45, test_predictions_45)
    except:
        aucroc = np.nan
    # Accuracy
    accuracy = accuracy_score(test_labels_out_45, test_predictions_45.round())
    # Precision
    precision = precision_score(test_labels_out_45, test_predictions_45.round())
    # Recall
    recall = recall_score(test_labels_out_45, test_predictions_45.round())
    # AUPRC
    try:
        auprc = average_precision_score(test_labels_out_45, test_predictions_45)
    except:
        auprc = np.nan
    # F1
    f1 = f1_score(test_labels_out_45, test_predictions_45.round())
    # TPR and FPR
    cm = confusion_matrix(test_labels_out_45, test_predictions_45.round())
    try:
        _tp = cm[0, 0]
    except:
        _tp = np.nan
    try:
        _fn = cm[0, 1]
    except:
        _fn = np.nan
    try:
        _fp = cm[1, 0]
    except:
        _fp = np.nan
    try:
        _tn = cm[1, 1]
    except:
        _tn = np.nan
    tpr = _tp / (_tp + _fn)
    fpr = _fp / (_tn + _fp)

    new_row = {'data': 'set', 'fold': fold, 'subset': 45, 'AUROC': aucroc, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'F1': f1, 'AUPRC': auprc, 'TPR': tpr, 'FPR': fpr}
    results_df = results_df.append(new_row, ignore_index=True)


100%|██████████| 89/89 [00:30<00:00,  2.89it/s]
100%|██████████| 12/12 [00:01<00:00,  7.05it/s]
100%|██████████| 89/89 [00:31<00:00,  2.86it/s]
100%|██████████| 12/12 [00:01<00:00,  7.06it/s]
100%|██████████| 89/89 [00:30<00:00,  2.93it/s]
100%|██████████| 12/12 [00:01<00:00,  7.06it/s]
100%|██████████| 89/89 [00:29<00:00,  2.97it/s]
100%|██████████| 12/12 [00:01<00:00,  7.11it/s]
100%|██████████| 89/89 [00:30<00:00,  2.95it/s]
100%|██████████| 12/12 [00:01<00:00,  7.15it/s]
100%|██████████| 89/89 [00:29<00:00,  2.97it/s]
100%|██████████| 12/12 [00:01<00:00,  6.62it/s]
100%|██████████| 89/89 [00:30<00:00,  2.92it/s]
100%|██████████| 12/12 [00:01<00:00,  6.92it/s]
100%|██████████| 89/89 [00:30<00:00,  2.92it/s]
100%|██████████| 12/12 [00:01<00:00,  6.95it/s]
100%|██████████| 89/89 [00:30<00:00,  2.91it/s]
100%|██████████| 12/12 [00:01<00:00,  6.87it/s]
100%|██████████| 89/89 [00:30<00:00,  2.91it/s]
100%|██████████| 12/12 [00:01<00:00,  6.48it/s]


Epoch: 0
train loss: 1.212
train AUROC: 0.696
valid loss: 1.325
valid AUROC: 0.719
BEST VALID LOSS: 1.3253378470738728
BEST VALID AUROC: 0.7194735455320971
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.158
train AUROC: 0.719
valid loss: 1.162
valid AUROC: 0.752
BEST VALID LOSS: 1.1617638220389683
BEST VALID AUROC: 0.752378372502901
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.103
train AUROC: 0.749
valid loss: 1.191
valid AUROC: 0.753
BEST VALID AUROC: 0.7528319378279742
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 1.107
train AUROC: 0.750
valid loss: 1.170
valid AUROC: 0.764
BEST VALID AUROC: 0.7635382420445381
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 1.089
train AUROC: 0.761
valid loss: 1.125
valid AUROC: 0.771
BEST VALID LOSS: 1.1254191795984905
BEST VALID AUROC: 0.7710303217240717
UPDATED BEST INTERMEDIATE MODEL
Epoch: 5
train loss: 1.078
train AUROC: 0.765
valid loss: 1.131
valid AUROC: 0.773
BEST VALID AUROC: 0.7733049774383973
UPDATED BEST 

<All keys matched successfully>

100%|██████████| 12/12 [00:01<00:00,  7.05it/s]
100%|██████████| 1/1 [00:00<00:00, 43.74it/s]
100%|██████████| 1/1 [00:00<00:00,  8.27it/s]


378

test predictions: 0.302
test loss: 1.186
test AUROC: 0.748
test predictions 8: 0.269
test loss 8: 1.470
test AUROC 8: 0.477
test predictions 45: 0.240
test loss 45: 0.974
test AUROC 45: 0.552
BEST TEST AUROC: 0.7478819865136874
UPDATED BEST MODEL
Epoch: 0
train loss: 1.208
train AUROC: 0.691
valid loss: 1.146
valid AUROC: 0.726
BEST VALID LOSS: 1.1460019201040268
BEST VALID AUROC: 0.7260285438656225
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.124
train AUROC: 0.737
valid loss: 1.127
valid AUROC: 0.747
BEST VALID LOSS: 1.1269170641899109
BEST VALID AUROC: 0.7471776756327319
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.107
train AUROC: 0.748
valid loss: 1.127
valid AUROC: 0.747
BEST VALID LOSS: 1.1267147064208984
BEST VALID AUROC: 0.7474869481330155
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 1.094
train AUROC: 0.755
valid loss: 1.125
valid AUROC: 0.752
BEST VALID LOSS: 1.1250528991222382
BEST VALID AUROC: 0.7516005561230281
UPDATED BEST INTERMEDIATE MODEL
E

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 89/89 [00:30<00:00,  2.94it/s]
100%|██████████| 12/12 [00:01<00:00,  6.88it/s]
100%|██████████| 89/89 [00:30<00:00,  2.91it/s]
100%|██████████| 12/12 [00:01<00:00,  7.03it/s]
100%|██████████| 89/89 [00:29<00:00,  2.97it/s]
100%|██████████| 12/12 [00:01<00:00,  7.14it/s]
100%|██████████| 89/89 [00:29<00:00,  2.97it/s]
100%|██████████| 12/12 [00:01<00:00,  7.11it/s]
100%|██████████| 89/89 [00:29<00:00,  2.99it/s]
100%|██████████| 12/12 [00:01<00:00,  7.13it/s]
100%|██████████| 89/89 [00:30<00:00,  2.95it/s]
100%|██████████| 12/12 [00:01<00:00,  7.00it/s]
100%|██████████| 89/89 [00:29<00:00,  2.99it/s]
100%|██████████| 12/12 [00:01<00:00,  7.12it/s]
100%|██████████| 89/89 [00:30<00:00,  2.95it/s]
100%|██████████| 12/12 [00:01<00:00,  7.14it/s]
100%|██████████| 89/89 [00:29<00:00,  2.99it/s]
1

<All keys matched successfully>

100%|██████████| 12/12 [00:01<00:00,  7.12it/s]
100%|██████████| 1/1 [00:00<00:00, 38.80it/s]
100%|██████████| 1/1 [00:00<00:00,  8.58it/s]


0

test predictions: 0.434
test loss: 1.160
test AUROC: 0.757
test predictions 8: 0.478
test loss 8: 0.876
test AUROC 8: 0.919
test predictions 45: 0.468
test loss 45: 1.039
test AUROC 45: 0.634
BEST TEST AUROC: 0.7573909237993182
UPDATED BEST MODEL
Epoch: 0
train loss: 1.208
train AUROC: 0.696
valid loss: 1.236
valid AUROC: 0.716
BEST VALID LOSS: 1.236232782403628
BEST VALID AUROC: 0.7159116445352401
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.123
train AUROC: 0.738
valid loss: 1.187
valid AUROC: 0.734
BEST VALID LOSS: 1.1866472413142521
BEST VALID AUROC: 0.7337348201112246
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.100
train AUROC: 0.754
valid loss: 1.181
valid AUROC: 0.737
BEST VALID LOSS: 1.1806623737017314
BEST VALID AUROC: 0.7368161956645102
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 1.091
train AUROC: 0.758
valid loss: 1.178
valid AUROC: 0.740
BEST VALID LOSS: 1.1777698993682861
BEST VALID AUROC: 0.7404151061173533
UPDATED BEST INTERMEDIATE MODEL
Ep

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 89/89 [00:30<00:00,  2.94it/s]
100%|██████████| 12/12 [00:01<00:00,  6.02it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
100%|██████████| 12/12 [00:01<00:00,  7.14it/s]
100%|██████████| 89/89 [00:30<00:00,  2.95it/s]
100%|██████████| 12/12 [00:01<00:00,  7.06it/s]
100%|██████████| 89/89 [00:30<00:00,  2.94it/s]
100%|██████████| 12/12 [00:01<00:00,  7.09it/s]
100%|██████████| 89/89 [00:30<00:00,  2.89it/s]
100%|██████████| 12/12 [00:01<00:00,  7.08it/s]
100%|██████████| 89/89 [00:30<00:00,  2.94it/s]
100%|██████████| 12/12 [00:01<00:00,  7.07it/s]
100%|██████████| 89/89 [00:30<00:00,  2.92it/s]
100%|██████████| 12/12 [00:01<00:00,  7.02it/s]
100%|██████████| 89/89 [00:30<00:00,  2.94it/s]
100%|██████████| 12/12 [00:01<00:00,  7.07it/s]
100%|██████████| 89/89 [00:30<00:00,  2.90it/s]
1

<All keys matched successfully>

100%|██████████| 12/12 [00:01<00:00,  6.92it/s]
100%|██████████| 1/1 [00:00<00:00, 44.53it/s]
100%|██████████| 1/1 [00:00<00:00,  9.70it/s]


0

test predictions: 0.333
test loss: 1.141
test AUROC: 0.761
test predictions 8: 0.305
test loss 8: 2.279
test AUROC 8: 0.301
test predictions 45: 0.284
test loss 45: 1.615
test AUROC 45: 0.669
BEST TEST AUROC: 0.7610864260583362
UPDATED BEST MODEL
Epoch: 0
train loss: 1.210
train AUROC: 0.679
valid loss: 1.128
valid AUROC: 0.716
BEST VALID LOSS: 1.1284069915612538
BEST VALID AUROC: 0.7162169447281806
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.136
train AUROC: 0.728
valid loss: 1.068
valid AUROC: 0.747
BEST VALID LOSS: 1.068473642071088
BEST VALID AUROC: 0.747057087731245
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.100
train AUROC: 0.753
valid loss: 1.076
valid AUROC: 0.754
BEST VALID AUROC: 0.7538684598796959
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 1.098
train AUROC: 0.753
valid loss: 1.055
valid AUROC: 0.756
BEST VALID LOSS: 1.0549821654955547
BEST VALID AUROC: 0.7558449665191238
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 1.076
train AUROC:

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 89/89 [00:30<00:00,  2.92it/s]
100%|██████████| 12/12 [00:01<00:00,  7.03it/s]
100%|██████████| 89/89 [00:30<00:00,  2.95it/s]
100%|██████████| 12/12 [00:01<00:00,  6.89it/s]
100%|██████████| 89/89 [00:30<00:00,  2.92it/s]
100%|██████████| 12/12 [00:01<00:00,  6.99it/s]
100%|██████████| 89/89 [00:30<00:00,  2.89it/s]
100%|██████████| 12/12 [00:01<00:00,  7.00it/s]
100%|██████████| 89/89 [00:30<00:00,  2.94it/s]
100%|██████████| 12/12 [00:01<00:00,  7.05it/s]
100%|██████████| 89/89 [00:30<00:00,  2.92it/s]
100%|██████████| 12/12 [00:01<00:00,  7.05it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
100%|██████████| 12/12 [00:01<00:00,  7.17it/s]
100%|██████████| 89/89 [00:29<00:00,  2.97it/s]
100%|██████████| 12/12 [00:01<00:00,  7.10it/s]
100%|██████████| 89/89 [00:29<00:00,  2.99it/s]
1

<All keys matched successfully>

  0%|          | 0/12 [00:00<?, ?it/s]  8%|▊         | 1/12 [00:00<00:01,  6.76it/s] 17%|█▋        | 2/12 [00:00<00:01,  6.66it/s] 25%|██▌       | 3/12 [00:00<00:01,  6.55it/s] 33%|███▎      | 4/12 [00:00<00:01,  6.61it/s] 42%|████▏     | 5/12 [00:00<00:01,  6.63it/s] 50%|█████     | 6/12 [00:00<00:00,  6.63it/s] 58%|█████▊    | 7/12 [00:01<00:00,  6.62it/s] 67%|██████▋   | 8/12 [00:01<00:00,  6.63it/s] 75%|███████▌  | 9/12 [00:01<00:00,  6.60it/s] 83%|████████▎ | 10/12 [00:01<00:00,  6.56it/s] 92%|█████████▏| 11/12 [00:01<00:00,  6.59it/s]100%|██████████| 12/12 [00:01<00:00,  7.14it/s]
  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00, 39.79it/s]
  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00,  9.30it/s]100%|██████████| 1/1 [00:00<00:00,  9.16it/s]


0

test predictions: 0.532
test loss: 1.226
test AUROC: 0.720
test predictions 8: 0.639
test loss 8: 1.372
test AUROC 8: 0.464
test predictions 45: 0.634
test loss 45: 1.253
test AUROC 45: 0.598
Epoch: 0
train loss: 1.209
train AUROC: 0.690
valid loss: 1.197
valid AUROC: 0.724
BEST VALID LOSS: 1.196710119644801
BEST VALID AUROC: 0.7241911706112041
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.150
train AUROC: 0.721
valid loss: 1.156
valid AUROC: 0.740
BEST VALID LOSS: 1.1556655565897624
BEST VALID AUROC: 0.7404451692223627
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.107
train AUROC: 0.749
valid loss: 1.135
valid AUROC: 0.753
BEST VALID LOSS: 1.1353984673817952
BEST VALID AUROC: 0.7533255353300776
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 1.093
train AUROC: 0.756
valid loss: 1.151
valid AUROC: 0.755
BEST VALID AUROC: 0.7547069772207233
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 1.096
train AUROC: 0.756
valid loss: 1.129
valid AUROC: 0.759
BEST VALID

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 89/89 [00:30<00:00,  2.95it/s]
100%|██████████| 12/12 [00:01<00:00,  7.09it/s]
100%|██████████| 89/89 [00:30<00:00,  2.95it/s]
100%|██████████| 12/12 [00:01<00:00,  7.07it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
100%|██████████| 12/12 [00:01<00:00,  6.84it/s]
100%|██████████| 89/89 [00:30<00:00,  2.95it/s]
100%|██████████| 12/12 [00:01<00:00,  7.15it/s]
100%|██████████| 89/89 [00:30<00:00,  2.96it/s]
100%|██████████| 12/12 [00:01<00:00,  6.93it/s]
100%|██████████| 89/89 [00:30<00:00,  2.95it/s]
100%|██████████| 12/12 [00:01<00:00,  7.00it/s]
100%|██████████| 89/89 [00:30<00:00,  2.95it/s]
100%|██████████| 12/12 [00:01<00:00,  7.12it/s]
100%|██████████| 89/89 [00:29<00:00,  2.99it/s]
100%|██████████| 12/12 [00:01<00:00,  7.14it/s]
100%|██████████| 89/89 [00:30<00:00,  2.96it/s]
1

<All keys matched successfully>

  0%|          | 0/12 [00:00<?, ?it/s]  8%|▊         | 1/12 [00:00<00:01,  6.78it/s] 17%|█▋        | 2/12 [00:00<00:01,  6.75it/s] 25%|██▌       | 3/12 [00:00<00:01,  6.76it/s] 33%|███▎      | 4/12 [00:00<00:01,  6.71it/s] 42%|████▏     | 5/12 [00:00<00:01,  6.70it/s] 50%|█████     | 6/12 [00:00<00:00,  6.70it/s] 58%|█████▊    | 7/12 [00:01<00:00,  6.62it/s] 67%|██████▋   | 8/12 [00:01<00:00,  6.65it/s] 75%|███████▌  | 9/12 [00:01<00:00,  6.68it/s] 83%|████████▎ | 10/12 [00:01<00:00,  6.68it/s] 92%|█████████▏| 11/12 [00:01<00:00,  6.67it/s]100%|██████████| 12/12 [00:01<00:00,  7.22it/s]
  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00, 36.89it/s]
  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00,  9.00it/s]100%|██████████| 1/1 [00:00<00:00,  8.90it/s]


0

test predictions: 0.519
test loss: 1.128
test AUROC: 0.749
test predictions 8: 0.526
test loss 8: 0.967
test AUROC 8: 0.831
test predictions 45: 0.513
test loss 45: 1.082
test AUROC 45: 0.709
Epoch: 0
train loss: 1.208
train AUROC: 0.693
valid loss: 1.142
valid AUROC: 0.741
BEST VALID LOSS: 1.1420462727546692
BEST VALID AUROC: 0.7405479093382512
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.116
train AUROC: 0.745
valid loss: 1.261
valid AUROC: 0.726
Epoch: 2
train loss: 1.140
train AUROC: 0.734
valid loss: 1.129
valid AUROC: 0.747
BEST VALID LOSS: 1.1288850555817287
BEST VALID AUROC: 0.7474351968853523
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 1.112
train AUROC: 0.743
valid loss: 1.106
valid AUROC: 0.754
BEST VALID LOSS: 1.1056880702575047
BEST VALID AUROC: 0.7543355759707661
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 1.089
train AUROC: 0.761
valid loss: 1.123
valid AUROC: 0.752
Epoch: 5
train loss: 1.098
train AUROC: 0.753
valid loss: 1.099
valid AUROC: 0

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 89/89 [00:30<00:00,  2.91it/s]
100%|██████████| 12/12 [00:01<00:00,  7.04it/s]
100%|██████████| 89/89 [00:30<00:00,  2.96it/s]
100%|██████████| 12/12 [00:01<00:00,  7.13it/s]
100%|██████████| 89/89 [00:30<00:00,  2.95it/s]
100%|██████████| 12/12 [00:01<00:00,  7.16it/s]
100%|██████████| 89/89 [00:30<00:00,  2.96it/s]
100%|██████████| 12/12 [00:01<00:00,  7.14it/s]
100%|██████████| 89/89 [00:30<00:00,  2.96it/s]
100%|██████████| 12/12 [00:01<00:00,  7.18it/s]
100%|██████████| 89/89 [00:29<00:00,  2.97it/s]
100%|██████████| 12/12 [00:01<00:00,  7.15it/s]
100%|██████████| 89/89 [00:30<00:00,  2.97it/s]
100%|██████████| 12/12 [00:01<00:00,  7.14it/s]
100%|██████████| 89/89 [00:30<00:00,  2.91it/s]
100%|██████████| 12/12 [00:01<00:00,  7.17it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
1

<All keys matched successfully>

  0%|          | 0/12 [00:00<?, ?it/s]  8%|▊         | 1/12 [00:00<00:01,  6.77it/s] 17%|█▋        | 2/12 [00:00<00:01,  6.57it/s] 25%|██▌       | 3/12 [00:00<00:01,  6.53it/s] 33%|███▎      | 4/12 [00:00<00:01,  6.56it/s] 42%|████▏     | 5/12 [00:00<00:01,  6.42it/s] 50%|█████     | 6/12 [00:00<00:00,  6.50it/s] 58%|█████▊    | 7/12 [00:01<00:00,  6.58it/s] 67%|██████▋   | 8/12 [00:01<00:00,  6.58it/s] 75%|███████▌  | 9/12 [00:01<00:00,  6.60it/s] 83%|████████▎ | 10/12 [00:01<00:00,  6.63it/s] 92%|█████████▏| 11/12 [00:01<00:00,  6.61it/s]100%|██████████| 12/12 [00:01<00:00,  7.09it/s]
  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00, 38.41it/s]
  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00,  9.50it/s]100%|██████████| 1/1 [00:00<00:00,  9.40it/s]


0

test predictions: 0.427
test loss: 1.120
test AUROC: 0.766
test predictions 8: 0.469
test loss 8: 0.816
test AUROC 8: 0.469
test predictions 45: 0.447
test loss 45: 0.956
test AUROC 45: 0.553
BEST TEST AUROC: 0.7662349305010074
UPDATED BEST MODEL
Epoch: 0
train loss: 1.212
train AUROC: 0.691
valid loss: 1.086
valid AUROC: 0.733
BEST VALID LOSS: 1.0855499332149823
BEST VALID AUROC: 0.7330205594071241
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.144
train AUROC: 0.731
valid loss: 1.081
valid AUROC: 0.761
BEST VALID LOSS: 1.0807827214399974
BEST VALID AUROC: 0.7611428912946963
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.113
train AUROC: 0.745
valid loss: 1.053
valid AUROC: 0.767
BEST VALID LOSS: 1.0532234758138657
BEST VALID AUROC: 0.7672816843687033
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 1.097
train AUROC: 0.753
valid loss: 1.048
valid AUROC: 0.771
BEST VALID LOSS: 1.0479139139254887
BEST VALID AUROC: 0.7706083580934167
UPDATED BEST INTERMEDIATE MODEL
E

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 89/89 [00:30<00:00,  2.91it/s]
100%|██████████| 12/12 [00:01<00:00,  6.93it/s]
100%|██████████| 89/89 [00:30<00:00,  2.90it/s]
100%|██████████| 12/12 [00:01<00:00,  6.98it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
100%|██████████| 12/12 [00:01<00:00,  7.13it/s]
100%|██████████| 89/89 [00:30<00:00,  2.96it/s]
100%|██████████| 12/12 [00:01<00:00,  7.09it/s]
100%|██████████| 89/89 [00:30<00:00,  2.96it/s]
100%|██████████| 12/12 [00:01<00:00,  7.05it/s]
100%|██████████| 89/89 [00:30<00:00,  2.97it/s]
100%|██████████| 12/12 [00:01<00:00,  7.12it/s]
100%|██████████| 89/89 [00:29<00:00,  2.99it/s]
100%|██████████| 12/12 [00:01<00:00,  7.00it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
100%|██████████| 12/12 [00:01<00:00,  6.95it/s]
100%|██████████| 89/89 [00:29<00:00,  3.01it/s]
1

<All keys matched successfully>

  0%|          | 0/12 [00:00<?, ?it/s]  8%|▊         | 1/12 [00:00<00:02,  4.66it/s] 17%|█▋        | 2/12 [00:00<00:02,  4.75it/s] 25%|██▌       | 3/12 [00:00<00:01,  4.77it/s] 33%|███▎      | 4/12 [00:00<00:01,  4.77it/s] 42%|████▏     | 5/12 [00:01<00:01,  4.74it/s] 50%|█████     | 6/12 [00:01<00:01,  4.77it/s] 58%|█████▊    | 7/12 [00:01<00:01,  4.79it/s] 67%|██████▋   | 8/12 [00:01<00:00,  5.15it/s] 75%|███████▌  | 9/12 [00:01<00:00,  5.42it/s] 83%|████████▎ | 10/12 [00:01<00:00,  5.75it/s] 92%|█████████▏| 11/12 [00:02<00:00,  5.90it/s]100%|██████████| 12/12 [00:02<00:00,  5.66it/s]
  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00, 42.57it/s]
  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00,  9.37it/s]100%|██████████| 1/1 [00:00<00:00,  9.21it/s]


0

test predictions: 0.311
test loss: 1.205
test AUROC: 0.758
test predictions 8: 0.281
test loss 8: 0.990
test AUROC 8: 0.765
test predictions 45: 0.238
test loss 45: 1.651
test AUROC 45: 0.544
Epoch: 0
train loss: 1.209
train AUROC: 0.687
valid loss: 1.142
valid AUROC: 0.723
BEST VALID LOSS: 1.1419608195622761
BEST VALID AUROC: 0.7228879933517753
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.132
train AUROC: 0.733
valid loss: 1.148
valid AUROC: 0.736
BEST VALID AUROC: 0.7362228635178671
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.117
train AUROC: 0.742
valid loss: 1.143
valid AUROC: 0.745
BEST VALID AUROC: 0.7451823139009369
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 1.095
train AUROC: 0.754
valid loss: 1.144
valid AUROC: 0.750
BEST VALID AUROC: 0.7499721093314208
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 1.091
train AUROC: 0.756
valid loss: 1.133
valid AUROC: 0.754
BEST VALID LOSS: 1.1333007911841075
BEST VALID AUROC: 0.7541642475780654
UPDATED 

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 89/89 [00:32<00:00,  2.72it/s]
100%|██████████| 12/12 [00:02<00:00,  5.78it/s]
100%|██████████| 89/89 [00:31<00:00,  2.86it/s]
100%|██████████| 12/12 [00:01<00:00,  7.18it/s]
100%|██████████| 89/89 [00:29<00:00,  2.99it/s]
100%|██████████| 12/12 [00:01<00:00,  7.12it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
100%|██████████| 12/12 [00:01<00:00,  7.17it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
100%|██████████| 12/12 [00:01<00:00,  7.16it/s]
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  7.20it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
100%|██████████| 12/12 [00:01<00:00,  7.10it/s]
100%|██████████| 89/89 [00:29<00:00,  2.99it/s]
100%|██████████| 12/12 [00:01<00:00,  7.17it/s]
100%|██████████| 89/89 [00:29<00:00,  2.97it/s]
1

<All keys matched successfully>

100%|██████████| 12/12 [00:01<00:00,  6.78it/s]
100%|██████████| 1/1 [00:00<00:00, 34.96it/s]
100%|██████████| 1/1 [00:00<00:00,  8.61it/s]


0

test predictions: 0.513
test loss: 1.135
test AUROC: 0.774
test predictions 8: 0.553
test loss 8: 1.105
test AUROC 8: 0.578
test predictions 45: 0.533
test loss 45: 1.115
test AUROC 45: 0.557
BEST TEST AUROC: 0.7740329337568161
UPDATED BEST MODEL
Epoch: 0
train loss: 1.207
train AUROC: 0.695
valid loss: 1.130
valid AUROC: 0.725
BEST VALID LOSS: 1.1301259348789852
BEST VALID AUROC: 0.7249297041312341
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.120
train AUROC: 0.743
valid loss: 1.076
valid AUROC: 0.744
BEST VALID LOSS: 1.0763489504655201
BEST VALID AUROC: 0.7443712077228692
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.106
train AUROC: 0.751
valid loss: 1.158
valid AUROC: 0.731
Epoch: 3
train loss: 1.146
train AUROC: 0.726
valid loss: 1.103
valid AUROC: 0.736
Epoch: 4
train loss: 1.105
train AUROC: 0.750
valid loss: 1.076
valid AUROC: 0.746
BEST VALID LOSS: 1.0759353935718536
BEST VALID AUROC: 0.7456274262035677
UPDATED BEST INTERMEDIATE MODEL
Epoch: 5
train loss: 1.

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 89/89 [00:30<00:00,  2.96it/s]
100%|██████████| 12/12 [00:01<00:00,  7.12it/s]
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  7.07it/s]
100%|██████████| 89/89 [00:30<00:00,  2.96it/s]
100%|██████████| 12/12 [00:01<00:00,  7.14it/s]
100%|██████████| 89/89 [00:29<00:00,  2.97it/s]
100%|██████████| 12/12 [00:01<00:00,  6.95it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
100%|██████████| 12/12 [00:01<00:00,  7.05it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
100%|██████████| 12/12 [00:01<00:00,  7.13it/s]
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  7.14it/s]
100%|██████████| 89/89 [00:30<00:00,  2.95it/s]
100%|██████████| 12/12 [00:01<00:00,  7.11it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
1

<All keys matched successfully>

100%|██████████| 12/12 [00:01<00:00,  7.07it/s]
100%|██████████| 1/1 [00:00<00:00, 38.24it/s]
100%|██████████| 1/1 [00:00<00:00,  9.09it/s]


0

test predictions: 0.548
test loss: 1.217
test AUROC: 0.747
test predictions 8: 0.572
test loss 8: 1.402
test AUROC 8: 0.554
test predictions 45: 0.574
test loss 45: 1.314
test AUROC 45: 0.604
Epoch: 0
train loss: 1.198
train AUROC: 0.708
valid loss: 1.125
valid AUROC: 0.740
BEST VALID LOSS: 1.124985138575236
BEST VALID AUROC: 0.74040242683058
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.111
train AUROC: 0.749
valid loss: 1.132
valid AUROC: 0.748
BEST VALID AUROC: 0.7479610901274707
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.110
train AUROC: 0.748
valid loss: 1.109
valid AUROC: 0.750
BEST VALID LOSS: 1.1093477408091228
BEST VALID AUROC: 0.7502989149146735
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 1.094
train AUROC: 0.756
valid loss: 1.090
valid AUROC: 0.754
BEST VALID LOSS: 1.0902640124162037
BEST VALID AUROC: 0.7543036064987887
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 1.081
train AUROC: 0.765
valid loss: 1.106
valid AUROC: 0.753
Epoch: 5
tra

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  7.14it/s]
100%|██████████| 89/89 [00:29<00:00,  2.97it/s]
100%|██████████| 12/12 [00:01<00:00,  7.08it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
100%|██████████| 12/12 [00:01<00:00,  7.05it/s]
100%|██████████| 89/89 [00:30<00:00,  2.92it/s]
100%|██████████| 12/12 [00:01<00:00,  7.02it/s]
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  6.96it/s]
100%|██████████| 89/89 [00:29<00:00,  2.99it/s]
100%|██████████| 12/12 [00:01<00:00,  7.15it/s]
100%|██████████| 89/89 [00:29<00:00,  2.99it/s]
100%|██████████| 12/12 [00:01<00:00,  7.05it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
100%|██████████| 12/12 [00:01<00:00,  7.14it/s]
100%|██████████| 89/89 [00:29<00:00,  2.97it/s]
1

<All keys matched successfully>

100%|██████████| 12/12 [00:01<00:00,  7.03it/s]
100%|██████████| 1/1 [00:00<00:00, 44.93it/s]
100%|██████████| 1/1 [00:00<00:00,  9.73it/s]


0

test predictions: 0.400
test loss: 1.097
test AUROC: 0.748
test predictions 8: 0.238
test loss 8: 1.148
test AUROC 8: 0.768
test predictions 45: 0.195
test loss 45: 1.194
test AUROC 45: 0.741


  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)


In [16]:
results_df

Unnamed: 0,data,fold,subset,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
0,set,1,overall,0.747882,0.799295,0.130594,0.43465,0.200843,0.150022,0.821756,0.56535
1,set,1,8,0.476526,0.864865,0.0,0.0,0.0,0.05752,0.901408,1.0
2,set,1,45,0.55188,0.931298,0.0,0.0,0.0,0.035175,0.958115,1.0
3,set,2,overall,0.757391,0.553263,0.100725,0.844985,0.179994,0.175319,0.535293,0.155015
4,set,2,8,0.919048,0.479452,0.073171,1.0,0.136364,0.344444,0.457143,0.0
5,set,2,45,0.633656,0.48731,0.043478,0.692308,0.081818,0.067011,0.480315,0.307692
6,set,3,overall,0.761086,0.717637,0.126396,0.651515,0.211718,0.158006,0.721723,0.348485
7,set,3,8,0.301136,0.728571,0.0,0.0,0.0,0.049177,0.772727,1.0
8,set,3,45,0.668624,0.779553,0.136364,0.428571,0.206897,0.158546,0.804795,0.571429
9,set,4,overall,0.720488,0.424515,0.08088,0.857576,0.147819,0.15442,0.397753,0.142424


In [17]:
results_df.groupby(['data', 'subset']).mean()
results_df.groupby(['data', 'subset']).std()

Unnamed: 0_level_0,Unnamed: 1_level_0,fold,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
data,subset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
set,8,5.5,0.612609,0.55383,0.054973,0.53,0.090925,0.162252,0.556003,0.47
set,45,5.5,0.615869,0.585422,0.056156,0.522018,0.092404,0.111725,0.589264,0.477982
set,overall,5.5,0.752853,0.582184,0.104106,0.768583,0.180791,0.155475,0.570686,0.231417


Unnamed: 0_level_0,Unnamed: 1_level_0,fold,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
data,subset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
set,8,3.02765,0.19777,0.2435,0.065615,0.426426,0.092775,0.155665,0.26736,0.426426
set,45,3.02765,0.070323,0.245368,0.046454,0.302303,0.063187,0.089432,0.266255,0.302303
set,overall,3.02765,0.014478,0.125501,0.016717,0.151343,0.020264,0.010098,0.142043,0.151343


In [20]:
# Random - all and rare

# Run
best_test_auroc = 0
results_df = pd.DataFrame()
for fold in range(10):
    fold += 1
    if fold < 10:
        # Get train, validation and test sets
        train_embeddings = feature_array[np.setdiff1d(train_index_dict[fold], test_index_dict[fold+1])]
        val_embeddings = feature_array[test_index_dict[fold+1]]
        test_embeddings = feature_array[test_index_dict[fold]]
        test_embeddings_8 = feature_array[test_index_dict_8[fold]]
        test_embeddings_45 = feature_array[test_index_dict_45[fold]]

        train_labels = los_array[np.setdiff1d(train_index_dict[fold], test_index_dict[fold+1])]
        val_labels = los_array[test_index_dict[fold+1]]
        test_labels = los_array[test_index_dict[fold]]
        test_labels_8 = los_array[test_index_dict_8[fold]]
        test_labels_45 = los_array[test_index_dict_45[fold]]    

        # Split masks
        train_padding_mask = mask[np.setdiff1d(train_index_dict[fold], test_index_dict[fold+1])]
        val_padding_mask = mask[test_index_dict[fold+1]]
        test_padding_mask = mask[test_index_dict[fold]]
        test_padding_mask_8 = mask[test_index_dict_8[fold]]
        test_padding_mask_45 = mask[test_index_dict_45[fold]]

    else: # Change here to fist cv split for final fold
        # Get train, validation and test sets
        train_embeddings = feature_array[np.setdiff1d(train_index_dict[fold], test_index_dict[1])]
        val_embeddings = feature_array[test_index_dict[1]]
        test_embeddings = feature_array[test_index_dict[fold]]
        test_embeddings_8 = feature_array[test_index_dict_8[fold]]
        test_embeddings_45 = feature_array[test_index_dict_45[fold]]

        train_labels = los_array[np.setdiff1d(train_index_dict[fold], test_index_dict[1])]
        val_labels = los_array[test_index_dict[1]]
        test_labels = los_array[test_index_dict[fold]]
        test_labels_8 = los_array[test_index_dict_8[fold]]
        test_labels_45 = los_array[test_index_dict_45[fold]]

        # Split masks
        train_padding_mask = mask[np.setdiff1d(train_index_dict[fold], test_index_dict[1])]
        val_padding_mask = mask[test_index_dict[1]]
        test_padding_mask = mask[test_index_dict[fold]]
        test_padding_mask_8 = mask[test_index_dict_8[fold]]
        test_padding_mask_45 = mask[test_index_dict_45[fold]]
    
    # Create datasets
    train_dataset = DiseaseDataset(train_embeddings.astype(np.float32), train_labels.astype(np.float32), train_padding_mask)
    val_dataset = DiseaseDataset(val_embeddings.astype(np.float32), val_labels.astype(np.float32), val_padding_mask)
    test_dataset = DiseaseDataset(test_embeddings.astype(np.float32), test_labels.astype(np.float32), test_padding_mask)
    test_dataset_8 = DiseaseDataset(test_embeddings_8.astype(np.float32), test_labels_8.astype(np.float32), test_padding_mask_8)
    test_dataset_45 = DiseaseDataset(test_embeddings_45.astype(np.float32), test_labels_45.astype(np.float32), test_padding_mask_45)

    # Define batch size
    batch_size = 512

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    test_loader_8 = DataLoader(test_dataset_8, batch_size=batch_size)
    test_loader_45 = DataLoader(test_dataset_45, batch_size=batch_size)

    # Create weight for loss
    unique, counts = np.unique(train_labels, return_counts=True)
    pos_weight = torch.Tensor([(counts[0] / counts[1])])

    # Define model
    model = SetTransformer(dim_input=128, num_outputs=1, dim_output=1, num_inds=32, dim_hidden=160, num_heads=4, ln=False)

    model = model.to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # Run
    best_valid_loss = float('inf')
    best_valid_auroc = 0
    num_epochs = 10

    for epoch in range(num_epochs):

        start_time = time.time()

        train_loss, train_auroc, train_predictions, train_labels_out, train_pma_final = train(model, train_loader, optimizer, criterion)
        valid_loss, valid_auroc, valid_predictions, valid_labels_out, valid_pma_final = evaluate(model, val_loader, criterion)
        
        if epoch % 1 == 0:
            print('Epoch:', epoch)
            print(f'train loss: {train_loss:.3f}')
            print(f'train AUROC: {train_auroc:.3f}')
            print(f'valid loss: {valid_loss:.3f}')
            print(f'valid AUROC: {valid_auroc:.3f}')
        
        end_time = time.time()

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss  
            print('BEST VALID LOSS:', best_valid_loss)

        if valid_auroc > best_valid_auroc:
            best_valid_auroc = valid_auroc
            print('BEST VALID AUROC:', best_valid_auroc)
            print('UPDATED BEST INTERMEDIATE MODEL')
            torch.save(model.state_dict(), f'intermediate_set_transformer_mortality.pt')

    # -----------------------------
    # Evaluate best model on test set
    # -----------------------------

    model.load_state_dict(torch.load(f'intermediate_set_transformer_mortality.pt'))

    test_loss, test_auroc, test_predictions, test_labels_out, test_pma_final = evaluate(model, test_loader, criterion)
    test_loss_8, test_auroc_8, test_predictions_8, test_labels_out_8, test_pma_final_8 = evaluate(model, test_loader_8, criterion)
    test_loss_45, test_auroc_45, test_predictions_45, test_labels_out_45, test_pma_final_45 = evaluate(model, test_loader_45, criterion)

    # del for memmory 
    del train_embeddings
    del val_embeddings
    del train_labels
    del val_labels
    del train_padding_mask
    del val_padding_mask
    del train_loader
    del val_loader
    del train_predictions
    del train_labels_out
    del train_pma_final
    del valid_predictions
    del valid_labels_out
    del valid_pma_final
    gc.collect()

    print(f'test predictions: {test_predictions.mean().item():.3f}') 
    print(f'test loss: {test_loss:.3f}')
    print(f'test AUROC: {test_auroc:.3f}')

    print(f'test predictions 8: {test_predictions_8.mean().item():.3f}') 
    print(f'test loss 8: {test_loss_8:.3f}')
    print(f'test AUROC 8: {test_auroc_8:.3f}')

    print(f'test predictions 45: {test_predictions_45.mean().item():.3f}') 
    print(f'test loss 45: {test_loss_45:.3f}')
    print(f'test AUROC 45: {test_auroc_45:.3f}')

    if test_auroc > best_test_auroc:
        best_test_auroc = test_auroc
        print('BEST TEST AUROC:', best_test_auroc)
        print('UPDATED BEST MODEL')
        torch.save(model.state_dict(), f'set_transformer_random_mortality.pt')

    # Get results
    # AUC
    aucroc = roc_auc_score(test_labels_out, test_predictions)
    # Accuracy
    accuracy = accuracy_score(test_labels_out, test_predictions.round())
    # Precision
    precision = precision_score(test_labels_out, test_predictions.round())
    # Recall
    recall = recall_score(test_labels_out, test_predictions.round())
    # AUPRC
    auprc = average_precision_score(test_labels_out, test_predictions)
    # F1
    f1 = f1_score(test_labels_out, test_predictions.round())
    # TPR and FPR
    cm = confusion_matrix(test_labels_out, test_predictions.round())
    _tp = cm[0, 0]
    _fn = cm[0, 1]
    _fp = cm[1, 0]
    _tn = cm[1, 1]
    tpr = _tp / (_tp + _fn)
    fpr = _fp / (_tn + _fp)

    new_row = {'data': 'set', 'fold': fold, 'subset':'overall', 'AUROC': aucroc, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'F1': f1, 'AUPRC': auprc, 'TPR': tpr, 'FPR': fpr}
    results_df = results_df.append(new_row, ignore_index=True)

    # Get results - rare
    # AUC
    try:
        aucroc = roc_auc_score(test_labels_out_8, test_predictions_8)
    except:
        aucroc = np.nan
    # Accuracy
    accuracy = accuracy_score(test_labels_out_8, test_predictions_8.round())
    # Precision
    precision = precision_score(test_labels_out_8, test_predictions_8.round())
    # Recall
    recall = recall_score(test_labels_out_8, test_predictions_8.round())
    # AUPRC
    try:
        auprc = average_precision_score(test_labels_out_8, test_predictions_8)
    except:
        auprc = np.nan
    # F1
    f1 = f1_score(test_labels_out_8, test_predictions_8.round())
    # TPR and FPR
    cm = confusion_matrix(test_labels_out_8, test_predictions_8.round())
    try:
        _tp = cm[0, 0]
    except:
        _tp = np.nan
    try:
        _fn = cm[0, 1]
    except:
        _fn = np.nan
    try:
        _fp = cm[1, 0]
    except:
        _fp = np.nan
    try:
        _tn = cm[1, 1]
    except:
        _tn = np.nan
    tpr = _tp / (_tp + _fn)
    fpr = _fp / (_tn + _fp)

    new_row = {'data': 'set', 'fold': fold, 'subset': 8, 'AUROC': aucroc, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'F1': f1, 'AUPRC': auprc, 'TPR': tpr, 'FPR': fpr}
    results_df = results_df.append(new_row, ignore_index=True)

    # Get results
    # AUC
    try:
        aucroc = roc_auc_score(test_labels_out_45, test_predictions_45)
    except:
        aucroc = np.nan
    # Accuracy
    accuracy = accuracy_score(test_labels_out_45, test_predictions_45.round())
    # Precision
    precision = precision_score(test_labels_out_45, test_predictions_45.round())
    # Recall
    recall = recall_score(test_labels_out_45, test_predictions_45.round())
    # AUPRC
    try:
        auprc = average_precision_score(test_labels_out_45, test_predictions_45)
    except:
        auprc = np.nan
    # F1
    f1 = f1_score(test_labels_out_45, test_predictions_45.round())
    # TPR and FPR
    cm = confusion_matrix(test_labels_out_45, test_predictions_45.round())
    try:
        _tp = cm[0, 0]
    except:
        _tp = np.nan
    try:
        _fn = cm[0, 1]
    except:
        _fn = np.nan
    try:
        _fp = cm[1, 0]
    except:
        _fp = np.nan
    try:
        _tn = cm[1, 1]
    except:
        _tn = np.nan
    tpr = _tp / (_tp + _fn)
    fpr = _fp / (_tn + _fp)

    new_row = {'data': 'set', 'fold': fold, 'subset': 45, 'AUROC': aucroc, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'F1': f1, 'AUPRC': auprc, 'TPR': tpr, 'FPR': fpr}
    results_df = results_df.append(new_row, ignore_index=True)


100%|██████████| 89/89 [00:30<00:00,  2.93it/s]
100%|██████████| 12/12 [00:01<00:00,  7.10it/s]
100%|██████████| 89/89 [00:29<00:00,  2.97it/s]
100%|██████████| 12/12 [00:01<00:00,  7.16it/s]
100%|██████████| 89/89 [00:30<00:00,  2.91it/s]
100%|██████████| 12/12 [00:01<00:00,  6.89it/s]
100%|██████████| 89/89 [00:30<00:00,  2.94it/s]
100%|██████████| 12/12 [00:01<00:00,  6.95it/s]
100%|██████████| 89/89 [00:31<00:00,  2.85it/s]
100%|██████████| 12/12 [00:01<00:00,  7.04it/s]
100%|██████████| 89/89 [00:30<00:00,  2.95it/s]
100%|██████████| 12/12 [00:01<00:00,  6.92it/s]
100%|██████████| 89/89 [00:30<00:00,  2.96it/s]
100%|██████████| 12/12 [00:01<00:00,  7.04it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
100%|██████████| 12/12 [00:01<00:00,  7.07it/s]
100%|██████████| 89/89 [00:29<00:00,  2.97it/s]
100%|██████████| 12/12 [00:01<00:00,  7.12it/s]
100%|██████████| 89/89 [00:30<00:00,  2.97it/s]
100%|██████████| 12/12 [00:01<00:00,  7.13it/s]


Epoch: 0
train loss: 1.211
train AUROC: 0.669
valid loss: 1.200
valid AUROC: 0.719
BEST VALID LOSS: 1.2002125680446625
BEST VALID AUROC: 0.7188088475400198
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.171
train AUROC: 0.705
valid loss: 1.186
valid AUROC: 0.730
BEST VALID LOSS: 1.1864332457383473
BEST VALID AUROC: 0.7300597716011199
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.137
train AUROC: 0.730
valid loss: 1.189
valid AUROC: 0.739
BEST VALID AUROC: 0.7386023927989535
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 1.126
train AUROC: 0.739
valid loss: 1.169
valid AUROC: 0.749
BEST VALID LOSS: 1.1690785586833954
BEST VALID AUROC: 0.7492731288438523
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 1.131
train AUROC: 0.736
valid loss: 1.141
valid AUROC: 0.754
BEST VALID LOSS: 1.1411865254243214
BEST VALID AUROC: 0.7543673446624126
UPDATED BEST INTERMEDIATE MODEL
Epoch: 5
train loss: 1.106
train AUROC: 0.750
valid loss: 1.131
valid AUROC: 0.757
BEST VALID LO

<All keys matched successfully>

100%|██████████| 12/12 [00:01<00:00,  7.09it/s]
100%|██████████| 1/1 [00:00<00:00, 41.30it/s]
100%|██████████| 1/1 [00:00<00:00,  8.59it/s]


0

test predictions: 0.426
test loss: 1.170
test AUROC: 0.745
test predictions 8: 0.451
test loss 8: 2.562
test AUROC 8: 0.204
test predictions 45: 0.409
test loss 45: 1.385
test AUROC 45: 0.488
BEST TEST AUROC: 0.7452061218229797
UPDATED BEST MODEL
Epoch: 0
train loss: 1.227
train AUROC: 0.646
valid loss: 1.196
valid AUROC: 0.695
BEST VALID LOSS: 1.196043421824773
BEST VALID AUROC: 0.6952057087731245
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.174
train AUROC: 0.702
valid loss: 1.165
valid AUROC: 0.716
BEST VALID LOSS: 1.1654473890860875
BEST VALID AUROC: 0.7155410850073772
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.148
train AUROC: 0.723
valid loss: 1.144
valid AUROC: 0.727
BEST VALID LOSS: 1.144082595904668
BEST VALID AUROC: 0.7272077516740438
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 1.119
train AUROC: 0.741
valid loss: 1.140
valid AUROC: 0.729
BEST VALID LOSS: 1.1398323824008305
BEST VALID AUROC: 0.7294969356486212
UPDATED BEST INTERMEDIATE MODEL
Epo

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 89/89 [00:30<00:00,  2.94it/s]
100%|██████████| 12/12 [00:01<00:00,  6.73it/s]
100%|██████████| 89/89 [00:31<00:00,  2.87it/s]
100%|██████████| 12/12 [00:01<00:00,  6.99it/s]
100%|██████████| 89/89 [00:30<00:00,  2.93it/s]
100%|██████████| 12/12 [00:01<00:00,  7.03it/s]
100%|██████████| 89/89 [00:30<00:00,  2.92it/s]
100%|██████████| 12/12 [00:01<00:00,  7.10it/s]
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  7.08it/s]
100%|██████████| 89/89 [00:29<00:00,  2.97it/s]
100%|██████████| 12/12 [00:01<00:00,  7.15it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
100%|██████████| 12/12 [00:01<00:00,  7.07it/s]
100%|██████████| 89/89 [00:30<00:00,  2.92it/s]
100%|██████████| 12/12 [00:01<00:00,  7.02it/s]
100%|██████████| 89/89 [00:30<00:00,  2.96it/s]
1

<All keys matched successfully>

100%|██████████| 12/12 [00:01<00:00,  7.06it/s]
100%|██████████| 1/1 [00:00<00:00, 33.49it/s]
100%|██████████| 1/1 [00:00<00:00,  8.46it/s]


0

test predictions: 0.383
test loss: 1.219
test AUROC: 0.755
test predictions 8: 0.361
test loss 8: 0.934
test AUROC 8: 0.776
test predictions 45: 0.362
test loss 45: 1.025
test AUROC 45: 0.698
BEST TEST AUROC: 0.7551179753572326
UPDATED BEST MODEL
Epoch: 0
train loss: 1.205
train AUROC: 0.672
valid loss: 1.284
valid AUROC: 0.685
BEST VALID LOSS: 1.283976713816325
BEST VALID AUROC: 0.6848487685847237
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.155
train AUROC: 0.718
valid loss: 1.223
valid AUROC: 0.706
BEST VALID LOSS: 1.2234223534663518
BEST VALID AUROC: 0.7062038928611963
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.125
train AUROC: 0.738
valid loss: 1.237
valid AUROC: 0.709
BEST VALID AUROC: 0.7087050278061514
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 1.113
train AUROC: 0.746
valid loss: 1.242
valid AUROC: 0.711
BEST VALID AUROC: 0.7114087504256044
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 1.111
train AUROC: 0.747
valid loss: 1.205
valid AURO

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 89/89 [00:30<00:00,  2.94it/s]
100%|██████████| 12/12 [00:01<00:00,  7.06it/s]
100%|██████████| 89/89 [00:29<00:00,  2.97it/s]
100%|██████████| 12/12 [00:01<00:00,  7.09it/s]
100%|██████████| 89/89 [00:30<00:00,  2.96it/s]
100%|██████████| 12/12 [00:01<00:00,  7.17it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
100%|██████████| 12/12 [00:01<00:00,  6.99it/s]
100%|██████████| 89/89 [00:29<00:00,  2.97it/s]
100%|██████████| 12/12 [00:01<00:00,  7.08it/s]
100%|██████████| 89/89 [00:29<00:00,  2.99it/s]
100%|██████████| 12/12 [00:01<00:00,  6.97it/s]
100%|██████████| 89/89 [00:30<00:00,  2.95it/s]
100%|██████████| 12/12 [00:01<00:00,  7.06it/s]
100%|██████████| 89/89 [00:29<00:00,  2.99it/s]
100%|██████████| 12/12 [00:01<00:00,  7.12it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
1

<All keys matched successfully>

100%|██████████| 12/12 [00:01<00:00,  7.01it/s]
100%|██████████| 1/1 [00:00<00:00, 42.77it/s]
100%|██████████| 1/1 [00:00<00:00, 10.32it/s]


0

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 89/89 [00:29<00:00,  2.97it/s]
100%|██████████| 12/12 [00:01<00:00,  7.12it/s]
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  6.91it/s]
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  7.09it/s]
100%|██████████| 89/89 [00:29<00:00,  2.97it/s]
100%|██████████| 12/12 [00:01<00:00,  7.09it/s]
100%|██████████| 89/89 [00:29<00:00,  2.97it/s]
100%|██████████| 12/12 [00:01<00:00,  7.16it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
100%|██████████| 12/12 [00:01<00:00,  7.09it/s]
100%|██████████| 89/89 [00:29<00:00,  2.99it/s]
100%|██████████| 12/12 [00:01<00:00,  7.18it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
100%|██████████| 12/12 [00:01<00:00,  7.13it/s]
100%|██████████| 89/89 [00:30<00:00,  2.96it/s]
1

test predictions: 0.497
test loss: 1.238
test AUROC: 0.706
test predictions 8: 0.592
test loss 8: 1.699
test AUROC 8: 0.384
test predictions 45: 0.600
test loss 45: 1.770
test AUROC 45: 0.465
Epoch: 0
train loss: 1.206
train AUROC: 0.676
valid loss: 1.111
valid AUROC: 0.721
BEST VALID LOSS: 1.1110415508349736
BEST VALID AUROC: 0.721474860969243
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.158
train AUROC: 0.713
valid loss: 1.101
valid AUROC: 0.729
BEST VALID LOSS: 1.100803608695666
BEST VALID AUROC: 0.7288576779026218
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.132
train AUROC: 0.736
valid loss: 1.084
valid AUROC: 0.740
BEST VALID LOSS: 1.083978494008382
BEST VALID AUROC: 0.7401254114175462
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 1.113
train AUROC: 0.745
valid loss: 1.075
valid AUROC: 0.742
BEST VALID LOSS: 1.074554483095805
BEST VALID AUROC: 0.7420837589376915
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 1.092
train AUROC: 0.759
valid loss: 1.

<All keys matched successfully>

  0%|          | 0/12 [00:00<?, ?it/s]  8%|▊         | 1/12 [00:00<00:01,  6.59it/s] 17%|█▋        | 2/12 [00:00<00:01,  6.36it/s] 25%|██▌       | 3/12 [00:00<00:01,  6.47it/s] 33%|███▎      | 4/12 [00:00<00:01,  6.53it/s] 42%|████▏     | 5/12 [00:00<00:01,  6.54it/s] 50%|█████     | 6/12 [00:00<00:00,  6.53it/s] 58%|█████▊    | 7/12 [00:01<00:00,  6.55it/s] 67%|██████▋   | 8/12 [00:01<00:00,  6.53it/s] 75%|███████▌  | 9/12 [00:01<00:00,  6.48it/s] 83%|████████▎ | 10/12 [00:01<00:00,  6.53it/s] 92%|█████████▏| 11/12 [00:01<00:00,  6.56it/s]100%|██████████| 12/12 [00:01<00:00,  7.05it/s]
  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00, 39.96it/s]
  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00,  8.79it/s]100%|██████████| 1/1 [00:00<00:00,  8.70it/s]


0

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  7.18it/s]
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  7.16it/s]
100%|██████████| 89/89 [00:30<00:00,  2.96it/s]
100%|██████████| 12/12 [00:01<00:00,  7.09it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
100%|██████████| 12/12 [00:01<00:00,  7.13it/s]
100%|██████████| 89/89 [00:29<00:00,  3.01it/s]
100%|██████████| 12/12 [00:01<00:00,  7.02it/s]
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  7.18it/s]
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  7.04it/s]
100%|██████████| 89/89 [00:29<00:00,  2.97it/s]
100%|██████████| 12/12 [00:01<00:00,  7.18it/s]
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
1

test predictions: 0.473
test loss: 1.180
test AUROC: 0.723
test predictions 8: 0.457
test loss 8: 1.169
test AUROC 8: 0.481
test predictions 45: 0.435
test loss 45: 1.069
test AUROC 45: 0.548
Epoch: 0
train loss: 1.210
train AUROC: 0.675
valid loss: 1.307
valid AUROC: 0.703
BEST VALID LOSS: 1.3067410935958226
BEST VALID AUROC: 0.7031490841615153
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.164
train AUROC: 0.718
valid loss: 1.274
valid AUROC: 0.719
BEST VALID LOSS: 1.2735552787780762
BEST VALID AUROC: 0.7189130038819258
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.130
train AUROC: 0.737
valid loss: 1.201
valid AUROC: 0.732
BEST VALID LOSS: 1.2010574887196224
BEST VALID AUROC: 0.7316374099245245
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 1.126
train AUROC: 0.740
valid loss: 1.207
valid AUROC: 0.734
BEST VALID AUROC: 0.7338345115717815
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 1.107
train AUROC: 0.750
valid loss: 1.207
valid AUROC: 0.741
BEST VALI

<All keys matched successfully>

  0%|          | 0/12 [00:00<?, ?it/s]  8%|▊         | 1/12 [00:00<00:01,  6.53it/s] 17%|█▋        | 2/12 [00:00<00:01,  6.60it/s] 25%|██▌       | 3/12 [00:00<00:01,  6.64it/s] 33%|███▎      | 4/12 [00:00<00:01,  6.60it/s] 42%|████▏     | 5/12 [00:00<00:01,  6.60it/s] 50%|█████     | 6/12 [00:00<00:00,  6.59it/s] 58%|█████▊    | 7/12 [00:01<00:00,  6.64it/s] 67%|██████▋   | 8/12 [00:01<00:00,  6.64it/s] 75%|███████▌  | 9/12 [00:01<00:00,  6.67it/s] 83%|████████▎ | 10/12 [00:01<00:00,  6.67it/s] 92%|█████████▏| 11/12 [00:01<00:00,  6.62it/s]100%|██████████| 12/12 [00:01<00:00,  7.15it/s]
  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00, 37.74it/s]
  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00,  8.94it/s]100%|██████████| 1/1 [00:00<00:00,  8.82it/s]


0

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  7.14it/s]
100%|██████████| 89/89 [00:29<00:00,  3.01it/s]
100%|██████████| 12/12 [00:01<00:00,  7.19it/s]
100%|██████████| 89/89 [00:30<00:00,  2.94it/s]
100%|██████████| 12/12 [00:01<00:00,  7.03it/s]
100%|██████████| 89/89 [00:30<00:00,  2.96it/s]
100%|██████████| 12/12 [00:01<00:00,  7.06it/s]
100%|██████████| 89/89 [00:30<00:00,  2.96it/s]
100%|██████████| 12/12 [00:01<00:00,  7.10it/s]
100%|██████████| 89/89 [00:29<00:00,  2.97it/s]
100%|██████████| 12/12 [00:01<00:00,  7.03it/s]
100%|██████████| 89/89 [00:30<00:00,  2.94it/s]
100%|██████████| 12/12 [00:01<00:00,  6.95it/s]
100%|██████████| 89/89 [00:30<00:00,  2.88it/s]
100%|██████████| 12/12 [00:01<00:00,  7.01it/s]
100%|██████████| 89/89 [00:30<00:00,  2.91it/s]
1

test predictions: 0.547
test loss: 1.245
test AUROC: 0.706
test predictions 8: 0.621
test loss 8: 1.714
test AUROC 8: 0.309
test predictions 45: 0.583
test loss 45: 1.427
test AUROC 45: 0.543
Epoch: 0
train loss: 1.212
train AUROC: 0.668
valid loss: 1.188
valid AUROC: 0.698
BEST VALID LOSS: 1.1876536160707474
BEST VALID AUROC: 0.698197067495418
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.162
train AUROC: 0.720
valid loss: 1.169
valid AUROC: 0.717
BEST VALID LOSS: 1.1686742355426152
BEST VALID AUROC: 0.7168397026513211
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.137
train AUROC: 0.734
valid loss: 1.171
valid AUROC: 0.716
Epoch: 3
train loss: 1.138
train AUROC: 0.730
valid loss: 1.160
valid AUROC: 0.718
BEST VALID LOSS: 1.160011500120163
BEST VALID AUROC: 0.718288879022802
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 1.130
train AUROC: 0.734
valid loss: 1.167
valid AUROC: 0.726
BEST VALID AUROC: 0.725690151748005
UPDATED BEST INTERMEDIATE MODEL
Epoch: 5
trai

<All keys matched successfully>

  0%|          | 0/12 [00:00<?, ?it/s]  8%|▊         | 1/12 [00:00<00:01,  6.51it/s] 17%|█▋        | 2/12 [00:00<00:01,  6.33it/s] 25%|██▌       | 3/12 [00:00<00:01,  6.34it/s] 33%|███▎      | 4/12 [00:00<00:01,  6.39it/s] 42%|████▏     | 5/12 [00:00<00:01,  5.93it/s] 50%|█████     | 6/12 [00:00<00:01,  5.86it/s] 58%|█████▊    | 7/12 [00:01<00:00,  6.02it/s] 67%|██████▋   | 8/12 [00:01<00:00,  6.06it/s] 75%|███████▌  | 9/12 [00:01<00:00,  6.17it/s] 83%|████████▎ | 10/12 [00:01<00:00,  6.24it/s] 92%|█████████▏| 11/12 [00:01<00:00,  6.24it/s]100%|██████████| 12/12 [00:01<00:00,  6.66it/s]
  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00, 34.65it/s]
  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00,  8.92it/s]100%|██████████| 1/1 [00:00<00:00,  8.82it/s]


0

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 89/89 [00:31<00:00,  2.81it/s]
100%|██████████| 12/12 [00:01<00:00,  6.99it/s]
100%|██████████| 89/89 [00:30<00:00,  2.92it/s]
100%|██████████| 12/12 [00:01<00:00,  6.98it/s]
100%|██████████| 89/89 [00:30<00:00,  2.93it/s]
100%|██████████| 12/12 [00:01<00:00,  6.95it/s]
100%|██████████| 89/89 [00:30<00:00,  2.92it/s]
100%|██████████| 12/12 [00:01<00:00,  7.02it/s]
100%|██████████| 89/89 [00:30<00:00,  2.96it/s]
100%|██████████| 12/12 [00:01<00:00,  7.07it/s]
100%|██████████| 89/89 [00:30<00:00,  2.92it/s]
100%|██████████| 12/12 [00:01<00:00,  7.10it/s]
100%|██████████| 89/89 [00:30<00:00,  2.96it/s]
100%|██████████| 12/12 [00:01<00:00,  6.98it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
100%|██████████| 12/12 [00:01<00:00,  7.05it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
1

test predictions: 0.341
test loss: 1.319
test AUROC: 0.749
test predictions 8: 0.285
test loss 8: 0.784
test AUROC 8: 0.383
test predictions 45: 0.256
test loss 45: 0.974
test AUROC 45: 0.569
Epoch: 0
train loss: 1.208
train AUROC: 0.672
valid loss: 1.129
valid AUROC: 0.701
BEST VALID LOSS: 1.1291442414124806
BEST VALID AUROC: 0.701134125656
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.144
train AUROC: 0.724
valid loss: 1.094
valid AUROC: 0.725
BEST VALID LOSS: 1.0940777485569317
BEST VALID AUROC: 0.7254212060152773
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.122
train AUROC: 0.739
valid loss: 1.095
valid AUROC: 0.738
BEST VALID AUROC: 0.7380881231287639
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 1.108
train AUROC: 0.749
valid loss: 1.110
valid AUROC: 0.731
Epoch: 4
train loss: 1.143
train AUROC: 0.731
valid loss: 1.082
valid AUROC: 0.746
BEST VALID LOSS: 1.0823510537544887
BEST VALID AUROC: 0.7455241168903612
UPDATED BEST INTERMEDIATE MODEL
Epoch: 5
trai

<All keys matched successfully>

100%|██████████| 12/12 [00:01<00:00,  7.02it/s]
100%|██████████| 1/1 [00:00<00:00, 44.57it/s]
100%|██████████| 1/1 [00:00<00:00, 10.08it/s]


0

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 89/89 [00:30<00:00,  2.95it/s]
100%|██████████| 12/12 [00:01<00:00,  6.99it/s]
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  6.44it/s]
100%|██████████| 89/89 [00:32<00:00,  2.75it/s]
100%|██████████| 12/12 [00:01<00:00,  6.87it/s]
100%|██████████| 89/89 [00:31<00:00,  2.81it/s]
100%|██████████| 12/12 [00:01<00:00,  7.15it/s]
100%|██████████| 89/89 [00:30<00:00,  2.96it/s]
100%|██████████| 12/12 [00:01<00:00,  7.15it/s]
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  7.18it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
100%|██████████| 12/12 [00:01<00:00,  7.14it/s]
100%|██████████| 89/89 [00:29<00:00,  2.99it/s]
100%|██████████| 12/12 [00:01<00:00,  7.20it/s]
100%|██████████| 89/89 [00:29<00:00,  2.99it/s]
1

test predictions: 0.295
test loss: 1.259
test AUROC: 0.742
test predictions 8: 0.213
test loss 8: 1.664
test AUROC 8: 0.461
test predictions 45: 0.218
test loss 45: 1.307
test AUROC 45: 0.600
Epoch: 0
train loss: 1.211
train AUROC: 0.667
valid loss: 1.175
valid AUROC: 0.710
BEST VALID LOSS: 1.1753203372160594
BEST VALID AUROC: 0.7100204341837142
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.156
train AUROC: 0.716
valid loss: 1.164
valid AUROC: 0.731
BEST VALID LOSS: 1.1635485241810481
BEST VALID AUROC: 0.7311954282071423
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.124
train AUROC: 0.738
valid loss: 1.149
valid AUROC: 0.741
BEST VALID LOSS: 1.1489960551261902
BEST VALID AUROC: 0.74117061120408
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 1.114
train AUROC: 0.744
valid loss: 1.140
valid AUROC: 0.744
BEST VALID LOSS: 1.140350451072057
BEST VALID AUROC: 0.7442567990619628
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 1.110
train AUROC: 0.746
valid loss: 1

<All keys matched successfully>

100%|██████████| 12/12 [00:01<00:00,  7.06it/s]
100%|██████████| 1/1 [00:00<00:00, 36.81it/s]
100%|██████████| 1/1 [00:00<00:00,  8.56it/s]


0

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 89/89 [00:29<00:00,  2.99it/s]
100%|██████████| 12/12 [00:01<00:00,  7.11it/s]
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  7.01it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
100%|██████████| 12/12 [00:01<00:00,  7.08it/s]
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  7.16it/s]
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  7.05it/s]
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  7.07it/s]
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  7.19it/s]
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  7.03it/s]
100%|██████████| 89/89 [00:29<00:00,  2.99it/s]
1

test predictions: 0.253
test loss: 1.186
test AUROC: 0.749
test predictions 8: 0.189
test loss 8: 0.383
test AUROC 8: 0.944
test predictions 45: 0.187
test loss 45: 0.923
test AUROC 45: 0.677
Epoch: 0
train loss: 1.212
train AUROC: 0.676
valid loss: 1.172
valid AUROC: 0.705
BEST VALID LOSS: 1.1723968386650085
BEST VALID AUROC: 0.7045939915531119
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.165
train AUROC: 0.714
valid loss: 1.157
valid AUROC: 0.724
BEST VALID LOSS: 1.1573457767566044
BEST VALID AUROC: 0.7240135810480061
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.138
train AUROC: 0.729
valid loss: 1.151
valid AUROC: 0.737
BEST VALID LOSS: 1.1511441469192505
BEST VALID AUROC: 0.736618740252496
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 1.109
train AUROC: 0.749
valid loss: 1.154
valid AUROC: 0.736
Epoch: 4
train loss: 1.118
train AUROC: 0.742
valid loss: 1.149
valid AUROC: 0.739
BEST VALID LOSS: 1.1486188471317291
BEST VALID AUROC: 0.7385841785913505
UPDATE

<All keys matched successfully>

  0%|          | 0/12 [00:00<?, ?it/s]  8%|▊         | 1/12 [00:00<00:01,  6.68it/s] 17%|█▋        | 2/12 [00:00<00:01,  6.62it/s] 25%|██▌       | 3/12 [00:00<00:01,  6.67it/s] 33%|███▎      | 4/12 [00:00<00:01,  6.67it/s] 42%|████▏     | 5/12 [00:00<00:01,  6.61it/s] 50%|█████     | 6/12 [00:00<00:00,  6.63it/s] 58%|█████▊    | 7/12 [00:01<00:00,  6.65it/s] 67%|██████▋   | 8/12 [00:01<00:00,  6.52it/s] 75%|███████▌  | 9/12 [00:01<00:00,  6.48it/s] 83%|████████▎ | 10/12 [00:01<00:00,  6.50it/s] 92%|█████████▏| 11/12 [00:01<00:00,  6.52it/s]100%|██████████| 12/12 [00:01<00:00,  7.10it/s]
  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00, 35.59it/s]
  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:00<00:00,  8.81it/s]100%|██████████| 1/1 [00:00<00:00,  8.72it/s]


0

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
100%|██████████| 89/89 [00:29<00:00,  3.01it/s]
100%|██████████| 12/12 [00:01<00:00,  7.16it/s]
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  7.21it/s]
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  7.16it/s]
100%|██████████| 89/89 [00:29<00:00,  2.98it/s]
100%|██████████| 12/12 [00:01<00:00,  7.16it/s]
100%|██████████| 89/89 [00:29<00:00,  3.02it/s]
100%|██████████| 12/12 [00:01<00:00,  7.23it/s]
100%|██████████| 89/89 [00:29<00:00,  3.00it/s]
100%|██████████| 12/12 [00:01<00:00,  7.21it/s]
100%|██████████| 89/89 [00:29<00:00,  3.02it/s]
100%|██████████| 12/12 [00:01<00:00,  7.19it/s]
100%|██████████| 89/89 [00:30<00:00,  2.95it/s]
100%|██████████| 12/12 [00:01<00:00,  7.05it/s]
100%|██████████| 89/89 [00:29<00:00,  3.01it/s]
1

test predictions: 0.392
test loss: 1.197
test AUROC: 0.743
test predictions 8: 0.274
test loss 8: 2.036
test AUROC 8: 0.468
test predictions 45: 0.277
test loss 45: 1.330
test AUROC 45: 0.582
Epoch: 0
train loss: 1.213
train AUROC: 0.668
valid loss: 1.210
valid AUROC: 0.695
BEST VALID LOSS: 1.20969125131766
BEST VALID AUROC: 0.6947229353245439
UPDATED BEST INTERMEDIATE MODEL
Epoch: 1
train loss: 1.162
train AUROC: 0.712
valid loss: 1.162
valid AUROC: 0.722
BEST VALID LOSS: 1.1619826455911
BEST VALID AUROC: 0.7215256867644859
UPDATED BEST INTERMEDIATE MODEL
Epoch: 2
train loss: 1.131
train AUROC: 0.737
valid loss: 1.135
valid AUROC: 0.736
BEST VALID LOSS: 1.134886662165324
BEST VALID AUROC: 0.7360517280725067
UPDATED BEST INTERMEDIATE MODEL
Epoch: 3
train loss: 1.118
train AUROC: 0.744
valid loss: 1.133
valid AUROC: 0.744
BEST VALID LOSS: 1.1331208844979603
BEST VALID AUROC: 0.7442341148277163
UPDATED BEST INTERMEDIATE MODEL
Epoch: 4
train loss: 1.120
train AUROC: 0.743
valid loss: 1.11

<All keys matched successfully>

100%|██████████| 12/12 [00:01<00:00,  7.15it/s]
100%|██████████| 1/1 [00:00<00:00, 45.40it/s]
100%|██████████| 1/1 [00:00<00:00,  9.07it/s]


0

  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)
  results_df = results_df.append(new_row, ignore_index=True)


test predictions: 0.389
test loss: 1.224
test AUROC: 0.741
test predictions 8: 0.464
test loss 8: 1.056
test AUROC 8: 0.757
test predictions 45: 0.369
test loss 45: 1.395
test AUROC 45: 0.569


In [21]:
results_df

Unnamed: 0,data,fold,subset,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
0,set,1,overall,0.745206,0.598765,0.105434,0.790274,0.186047,0.133456,0.586969,0.209726
1,set,1,8,0.204225,0.527027,0.0,0.0,0.0,0.032988,0.549296,1.0
2,set,1,45,0.488339,0.615776,0.027027,0.363636,0.050314,0.11916,0.623037,0.636364
3,set,2,overall,0.755118,0.647972,0.116782,0.772036,0.202875,0.143531,0.64033,0.227964
4,set,2,8,0.77619,0.69863,0.086957,0.666667,0.153846,0.119744,0.7,0.333333
5,set,2,45,0.698062,0.703046,0.066667,0.615385,0.120301,0.081267,0.706037,0.384615
6,set,3,overall,0.705928,0.489418,0.085354,0.8,0.154251,0.122334,0.470225,0.2
7,set,3,8,0.38447,0.342857,0.043478,0.5,0.08,0.054819,0.333333,0.5
8,set,3,45,0.465427,0.373802,0.064677,0.619048,0.117117,0.105033,0.356164,0.380952
9,set,4,overall,0.722723,0.559083,0.095149,0.772727,0.169435,0.14008,0.54588,0.227273


In [22]:
results_df.groupby(['data', 'subset']).mean()
results_df.groupby(['data', 'subset']).std()

Unnamed: 0_level_0,Unnamed: 1_level_0,fold,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
data,subset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
set,8,5.5,0.516795,0.638347,0.057155,0.436667,0.098353,0.114314,0.645232,0.563333
set,45,5.5,0.573995,0.691051,0.049497,0.380783,0.084323,0.073626,0.701001,0.619217
set,overall,5.5,0.735816,0.632163,0.110014,0.695001,0.186081,0.135336,0.628273,0.304999


Unnamed: 0_level_0,Unnamed: 1_level_0,fold,AUROC,accuracy,precision,recall,F1,AUPRC,TPR,FPR
data,subset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
set,8,3.02765,0.233651,0.218914,0.059931,0.392035,0.102035,0.124107,0.231067,0.392035
set,45,3.02765,0.072782,0.189061,0.025875,0.267785,0.043078,0.030454,0.206499,0.267785
set,overall,3.02765,0.017869,0.128342,0.020121,0.154091,0.022738,0.009659,0.145021,0.154091
