In [1]:
import sys
sys.path.append('..')
sys.path.append('../ehrshot')
import copy
from typing import Literal
import argparse
import pandas as pd
import numpy as np
import os

import torch
from torch import nn
from torch.distributions import Distribution
from torch_uncertainty.utils.distributions import cat_dist
from torch_uncertainty.routines import ClassificationRoutine
# from torch_uncertainty.utils import TUTrainer
from lightning.pytorch import Trainer
from torch_uncertainty.models import deep_ensembles, mc_dropout
from torch_uncertainty.transforms import RepeatTarget
import torchvision.transforms as T

from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import roc_auc_score
from tqdm import tqdm

from lightning.pytorch import LightningModule

from torch_uncertainty.metrics.classification import BrierScore, CategoricalNLL
from torch_uncertainty.metrics.classification.adaptive_calibration_error import BinaryAdaptiveCalibrationError

from typing import List, Tuple
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import os

In [2]:
results_dict = {}

unique_tasks_1 = ['value_los', 'value_icu']
unique_tasks_2 = ['value_hypoglycemia', 'value_hyperkalemia', 'value_hyponatremia', 'value_anemia', 'value_thrombocytopenia']
unique_tasks_3 = ['value_new_hypertension', 'value_new_hyperlipidemia', 'value_new_acutemi']


all_tasks = [unique_tasks_1, unique_tasks_2, unique_tasks_3]
all_tasks_name = ['general_operation_v1', 'lab_test', 'new_diagnose']

In [3]:
class MultiTaskModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MultiTaskModel, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.dropout1 = nn.Dropout(p=0.2)

        self.fc21 = nn.Linear(hidden_size, hidden_size)
        self.dropout21 = nn.Dropout(p=0.2)
        self.fc31 = nn.Linear(hidden_size, num_classes)

        self.fc22 = nn.Linear(hidden_size, hidden_size)
        self.dropout22 = nn.Dropout(p=0.2)
        self.fc32 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)

        x1 = F.relu(self.fc21(x))
        x1 = self.dropout21(x1)
        x1 = self.fc31(x1)

        x2 = F.relu(self.fc22(x))
        x2 = self.dropout22(x2)
        x2 = self.fc32(x2)

        return x1, x2

In [4]:
def train_and_validate(model, train_loader, val_loader, weight_list, num_epochs=100, lr_mult=1):
    criterion1 = nn.CrossEntropyLoss(weight = weight_list[0])
    criterion2 = nn.CrossEntropyLoss(weight = weight_list[1])
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01 * lr_mult, momentum=0.9)
    best_val_loss = float('inf')
    best_model = None

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        for data, target1, target2, in train_loader:
            optimizer.zero_grad()
            outputs = model(data)
            loss1 = criterion1(outputs[0], target1)
            loss2 = criterion2(outputs[1], target2)
            total_loss = loss1 + loss2
            total_loss.backward()
            optimizer.step()

        # Validation phase
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for data, target1, target2 in val_loader:
                outputs = model(data)
                loss1 = criterion1(outputs[0], target1)
                loss2 = criterion2(outputs[1], target2)
                total_val_loss += loss1.item() + loss2.item()
        
        average_val_loss = total_val_loss / len(val_loader)
        if average_val_loss < best_val_loss:
            best_val_loss = average_val_loss
            best_model = model.state_dict()
        
        if epoch % 2 == 0:
            print(f'Epoch {epoch+1}: Avg Val Loss: {average_val_loss:.4f}')

    return model.state_dict()

In [5]:
max_epochs = 100
batch_size = 64

# for i in tqdm(range(len(all_tasks))):
# for i in [0]:
i = 0
general_task_name = all_tasks_name[i]

folder_path = f'same_time_data/{general_task_name}'

train_x_name = os.path.join(folder_path, 'x_train.csv')
train_y_name = os.path.join(folder_path, 'y_train.csv')
val_x_name = os.path.join(folder_path, 'x_val.csv')
val_y_name = os.path.join(folder_path, 'y_val.csv')
test_x_name = os.path.join(folder_path, 'x_test.csv')
test_y_name = os.path.join(folder_path, 'y_test.csv')

X_train = pd.read_csv(train_x_name).to_numpy()
X_val = pd.read_csv(val_x_name).to_numpy()
X_test = pd.read_csv(test_x_name).to_numpy()

X_train = torch.tensor(X_train).float()
X_val = torch.tensor(X_val).float()
X_test = torch.tensor(X_test).float()

y_train_list = []
y_val_list = []
y_test_list = []

class_weights_list = []

for j in tqdm(range(len(all_tasks[i]))):
# for j in range(1):
    specific_task_name = all_tasks[i][j]
    y_train = pd.read_csv(train_y_name)[specific_task_name].astype(int).to_numpy()
    y_val = pd.read_csv(val_y_name)[specific_task_name].astype(int).to_numpy()
    y_test = pd.read_csv(test_y_name)[specific_task_name].astype(int).to_numpy()

    assert len(np.unique(y_train)) == 2
    # Create class weights
    class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
    class_weights = torch.tensor(class_weights, dtype=torch.float)
    class_weights_list.append(class_weights)
    
    y_train = torch.tensor(y_train).long().T
    y_val = torch.tensor(y_val).long().T
    y_test = torch.tensor(y_test).long().T

    # Create TensorDatasets
    y_train_list.append(y_train)
    y_val_list.append(y_val)
    y_test_list.append(y_test)

train_dataset = TensorDataset(X_train, y_train_list[0], y_train_list[1])
val_dataset = TensorDataset(X_val, y_val_list[0], y_val_list[1])
test_dataset = TensorDataset(X_test, y_test_list[0], y_test_list[1])

train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

model = MultiTaskModel(X_train.shape[1], 128, 2)

best_model = train_and_validate(model, train_dl, val_dl, class_weights_list, num_epochs=max_epochs, lr_mult=1)

  y_train = torch.tensor(y_train).long().T
100%|██████████| 2/2 [00:00<00:00, 112.10it/s]


Epoch 1: Avg Val Loss: 1.0723
Epoch 3: Avg Val Loss: 1.0211
Epoch 5: Avg Val Loss: 1.0593
Epoch 7: Avg Val Loss: 1.2705
Epoch 9: Avg Val Loss: 1.2195
Epoch 11: Avg Val Loss: 1.7397
Epoch 13: Avg Val Loss: 1.5661
Epoch 15: Avg Val Loss: 1.8153
Epoch 17: Avg Val Loss: 2.1512
Epoch 19: Avg Val Loss: 2.3695
Epoch 21: Avg Val Loss: 1.9653
Epoch 23: Avg Val Loss: 2.3833
Epoch 25: Avg Val Loss: 2.2461
Epoch 27: Avg Val Loss: 2.3698
Epoch 29: Avg Val Loss: 3.0609
Epoch 31: Avg Val Loss: 3.6845
Epoch 33: Avg Val Loss: 2.6466
Epoch 35: Avg Val Loss: 2.8877
Epoch 37: Avg Val Loss: 2.3694
Epoch 39: Avg Val Loss: 2.7159
Epoch 41: Avg Val Loss: 3.0551
Epoch 43: Avg Val Loss: 2.3862
Epoch 45: Avg Val Loss: 2.8843
Epoch 47: Avg Val Loss: 3.3995
Epoch 49: Avg Val Loss: 3.2771
Epoch 51: Avg Val Loss: 3.0651
Epoch 53: Avg Val Loss: 3.3066
Epoch 55: Avg Val Loss: 3.3818
Epoch 57: Avg Val Loss: 3.3694
Epoch 59: Avg Val Loss: 3.7250
Epoch 61: Avg Val Loss: 3.5283
Epoch 63: Avg Val Loss: 3.8328
Epoch 65: Avg

In [6]:
new_model = MultiTaskModel(X_train.shape[1], 128, 2)
new_model.load_state_dict(best_model)

<All keys matched successfully>

In [7]:
def evaluate_model(model, test_loader):
    model.eval()
    preds = []
    preds_prob = []
    probs = []
    gts = []
    with torch.no_grad():
        correct1 = 0
        total1 = 0
        for inputs, labels1, labels2 in test_loader:
            # inputs, labels = inputs.to(device), labels.to(device)
            outputs1, outputs2 = model(inputs)
            probabilities1 = torch.softmax(outputs1.data, dim=1)
            predicted1 = torch.argmax(probabilities1, dim=1)
            predicted_probabilities1 = probabilities1[torch.arange(probabilities1.shape[0]), predicted1]

            total1 += labels1.size(0)
            correct1 += (predicted1 == labels1).sum().item()

            preds.extend(list(predicted1.cpu().numpy()))
            gts.extend(list(labels1.cpu().numpy()))
            preds_prob.extend(list(predicted_probabilities1.cpu().numpy()))
            probs.extend(list(probabilities1.cpu().numpy()))

        test_accuracy1 = 100 * correct1 / total1
        print(f'Test Accuracy: {test_accuracy1:.3f}%')
        auc_score1 = roc_auc_score(gts, preds)
        print(f'AUC Score: {auc_score1:.3f}')
            
    # return test_accuracy, auc_score * 100, ave_preds_prob, gts, preds, preds_prob, probs
    return gts, preds, preds_prob, probs

In [8]:
def calcualte_brier_score(gt, binary_prob):
    metric_brier = BrierScore(num_classes=2, top_class=False)
    metric_brier.update(torch.tensor(binary_prob), torch.tensor(gt))
    brierScore = metric_brier.compute()
    return np.round(brierScore.item(), 3)

In [9]:
gts, preds, preds_prob, probs = evaluate_model(new_model, test_dl)

Test Accuracy: 77.761%
AUC Score: 0.674


In [10]:
calcualte_brier_score(np.array(gts), probs)

  metric_brier.update(torch.tensor(binary_prob), torch.tensor(gt))


0.388