In [3]:
import sys
sys.path.append('..')
sys.path.append('../ehrshot')
import copy
from typing import Literal
import argparse
import pandas as pd
import numpy as np
import os

import torch
from torch import nn
from torch.distributions import Distribution
from torch_uncertainty.utils.distributions import cat_dist
from torch_uncertainty.routines import ClassificationRoutine
from torch_uncertainty.utils import TUTrainer
from torch_uncertainty.models import deep_ensembles, mc_dropout
from torch_uncertainty.transforms import RepeatTarget
import torchvision.transforms as T

from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import roc_auc_score
from tqdm import tqdm

In [None]:
results_dict = {}

labeling_functions=[
    "guo_los",
    "guo_readmission",
    "guo_icu",
    "new_hypertension",
    "new_hyperlipidemia",
    "new_pancan",
    "new_celiac",
    "new_lupus",
    "new_acutemi",
    "lab_thrombocytopenia",
    "lab_hyperkalemia",
    "lab_hyponatremia",
    "lab_anemia",
    "lab_hypoglycemia" # will OOM at 200G on `gpu` partition
]

unique_tasks_1 = ['guo_los', 'guo_readmission', 'guo_icu']
unique_tasks_2 = ['new_hypertension', 'new_hyperlipidemia', 'new_pancan', 'new_celiac', 'new_lupus', 'new_acutemi']
unique_tasks_3 = ['lab_thrombocytopenia', 'lab_hyperkalemia', 'lab_hyponatremia', 'lab_anemia', 'lab_hypoglycemia']

all_tasks = [unique_tasks_1, unique_tasks_2, unique_tasks_3]
all_tasks_name = ['unique_tasks_1', 'unique_tasks_2', 'unique_tasks_3']

In [None]:
class TwoLayerNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(TwoLayerNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.dropout1 = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        return x

def optim_recipe(model, lr_mult: float = 1.0):
    optimizer = torch.optim.SGD(model.parameters(), lr=0.05 * lr_mult)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
    return {"optimizer": optimizer, "scheduler": scheduler}

max_epochs = 50
batch_size = 64

# for i in tqdm(range(len(all_tasks))):
for i in [0]:

    general_task_name = all_tasks_name[i]

    folder_path = f'multi_task_data_uq/{general_task_name}'

    train_x_name = os.path.join(folder_path, 'X_train_all.csv')
    train_y_name = os.path.join(folder_path, 'y_train_all.csv')
    val_x_name = os.path.join(folder_path, 'X_val_all.csv')
    val_y_name = os.path.join(folder_path, 'y_val_all.csv')

    X_train = pd.read_csv(train_x_name).to_numpy()
    y_train = pd.read_csv(train_y_name).to_numpy().reshape(-1)
    X_val = pd.read_csv(val_x_name).to_numpy()
    y_val = pd.read_csv(val_y_name).to_numpy().reshape(-1)

    # Create class weights
    class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
    class_weights = torch.tensor(class_weights, dtype=torch.float)

    X_train = torch.tensor(X_train).float()
    X_val = torch.tensor(X_val).float()

    y_train = torch.tensor(y_train).long()
    y_val = torch.tensor(y_val).long()

    # Create TensorDatasets
    train_dataset = TensorDataset(X_train, y_train)
    val_dataset = TensorDataset(X_val, y_val)

    # Create DataLoaders
    train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dl = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    input_size = X_train.shape[1]
    hidden_size = 128
    num_classes = 2
    model = TwoLayerNet(input_size, hidden_size, num_classes)

    ensemble = deep_ensembles(
        model,
        num_estimators=5,
        task="classification",
        reset_model_parameters=True,
    )

    trainer = TUTrainer(accelerator="gpu", max_epochs=max_epochs)

    ens_routine = ClassificationRoutine(
        is_ensemble=True,
        num_classes=2,
        model=ensemble,
        loss=nn.CrossEntropyLoss(weight=class_weights),
        format_batch_fn=RepeatTarget(
            5
        ), 
        optim_recipe=optim_recipe(
            ensemble, 2.0
        ),
        eval_ood=False,
    )

    trainer.fit(ens_routine, train_dataloaders=train_dl, val_dataloaders=val_dl)
    
        ens_perf = trainer.test(ens_routine, dataloaders=[test_dl])

        results_dict[current_all_tasks_list[j]] = ens_perf

In [None]:
# import json
# with open('results_model_uq/results_multi_task_deep_ensemble.json', 'w') as f:
#     json.dump(results_dict, f)

import json
with open('results_model_uq_v2/results_single_task_deep_ensemble.json', 'w') as f:
    json.dump(results_dict, f)