In [81]:
import os

import torch

from sklearn.metrics import (
    auc,
    average_precision_score,
    balanced_accuracy_score,
    f1_score,
    precision_recall_curve,
    precision_score,
    recall_score,
    roc_auc_score,
)


ROOT = "/fs01/home/afallah/odyssey/odyssey"
os.chdir(ROOT)

from odyssey.data.dataset import FinetuneMultiDataset, FinetuneDatasetDecoder
from odyssey.data.tokenizer import ConceptTokenizer
from odyssey.models.model_utils import load_finetune_data, load_pretrain_data

In [82]:
class config:
    """Save the configuration arguments."""

    # test_outputs = "checkpoints/mamba_finetune/test_outputs/test_outputs_f8471ffd.pt"
    test_outputs = "checkpoints/mamba_finetune_with_embeddings/test_outputs/test_outputs_f521131a.pt"
    vocab_dir = "odyssey/data/vocab"
    data_dir = "odyssey/data/bigbird_data"
    sequence_file = "patient_sequences_2048_multi.parquet"
    id_file = "dataset_2048_multi.pkl"
    valid_scheme = "few_shot"
    num_finetune_patients = "all"
    # label_name = "label_mortality_1month"
    tasks = ['mortality_1month', 'readmission_1month', 'los_1week', 'c0', 'c1', 'c2']

    max_len = 2048
    batch_size = 1
    device = torch.device("cpu")

In [83]:
def calculate_metrics(y_true, y_pred, y_prob):
    """
    Calculate and return performance metrics.
    """
    metrics = {
        "Balanced Accuracy": balanced_accuracy_score(y_true, y_pred),
        "F1 Score": f1_score(y_true, y_pred),
        "Precision": precision_score(y_true, y_pred),
        "Recall": recall_score(y_true, y_pred),
        "AUROC": roc_auc_score(y_true, y_prob),
        "Average Precision Score": average_precision_score(y_true, y_pred),
    }

    precision, recall, _ = precision_recall_curve(y_true, y_pred)
    metrics["AUC-PR"] = auc(recall, precision)

    return metrics

In [84]:
# Load tokenizer
tokenizer = ConceptTokenizer(data_dir=config.vocab_dir)
tokenizer.fit_on_vocab()

# Load pretrain data
# pretrain = load_pretrain_data(
#     config.data_dir,
#     'patient_sequences/'+config.sequence_file,
#     'patient_id_dict/'+config.id_file,
# )

# Load test data
fine_tune, fine_test = load_finetune_data(
    config.data_dir,
    config.sequence_file,
    config.id_file,
    config.valid_scheme,
    config.num_finetune_patients,
)

test_dataset = FinetuneDatasetDecoder(
    data=fine_test,
    tokenizer=tokenizer,
    tasks=config.tasks,
    balance_guide=None,
    max_len=config.max_len,
)

# train_dataset = FinetuneMultiDataset(
#             data=fine_tune,
#             tokenizer=tokenizer,
#             tasks=['mortality_1month', 'readmission_1month', 'los_1week', 'c0', 'c1', 'c2'],
#             balance_guide={'mortality_1month':0.5, 'readmission_1month':0.5, 'los_1week':0.5, 'c0':0.5, 'c1':0.5, 'c2':0.5},
#             max_len=config.max_len,
# )

In [None]:
"""
Dataset Specifications
------------------------

Current Approach:
    - Pretrain: 141234 Patients
    - Test: 24924 Patients, 132682 Datapoints
    - Finetune: 139514 Unique Patients, 434270 Datapoints
        - Mortality: 26962 Patients
        - Readmission: 48898 Patients
        - Length of Stay: 72686 Patients
        - Condition 0: 122722 Patients
        - Condition 1: 94048 Patients
        - Condition 2: 68954 Patients

New Approach:
    - Pretrain: . Patients
    - Test: . Patients, . Datapoints
    - Finetune: . Unique Patients, . Datapoints
        - Mortality: . Patients
        - Readmission: . Patients
        - Length of Stay: . Patients
        - Condition 0: . Patients
        - Condition 1: . Patients
        - Condition 2: . Patients
"""

In [5]:
# Tasks we have are: tasks = ['mortality_1month', 'readmission_1month', 'los_1week', 'c0', 'c1', 'c2']

tasks = [test_dataset.index_mapper[i][1] for i in range(len(test_dataset))]

task2index = {task:[] for task in config.tasks}

for i, task in enumerate(tasks):
    task2index[task].append(i)

In [9]:
test_outputs = torch.load(config.test_outputs, map_location=config.device)
test_outputs.keys()

labels = test_outputs["labels"].cpu().numpy()
logits = test_outputs["logits"].cpu().numpy()
probs = torch.sigmoid(torch.tensor(logits[:, 1])).cpu().numpy()
preds = (probs >= 0.5).astype(int)

for task, task_idx in task2index.items():
    print(f'Task: {task}')
    print(calculate_metrics(labels[task_idx], preds[task_idx], probs[task_idx]))
    print()

Task: mortality_1month
{'Balanced Accuracy': 0.9046091364922854, 'F1 Score': 0.6942788074133763, 'Precision': 0.573024740622506, 'Recall': 0.8806214227309894, 'AUROC': 0.9721118612424308, 'Average Precision Score': 0.5163334778180794, 'AUC-PR': 0.7326808894122636}

Task: readmission_1month
{'Balanced Accuracy': 0.6449577401436226, 'F1 Score': 0.5728762397585166, 'Precision': 0.5361178369652946, 'Recall': 0.6150462962962963, 'AUROC': 0.6999407465024068, 'Average Precision Score': 0.47572974018385267, 'AUC-PR': 0.6485782917207787}

Task: los_1week
{'Balanced Accuracy': 0.8231965165355855, 'F1 Score': 0.7355366174738447, 'Precision': 0.6648296593186372, 'Recall': 0.8230733447046054, 'AUROC': 0.9084304816895977, 'Average Precision Score': 0.6000398760543761, 'AUC-PR': 0.7703696543617022}

Task: c0
{'Balanced Accuracy': 0.8015438292543531, 'F1 Score': 0.7808362369337979, 'Precision': 0.738933311351084, 'Recall': 0.8277772647520547, 'AUROC': 0.8877116519104709, 'Average Precision Score': 0.6

In [None]:
labels.sum()

In [None]:
preds

In [None]:
# model = torch.load('checkpoints/bigbird_finetune_with_condition/mortality_1month_20000_patients/best.ckpt')
# model['state_dict']['model.bert.embeddings.word_embeddings.weight'].shape