In [2]:
# Libraries
import numpy as np
import pandas as pd
pd.set_option('display.max_columns', 40)
pd.set_option('display.width', 2000)
import math
import time
import random
import random

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch import optim
from tqdm import tqdm
import torch.nn.functional as F

import pickle
from sklearn import metrics
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
from sklearn.metrics import f1_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import confusion_matrix
from sklearn.metrics import average_precision_score

from sklearn.model_selection import StratifiedKFold, train_test_split

import torch.optim as optim

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

# Remove printing error
pd.options.mode.chained_assignment = None

In [3]:
# Set the random seeds for deterministic results.
SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# Set device
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available")
else:
    device = torch.device("cpu")
    print("GPU not available, CPU used")

<torch._C.Generator at 0x7fda3bedf930>

GPU not available, CPU used


In [4]:
# Import
path = r'switch_data/chronic_switch_icare_df_preprocessed_2023.csv'
icare_df_preprocessed = pd.read_csv(path)

In [5]:
icare_df_preprocessed.SPELL_IDENTIFIER.nunique()

547

# Main run

In [6]:
def set_transformer_processing_fun(patient_df, snomed_embedding):
    # Str
    patient_df.columns = patient_df.columns.astype(str)
    snomed_embedding['snomed_code'] = snomed_embedding['snomed_code'].astype(str)
    # Filter
    snomed_embedding = snomed_embedding[snomed_embedding['snomed_code'].isin(patient_df.columns.tolist())]
    snomed_embedding.set_index('snomed_code', inplace=True)
    # Get lengths of each patients co-morbidities
    comorbidity_len = np.array(patient_df.sum(axis=1))
    # Add padding embedding 
    padding_df = pd.DataFrame(np.random.choice([0], size=len(snomed_embedding.columns))) # Changed to 0
    padding_df = padding_df.T
    padding_df.index = ['9999999999']
    padding_df.columns = snomed_embedding.columns
    snomed_embedding2 = pd.concat([snomed_embedding, padding_df])
    snomed_embedding2.index = snomed_embedding2.index.astype(str)
    # Get max number of co-morbidities
    max_len = 22 # Define for same for all splits (train val etr)
    # Format patients embeddings into set and pad / create array
    feature_array = np.zeros(shape=(len(patient_df), max_len , 128))
    n = -1
    for index, row in patient_df.iterrows():
        n += 1
        n2 = -1
        code_list = row[row ==1].index.tolist()
        while len(code_list) < max_len:
            code_list.append('9999999999')
        for code in code_list:
            n2 += 1
            feature_array[n, n2] = np.array(snomed_embedding2.loc[code])
    
    # Create mask tensor based on lengths
    comorbidity_len2 = torch.as_tensor(comorbidity_len, dtype=torch.long)
    mask = torch.arange(max_len)[None, :] < comorbidity_len2[:, None]

    return feature_array, mask

In [7]:
# Define how long an epoch takes
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

# Initializing the weights of our model.
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

# Calculate the number of trainable parameters in the model.
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [8]:
columns_to_drop = [
 'Diastolic Blood Pressure2',
 'Diastolic Blood Pressure3',
 'Diastolic Blood Pressure4',
 'Diastolic Blood Pressure5',
 'Diastolic Blood Pressure6',
 'Diastolic Blood Pressure7',
 'Diastolic Blood Pressure8',
 'Diastolic Blood Pressure9',
 'Diastolic Blood Pressure11',
 'Diastolic Blood Pressure12',
 'Diastolic Blood Pressure13',
 'Diastolic Blood Pressure14',
 'Diastolic Blood Pressure15',
 'Diastolic Blood Pressure16',
 'Diastolic Blood Pressure18',
 'Diastolic Blood Pressure19',
 'Diastolic Blood Pressure20',
 'Diastolic Blood Pressure21',
 'Diastolic Blood Pressure2_current_stay',
 'Diastolic Blood Pressure3_current_stay',
 'Diastolic Blood Pressure4_current_stay',
 'Diastolic Blood Pressure5_current_stay',
 'Diastolic Blood Pressure6_current_stay',
 'Diastolic Blood Pressure8_current_stay',
 'Diastolic Blood Pressure12_current_stay',
 'Diastolic Blood Pressure13_current_stay',
 'Diastolic Blood Pressure14_current_stay',
 'Diastolic Blood Pressure16_current_stay',
 'Diastolic Blood Pressure18_current_stay',
 'Diastolic Blood Pressure19_current_stay',
 'Diastolic Blood Pressure20_current_stay',
 'Diastolic Blood Pressure21_current_stay',
 'Glasgow Coma Score0',
 'Glasgow Coma Score1',
 'Glasgow Coma Score2',
 'Glasgow Coma Score3',
 'Glasgow Coma Score4',
 'Glasgow Coma Score5',
 'Glasgow Coma Score6',
 'Glasgow Coma Score7',
 'Glasgow Coma Score8',
 'Glasgow Coma Score9',
 'Glasgow Coma Score10',
 'Glasgow Coma Score11',
 'Glasgow Coma Score12',
 'Glasgow Coma Score13',
 'Glasgow Coma Score14',
 'Glasgow Coma Score15',
 'Glasgow Coma Score16',
 'Glasgow Coma Score17',
 'Glasgow Coma Score18',
 'Glasgow Coma Score19',
 'Glasgow Coma Score20',
 'Glasgow Coma Score21',
 'Glasgow Coma Score0_current_stay',
 'Glasgow Coma Score1_current_stay',
 'Glasgow Coma Score2_current_stay',
 'Glasgow Coma Score3_current_stay',
 'Glasgow Coma Score4_current_stay',
 'Glasgow Coma Score5_current_stay',
 'Glasgow Coma Score6_current_stay',
 'Glasgow Coma Score7_current_stay',
 'Glasgow Coma Score8_current_stay',
 'Glasgow Coma Score10_current_stay',
 'Glasgow Coma Score11_current_stay',
 'Glasgow Coma Score12_current_stay',
 'Glasgow Coma Score13_current_stay',
 'Glasgow Coma Score14_current_stay',
 'Glasgow Coma Score15_current_stay',
 'Glasgow Coma Score16_current_stay',
 'Glasgow Coma Score17_current_stay',
 'Glasgow Coma Score18_current_stay',
 'Glasgow Coma Score19_current_stay',
 'Glasgow Coma Score20_current_stay',
 'Glasgow Coma Score21_current_stay',
 'Heart Rate2',
 'Heart Rate3',
 'Heart Rate4',
 'Heart Rate5',
 'Heart Rate6',
 'Heart Rate7',
 'Heart Rate8',
 'Heart Rate9',
 'Heart Rate11',
 'Heart Rate12',
 'Heart Rate13',
 'Heart Rate14',
 'Heart Rate15',
 'Heart Rate16',
 'Heart Rate18',
 'Heart Rate19',
 'Heart Rate20',
 'Heart Rate21',
 'Heart Rate2_current_stay',
 'Heart Rate3_current_stay',
 'Heart Rate4_current_stay',
 'Heart Rate5_current_stay',
 'Heart Rate6_current_stay',
 'Heart Rate8_current_stay',
 'Heart Rate12_current_stay',
 'Heart Rate13_current_stay',
 'Heart Rate14_current_stay',
 'Heart Rate16_current_stay',
 'Heart Rate18_current_stay',
 'Heart Rate19_current_stay',
 'Heart Rate20_current_stay',
 'Heart Rate21_current_stay',
 'Mean Arterial Pressure2',
 'Mean Arterial Pressure3',
 'Mean Arterial Pressure4',
 'Mean Arterial Pressure5',
 'Mean Arterial Pressure6',
 'Mean Arterial Pressure7',
 'Mean Arterial Pressure8',
 'Mean Arterial Pressure9',
 'Mean Arterial Pressure11',
 'Mean Arterial Pressure12',
 'Mean Arterial Pressure13',
 'Mean Arterial Pressure14',
 'Mean Arterial Pressure15',
 'Mean Arterial Pressure16',
 'Mean Arterial Pressure18',
 'Mean Arterial Pressure19',
 'Mean Arterial Pressure20',
 'Mean Arterial Pressure21',
 'Mean Arterial Pressure2_current_stay',
 'Mean Arterial Pressure3_current_stay',
 'Mean Arterial Pressure4_current_stay',
 'Mean Arterial Pressure5_current_stay',
 'Mean Arterial Pressure6_current_stay',
 'Mean Arterial Pressure8_current_stay',
 'Mean Arterial Pressure12_current_stay',
 'Mean Arterial Pressure13_current_stay',
 'Mean Arterial Pressure14_current_stay',
 'Mean Arterial Pressure16_current_stay',
 'Mean Arterial Pressure18_current_stay',
 'Mean Arterial Pressure19_current_stay',
 'Mean Arterial Pressure20_current_stay',
 'Mean Arterial Pressure21_current_stay',
 'NEWS Conscious Level Score0',
 'NEWS Conscious Level Score1',
 'NEWS Conscious Level Score2',
 'NEWS Conscious Level Score3',
 'NEWS Conscious Level Score4',
 'NEWS Conscious Level Score5',
 'NEWS Conscious Level Score6',
 'NEWS Conscious Level Score7',
 'NEWS Conscious Level Score8',
 'NEWS Conscious Level Score9',
 'NEWS Conscious Level Score10',
 'NEWS Conscious Level Score11',
 'NEWS Conscious Level Score12',
 'NEWS Conscious Level Score13',
 'NEWS Conscious Level Score14',
 'NEWS Conscious Level Score15',
 'NEWS Conscious Level Score16',
 'NEWS Conscious Level Score17',
 'NEWS Conscious Level Score18',
 'NEWS Conscious Level Score19',
 'NEWS Conscious Level Score20',
 'NEWS Conscious Level Score21',
 'NEWS Conscious Level Score0_current_stay',
 'NEWS Conscious Level Score1_current_stay',
 'NEWS Conscious Level Score2_current_stay',
 'NEWS Conscious Level Score3_current_stay',
 'NEWS Conscious Level Score4_current_stay',
 'NEWS Conscious Level Score5_current_stay',
 'NEWS Conscious Level Score6_current_stay',
 'NEWS Conscious Level Score7_current_stay',
 'NEWS Conscious Level Score8_current_stay',
 'NEWS Conscious Level Score9_current_stay',
 'NEWS Conscious Level Score10_current_stay',
 'NEWS Conscious Level Score11_current_stay',
 'NEWS Conscious Level Score12_current_stay',
 'NEWS Conscious Level Score13_current_stay',
 'NEWS Conscious Level Score14_current_stay',
 'NEWS Conscious Level Score15_current_stay',
 'NEWS Conscious Level Score16_current_stay',
 'NEWS Conscious Level Score17_current_stay',
 'NEWS Conscious Level Score18_current_stay',
 'NEWS Conscious Level Score19_current_stay',
 'NEWS Conscious Level Score20_current_stay',
 'NEWS Conscious Level Score21_current_stay',
 'NEWS Supplemental Oxygen Calc0',
 'NEWS Supplemental Oxygen Calc1',
 'NEWS Supplemental Oxygen Calc2',
 'NEWS Supplemental Oxygen Calc3',
 'NEWS Supplemental Oxygen Calc4',
 'NEWS Supplemental Oxygen Calc5',
 'NEWS Supplemental Oxygen Calc6',
 'NEWS Supplemental Oxygen Calc7',
 'NEWS Supplemental Oxygen Calc8',
 'NEWS Supplemental Oxygen Calc9',
 'NEWS Supplemental Oxygen Calc10',
 'NEWS Supplemental Oxygen Calc11',
 'NEWS Supplemental Oxygen Calc12',
 'NEWS Supplemental Oxygen Calc13',
 'NEWS Supplemental Oxygen Calc14',
 'NEWS Supplemental Oxygen Calc15',
 'NEWS Supplemental Oxygen Calc16',
 'NEWS Supplemental Oxygen Calc17',
 'NEWS Supplemental Oxygen Calc18',
 'NEWS Supplemental Oxygen Calc19',
 'NEWS Supplemental Oxygen Calc20',
 'NEWS Supplemental Oxygen Calc21',
 'NEWS Supplemental Oxygen Calc0_current_stay',
 'NEWS Supplemental Oxygen Calc1_current_stay',
 'NEWS Supplemental Oxygen Calc2_current_stay',
 'NEWS Supplemental Oxygen Calc3_current_stay',
 'NEWS Supplemental Oxygen Calc4_current_stay',
 'NEWS Supplemental Oxygen Calc5_current_stay',
 'NEWS Supplemental Oxygen Calc6_current_stay',
 'NEWS Supplemental Oxygen Calc7_current_stay',
 'NEWS Supplemental Oxygen Calc8_current_stay',
 'NEWS Supplemental Oxygen Calc9_current_stay',
 'NEWS Supplemental Oxygen Calc10_current_stay',
 'NEWS Supplemental Oxygen Calc11_current_stay',
 'NEWS Supplemental Oxygen Calc12_current_stay',
 'NEWS Supplemental Oxygen Calc13_current_stay',
 'NEWS Supplemental Oxygen Calc14_current_stay',
 'NEWS Supplemental Oxygen Calc15_current_stay',
 'NEWS Supplemental Oxygen Calc16_current_stay',
 'NEWS Supplemental Oxygen Calc17_current_stay',
 'NEWS Supplemental Oxygen Calc18_current_stay',
 'NEWS Supplemental Oxygen Calc19_current_stay',
 'NEWS Supplemental Oxygen Calc20_current_stay',
 'NEWS Supplemental Oxygen Calc21_current_stay',
 'Respiratory Rate0',
 'Respiratory Rate2',
 'Respiratory Rate3',
 'Respiratory Rate4',
 'Respiratory Rate5',
 'Respiratory Rate6',
 'Respiratory Rate7',
 'Respiratory Rate8',
 'Respiratory Rate9',
 'Respiratory Rate11',
 'Respiratory Rate12',
 'Respiratory Rate13',
 'Respiratory Rate14',
 'Respiratory Rate15',
 'Respiratory Rate16',
 'Respiratory Rate18',
 'Respiratory Rate19',
 'Respiratory Rate20',
 'Respiratory Rate21',
 'Respiratory Rate0_current_stay',
 'Respiratory Rate2_current_stay',
 'Respiratory Rate3_current_stay',
 'Respiratory Rate4_current_stay',
 'Respiratory Rate5_current_stay',
 'Respiratory Rate6_current_stay',
 'Respiratory Rate8_current_stay',
 'Respiratory Rate12_current_stay',
 'Respiratory Rate13_current_stay',
 'Respiratory Rate14_current_stay',
 'Respiratory Rate16_current_stay',
 'Respiratory Rate18_current_stay',
 'Respiratory Rate19_current_stay',
 'Respiratory Rate20_current_stay',
 'Respiratory Rate21_current_stay',
 'SpO20',
 'SpO22',
 'SpO23',
 'SpO24',
 'SpO25',
 'SpO26',
 'SpO27',
 'SpO28',
 'SpO29',
 'SpO211',
 'SpO212',
 'SpO213',
 'SpO214',
 'SpO215',
 'SpO216',
 'SpO218',
 'SpO219',
 'SpO220',
 'SpO221',
 'SpO20_current_stay',
 'SpO22_current_stay',
 'SpO23_current_stay',
 'SpO24_current_stay',
 'SpO25_current_stay',
 'SpO26_current_stay',
 'SpO28_current_stay',
 'SpO212_current_stay',
 'SpO213_current_stay',
 'SpO214_current_stay',
 'SpO216_current_stay',
 'SpO218_current_stay',
 'SpO219_current_stay',
 'SpO220_current_stay',
 'SpO221_current_stay',
 'Systolic Blood Pressure2',
 'Systolic Blood Pressure3',
 'Systolic Blood Pressure4',
 'Systolic Blood Pressure5',
 'Systolic Blood Pressure6',
 'Systolic Blood Pressure7',
 'Systolic Blood Pressure8',
 'Systolic Blood Pressure9',
 'Systolic Blood Pressure11',
 'Systolic Blood Pressure12',
 'Systolic Blood Pressure13',
 'Systolic Blood Pressure14',
 'Systolic Blood Pressure15',
 'Systolic Blood Pressure16',
 'Systolic Blood Pressure18',
 'Systolic Blood Pressure19',
 'Systolic Blood Pressure20',
 'Systolic Blood Pressure21',
 'Systolic Blood Pressure2_current_stay',
 'Systolic Blood Pressure3_current_stay',
 'Systolic Blood Pressure4_current_stay',
 'Systolic Blood Pressure5_current_stay',
 'Systolic Blood Pressure6_current_stay',
 'Systolic Blood Pressure8_current_stay',
 'Systolic Blood Pressure12_current_stay',
 'Systolic Blood Pressure13_current_stay',
 'Systolic Blood Pressure14_current_stay',
 'Systolic Blood Pressure16_current_stay',
 'Systolic Blood Pressure18_current_stay',
 'Systolic Blood Pressure19_current_stay',
 'Systolic Blood Pressure20_current_stay',
 'Systolic Blood Pressure21_current_stay',
 'Temperature0',
 'Temperature2',
 'Temperature3',
 'Temperature4',
 'Temperature5',
 'Temperature6',
 'Temperature7',
 'Temperature8',
 'Temperature9',
 'Temperature11',
 'Temperature12',
 'Temperature13',
 'Temperature14',
 'Temperature15',
 'Temperature16',
 'Temperature18',
 'Temperature19',
 'Temperature20',
 'Temperature21',
 'Temperature0_current_stay',
 'Temperature2_current_stay',
 'Temperature3_current_stay',
 'Temperature4_current_stay',
 'Temperature5_current_stay',
 'Temperature6_current_stay',
 'Temperature8_current_stay',
 'Temperature12_current_stay',
 'Temperature13_current_stay',
 'Temperature14_current_stay',
 'Temperature16_current_stay',
 'Temperature18_current_stay',
 'Temperature19_current_stay',
 'Temperature20_current_stay',
 'Temperature21_current_stay',
 'Diastolic Blood Pressure2_difference',
 'Diastolic Blood Pressure3_difference',
 'Diastolic Blood Pressure4_difference',
 'Diastolic Blood Pressure5_difference',
 'Diastolic Blood Pressure6_difference',
 'Diastolic Blood Pressure7_difference',
 'Diastolic Blood Pressure8_difference',
 'Diastolic Blood Pressure9_difference',
 'Diastolic Blood Pressure11_difference',
 'Diastolic Blood Pressure12_difference',
 'Diastolic Blood Pressure13_difference',
 'Diastolic Blood Pressure14_difference',
 'Diastolic Blood Pressure15_difference',
 'Diastolic Blood Pressure16_difference',
 'Diastolic Blood Pressure18_difference',
 'Diastolic Blood Pressure19_difference',
 'Diastolic Blood Pressure20_difference',
 'Diastolic Blood Pressure21_difference',
 'Diastolic Blood Pressure2_current_stay_difference',
 'Diastolic Blood Pressure3_current_stay_difference',
 'Diastolic Blood Pressure4_current_stay_difference',
 'Diastolic Blood Pressure5_current_stay_difference',
 'Diastolic Blood Pressure6_current_stay_difference',
 'Diastolic Blood Pressure8_current_stay_difference',
 'Diastolic Blood Pressure12_current_stay_difference',
 'Diastolic Blood Pressure13_current_stay_difference',
 'Diastolic Blood Pressure14_current_stay_difference',
 'Diastolic Blood Pressure16_current_stay_difference',
 'Diastolic Blood Pressure18_current_stay_difference',
 'Diastolic Blood Pressure19_current_stay_difference',
 'Diastolic Blood Pressure21_current_stay_difference',
 'Glasgow Coma Score0_difference',
 'Glasgow Coma Score1_difference',
 'Glasgow Coma Score2_difference',
 'Glasgow Coma Score3_difference',
 'Glasgow Coma Score4_difference',
 'Glasgow Coma Score5_difference',
 'Glasgow Coma Score6_difference',
 'Glasgow Coma Score7_difference',
 'Glasgow Coma Score8_difference',
 'Glasgow Coma Score9_difference',
 'Glasgow Coma Score10_difference',
 'Glasgow Coma Score11_difference',
 'Glasgow Coma Score12_difference',
 'Glasgow Coma Score13_difference',
 'Glasgow Coma Score14_difference',
 'Glasgow Coma Score15_difference',
 'Glasgow Coma Score16_difference',
 'Glasgow Coma Score17_difference',
 'Glasgow Coma Score18_difference',
 'Glasgow Coma Score19_difference',
 'Glasgow Coma Score20_difference',
 'Glasgow Coma Score21_difference',
 'Glasgow Coma Score0_current_stay_difference',
 'Glasgow Coma Score1_current_stay_difference',
 'Glasgow Coma Score2_current_stay_difference',
 'Glasgow Coma Score3_current_stay_difference',
 'Glasgow Coma Score4_current_stay_difference',
 'Glasgow Coma Score5_current_stay_difference',
 'Glasgow Coma Score6_current_stay_difference',
 'Glasgow Coma Score7_current_stay_difference',
 'Glasgow Coma Score8_current_stay_difference',
 'Glasgow Coma Score10_current_stay_difference',
 'Glasgow Coma Score11_current_stay_difference',
 'Glasgow Coma Score12_current_stay_difference',
 'Glasgow Coma Score13_current_stay_difference',
 'Glasgow Coma Score14_current_stay_difference',
 'Glasgow Coma Score15_current_stay_difference',
 'Glasgow Coma Score16_current_stay_difference',
 'Glasgow Coma Score17_current_stay_difference',
 'Glasgow Coma Score18_current_stay_difference',
 'Glasgow Coma Score19_current_stay_difference',
 'Glasgow Coma Score20_current_stay_difference',
 'Glasgow Coma Score21_current_stay_difference',
 'Heart Rate2_difference',
 'Heart Rate3_difference',
 'Heart Rate4_difference',
 'Heart Rate5_difference',
 'Heart Rate6_difference',
 'Heart Rate7_difference',
 'Heart Rate8_difference',
 'Heart Rate9_difference',
 'Heart Rate11_difference',
 'Heart Rate12_difference',
 'Heart Rate13_difference',
 'Heart Rate14_difference',
 'Heart Rate15_difference',
 'Heart Rate16_difference',
 'Heart Rate18_difference',
 'Heart Rate19_difference',
 'Heart Rate20_difference',
 'Heart Rate21_difference',
 'Heart Rate2_current_stay_difference',
 'Heart Rate3_current_stay_difference',
 'Heart Rate4_current_stay_difference',
 'Heart Rate5_current_stay_difference',
 'Heart Rate6_current_stay_difference',
 'Heart Rate8_current_stay_difference',
 'Heart Rate12_current_stay_difference',
 'Heart Rate13_current_stay_difference',
 'Heart Rate14_current_stay_difference',
 'Heart Rate16_current_stay_difference',
 'Heart Rate18_current_stay_difference',
 'Heart Rate19_current_stay_difference',
 'Heart Rate21_current_stay_difference',
 'Mean Arterial Pressure2_difference',
 'Mean Arterial Pressure3_difference',
 'Mean Arterial Pressure4_difference',
 'Mean Arterial Pressure5_difference',
 'Mean Arterial Pressure6_difference',
 'Mean Arterial Pressure7_difference',
 'Mean Arterial Pressure8_difference',
 'Mean Arterial Pressure9_difference',
 'Mean Arterial Pressure11_difference',
 'Mean Arterial Pressure12_difference',
 'Mean Arterial Pressure13_difference',
 'Mean Arterial Pressure14_difference',
 'Mean Arterial Pressure15_difference',
 'Mean Arterial Pressure16_difference',
 'Mean Arterial Pressure18_difference',
 'Mean Arterial Pressure19_difference',
 'Mean Arterial Pressure20_difference',
 'Mean Arterial Pressure21_difference',
 'Mean Arterial Pressure2_current_stay_difference',
 'Mean Arterial Pressure3_current_stay_difference',
 'Mean Arterial Pressure4_current_stay_difference',
 'Mean Arterial Pressure5_current_stay_difference',
 'Mean Arterial Pressure6_current_stay_difference',
 'Mean Arterial Pressure8_current_stay_difference',
 'Mean Arterial Pressure12_current_stay_difference',
 'Mean Arterial Pressure13_current_stay_difference',
 'Mean Arterial Pressure14_current_stay_difference',
 'Mean Arterial Pressure16_current_stay_difference',
 'Mean Arterial Pressure18_current_stay_difference',
 'Mean Arterial Pressure19_current_stay_difference',
 'Mean Arterial Pressure21_current_stay_difference',
 'NEWS Conscious Level Score0_difference',
 'NEWS Conscious Level Score1_difference',
 'NEWS Conscious Level Score2_difference',
 'NEWS Conscious Level Score3_difference',
 'NEWS Conscious Level Score4_difference',
 'NEWS Conscious Level Score5_difference',
 'NEWS Conscious Level Score6_difference',
 'NEWS Conscious Level Score7_difference',
 'NEWS Conscious Level Score8_difference',
 'NEWS Conscious Level Score9_difference',
 'NEWS Conscious Level Score10_difference',
 'NEWS Conscious Level Score11_difference',
 'NEWS Conscious Level Score12_difference',
 'NEWS Conscious Level Score13_difference',
 'NEWS Conscious Level Score14_difference',
 'NEWS Conscious Level Score15_difference',
 'NEWS Conscious Level Score16_difference',
 'NEWS Conscious Level Score17_difference',
 'NEWS Conscious Level Score18_difference',
 'NEWS Conscious Level Score19_difference',
 'NEWS Conscious Level Score20_difference',
 'NEWS Conscious Level Score21_difference',
 'NEWS Conscious Level Score0_current_stay_difference',
 'NEWS Conscious Level Score1_current_stay_difference',
 'NEWS Conscious Level Score2_current_stay_difference',
 'NEWS Conscious Level Score3_current_stay_difference',
 'NEWS Conscious Level Score4_current_stay_difference',
 'NEWS Conscious Level Score5_current_stay_difference',
 'NEWS Conscious Level Score6_current_stay_difference',
 'NEWS Conscious Level Score7_current_stay_difference',
 'NEWS Conscious Level Score8_current_stay_difference',
 'NEWS Conscious Level Score9_current_stay_difference',
 'NEWS Conscious Level Score10_current_stay_difference',
 'NEWS Conscious Level Score11_current_stay_difference',
 'NEWS Conscious Level Score12_current_stay_difference',
 'NEWS Conscious Level Score13_current_stay_difference',
 'NEWS Conscious Level Score14_current_stay_difference',
 'NEWS Conscious Level Score15_current_stay_difference',
 'NEWS Conscious Level Score16_current_stay_difference',
 'NEWS Conscious Level Score17_current_stay_difference',
 'NEWS Conscious Level Score18_current_stay_difference',
 'NEWS Conscious Level Score19_current_stay_difference',
 'NEWS Conscious Level Score20_current_stay_difference',
 'NEWS Conscious Level Score21_current_stay_difference',
 'NEWS Supplemental Oxygen Calc0_difference',
 'NEWS Supplemental Oxygen Calc1_difference',
 'NEWS Supplemental Oxygen Calc2_difference',
 'NEWS Supplemental Oxygen Calc3_difference',
 'NEWS Supplemental Oxygen Calc4_difference',
 'NEWS Supplemental Oxygen Calc5_difference',
 'NEWS Supplemental Oxygen Calc6_difference',
 'NEWS Supplemental Oxygen Calc7_difference',
 'NEWS Supplemental Oxygen Calc8_difference',
 'NEWS Supplemental Oxygen Calc9_difference',
 'NEWS Supplemental Oxygen Calc10_difference',
 'NEWS Supplemental Oxygen Calc11_difference',
 'NEWS Supplemental Oxygen Calc12_difference',
 'NEWS Supplemental Oxygen Calc13_difference',
 'NEWS Supplemental Oxygen Calc14_difference',
 'NEWS Supplemental Oxygen Calc15_difference',
 'NEWS Supplemental Oxygen Calc16_difference',
 'NEWS Supplemental Oxygen Calc17_difference',
 'NEWS Supplemental Oxygen Calc18_difference',
 'NEWS Supplemental Oxygen Calc19_difference',
 'NEWS Supplemental Oxygen Calc20_difference',
 'NEWS Supplemental Oxygen Calc21_difference',
 'NEWS Supplemental Oxygen Calc0_current_stay_difference',
 'NEWS Supplemental Oxygen Calc1_current_stay_difference',
 'NEWS Supplemental Oxygen Calc2_current_stay_difference',
 'NEWS Supplemental Oxygen Calc3_current_stay_difference',
 'NEWS Supplemental Oxygen Calc4_current_stay_difference',
 'NEWS Supplemental Oxygen Calc5_current_stay_difference',
 'NEWS Supplemental Oxygen Calc6_current_stay_difference',
 'NEWS Supplemental Oxygen Calc7_current_stay_difference',
 'NEWS Supplemental Oxygen Calc8_current_stay_difference',
 'NEWS Supplemental Oxygen Calc10_current_stay_difference',
 'NEWS Supplemental Oxygen Calc12_current_stay_difference',
 'NEWS Supplemental Oxygen Calc13_current_stay_difference',
 'NEWS Supplemental Oxygen Calc14_current_stay_difference',
 'NEWS Supplemental Oxygen Calc15_current_stay_difference',
 'NEWS Supplemental Oxygen Calc16_current_stay_difference',
 'NEWS Supplemental Oxygen Calc18_current_stay_difference',
 'NEWS Supplemental Oxygen Calc19_current_stay_difference',
 'NEWS Supplemental Oxygen Calc20_current_stay_difference',
 'NEWS Supplemental Oxygen Calc21_current_stay_difference',
 'Respiratory Rate2_difference',
 'Respiratory Rate3_difference',
 'Respiratory Rate4_difference',
 'Respiratory Rate5_difference',
 'Respiratory Rate6_difference',
 'Respiratory Rate7_difference',
 'Respiratory Rate8_difference',
 'Respiratory Rate9_difference',
 'Respiratory Rate11_difference',
 'Respiratory Rate12_difference',
 'Respiratory Rate13_difference',
 'Respiratory Rate14_difference',
 'Respiratory Rate15_difference',
 'Respiratory Rate16_difference',
 'Respiratory Rate18_difference',
 'Respiratory Rate19_difference',
 'Respiratory Rate20_difference',
 'Respiratory Rate21_difference',
 'Respiratory Rate0_current_stay_difference',
 'Respiratory Rate2_current_stay_difference',
 'Respiratory Rate3_current_stay_difference',
 'Respiratory Rate4_current_stay_difference',
 'Respiratory Rate5_current_stay_difference',
 'Respiratory Rate6_current_stay_difference',
 'Respiratory Rate8_current_stay_difference',
 'Respiratory Rate12_current_stay_difference',
 'Respiratory Rate13_current_stay_difference',
 'Respiratory Rate14_current_stay_difference',
 'Respiratory Rate16_current_stay_difference',
 'Respiratory Rate18_current_stay_difference',
 'Respiratory Rate19_current_stay_difference',
 'Respiratory Rate21_current_stay_difference',
 'SpO22_difference',
 'SpO23_difference',
 'SpO24_difference',
 'SpO25_difference',
 'SpO26_difference',
 'SpO27_difference',
 'SpO28_difference',
 'SpO29_difference',
 'SpO211_difference',
 'SpO212_difference',
 'SpO213_difference',
 'SpO214_difference',
 'SpO215_difference',
 'SpO216_difference',
 'SpO218_difference',
 'SpO219_difference',
 'SpO220_difference',
 'SpO221_difference',
 'SpO22_current_stay_difference',
 'SpO23_current_stay_difference',
 'SpO24_current_stay_difference',
 'SpO25_current_stay_difference',
 'SpO26_current_stay_difference',
 'SpO28_current_stay_difference',
 'SpO212_current_stay_difference',
 'SpO213_current_stay_difference',
 'SpO214_current_stay_difference',
 'SpO216_current_stay_difference',
 'SpO218_current_stay_difference',
 'SpO219_current_stay_difference',
 'SpO221_current_stay_difference',
 'Systolic Blood Pressure2_difference',
 'Systolic Blood Pressure3_difference',
 'Systolic Blood Pressure4_difference',
 'Systolic Blood Pressure5_difference',
 'Systolic Blood Pressure6_difference',
 'Systolic Blood Pressure7_difference',
 'Systolic Blood Pressure8_difference',
 'Systolic Blood Pressure9_difference',
 'Systolic Blood Pressure11_difference',
 'Systolic Blood Pressure12_difference',
 'Systolic Blood Pressure13_difference',
 'Systolic Blood Pressure14_difference',
 'Systolic Blood Pressure15_difference',
 'Systolic Blood Pressure16_difference',
 'Systolic Blood Pressure18_difference',
 'Systolic Blood Pressure19_difference',
 'Systolic Blood Pressure20_difference',
 'Systolic Blood Pressure21_difference',
 'Systolic Blood Pressure2_current_stay_difference',
 'Systolic Blood Pressure3_current_stay_difference',
 'Systolic Blood Pressure4_current_stay_difference',
 'Systolic Blood Pressure5_current_stay_difference',
 'Systolic Blood Pressure6_current_stay_difference',
 'Systolic Blood Pressure8_current_stay_difference',
 'Systolic Blood Pressure12_current_stay_difference',
 'Systolic Blood Pressure13_current_stay_difference',
 'Systolic Blood Pressure14_current_stay_difference',
 'Systolic Blood Pressure16_current_stay_difference',
 'Systolic Blood Pressure18_current_stay_difference',
 'Systolic Blood Pressure19_current_stay_difference',
 'Systolic Blood Pressure21_current_stay_difference',
 'Temperature2_difference',
 'Temperature3_difference',
 'Temperature4_difference',
 'Temperature5_difference',
 'Temperature6_difference',
 'Temperature7_difference',
 'Temperature8_difference',
 'Temperature9_difference',
 'Temperature11_difference',
 'Temperature12_difference',
 'Temperature13_difference',
 'Temperature14_difference',
 'Temperature15_difference',
 'Temperature16_difference',
 'Temperature18_difference',
 'Temperature19_difference',
 'Temperature20_difference',
 'Temperature21_difference',
 'Temperature2_current_stay_difference',
 'Temperature3_current_stay_difference',
 'Temperature4_current_stay_difference',
 'Temperature5_current_stay_difference',
 'Temperature6_current_stay_difference',
 'Temperature8_current_stay_difference',
 'Temperature12_current_stay_difference',
 'Temperature13_current_stay_difference',
 'Temperature14_current_stay_difference',
 'Temperature16_current_stay_difference',
 'Temperature18_current_stay_difference',
 'Temperature19_current_stay_difference',
 'Temperature21_current_stay_difference']

In [9]:
def evaluate(model, dataloader, criterion):

    # Set the model to evaluation mode
    model.eval()

    epoch_loss = 0

    batch_prediction_list = []
    batch_label_list = []

    # use the with torch.no_grad() block to ensure no gradients are calculated within the bloc
    with torch.no_grad():
        for batch_idx, sample in enumerate(tqdm(dataloader)):
            labels = sample["labels"]
            features = sample["features"]
            batch_mask = sample["mask"]
            features = [data.float() for data in features]
            features = [data.to(device=device) for data in features]
            labels = labels.float()
            labels = labels.to(device=device)
            batch_mask = batch_mask.to(device)

            # Run model
            output = model(features, batch_mask)

            # Generate loss
            loss = criterion(output, labels)

            epoch_loss += loss.item()

            # Get predictions for performance calc - masking outputs and labels
            sig = torch.nn.Sigmoid()
            output = sig(output)  
            np_predictions = output.cpu().detach().numpy()
            np_labels = labels.cpu().detach().numpy()

            np_predictions = np_predictions.squeeze()
            np_labels = np_labels.squeeze()

            np_predictions = np_predictions.flatten()
            np_labels = np_labels.flatten()
            
            # Create list
            for x in np_predictions:
                batch_prediction_list.append(x)
            for x in np_labels:
                batch_label_list.append(x)

        final_predictions = np.array(batch_prediction_list)

        final_labels = np.array(batch_label_list)

        try:
            auroc = roc_auc_score(final_labels, final_predictions)
        except:
            auroc = np.nan
        
        try:
            final_loss = epoch_loss / len(dataloader)
        except:
            final_loss = np.nan

        return final_loss, auroc, final_predictions, final_labels

In [10]:
class SetTransformer(nn.Module):
    def __init__(self, dim_input, num_outputs, dim_output,
            num_inds=36, dim_hidden=160, num_heads=4, ln=False):
        super(SetTransformer, self).__init__()
        self.enc = nn.Sequential(
                ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln),
                ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
        self.isab = ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln)
        self.pma = PMA(dim_hidden, num_heads, num_outputs, ln=ln)
        self.dec = nn.Sequential(
                #SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
                #SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
                nn.Linear(dim_hidden, dim_output))

    def forward(self, X, batch_mask):
        x = self.isab(X, batch_mask)
        x = self.pma(x, batch_mask)
        return self.dec(x)#, x

class MAB0(nn.Module):
    def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
        super(MAB0, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)
        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)
        self.fc_o = nn.Linear(dim_V, dim_V)

    def forward(self, Q, K, mask):
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)

        dim_split = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), 0)
        K_ = torch.cat(K.split(dim_split, 2), 0)
        V_ = torch.cat(V.split(dim_split, 2), 0)

        # Create new variable for softmax
        WB_ = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V)
        # Exspand mask dimensions to align
        mask = mask.unsqueeze(1).repeat(self.num_heads, Q.shape[1], 1)
        # Mask for softmax
        WB_[~mask] = float('-inf')
        
        A = torch.softmax(WB_, 2)
        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
        O = O + F.relu(self.fc_o(O))
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        return O

class MAB(nn.Module):
    def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
        super(MAB, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)
        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)
        self.fc_o = nn.Linear(dim_V, dim_V)

    def forward(self, Q, K):
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)

        dim_split = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), 0)
        K_ = torch.cat(K.split(dim_split, 2), 0)
        V_ = torch.cat(V.split(dim_split, 2), 0)

        A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
        O = O + F.relu(self.fc_o(O))
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        return O

class SAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, ln=False):
        super(SAB, self).__init__()
        self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)

    def forward(self, X):
        return self.mab(X, X)

class ISAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
        super(ISAB, self).__init__()
        self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
        nn.init.xavier_uniform_(self.I)
        self.mab0 = MAB0(dim_out, dim_in, dim_out, num_heads, ln=ln)
        self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)

    def forward(self, X, mask):
        H = self.mab0(self.I.repeat(X.size(0), 1, 1), X, mask)
        return self.mab1(X, H)

class PMA(nn.Module):
    def __init__(self, dim, num_heads, num_seeds, ln=False):
        super(PMA, self).__init__()
        self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
        nn.init.xavier_uniform_(self.S)
        self.mab = MAB0(dim, dim, dim, num_heads, ln=ln)

    def forward(self, X, mask):
        return self.mab(self.S.repeat(X.size(0), 1, 1), X, mask)

In [11]:
class Initial_vitals_model(nn.Module):
    def __init__(self, input_dim, output_dim, hid_dim, dropout):
        super().__init__()

        self.layers = nn.Sequential(
            nn.BatchNorm1d(input_dim),
            nn.Linear(input_dim, hid_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hid_dim, output_dim),
            nn.ReLU(),
            nn.Dropout(dropout))

    def forward(self, x):
        
        x1 = self.layers(x)

        return x1


class Chronic_switch_model(nn.Module):
    def __init__(self, 
    final_input_dim, 
    final_output_dim, 
    final_hid_dim, 
    final_hid_dim2,
    demographics_input_dim,
    demographics_output_dim,
    vital_input_dim, 
    vital_hid_dim, 
    vital_output_dim, 
    dropout):
        super().__init__()

        self.final_layers = nn.Sequential(
            nn.Linear(final_input_dim, final_hid_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(final_hid_dim, final_hid_dim2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(final_hid_dim2, final_output_dim),
            nn.ReLU(),
            nn.Dropout(dropout))

        self.vital_model = Initial_vitals_model(vital_input_dim, vital_output_dim, vital_hid_dim, dropout)

        self.set_transformer = SetTransformer(dim_input=128, num_outputs=1, dim_output=128, num_inds=32, dim_hidden=160, num_heads=4, ln=False)

        # Embedding for demographics (passing feature directly)
        self.demographics = nn.Linear(demographics_input_dim, demographics_output_dim)

    def forward(self, inputs: torch.Tensor, batch_mask: torch.Tensor) -> torch.Tensor:

        # Directly pass demographics feature to embedding
        demographics = self.demographics(inputs[1])

        # Pass other inputs through initial models
        vital_embeddings = self.vital_model(inputs[0])
        comorbidity_embeddings = self.set_transformer(inputs[2], batch_mask)
        comorbidity_embeddings = torch.squeeze(comorbidity_embeddings)

        # Concatenate embeddings
        concatenated_embeddings = torch.cat([demographics, vital_embeddings, comorbidity_embeddings], dim=1)

        # Pass through final layers
        output = self.final_layers(concatenated_embeddings)

        return output
    
    def latent_representation(self, inputs: torch.Tensor, batch_mask: torch.Tensor) -> torch.Tensor:

        # Directly pass demographics feature to embedding
        demographics = self.demographics(inputs[1])

        # Pass other inputs through initial models
        vital_embeddings = self.vital_model(inputs[0])
        comorbidity_embeddings = self.set_transformer(inputs[2], batch_mask)
        comorbidity_embeddings = torch.squeeze(comorbidity_embeddings)

        # Concatenate embeddings
        concatenated_embeddings = torch.cat([demographics, vital_embeddings, comorbidity_embeddings], dim=1)
        
        return concatenated_embeddings

In [12]:
class MultiInputDataset(Dataset):

    def __init__(self, dfs_list, labels, comorbidites, padding_mask):
        self.dfs_list = dfs_list
        self.labels = labels
        self.padding_mask = padding_mask
        self.comorbidites = comorbidites

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        features = [torch.tensor(df.iloc[idx].values) for df in self.dfs_list]
        labels = torch.tensor(self.labels[['po_flag']].iloc[idx].to_numpy())
        features.append(self.comorbidites[idx])
        sample = {"labels": labels, "features": features, "mask":self.padding_mask[idx]}
        return sample

In [13]:
def analyze_results_fun(cv_test_results):
    # Assign results
    test_auroc_results, test_accuracy_results,test_balanced_accuracy_results,test_recall_results,test_precision_results,test_f1_results,test_auprc_results,test_cm_results, test_tpr_results, test_fpr_results = [cv_test_results[i] for i in range(len(cv_test_results))]
    print('mean test_auroc:', np.array(test_auroc_results).mean())
    print('std test_auroc:', np.array(test_auroc_results).std())
    print('test_auroc 2.5th percentile:', max(0, np.percentile(test_auroc_results, 2.5)))
    print('test_auroc 97.5th percentile:', min(1, np.percentile(test_auroc_results, 97.5)))
    print('mean test_accuracy:', np.array(test_accuracy_results).mean())
    print('std test_accuracy:', np.array(test_accuracy_results).std())
    print('test_accuracy 2.5th percentile:', max(0, np.percentile(test_accuracy_results, 2.5)))
    print('test_accuracy 97.5th percentile:', min(1, np.percentile(test_accuracy_results, 97.5)))
    print('mean test_balanced_accuracy:', np.array(test_balanced_accuracy_results).mean())
    print('std test_balanced_accuracy:', np.array(test_balanced_accuracy_results).std())
    print('test_balanced_accuracy 2.5th percentile:', max(0, np.percentile(test_balanced_accuracy_results, 2.5)))
    print('test_balanced_accuracy 97.5th percentile:', min(1, np.percentile(test_balanced_accuracy_results, 97.5)))
    print('mean test_recall:', np.array(test_recall_results).mean())
    print('std test_recall:', np.array(test_recall_results).std())
    print('test_recall 2.5th percentile:', max(0, np.percentile(test_recall_results, 2.5)))
    print('test_recall 97.5th percentile:', min(1, np.percentile(test_recall_results, 97.5)))
    print('mean test_precision:', np.array(test_precision_results).mean())
    print('std test_precision:', np.array(test_precision_results).std())
    print('test_precision 2.5th percentile:', max(0, np.percentile(test_precision_results, 2.5)))
    print('test_precision 97.5th percentile:', min(1, np.percentile(test_precision_results, 97.5)))
    print('mean test_f1:', np.array(test_f1_results).mean())
    print('std test_f1:', np.array(test_f1_results).std())
    print('test_f1 2.5th percentile:', max(0, np.percentile(test_f1_results, 2.5)))
    print('test_f1 97.5th percentile:', min(1, np.percentile(test_f1_results, 97.5)))
    print('mean test_auprc:', np.array(test_auprc_results).mean())
    print('std test_auprc:', np.array(test_auprc_results).std())
    print('test_auprc 2.5th percentile:', max(0, np.percentile(test_auprc_results, 2.5)))
    print('test_auprc 97.5th percentile:', min(1, np.percentile(test_auprc_results, 97.5)))
    print('mean test_tpr:', np.array(test_tpr_results).mean())
    print('std test_tpr:', np.array(test_tpr_results).std())
    print('test_tpr 2.5th percentile:', max(0, np.percentile(test_tpr_results, 2.5)))
    print('test_tpr 97.5th percentile:', min(1, np.percentile(test_tpr_results, 97.5)))
    print('mean test_fpr:', np.array(test_fpr_results).mean())
    print('std test_fpr:', np.array(test_fpr_results).std())
    print('test_fpr 2.5th percentile:', max(0, np.percentile(test_fpr_results, 2.5)))
    print('test_fpr 97.5th percentile:', min(1, np.percentile(test_fpr_results, 97.5)))

In [14]:
# Load model

# Hyperparameters
final_input_dim = 268
final_output_dim = 1
final_hid_dim = 512
final_hid_dim2 = 128
demographics_input_dim = 12
demographics_output_dim = 12
vital_input_dim = 253
vital_hid_dim = 512
vital_output_dim = 128
dropout = 0.1

# Define model
model = Chronic_switch_model(
    final_input_dim, 
    final_output_dim, 
    final_hid_dim, 
    final_hid_dim2,
    demographics_input_dim,
    demographics_output_dim,
    vital_input_dim, 
    vital_hid_dim, 
    vital_output_dim, 
    dropout).to(device)

model.load_state_dict(torch.load('chronic_switch_model.pt'))

print(f'The model has {count_parameters(model):,} trainable parameters')

<All keys matched successfully>

The model has 1,126,583 trainable parameters


In [15]:
# Import
path = r'switch_data/chronic_switch_icare_df_preprocessed_2023.csv'
icare_df_preprocessed = pd.read_csv(path)

# Import
path = r'switch_data/chronic_switch_episodes_2023.csv'
episodes = pd.read_csv(path)

# Import
path = r'switch_data/chronic_switch_disease_2023.csv'
disease = pd.read_csv(path)

# Import
path = r'switch_data/chronic_switch_demographics_2023.csv'
demographics = pd.read_csv(path)

# Import
path = r'switch_data/snomed_embedding_128d-copy.csv'
embedding = pd.read_csv(path)

# Import
path = r'switch_data/chronic_switch_problem_dummies_2023.csv'
problem_dummies = pd.read_csv(path)

In [16]:
# Merge
problem_dummies2 = pd.merge(problem_dummies, episodes[['SUBJECT', 'SPELL_IDENTIFIER']])
problem_dummies2 = pd.merge(icare_df_preprocessed[['SPELL_IDENTIFIER', 'date']], problem_dummies2)

# Strip name 
problem_dummies2.columns = problem_dummies2.columns.str.removeprefix('PROBLEM_')

# Convert the date columns to datetime objects if they are not already
problem_dummies2['date'] = pd.to_datetime(problem_dummies2['date'])
problem_dummies2['DT_TM'] = pd.to_datetime(problem_dummies2['DT_TM'])

# Calculate the absolute time difference between 'date' and 'DT_TM'
problem_dummies2['time_diff'] = (problem_dummies2['date'] - problem_dummies2['DT_TM']).abs()

# Filter rows where 'time_diff' is not negative
problem_dummies2 = problem_dummies2[problem_dummies2['time_diff'] >= pd.Timedelta(0)]

# Sort the DataFrame by 'SPELL_IDENTIFIER' and 'time_diff'
problem_dummies2.sort_values(by=['SPELL_IDENTIFIER', 'time_diff'], inplace=True)

# Convert to str
problem_dummies2['date'] = problem_dummies2['date'].astype(str)

# Keep only the rows with the smallest time difference for each 'SPELL_IDENTIFIER'
problem_dummies2 = problem_dummies2.groupby(['SPELL_IDENTIFIER', 'date']).first().reset_index()

### REASON the problem_dummies2 is shorter is not because it is missing some dates! 
### It is because in the final data we have some dates repeated for a spesfic spell 
### if they were admited in between 6am and 12pm the 12hour prediction is done at 6am 
### of the first day, then 48 done at 6am the next day, then the next prediction done 
### at 12pm that day causing there to be two prediction for that day...phew
### So just get a set of co-morbidities for each spell and merge 

# Drop the 'time_diff' column as it's no longer needed
problem_dummies2.drop(columns=['time_diff', 'SUBJECT', 'DT_TM', 'new_subject', 'date'], inplace=True)

# Drop duplicates
problem_dummies2.drop_duplicates(inplace=True)

# Some still got through by having their co-morbidid diagnosis updated during their stay 
# In this case we just use the frst one throughout and remove the others 
# Drop duplicates
problem_dummies2.drop_duplicates(subset=['SPELL_IDENTIFIER'], keep='first', inplace=True)

# Filter for features
X_data = icare_df_preprocessed.drop(columns=['SPELL_IDENTIFIER', 'po_flag'])
X_data = X_data.drop(columns=columns_to_drop)
model_data = pd.concat([icare_df_preprocessed[['SPELL_IDENTIFIER', 'po_flag']], X_data], axis=1)
# Merge
demographics = pd.merge(demographics, episodes[['SUBJECT', 'SPELL_IDENTIFIER']])
demographics.drop(columns=['SUBJECT'], inplace=True)
model_data = pd.merge(model_data, demographics, how='left')
model_data = pd.merge(model_data, disease, how='left')
# Drop 
model_data = model_data.drop(columns=['date', 'ROUTE', '24_hour_flag', '48_hour_flag', 'iv_treatment_length'])
# fillna
model_data['AGE'] = model_data['AGE'].fillna(-1)
model_data['IMDDECIL'] = model_data['IMDDECIL'].fillna(-1)
model_data = model_data.fillna(0)
# Merge co-morbidites
model_data = pd.merge(model_data, problem_dummies2, how='left')
# Rename
model_data.rename(columns={'SPELL_IDENTIFIER': 'stay_id'}, inplace=True)
# Random shuffle
stays = model_data['stay_id'].unique()
random.Random(5).shuffle(stays)
model_data = model_data.set_index("stay_id").loc[stays].reset_index()

  problem_dummies2['time_diff'] = (problem_dummies2['date'] - problem_dummies2['DT_TM']).abs()


In [17]:
# Split up dfs
vitals_data = model_data.iloc[:,2:255]
demographics_data = model_data.iloc[:,255:267]
comorbidity_data = model_data.iloc[:, 267:]

# Get labels
labels = model_data[['po_flag']]

# Preprocess comorbidity data
print('Working on set_transformer_processing_fun...')
comorbidity_data, comorbidity_mask = set_transformer_processing_fun(comorbidity_data, embedding)
print('Done!')

# Loss
criterion = nn.BCEWithLogitsLoss()

# Define dataloaders
batch_size=512
dataset =  MultiInputDataset([vitals_data, demographics_data], labels, comorbidity_data, comorbidity_mask)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size)

Working on set_transformer_processing_fun...
Done!


In [18]:
def final_threshold_fun(predictions, bound=0.7373848): # Set threshold
    new_predictions = [1 if a_ >= bound else 0 for a_ in predictions]
    return new_predictions

In [21]:
loss, auroc, test_predictions, labels_out = evaluate(model, dataloader, criterion)

new_predictions = final_threshold_fun(test_predictions)

print('AUROC result:', auroc)

# Lower bound
print('new_predictions:')
print('AUROC:', roc_auc_score(labels_out, new_predictions))
print('Accuracy:', accuracy_score(labels_out, new_predictions))
print('balanced_accuracy_score:', balanced_accuracy_score(labels_out, new_predictions))
print('Recall:', recall_score(labels_out, new_predictions))
print('Precision:', precision_score(labels_out, new_predictions))
print('F1:', f1_score(labels_out, new_predictions))
print('AUPRC:', average_precision_score(labels_out, new_predictions))
test_cm = confusion_matrix(labels_out, new_predictions)
print('CM:', test_cm)
tn, fp, fn, tp = test_cm.ravel()
test_true_positive_rate = (tp / (tp + fn))
print('TPR:', test_true_positive_rate)
test_false_positive_rate = (fp / (fp + tn))
print('FPR:', test_false_positive_rate)

100%|██████████| 6/6 [00:03<00:00,  1.82it/s]


AUROC result: 0.772141651875208
new_predictions:
AUROC: 0.7032650887567642
Accuracy: 0.6985401459854015
balanced_accuracy_score: 0.7032650887567643
Recall: 0.6094510076441974
Precision: 0.768624014022787
F1: 0.6798449612403101
AUPRC: 0.6735481688968101
CM: [[1037  264]
 [ 562  877]]
TPR: 0.6094510076441974
FPR: 0.20292083013066872


# Early, late, agree

In [19]:
loss, auroc, test_predictions, labels_out = evaluate(model, dataloader, criterion)

new_predictions = final_threshold_fun(test_predictions)

print('AUROC result:', auroc)

100%|██████████| 6/6 [00:04<00:00,  1.26it/s]


AUROC result: 0.772141651875208


In [26]:
# Merge
problem_dummies2 = pd.merge(problem_dummies, episodes[['SUBJECT', 'SPELL_IDENTIFIER']])
problem_dummies2 = pd.merge(icare_df_preprocessed[['SPELL_IDENTIFIER', 'date']], problem_dummies2)

# Strip name 
problem_dummies2.columns = problem_dummies2.columns.str.removeprefix('PROBLEM_')

# Convert the date columns to datetime objects if they are not already
problem_dummies2['date'] = pd.to_datetime(problem_dummies2['date'])
problem_dummies2['DT_TM'] = pd.to_datetime(problem_dummies2['DT_TM'])

# Calculate the absolute time difference between 'date' and 'DT_TM'
problem_dummies2['time_diff'] = (problem_dummies2['date'] - problem_dummies2['DT_TM']).abs()

# Filter rows where 'time_diff' is not negative
problem_dummies2 = problem_dummies2[problem_dummies2['time_diff'] >= pd.Timedelta(0)]

# Sort the DataFrame by 'SPELL_IDENTIFIER' and 'time_diff'
problem_dummies2.sort_values(by=['SPELL_IDENTIFIER', 'time_diff'], inplace=True)

# Convert to str
problem_dummies2['date'] = problem_dummies2['date'].astype(str)

# Keep only the rows with the smallest time difference for each 'SPELL_IDENTIFIER'
problem_dummies2 = problem_dummies2.groupby(['SPELL_IDENTIFIER', 'date']).first().reset_index()

### REASON the problem_dummies2 is shorter is not because it is missing some dates! 
### It is because in the final data we have some dates repeated for a spesfic spell 
### if they were admited in between 6am and 12pm the 12hour prediction is done at 6am 
### of the first day, then 48 done at 6am the next day, then the next prediction done 
### at 12pm that day causing there to be two prediction for that day...phew
### So just get a set of co-morbidities for each spell and merge 

# Drop the 'time_diff' column as it's no longer needed
problem_dummies2.drop(columns=['time_diff', 'SUBJECT', 'DT_TM', 'new_subject', 'date'], inplace=True)

# Drop duplicates
problem_dummies2.drop_duplicates(inplace=True)

# Some still got through by having their co-morbidid diagnosis updated during their stay 
# In this case we just use the frst one throughout and remove the others 
# Drop duplicates
problem_dummies2.drop_duplicates(subset=['SPELL_IDENTIFIER'], keep='first', inplace=True)

# Filter for features
X_data = icare_df_preprocessed.drop(columns=['SPELL_IDENTIFIER', 'po_flag'])
X_data = X_data.drop(columns=columns_to_drop)
model_data = pd.concat([icare_df_preprocessed[['SPELL_IDENTIFIER', 'po_flag']], X_data], axis=1)
# Merge
demographics = pd.merge(demographics, episodes[['SUBJECT', 'SPELL_IDENTIFIER']])
demographics.drop(columns=['SUBJECT'], inplace=True)
model_data = pd.merge(model_data, demographics, how='left')
model_data = pd.merge(model_data, disease, how='left')
# Drop 
model_data = model_data.drop(columns=['ROUTE', '24_hour_flag', '48_hour_flag'])
# fillna
model_data['AGE'] = model_data['AGE'].fillna(-1)
model_data['IMDDECIL'] = model_data['IMDDECIL'].fillna(-1)
model_data = model_data.fillna(0)
# Merge co-morbidites
model_data = pd.merge(model_data, problem_dummies2, how='left')
# Rename
model_data.rename(columns={'SPELL_IDENTIFIER': 'stay_id'}, inplace=True)
# Random shuffle
stays = model_data['stay_id'].unique()
random.Random(5).shuffle(stays)
model_data = model_data.set_index("stay_id").loc[stays].reset_index()

  problem_dummies2['time_diff'] = (problem_dummies2['date'] - problem_dummies2['DT_TM']).abs()


In [21]:
# For only having one positive switch day per stay
def lb_predicted_switch_day_fun(data):
    # Convert to datetime
    data['date'] = pd.to_datetime(data['date'])

    # iv_treatment_length
    cumcount = []
    count = 0
    pos = -1
    flag = 0

    for x in range(len(data)):
        pos += 1
        if pos == len(data) - 1:
            cumcount.append(count) # add count to last one
            break # end
        elif pos == 0:
            cumcount.append(count) # add 0 to first one
            count += 1
        elif data.iloc[x]['stay_id'] == data.iloc[x+1]['stay_id']:
            if data.iloc[x]['lb_prediction'] == 0:
                cumcount.append(count)
                count += 1
            elif flag == 1:
                cumcount.append(999)
                count = 0
                flag = 1
            elif data.iloc[x]['stay_id'] != data.iloc[x-1]['stay_id']:
                if data.iloc[x]['lb_prediction'] == 1:
                    cumcount.append(count)
                    count += 1
                else:
                    cumcount.append(999)
                    count = 0
            else:
                cumcount.append(count)
                count = 0
                flag = 1
        else:
            if data.iloc[x]['lb_prediction'] == 0:
                cumcount.append(count)
                count = 0
                flag = 0
            elif flag == 1:
                cumcount.append(999)
                count = 0
                flag = 0
            else:
                cumcount.append(count)
                count = 0
                flag = 0

    print(len(cumcount))

    data['lb_predicted_switch_day'] = cumcount
    
    return data

In [None]:
# Filter for those who switch
test_stay_id_list = (model_data.groupby(['stay_id'])['po_flag'].nunique() > 1).where(lambda x : x==True).dropna().reset_index()['stay_id'].unique().tolist()
filtered_test_data = model_data[model_data['stay_id'].isin(test_stay_id_list)]


# Find the day they actually switched
test_switch_day = filtered_test_data[filtered_test_data['po_flag'] == 1].drop_duplicates(subset=['stay_id'], keep='first')
test_switch_day = test_switch_day[['stay_id', 'iv_treatment_length']]
test_switch_day.rename(columns={'iv_treatment_length': 'real_switch_day'}, inplace=True)
test_switch_day.reset_index(drop=True, inplace=True)

# Find day we predict they could switch
filtered_test_data.reset_index(inplace=True, drop=True)
filtered_test_data2 = filtered_test_data.drop(columns=['date', 'iv_treatment_length'])

# Get predictions
vitals_test_data = filtered_test_data2.iloc[:,2:255]
demographics_test_data = filtered_test_data2.iloc[:,255:267]
comorbidity_test_data = filtered_test_data2.iloc[:, 267:]
comorbidity_test_data, comorbidity_test_mask = set_transformer_processing_fun(comorbidity_test_data, embedding)
test_labels = filtered_test_data2[['po_flag']]

test_dataset = MultiInputDataset([vitals_test_data, demographics_test_data], test_labels, comorbidity_test_data, comorbidity_test_mask)
temp_test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size)#, collate_fn=test_dataset.collate_fn_padd)


test_loss, test_auroc, test_predictions, test_labels_out = evaluate(model, temp_test_dataloader, criterion)
new_test_predictions = final_threshold_fun(test_predictions)

filtered_test_data['lb_prediction'] = new_test_predictions

# Find the day we predict they switched
filtered_test_data = lb_predicted_switch_day_fun(filtered_test_data)

test_lb_predicted_switch_day = filtered_test_data[filtered_test_data['lb_prediction'] == 1].drop_duplicates(subset=['stay_id'], keep='first')
test_lb_predicted_switch_day = test_lb_predicted_switch_day[['stay_id', 'lb_predicted_switch_day']]
test_lb_predicted_switch_day.reset_index(drop=True, inplace=True)
test_lb_predicted_switch_day

# Merge and work out difference
test_switch_data = pd.merge(test_switch_day, test_lb_predicted_switch_day)
test_switch_data['lb_difference'] = test_switch_data['lb_predicted_switch_day'] - test_switch_data['real_switch_day'] #- test_switch_data['lb_predicted_switch_day']

lb_percentage_agree = len(test_switch_data[test_switch_data['lb_difference'] == 0])/len(test_switch_data)
lb_percentage_early = len(test_switch_data[test_switch_data['lb_difference'] < 0])/len(test_switch_data)
lb_percentage_late = len(test_switch_data[test_switch_data['lb_difference'] > 0])/len(test_switch_data)

lb_percentage_agree
lb_percentage_late
lb_percentage_early

# Re-train and run

In [22]:
# Train
def train(model, dataloader, optimizer, criterion, clip):
    model.train()

    epoch_loss = 0

    batch_prediction_list = []
    batch_label_list = []

    for batch_idx, sample in enumerate(tqdm(dataloader)):
        labels = sample["labels"]
        features = sample["features"]
        batch_mask = sample["mask"]
        features = [data.float() for data in features]
        features = [data.to(device=device) for data in features]
        labels = labels.float()
        labels = labels.to(device=device)
        batch_mask = batch_mask.to(device)

        # zero the gradients calculated from the last batch
        optimizer.zero_grad()

        # Run model
        output = model(features, batch_mask)
        
        # Generate loss
        loss = criterion(output, labels)

        # calculate the gradients
        loss.backward()

        # update the parameters of our model by doing an optimizer step
        optimizer.step()

        epoch_loss += loss.item()

        # Get predictions for performance calc - masking outputs and labels
        sig = torch.nn.Sigmoid()
        output = sig(output)      
        np_predictions = output.cpu().detach().numpy()
        np_labels = labels.cpu().detach().numpy()

        np_predictions = np_predictions.squeeze()
        np_labels = np_labels.squeeze()

        np_predictions = np_predictions.flatten()
        np_labels = np_labels.flatten()
        
        # Create list
        for x in np_predictions:
            batch_prediction_list.append(x)
        for x in np_labels:
            batch_label_list.append(x)

    final_predictions = np.array(batch_prediction_list)

    final_labels = np.array(batch_label_list)

    try:
        auroc = roc_auc_score(final_labels, final_predictions)
    except:
        auroc = np.nan
    
    try:
        final_loss = epoch_loss / len(dataloader)
    except:
        final_loss = np.nan

    return final_loss, auroc, final_predictions, final_labels

In [23]:
# Hyperparameters
final_input_dim = 268
final_output_dim = 1
final_hid_dim = 512
final_hid_dim2 = 128
demographics_input_dim = 12
demographics_output_dim = 12
vital_input_dim = 253
vital_hid_dim = 512
vital_output_dim = 128
dropout = 0.1

# Define model
model = Chronic_switch_model(
    final_input_dim, 
    final_output_dim, 
    final_hid_dim, 
    final_hid_dim2,
    demographics_input_dim,
    demographics_output_dim,
    vital_input_dim, 
    vital_hid_dim, 
    vital_output_dim, 
    dropout).to(device)

model.apply(init_weights)

print(f'The model has {count_parameters(model):,} trainable parameters')

  torch.nn.init.xavier_uniform(m.weight)


Chronic_switch_model(
  (final_layers): Sequential(
    (0): Linear(in_features=268, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=512, out_features=128, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=128, out_features=1, bias=True)
    (7): ReLU()
    (8): Dropout(p=0.1, inplace=False)
  )
  (vital_model): Initial_vitals_model(
    (layers): Sequential(
      (0): BatchNorm1d(253, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): Linear(in_features=253, out_features=512, bias=True)
      (2): ReLU()
      (3): Dropout(p=0.1, inplace=False)
      (4): Linear(in_features=512, out_features=128, bias=True)
      (5): ReLU()
      (6): Dropout(p=0.1, inplace=False)
    )
  )
  (set_transformer): SetTransformer(
    (enc): Sequential(
      (0): ISAB(
        (mab0): MAB0(
          (fc_q): Linear(in_features=160, out_features=160, bias=True)
          (f

The model has 1,126,583 trainable parameters


In [24]:
class StratifiedKFold3(StratifiedKFold):

    def split(self, X, y, groups=None):
        s = super().split(X, y, groups)
        for train_indxs, test_indxs in s:
            y_train = y[train_indxs]
            train_indxs, cv_indxs = train_test_split(train_indxs,stratify=y_train, test_size=(1 / (self.n_splits - 1)), random_state=0)
            yield train_indxs, cv_indxs, test_indxs

In [25]:
# Function to split data so even distribution between val and test
def cv_data_fun(data, n_cv=10):
    X = data.iloc[:, 2:]
    y = data['po_flag']
    g = StratifiedKFold3(n_cv).split(X,y)
    return g

In [26]:
def new_threshold_fun(predictions, bound=0.5):
    new_predictions = [1 if a_ >= bound else 0 for a_ in predictions]
    return new_predictions

In [27]:
# Function to train and eval model 
def cv_run_2023_fun(data, model):

    overall_best_test_auroc = 0

    actual_test_auroc_results = []

    test_auroc_results = []
    test_accuracy_results = []
    test_balanced_accuracy_results = []
    test_recall_results = []
    test_precision_results = []
    test_f1_results = []
    test_auprc_results = []
    test_cm_results = []
    test_true_positive_rate_results = []
    test_fasle_positive_rate_results = []

    final_threshold = 0

    # Define batch size 
    batch_size = 512

    # Define optimizer and learning_rate
    learning_rate = 0.0001
    optimizer = optim.RMSprop(model.parameters(), lr=learning_rate)

    # Define loss
    criterion = nn.BCEWithLogitsLoss()

    # Define epochs and clip
    N_EPOCHS = 10
    CLIP = 1

    # Split into folds
    split_generator = cv_data_fun(data)

    # Iterate through folds
    for x in range(N_EPOCHS): # Note this only works as number of pslits and epocs are the same
        train_idx, val_idx, test_idx = next(split_generator)

        # Get train val and test
        train_data = data.loc[train_idx]
        valid_data = data.loc[val_idx]
        test_data = data.loc[test_idx]

        #Apply smote - crashes with these features
        #train_data = smote_fun(train_data)

        # Split up dfs
        vitals_train_data = train_data.iloc[:,2:255]
        demographics_train_data = train_data.iloc[:,255:267]
        comorbidity_train_data = train_data.iloc[:, 267:]

        vitals_valid_data = valid_data.iloc[:,2:255]
        demographics_valid_data = valid_data.iloc[:,255:267]
        comorbidity_valid_data = valid_data.iloc[:, 267:]

        vitals_test_data = test_data.iloc[:,2:255]
        demographics_test_data = test_data.iloc[:,255:267]
        comorbidity_test_data = test_data.iloc[:, 267:]

        # Initializing the weights of our model each fold
        model.apply(init_weights)

        # Get labels
        train_labels = train_data[['po_flag']]
        valid_labels = valid_data[['po_flag']]
        test_labels = test_data[['po_flag']]

        # Preprocess comorbidity data
        print('Working on set_transformer_processing_fun...')
        comorbidity_train_data, comorbidity_train_mask = set_transformer_processing_fun(comorbidity_train_data, embedding)
        comorbidity_valid_data, comorbidity_valid_mask = set_transformer_processing_fun(comorbidity_valid_data, embedding)
        comorbidity_test_data, comorbidity_test_mask = set_transformer_processing_fun(comorbidity_test_data, embedding)
        print('Done!')

        # Define dataloaders
        train_dataset =  MultiInputDataset([vitals_train_data, demographics_train_data], train_labels, comorbidity_train_data, comorbidity_train_mask)
        train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size)

        valid_dataset = MultiInputDataset([vitals_valid_data, demographics_valid_data], valid_labels, comorbidity_valid_data, comorbidity_valid_mask)
        valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=batch_size)

        test_dataset = MultiInputDataset([vitals_test_data, demographics_test_data], test_labels, comorbidity_test_data, comorbidity_test_mask)
        test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size)

        # Run
        best_valid_loss = float('inf')
        best_valid_auroc = 0

        optimal_threshold = 0

        for epoch in range(N_EPOCHS):

            start_time = time.time()

            train_loss, train_auroc, train_predictions, train_labels_out = train(model, train_dataloader, optimizer, criterion, CLIP)

            valid_loss, valid_auroc, valid_predictions, valid_labels_out = evaluate(model, valid_dataloader, criterion)

            end_time = time.time()
            
            fpr, tpr, thresholds = roc_curve(valid_labels_out, valid_predictions)
            optimal_idx = np.argmax(tpr - fpr)
            current_threshold = thresholds[optimal_idx]

            epoch_mins, epoch_secs = epoch_time(start_time, end_time)

            print('Train AUROC:', train_auroc)
            print('Valid AUROC:', valid_auroc)
            print(train_predictions)
            print(train_labels_out)
            print('Train loss:', train_loss)
            print('Valid loss:', valid_loss)
            print(current_threshold)

            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                print('BEST VALID LOSS')

            if valid_auroc > best_valid_auroc:
                best_valid_auroc = valid_auroc
                print('UPDATED BEST INTERMEDIATE MODEL')
                torch.save(model.state_dict(), f'chronic_switch_model_intermediate_2023.pt')
                optimal_threshold = current_threshold

        # -----------------------------
        # Evaluate best model on test set
        # -----------------------------

        model.load_state_dict(torch.load(f'chronic_switch_model_intermediate_2023.pt'))

        test_loss, test_auroc, test_predictions, test_labels_out = evaluate(model, test_dataloader, criterion)

        print('Test AUROC result:', test_auroc)
        
        new_test_predictions = new_threshold_fun(test_predictions, optimal_threshold)

        test_accuracy = accuracy_score(test_labels_out, new_test_predictions)
        test_balanced_accuracy = balanced_accuracy_score(test_labels_out, new_test_predictions)
        test_recall = recall_score(test_labels_out, new_test_predictions)
        test_precision = precision_score(test_labels_out, new_test_predictions)
        test_f1 = f1_score(test_labels_out, new_test_predictions)
        test_auprc = average_precision_score(test_labels_out, test_predictions)
        test_cm = confusion_matrix(test_labels_out, new_test_predictions)
        tn, fp, fn, tp = test_cm.ravel()
        test_true_positive_rate = (tp / (tp + fn))
        test_false_positive_rate = (fp / (fp + tn))

        if test_auroc > overall_best_test_auroc:
            overall_best_test_auroc = test_auroc
            print('UPDATED BEST OVERALL MODEL')
            torch.save(model.state_dict(), f'chronic_switch_model_2023.pt') # Hastag out when dont want to change
            final_threshold = optimal_threshold
        
        actual_test_auroc_results.append(test_auroc)
        
        test_auroc_results.append(test_auroc)
        test_accuracy_results.append(test_accuracy)
        test_balanced_accuracy_results.append(test_balanced_accuracy)
        test_recall_results.append(test_recall)
        test_precision_results.append(test_precision)
        test_f1_results.append(test_f1)
        test_auprc_results.append(test_auprc)
        test_cm_results.append(test_cm)
        test_true_positive_rate_results.append(test_true_positive_rate)
        test_fasle_positive_rate_results.append(test_false_positive_rate)

    test_results = [test_auroc_results, test_accuracy_results,
        test_balanced_accuracy_results,
        test_recall_results,
        test_precision_results,
        test_f1_results,
        test_auprc_results,
        test_cm_results,
        test_true_positive_rate_results,
        test_fasle_positive_rate_results
        ]
    
    return test_results, actual_test_auroc_results, final_threshold

In [28]:
# Run model
test_results, actual_test_auroc_results, final_threshold = cv_run_2023_fun(model_data, model)

  torch.nn.init.xavier_uniform(m.weight)
100%|██████████| 5/5 [00:04<00:00,  1.25it/s]
100%|██████████| 1/1 [00:00<00:00,  2.98it/s]
100%|██████████| 5/5 [00:03<00:00,  1.31it/s]
100%|██████████| 1/1 [00:00<00:00,  3.21it/s]
100%|██████████| 5/5 [00:03<00:00,  1.32it/s]
100%|██████████| 1/1 [00:00<00:00,  3.20it/s]
100%|██████████| 5/5 [00:03<00:00,  1.33it/s]
100%|██████████| 1/1 [00:00<00:00,  3.03it/s]
100%|██████████| 5/5 [00:03<00:00,  1.33it/s]
100%|██████████| 1/1 [00:00<00:00,  3.17it/s]
100%|██████████| 5/5 [00:03<00:00,  1.31it/s]
100%|██████████| 1/1 [00:00<00:00,  3.16it/s]
100%|██████████| 5/5 [00:03<00:00,  1.31it/s]
100%|██████████| 1/1 [00:00<00:00,  3.14it/s]
100%|██████████| 5/5 [00:03<00:00,  1.33it/s]
100%|██████████| 1/1 [00:00<00:00,  2.80it/s]
100%|██████████| 5/5 [00:03<00:00,  1.31it/s]
100%|██████████| 1/1 [00:00<00:00,  2.82it/s]
100%|██████████| 5/5 [00:03<00:00,  1.33it/s]
100%|██████████| 1/1 [00:00<00:00,  3.05it/s]
100%|██████████| 1/1 [00:00<00:00,  2.8

Working on set_transformer_processing_fun...
Done!
Train AUROC: 0.522790606839811
Valid AUROC: 0.5340277777777778
[0.5433474 0.5232676 0.6712101 ... 0.5       0.5       0.5      ]
[0. 1. 1. ... 0. 0. 0.]
Train loss: 0.7142103910446167
Valid loss: 0.6916259527206421
0.50500834
BEST VALID LOSS
UPDATED BEST INTERMEDIATE MODEL
Train AUROC: 0.663797341158463
Valid AUROC: 0.7182692307692309
[0.5       0.5       0.5628729 ... 0.5       0.5       0.5      ]
[0. 1. 1. ... 0. 0. 0.]
Train loss: 0.6493540406227112
Valid loss: 0.650580108165741
0.55437297
BEST VALID LOSS
UPDATED BEST INTERMEDIATE MODEL
Train AUROC: 0.7356231185178322
Valid AUROC: 0.7347222222222223
[0.5709091 0.6438756 0.7150964 ... 0.5       0.5       0.5      ]
[0. 1. 1. ... 0. 0. 0.]
Train loss: 0.6173887610435486
Valid loss: 0.636620819568634
0.58532774
BEST VALID LOSS
UPDATED BEST INTERMEDIATE MODEL
Train AUROC: 0.754335911386415
Valid AUROC: 0.7512019230769231
[0.5        0.61641073 0.7705373  ... 0.5        0.5        0.5  

In [29]:
# Save 
with open("chronic_switch_test_results_2023_2", "wb") as fp:   #Pickling
    pickle.dump(test_results, fp)

In [30]:
analyze_results_fun(test_results)

mean test_auroc: 0.7562443559938788
std test_auroc: 0.031113337711807052
test_auroc 2.5th percentile: 0.70499265491453
test_auroc 97.5th percentile: 0.8044350961538461
mean test_accuracy: 0.7036496350364964
std test_accuracy: 0.01895000727846303
test_accuracy 2.5th percentile: 0.6699817518248175
test_accuracy 97.5th percentile: 0.7359489051094891
mean test_balanced_accuracy: 0.7070044025101276
std test_balanced_accuracy: 0.019052215906055027
test_balanced_accuracy 2.5th percentile: 0.6734041132478633
test_balanced_accuracy 97.5th percentile: 0.7389129273504274
mean test_recall: 0.6407439782439782
std test_recall: 0.05060912301847776
test_recall 2.5th percentile: 0.56875
test_recall 97.5th percentile: 0.7121527777777777
mean test_precision: 0.7610007716887507
std test_precision: 0.03852150103709388
test_precision 2.5th percentile: 0.7123214285714285
test_precision 97.5th percentile: 0.8368764302059497
mean test_f1: 0.6934708223793875
std test_f1: 0.024892712631346316
test_f1 2.5th perce