In [1]:
import os

import torch

from sklearn.metrics import (
    auc,
    average_precision_score,
    balanced_accuracy_score,
    f1_score,
    precision_recall_curve,
    precision_score,
    recall_score,
    roc_auc_score,
)


ROOT = "/fs01/home/afallah/odyssey/odyssey"
os.chdir(ROOT)

from odyssey.data.dataset import FinetuneMultiDataset, FinetuneDatasetDecoder
from odyssey.data.tokenizer import ConceptTokenizer
from odyssey.models.model_utils import load_finetune_data, load_pretrain_data
from odyssey.evals.evaluation import calculate_metrics

In [7]:
class config:
    """Save the configuration arguments."""

    # test_outputs = "checkpoints/mamba_finetune_with_embeddings/test_outputs/test_outputs_f521131a.pt"   # original
    # test_outputs = "checkpoints/mamba_finetune_with_embeddings/test_outputs/test_outputs_dd3a541e.pt"   # original 2
    # test_outputs = "checkpoints/mamba_finetune_with_embeddings/test_outputs/test_outputs_8d4a82a7.pt"   # Mamba new 3
    # test_outputs = "checkpoints/mamba_finetune_with_embeddings/test_outputs/test_outputs_cc5497f1.pt"   # Mamba new 5
    # test_outputs = "checkpoints/mamba_finetune_with_embeddings/test_outputs/test_outputs_189dcaa6.pt"   # Mamba new 4
    # test_outputs = "checkpoints/mamba_finetune_with_embeddings/test_outputs/test_outputs_f65e9b55.pt"   # Mamba new 6
    # test_outputs = "checkpoints/multibird_finetune/test_outputs/test_outputs_e112ebb5.pt"               # Multibird 4
    # test_outputs = "checkpoints/multibird_finetune/test_outputs/test_outputs_782a830c.pt"                # Multibird 5
    # test_outputs = "checkpoints/multibird_finetune/test_outputs/test_outputs_849dfdc5.pt" # Multibird 7
    test_outputs = "checkpoints/multibird_finetune/test_outputs/test_outputs_102b8d16.pt"  # Multibird 8
    # test_outputs = "checkpoints/mamba_finetune_with_embeddings_multihead/test_outputs/test_outputs_d3d3deb6.pt" # Multihead 6
    # test_outputs = "checkpoints/mamba_finetune_with_embeddings_multihead/test_outputs/test_outputs_543b92de.pt" # Multihead 4
    # test_outputs = "checkpoints/mamba_finetune_with_embeddings_multihead/test_outputs/test_outputs_65922b58.pt" # Multihead 3
    # test_outputs = "/ssd003/projects/aieng/public/odyssey/checkpoints/multibird_finetune/v1/test_outputs/test_outputs_1ec842db.pt"
    vocab_dir = "odyssey/data/vocab"
    data_dir = "odyssey/data/bigbird_data"
    sequence_file = "patient_sequences_2048_multi_v2.parquet"
    id_file = "dataset_2048_multi_v2.pkl"
    valid_scheme = "few_shot"
    num_finetune_patients = "all"
    # label_name = "label_mortality_1month"
    tasks = [
        "mortality_1month",
        "readmission_1month",
        "los_1week",
        "c0",
        "c1",
        "c2",
    ]  # CHANGE LATER!!!! 'readmission_1month',

    max_len = 2048
    batch_size = 1
    device = torch.device("cpu")

In [4]:
# Load tokenizer
tokenizer = ConceptTokenizer(data_dir=config.vocab_dir)
tokenizer.fit_on_vocab()

# Load pretrain data
# pretrain = load_pretrain_data(
#     config.data_dir,
#     'patient_sequences/'+config.sequence_file,
#     'patient_id_dict/'+config.id_file,
# )

# Load test data
fine_tune, fine_test = load_finetune_data(
    config.data_dir,
    config.sequence_file,
    config.id_file,
    config.valid_scheme,
    config.num_finetune_patients,
)

# test_dataset = FinetuneDatasetDecoder(
#     data=fine_test,
#     tokenizer=tokenizer,
#     tasks=config.tasks,
#     balance_guide=None,
#     max_len=config.max_len,
# )

test_dataset = FinetuneMultiDataset(
    data=fine_test,
    tokenizer=tokenizer,
    tasks=config.tasks,
    balance_guide=None,
    max_len=config.max_len,
)

In [62]:
# label = 'label_readmission_1month'

# (len(fine_tune.loc[fine_tune[label] != -1]),
#  fine_tune.loc[fine_tune[label] != -1]['num_visits'].sum(),
# fine_tune.loc[fine_tune[label] != -1]['event_tokens_2048'].transform(len).sum(),
# fine_tune.loc[fine_tune[label] != -1][label].mean())

# (len(fine_test.loc[fine_test[label] != -1]),
# fine_test.loc[fine_test[label] != -1]['num_visits'].sum(),
# fine_test.loc[fine_test[label] != -1]['event_tokens_2048'].transform(len).sum(),
# fine_test.loc[fine_test[label] != -1][label].mean())

(31775, 122109, 20851581, 0.4911723052714398)

In [8]:
tasks = [test_dataset.index_mapper[i][1] for i in range(len(test_dataset))]
task2index = {task: [] for task in config.tasks}
for i, task in enumerate(tasks):
    task2index[task].append(i)


def test_model():
    test_outputs = torch.load(config.test_outputs, map_location=config.device)
    test_outputs.keys()

    labels = test_outputs["labels"].cpu().numpy()
    logits = test_outputs["logits"].cpu().numpy()
    probs = torch.sigmoid(torch.tensor(logits[:, 1])).cpu().numpy()
    preds = (probs >= 0.5).astype(int)

    for task, task_idx in task2index.items():
        print(f"Task: {task}")
        print(calculate_metrics(labels[task_idx], preds[task_idx], probs[task_idx]))
        print()

In [66]:
test_model()  # Multihead 3

Task: mortality_1month
{'Balanced Accuracy': 0.8655722026231953, 'F1 Score': 0.7755274261603375, 'Precision': 0.8012205754141238, 'Recall': 0.7514309076042518, 'AUROC': 0.9689994603245022, 'Average Precision Score': 0.6264560624156887, 'AUC-PR': 0.8643278717781356}

Task: readmission_1month
{'Balanced Accuracy': 0.624845239226117, 'F1 Score': 0.5496580918267665, 'Precision': 0.5174739423666462, 'Recall': 0.5861111111111111, 'AUROC': 0.6787862460815047, 'Average Precision Score': 0.4609969980142222, 'AUC-PR': 0.5676588678703538}

Task: los_1week
{'Balanced Accuracy': 0.7985457953399953, 'F1 Score': 0.7134502923976608, 'Precision': 0.6906662795761359, 'Recall': 0.7377888044658086, 'AUROC': 0.8908909347887108, 'Average Precision Score': 0.5878710119255549, 'AUC-PR': 0.8054669969100448}

Task: c0
{'Balanced Accuracy': 0.7566179832371729, 'F1 Score': 0.7189705670523839, 'Precision': 0.7488497699539908, 'Recall': 0.6913842460060947, 'AUROC': 0.8520657893413595, 'Average Precision Score': 0.6

In [11]:
test_model()  # Multihead 6

Task: mortality_1month
{'Balanced Accuracy': 0.8780054916237896, 'F1 Score': 0.7883060335890525, 'Precision': 0.7997475809844342, 'Recall': 0.7771872444807849, 'AUROC': 0.9726084128993358, 'Average Precision Score': 0.6434200928266727, 'AUC-PR': 0.8776861929897538}

Task: readmission_1month
{'Balanced Accuracy': 0.6276137155251575, 'F1 Score': 0.5369158878504673, 'Precision': 0.5419811320754717, 'Recall': 0.5319444444444444, 'AUROC': 0.6863716679596382, 'Average Precision Score': 0.4666421834766151, 'AUC-PR': 0.5772680378049946}

Task: los_1week
{'Balanced Accuracy': 0.8087185027022964, 'F1 Score': 0.7203293583191368, 'Precision': 0.664223065846315, 'Recall': 0.7867886494030082, 'AUROC': 0.8948934510936986, 'Average Precision Score': 0.5862753152097165, 'AUC-PR': 0.8089246395692433}

Task: c0
{'Balanced Accuracy': 0.7832233228926138, 'F1 Score': 0.7517542196093305, 'Precision': 0.7724836792360908, 'Recall': 0.7321082279065473, 'AUROC': 0.8719728806194893, 'Average Precision Score': 0.6

In [29]:
test_model()  # Original 2

Task: mortality_1month
{'Balanced Accuracy': 0.9046091364922854, 'F1 Score': 0.6942788074133763, 'Precision': 0.573024740622506, 'Recall': 0.8806214227309894, 'AUROC': 0.9721118612424308, 'Average Precision Score': 0.5163334778180794, 'AUC-PR': 0.8741390162810497}

Task: readmission_1month
{'Balanced Accuracy': 0.6447714796662094, 'F1 Score': 0.572943722943723, 'Precision': 0.5363728470111448, 'Recall': 0.6148664343786295, 'AUROC': 0.6995091168843479, 'Average Precision Score': 0.476031563625974, 'AUC-PR': 0.5895119974155991}

Task: los_1week
{'Balanced Accuracy': 0.8228229864371363, 'F1 Score': 0.7352697095435685, 'Precision': 0.6644169478815148, 'Recall': 0.8230376219228983, 'AUROC': 0.9082735533076256, 'Average Precision Score': 0.59976906348086, 'AUC-PR': 0.8323074699576767}

Task: c0
{'Balanced Accuracy': 0.8014407368019587, 'F1 Score': 0.7808064094748759, 'Precision': 0.7390372568414112, 'Recall': 0.8275798412405391, 'AUROC': 0.8875025728612694, 'Average Precision Score': 0.68656

In [12]:
# Original

Task: mortality_1month
{'Balanced Accuracy': 0.9046091364922854, 'F1 Score': 0.6942788074133763, 'Precision': 0.573024740622506, 'Recall': 0.8806214227309894, 'AUROC': 0.9721118612424308, 'Average Precision Score': 0.5163334778180794, 'AUC-PR': 0.7326808894122636}

Task: readmission_1month
{'Balanced Accuracy': 0.6447714796662094, 'F1 Score': 0.572943722943723, 'Precision': 0.5363728470111448, 'Recall': 0.6148664343786295, 'AUROC': 0.6995091168843479, 'Average Precision Score': 0.476031563625974, 'AUC-PR': 0.6487365925382458}

Task: los_1week
{'Balanced Accuracy': 0.8228229864371363, 'F1 Score': 0.7352697095435685, 'Precision': 0.6644169478815148, 'Recall': 0.8230376219228983, 'AUROC': 0.9082735533076256, 'Average Precision Score': 0.59976906348086, 'AUC-PR': 0.7701917442678005}

Task: c0
{'Balanced Accuracy': 0.8014407368019587, 'F1 Score': 0.7808064094748759, 'Precision': 0.7390372568414112, 'Recall': 0.8275798412405391, 'AUROC': 0.8875025728612694, 'Average Precision Score': 0.68656

In [22]:
test_model()  # New 3

Task: mortality_1month
{'Balanced Accuracy': 0.9031064952616157, 'F1 Score': 0.7067020570670206, 'Precision': 0.5946398659966499, 'Recall': 0.8708094848732625, 'AUROC': 0.972161114452456, 'Average Precision Score': 0.5304965781636694, 'AUC-PR': 0.8680130845060448}

Task: readmission_1month
{'Balanced Accuracy': 0.617183815538263, 'F1 Score': 0.4710590252965557, 'Precision': 0.612184249628529, 'Recall': 0.3828106852497096, 'AUROC': 0.6963517269590036, 'Average Precision Score': 0.4686953537010699, 'AUC-PR': 0.583781093737475}

Task: los_1week
{'Balanced Accuracy': 0.8264205588789526, 'F1 Score': 0.737919737919738, 'Precision': 0.6598315635298425, 'Recall': 0.8369716674407803, 'AUROC': 0.9093339907980309, 'Average Precision Score': 0.6010216112926607, 'AUC-PR': 0.831706725413936}

Task: c0
{'Balanced Accuracy': 0.7992014533014051, 'F1 Score': 0.7732314631564413, 'Precision': 0.7711374277058661, 'Recall': 0.7753369023444711, 'AUROC': 0.8856379346465069, 'Average Precision Score': 0.695548

In [17]:
test_model()  # New 5

Task: mortality_1month
{'Balanced Accuracy': 0.8622720192950359, 'F1 Score': 0.7799742157284056, 'Precision': 0.8220108695652174, 'Recall': 0.7420278004905969, 'AUROC': 0.9702276440443593, 'Average Precision Score': 0.6352718810920918, 'AUC-PR': 0.8674110558595907}

Task: readmission_1month
{'Balanced Accuracy': 0.6262484101727832, 'F1 Score': 0.5413209048539275, 'Precision': 0.5318293500111682, 'Recall': 0.5511574074074074, 'AUROC': 0.68287825353852, 'Average Precision Score': 0.4641395019286542, 'AUC-PR': 0.5680185975838656}

Task: los_1week
{'Balanced Accuracy': 0.799364691960775, 'F1 Score': 0.7138704628456436, 'Precision': 0.687284730195178, 'Recall': 0.7425957512792681, 'AUROC': 0.8938146814638854, 'Average Precision Score': 0.5872443663133886, 'AUC-PR': 0.808685370471159}

Task: c0
{'Balanced Accuracy': 0.7842020547679327, 'F1 Score': 0.7565109549400578, 'Precision': 0.7526046426613051, 'Recall': 0.7604580293655924, 'AUROC': 0.871564490669245, 'Average Precision Score': 0.676400

In [13]:
test_model()  # New 4

Task: mortality_1month
{'Balanced Accuracy': 0.8753458000943887, 'F1 Score': 0.7728013029315961, 'Precision': 0.7696674776966748, 'Recall': 0.7759607522485691, 'AUROC': 0.9705738897457072, 'Average Precision Score': 0.6192185949683511, 'AUC-PR': 0.8678600307736943}

Task: readmission_1month
{'Balanced Accuracy': 0.6246481600223763, 'F1 Score': 0.5342481683916734, 'Precision': 0.5368076653423697, 'Recall': 0.531712962962963, 'AUROC': 0.680164563608921, 'Average Precision Score': 0.4638541245326711, 'AUC-PR': 0.5693479032442244}

Task: los_1week
{'Balanced Accuracy': 0.7979240340522251, 'F1 Score': 0.7127627627627627, 'Precision': 0.6908746907291515, 'Recall': 0.7360831136610327, 'AUROC': 0.8915162383799904, 'Average Precision Score': 0.587355733904388, 'AUC-PR': 0.804778620144726}

Task: c0
{'Balanced Accuracy': 0.7675343469447702, 'F1 Score': 0.7329269454233737, 'Precision': 0.7567115743927623, 'Recall': 0.710591929079324, 'AUROC': 0.8575170903900623, 'Average Precision Score': 0.66345

In [9]:
test_model()  # New 6

Task: mortality_1month
{'Balanced Accuracy': 0.8690388246976402, 'F1 Score': 0.7823516993877981, 'Precision': 0.8088171104321257, 'Recall': 0.7575633687653311, 'AUROC': 0.9719834173099352, 'Average Precision Score': 0.6365225435731867, 'AUC-PR': 0.8732467381784418}

Task: readmission_1month
{'Balanced Accuracy': 0.6299158315126183, 'F1 Score': 0.5478495834271561, 'Precision': 0.5333187198597107, 'Recall': 0.5631944444444444, 'AUROC': 0.6873234698077959, 'Average Precision Score': 0.46679360953816706, 'AUC-PR': 0.5735717003134541}

Task: los_1week
{'Balanced Accuracy': 0.8025800848789157, 'F1 Score': 0.7169839161871034, 'Precision': 0.6838845883180859, 'Recall': 0.7534501473096604, 'AUROC': 0.8937000531366002, 'Average Precision Score': 0.5889010984765001, 'AUC-PR': 0.8078436780997615}

Task: c0
{'Balanced Accuracy': 0.7889936829917931, 'F1 Score': 0.7573392078923423, 'Precision': 0.7866096299243932, 'Recall': 0.7301689906731924, 'AUROC': 0.8796977562547927, 'Average Precision Score': 0

In [8]:
test_model()  # Multibird 4

Task: mortality_1month
{'Balanced Accuracy': 0.878934427535469, 'F1 Score': 0.7252909661924996, 'Precision': 0.6616110549376475, 'Recall': 0.8025347506132461, 'AUROC': 0.9646008285597611, 'Average Precision Score': 0.5503447748695037, 'AUC-PR': 0.8353257173010042}

Task: los_1week
{'Balanced Accuracy': 0.7773360432023431, 'F1 Score': 0.6808632680863268, 'Precision': 0.6332355294745265, 'Recall': 0.7362381764614669, 'AUROC': 0.8699980601281244, 'Average Precision Score': 0.5449804048782951, 'AUC-PR': 0.7670872175086823}

Task: readmission_1month
{'Balanced Accuracy': 0.6258331420263239, 'F1 Score': 0.5516272029408584, 'Precision': 0.5175491986204098, 'Recall': 0.5905092592592592, 'AUROC': 0.679461807204754, 'Average Precision Score': 0.461641584029276, 'AUC-PR': 0.5637195916045253}

Task: c0
{'Balanced Accuracy': 0.7902966398445754, 'F1 Score': 0.7589046342397548, 'Precision': 0.7879510885773934, 'Recall': 0.7319235386462277, 'AUROC': 0.8802952315807759, 'Average Precision Score': 0.693

In [7]:
test_model()  # Multibird 5

Task: mortality_1month
{'Balanced Accuracy': 0.8307360510289447, 'F1 Score': 0.7477946166025786, 'Precision': 0.8369620253164557, 'Recall': 0.6757972199509403, 'AUROC': 0.9617723793090831, 'Average Precision Score': 0.597433332750795, 'AUC-PR': 0.8299543351752259}

Task: readmission_1month
{'Balanced Accuracy': 0.6081461823037059, 'F1 Score': 0.47329276538201487, 'Precision': 0.5691056910569106, 'Recall': 0.4050925925925926, 'AUROC': 0.6729919360797103, 'Average Precision Score': 0.4572118704615439, 'AUC-PR': 0.5584318732316119}

Task: los_1week
{'Balanced Accuracy': 0.7567032514271541, 'F1 Score': 0.6653495962892974, 'Precision': 0.7458116695551704, 'Recall': 0.6005582260815631, 'AUROC': 0.8733449329098977, 'Average Precision Score': 0.5671902052200841, 'AUC-PR': 0.772768315710113}

Task: c0
{'Balanced Accuracy': 0.7907783355108081, 'F1 Score': 0.7676427590484207, 'Precision': 0.7377384196185286, 'Recall': 0.8000738757041278, 'AUROC': 0.878586426793819, 'Average Precision Score': 0.67

In [6]:
test_model()  # Multibird 7

Task: mortality_1month
{'Balanced Accuracy': 0.8715696903457233, 'F1 Score': 0.7514382067050188, 'Precision': 0.7298651252408478, 'Recall': 0.7743254292722813, 'AUROC': 0.9675795019198202, 'Average Precision Score': 0.5873004542897322, 'AUC-PR': 0.8468842847116425}

Task: readmission_1month
{'Balanced Accuracy': 0.6297042063793631, 'F1 Score': 0.5435356200527705, 'Precision': 0.5387764384807824, 'Recall': 0.5483796296296296, 'AUROC': 0.6840408229367869, 'Average Precision Score': 0.467530227703297, 'AUC-PR': 0.5705826681588523}

Task: los_1week
{'Balanced Accuracy': 0.781210473498072, 'F1 Score': 0.6900330231161813, 'Precision': 0.6686545454545455, 'Recall': 0.7128236935959064, 'AUROC': 0.8750743499816753, 'Average Precision Score': 0.5623933955603996, 'AUC-PR': 0.7840703889371606}

Task: c0
{'Balanced Accuracy': 0.7974596742399855, 'F1 Score': 0.7718961418906565, 'Precision': 0.7642798949941161, 'Recall': 0.7796657124388217, 'AUROC': 0.8842594008160198, 'Average Precision Score': 0.69

In [9]:
test_model()  # Multibird 8

Task: mortality_1month
{'Balanced Accuracy': 0.8715696903457233, 'F1 Score': 0.7514382067050188, 'Precision': 0.7298651252408478, 'Recall': 0.7743254292722813, 'AUROC': 0.9675795019198202, 'Average Precision Score': 0.5873004542897322, 'AUC-PR': 0.8468842847116425}

Task: readmission_1month
{'Balanced Accuracy': 0.6297042063793631, 'F1 Score': 0.5435356200527705, 'Precision': 0.5387764384807824, 'Recall': 0.5483796296296296, 'AUROC': 0.6840408229367869, 'Average Precision Score': 0.467530227703297, 'AUC-PR': 0.5705826681588523}

Task: los_1week
{'Balanced Accuracy': 0.781210473498072, 'F1 Score': 0.6900330231161813, 'Precision': 0.6686545454545455, 'Recall': 0.7128236935959064, 'AUROC': 0.8750743499816753, 'Average Precision Score': 0.5623933955603996, 'AUC-PR': 0.7840703889371606}

Task: c0
{'Balanced Accuracy': 0.7974596742399855, 'F1 Score': 0.7718961418906565, 'Precision': 0.7642798949941161, 'Recall': 0.7796657124388217, 'AUROC': 0.8842594008160198, 'Average Precision Score': 0.69

In [None]:
# model = torch.load('checkpoints/bigbird_finetune_with_condition/mortality_1month_20000_patients/best.ckpt')
# model['state_dict']['model.bert.embeddings.word_embeddings.weight'].shape