In [1]:

import sys
sys.path.append('..')
sys.path.append('../ehrshot')
import argparse
import json
import os
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import collections
import pandas as pd
import sklearn
from sklearn import metrics
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from loguru import logger
from sklearn.preprocessing import MaxAbsScaler
from utils import (
    LABELING_FUNCTION_2_PAPER_NAME,
    SHOT_STRATS,
    MODEL_2_INFO,
    get_labels_and_features, 
    process_chexpert_labels, 
    convert_multiclass_to_binary_labels,
    CHEXPERT_LABELS, 
    LR_PARAMS, 
    XGB_PARAMS, 
    RF_PARAMS,
    ProtoNetCLMBRClassifier, 
    get_patient_splits_by_idx
)
from sklearn.model_selection import GridSearchCV, PredefinedSplit
from scipy.sparse import issparse
import scipy
import lightgbm as lgb
import femr
import femr.datasets
from femr.labelers import load_labeled_patients, LabeledPatients

import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import roc_auc_score
from tqdm import tqdm

In [2]:
class TwoLayerNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(TwoLayerNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [3]:
def train_model(num_epochs, model, train_loader, val_loader, criterion, optimizer):
    for epoch in range(num_epochs):
        model.train()
        for inputs, labels in train_loader:
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Evaluate on validation set
        model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            for inputs, labels in val_loader:
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            val_accuracy = 100 * correct / total
            if epoch % 10 == 0:
                pass
                print(f'Epoch [{epoch+1}/{num_epochs}], Validation Accuracy: {val_accuracy:.2f}%')
    return model

In [4]:
def evaluate_model(model, test_loader):
    model.eval()
    preds = []
    gts = []
    with torch.no_grad():
        correct = 0
        total = 0
        for inputs, labels in test_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            preds.extend(list(predicted.numpy()))
            gts.extend(list(labels.numpy()))
        test_accuracy = 100 * correct / total
        print(f'Test Accuracy: {test_accuracy:.2f}%')
        auc_score = roc_auc_score(gts, preds)
    return test_accuracy, auc_score

In [5]:
labeling_functions=[
    "guo_los",
    "guo_readmission",
    "guo_icu",
    "new_hypertension",
    "new_hyperlipidemia",
    "new_pancan",
    "new_celiac",
    "new_lupus",
    "new_acutemi",
    "lab_thrombocytopenia",
    "lab_hyperkalemia",
    "lab_hyponatremia",
    "lab_anemia",
    "lab_hypoglycemia" # will OOM at 200G on `gpu` partition
]

In [6]:
path_to_database='../EHRSHOT_ASSETS/femr/extract'
path_to_labels_dir='../EHRSHOT_ASSETS/benchmark'
path_to_features_dir='../EHRSHOT_ASSETS/features'
path_to_output_dir='../uncertainty_quantification/single_task_results'
path_to_output_data_dir = '../uncertainty_quantification/single_task_data'
path_to_split_csv='../EHRSHOT_ASSETS/splits/person_id_map.csv'
path_to_data_csv = '../uncertainty_quantification'


In [7]:
# def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run EHRSHOT evaluation benchmark on a specific task.")
parser.add_argument("--path_to_database", required=True, type=str, help="Path to FEMR patient database")
parser.add_argument("--path_to_labels_dir", required=True, type=str, help="Path to directory containing saved labels")
parser.add_argument("--path_to_features_dir", required=True, type=str, help="Path to directory containing saved features")
parser.add_argument("--path_to_output_dir", required=True, type=str, help="Path to directory where results will be saved")
parser.add_argument("--path_to_split_csv", required=True, type=str, help="Path to CSV of splits")
parser.add_argument("--labeling_function", required=True, type=str, help="Labeling function for which we will create k-shot samples.", choices=LABELING_FUNCTION_2_PAPER_NAME.keys(), )
parser.add_argument("--num_threads", type=int, help="Number of threads to use")
parser.add_argument("--is_force_refresh", action='store_true', default=False, help="If set, then overwrite all outputs")
# return parser.parse_args()

_StoreTrueAction(option_strings=['--is_force_refresh'], dest='is_force_refresh', nargs=0, const=True, default=False, type=None, choices=None, required=False, help='If set, then overwrite all outputs', metavar=None)

In [8]:
X_train_all = pd.DataFrame()
X_val_all = pd.DataFrame()
X_test_all = pd.DataFrame()

y_train_all = pd.DataFrame()
y_val_all = pd.DataFrame()
y_test_all = pd.DataFrame()

task_indicator = []
# for i in tqdm(range(len(labeling_functions))):
for i in range(len(labeling_functions)):
    labeling_function = labeling_functions[i]
    args = parser.parse_args(f'--labeling_function {labeling_function} --path_to_database {path_to_database} --path_to_labels_dir {path_to_labels_dir} --path_to_features_dir {path_to_features_dir} --path_to_split_csv {path_to_split_csv} --path_to_output_dir {path_to_output_dir}'.split())

    LABELING_FUNCTION: str = args.labeling_function
    # SHOT_STRAT: str = args.shot_strat
    # NUM_THREADS: int = args.num_threads
    IS_FORCE_REFRESH: bool = args.is_force_refresh
    PATH_TO_DATABASE: str = args.path_to_database
    PATH_TO_FEATURES_DIR: str = args.path_to_features_dir
    PATH_TO_LABELS_DIR: str = args.path_to_labels_dir
    PATH_TO_SPLIT_CSV: str = args.path_to_split_csv
    PATH_TO_LABELED_PATIENTS: str = os.path.join(PATH_TO_LABELS_DIR, LABELING_FUNCTION, 'labeled_patients.csv')
    PATH_TO_OUTPUT_DIR: str = args.path_to_output_dir
    PATH_TO_OUTPUT_FILE: str = os.path.join(PATH_TO_OUTPUT_DIR, f'{LABELING_FUNCTION}_results.csv')
    # PATH_TO_OUTPUT_DATA: str = 

    database = femr.datasets.PatientDatabase('../EHRSHOT_ASSETS/femr/extract')

    labeled_patients: LabeledPatients = load_labeled_patients(PATH_TO_LABELED_PATIENTS)
    patient_ids, label_values, label_times, feature_matrixes = get_labels_and_features(labeled_patients, PATH_TO_FEATURES_DIR)
    train_pids_idx, val_pids_idx, test_pids_idx = get_patient_splits_by_idx(PATH_TO_SPLIT_CSV, patient_ids)

    if LABELING_FUNCTION == "chexpert":
        label_values = process_chexpert_labels(label_values)
        sub_tasks: List[str] = CHEXPERT_LABELS
    elif LABELING_FUNCTION.startswith('lab_'):
        # Lab value is multi-class, convert to binary
        label_values = convert_multiclass_to_binary_labels(label_values, threshold=1)
        sub_tasks: List[str] = [LABELING_FUNCTION]
    else:
        # Binary classification
        sub_tasks: List[str] = [LABELING_FUNCTION]
            

    model = 'clmbr'
    X_train: np.ndarray = feature_matrixes[model][train_pids_idx]
    X_val: np.ndarray = feature_matrixes[model][val_pids_idx]
    X_test: np.ndarray = feature_matrixes[model][test_pids_idx]

    y_train: np.array = label_values[train_pids_idx].astype(int)
    y_val: np.array = label_values[val_pids_idx].astype(int)
    y_test: np.ndarray = label_values[test_pids_idx].astype(int)
    

    X_train = pd.DataFrame(X_train)
    X_val = pd.DataFrame(X_val)
    X_test = pd.DataFrame(X_test)

    X_train['task'] = labeling_function
    X_val['task'] = labeling_function
    X_test['task'] = labeling_function

    y_train = pd.DataFrame(y_train)
    y_val = pd.DataFrame(y_val)
    y_test = pd.DataFrame(y_test)

    X_train_all = pd.concat([X_train_all, X_train])
    X_val_all = pd.concat([X_val_all, X_val])
    X_test_all = pd.concat([X_test_all, X_test])

    y_train_all = pd.concat([y_train_all, y_train])
    y_val_all = pd.concat([y_val_all, y_val])
    y_test_all = pd.concat([y_test_all, y_test])

In [16]:
# X_train_all.to_csv(os.path.join(path_to_data_csv, 'multi_task_data', 'X_train_all.csv'))
# y_train_all.to_csv(os.path.join(path_to_data_csv, 'multi_task_data', 'y_train_all.csv'))
# X_val_all.to_csv(os.path.join(path_to_data_csv, 'multi_task_data', 'X_val_all.csv'))
# y_val_all.to_csv(os.path.join(path_to_data_csv, 'multi_task_data', 'y_val_all.csv'))
# X_test_all.to_csv(os.path.join(path_to_data_csv, 'multi_task_data', 'X_test_all.csv'))
# y_test_all.to_csv(os.path.join(path_to_data_csv, 'multi_task_data', 'y_test_all.csv'))

Unnamed: 0,0
0,1
1,1
2,1
3,1
4,0
...,...
122103,0
122104,0
122105,0
122106,0
