In [None]:
import os

import torch
from sklearn.metrics import (
    auc,
    average_precision_score,
    balanced_accuracy_score,
    f1_score,
    precision_recall_curve,
    precision_score,
    recall_score,
    roc_auc_score,
)
from transformers import utils

from tqdm import tqdm

ROOT = "/fs01/home/afallah/odyssey/odyssey"
os.chdir(ROOT)

from odyssey.data.tokenizer import ConceptTokenizer
from odyssey.models.model_utils import load_finetune_data
from odyssey.data.dataset import FinetuneMultiDataset

In [2]:
class config:
    """Save the configuration arguments."""

    model_path = "checkpoints/multibird_finetune/multibird_finetune/test_outputs/test_outputs_1ec842db.pt"
    vocab_dir = "odyssey/data/vocab"
    data_dir = "odyssey/data/bigbird_data"
    sequence_file = "patient_sequences/patient_sequences_2048_multi.parquet"
    id_file = "patient_id_dict/dataset_2048_multi.pkl"
    valid_scheme = "few_shot"
    num_finetune_patients = "all"
    # label_name = "label_mortality_1month"
    tasks = ['mortality_1month', 'los_1week', 'c0', 'c1', 'c2']

    max_len = 2048
    batch_size = 1
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
def calculate_metrics(y_true, y_pred, y_prob):
    """
    Calculate and return performance metrics.
    """
    metrics = {
        "Balanced Accuracy": balanced_accuracy_score(y_true, y_pred),
        "F1 Score": f1_score(y_true, y_pred),
        "Precision": precision_score(y_true, y_pred),
        "Recall": recall_score(y_true, y_pred),
        "AUROC": roc_auc_score(y_true, y_prob),
        "Average Precision Score": average_precision_score(y_true, y_pred),
    }

    precision, recall, _ = precision_recall_curve(y_true, y_pred)
    metrics["AUC-PR"] = auc(recall, precision)

    return metrics

In [4]:
# Load tokenizer
tokenizer = ConceptTokenizer(data_dir=config.vocab_dir)
tokenizer.fit_on_vocab()

# Load test data
fine_tune, fine_test = load_finetune_data(
    config.data_dir,
    config.sequence_file,
    config.id_file,
    config.valid_scheme,
    config.num_finetune_patients,
)

test_dataset = FinetuneMultiDataset(
            data=fine_test,
            tokenizer=tokenizer,
            tasks=config.tasks,
            balance_guide=None,
            max_len=config.max_len,
)

# train_dataset = FinetuneMultiDataset(
#             data=fine_tune,
#             tokenizer=tokenizer,
#             tasks=['los_1week'],
#             balance_guide={'los_1week': 0.5},
#             max_len=config.max_len,
# )

tasks = [test_dataset.index_mapper[i][1] for i in range(len(test_dataset))]

FileNotFoundError: Sequence file not found: odyssey/data/bigbird_data/patient_sequences/patient_sequences/patient_sequences_2048_multi.parquet

In [5]:
tokenizer.get_vocab_size()

20600

In [25]:
test_outputs = torch.load(config.model_path, map_location=config.device)
test_outputs.keys()

dict_keys(['labels', 'logits'])

In [26]:
labels = test_outputs['labels'].cpu().numpy()
logits = test_outputs['logits'].cpu().numpy()
probs = torch.sigmoid(torch.tensor(logits[:, 1])).cpu().numpy()
preds = (probs >= 0.5).astype(int)

preds

array([0, 0, 0, ..., 0, 0, 0])

In [27]:
# Tasks we have are: tasks = ['mortality_1month', 'los_1week', 'c0', 'c1', 'c2']

task_idx = []
for i, task in enumerate(tasks):
    if task == 'los_1week':
        task_idx.append(i)


calculate_metrics(labels[task_idx], preds[task_idx], probs[task_idx])

{'Balanced Accuracy': 0.8043450887917278,
 'F1 Score': 0.7121043864519712,
 'Precision': 0.6445533358462119,
 'Recall': 0.7954721662273221,
 'AUROC': 0.8889347846976665,
 'Average Precision Score': 0.5738031917760823,
 'AUC-PR': 0.7505522277674918}

In [24]:
old_test_outputs = torch.load('test_outputs.pt')
target = 2

labels = old_test_outputs["labels"][:, target]
logits = torch.sigmoid(old_test_outputs["logits"][:, target])
preds = (logits >= 0.5).int()
calculate_metrics(labels, preds, logits)

{'Balanced Accuracy': 0.8220779574282759,
 'F1 Score': 0.7388243277631691,
 'Precision': 0.7633654688869412,
 'Recall': 0.7158119658119658,
 'AUROC': 0.9255784114401215,
 'Average Precision Score': 0.6157970244149283,
 'AUC-PR': 0.7742741610984505}

In [68]:
preds.sum()

tensor(0)

In [69]:
labels.sum()

tensor(34., dtype=torch.float64)

In [9]:
preds

tensor([0, 0, 0,  ..., 0, 0, 0], dtype=torch.int32)

In [None]:
# model = torch.load('checkpoints/bigbird_finetune_with_condition/mortality_1month_20000_patients/best.ckpt')
# model['state_dict']['model.bert.embeddings.word_embeddings.weight'].shape