In [None]:
import os

import torch
from sklearn.metrics import (
    auc,
    average_precision_score,
    balanced_accuracy_score,
    f1_score,
    precision_recall_curve,
    precision_score,
    recall_score,
    roc_auc_score,
)


ROOT = "/fs01/home/afallah/odyssey/odyssey"
os.chdir(ROOT)

from odyssey.data.dataset import FinetuneMultiDataset
from odyssey.data.tokenizer import ConceptTokenizer
from odyssey.models.model_utils import load_finetune_data

In [None]:
class config:
    """Save the configuration arguments."""

    model_path = "checkpoints/multibird_finetune/multibird_finetune/test_outputs/test_outputs_1ec842db.pt"
    vocab_dir = "odyssey/data/vocab"
    data_dir = "odyssey/data/bigbird_data"
    sequence_file = "patient_sequences/patient_sequences_2048_multi.parquet"
    id_file = "patient_id_dict/dataset_2048_multi.pkl"
    valid_scheme = "few_shot"
    num_finetune_patients = "all"
    # label_name = "label_mortality_1month"
    tasks = ["mortality_1month", "los_1week", "c0", "c1", "c2"]

    max_len = 2048
    batch_size = 1
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def calculate_metrics(y_true, y_pred, y_prob):
    """
    Calculate and return performance metrics.
    """
    metrics = {
        "Balanced Accuracy": balanced_accuracy_score(y_true, y_pred),
        "F1 Score": f1_score(y_true, y_pred),
        "Precision": precision_score(y_true, y_pred),
        "Recall": recall_score(y_true, y_pred),
        "AUROC": roc_auc_score(y_true, y_prob),
        "Average Precision Score": average_precision_score(y_true, y_pred),
    }

    precision, recall, _ = precision_recall_curve(y_true, y_pred)
    metrics["AUC-PR"] = auc(recall, precision)

    return metrics

In [None]:
# Load tokenizer
tokenizer = ConceptTokenizer(data_dir=config.vocab_dir)
tokenizer.fit_on_vocab()

# Load test data
fine_tune, fine_test = load_finetune_data(
    config.data_dir,
    config.sequence_file,
    config.id_file,
    config.valid_scheme,
    config.num_finetune_patients,
)

test_dataset = FinetuneMultiDataset(
    data=fine_test,
    tokenizer=tokenizer,
    tasks=config.tasks,
    balance_guide=None,
    max_len=config.max_len,
)

# train_dataset = FinetuneMultiDataset(
#             data=fine_tune,
#             tokenizer=tokenizer,
#             tasks=['los_1week'],
#             balance_guide={'los_1week': 0.5},
#             max_len=config.max_len,
# )

tasks = [test_dataset.index_mapper[i][1] for i in range(len(test_dataset))]

In [None]:
tokenizer.get_vocab_size()

In [None]:
test_outputs = torch.load(config.model_path, map_location=config.device)
test_outputs.keys()

In [None]:
labels = test_outputs["labels"].cpu().numpy()
logits = test_outputs["logits"].cpu().numpy()
probs = torch.sigmoid(torch.tensor(logits[:, 1])).cpu().numpy()
preds = (probs >= 0.5).astype(int)

preds

In [None]:
# Tasks we have are: tasks = ['mortality_1month', 'los_1week', 'c0', 'c1', 'c2']

task_idx = []
for i, task in enumerate(tasks):
    if task == "los_1week":
        task_idx.append(i)


calculate_metrics(labels[task_idx], preds[task_idx], probs[task_idx])

In [None]:
old_test_outputs = torch.load("test_outputs.pt")
target = 2

labels = old_test_outputs["labels"][:, target]
logits = torch.sigmoid(old_test_outputs["logits"][:, target])
preds = (logits >= 0.5).int()
calculate_metrics(labels, preds, logits)

In [None]:
preds.sum()

In [None]:
labels.sum()

In [None]:
preds

In [None]:
# model = torch.load('checkpoints/bigbird_finetune_with_condition/mortality_1month_20000_patients/best.ckpt')
# model['state_dict']['model.bert.embeddings.word_embeddings.weight'].shape