In [1]:
import os
import sys
from typing import Any, Dict

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import plotly.figure_factory as ff
import plotly.graph_objects as go
import seaborn as sns

import pytorch_lightning as pl
import torch
from torch.utils.data import Subset
from lightning.pytorch.loggers import WandbLogger
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.strategies import DeepSpeedStrategy
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torch.utils.data import Subset

from bertviz.transformers_neuron_view import BertModel, BertTokenizer
from bertviz.neuron_view import show
from transformers import AutoTokenizer, AutoModel, utils
from bertviz import model_view, head_view

from sklearn.metrics import (
    auc,
    average_precision_score,
    balanced_accuracy_score,
    f1_score,
    precision_recall_curve,
    precision_score,
    recall_score,
    roc_auc_score,
    roc_curve,
)

utils.logging.set_verbosity_error()  # Suppress standard warnings


ROOT = "/fs01/home/afallah/odyssey/odyssey"
os.chdir(ROOT)

from lib.data import FinetuneDataset
from lib.tokenizer import ConceptTokenizer
from lib.utils import (
    get_run_id,
    load_config,
    load_finetune_data,
    seed_everything,
)
from lib.prediction import (
    load_finetuned_model,
    predict_patient_outcomes
)
from models.big_bird_cehr.model import BigBirdFinetune, BigBirdPretrain
from models.cehr_bert.model import BertFinetune, BertPretrain

[2024-04-10 12:13:14,754] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
class config:
    """ Save the configuration arguments. """
    model_path = 'test_epoch_end.ckpt'
    vocab_dir = 'data/vocab'
    data_dir = 'data/bigbird_data'
    sequence_file = 'patient_sequences/patient_sequences_2048_mortality.parquet'
    id_file = 'patient_id_dict/dataset_2048_mortality_1month.pkl'
    valid_scheme = 'few_shot'
    num_finetune_patients = '20000'
    label_name = 'label_mortality_1month'
    
    max_len = 2048
    batch_size = 1
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
tokenizer = ConceptTokenizer(data_dir=config.vocab_dir)
tokenizer.fit_on_vocab()

In [4]:
model = torch.load(config.model_path, map_location=config.device)
model.keys()

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'MixedPrecision'])

In [5]:
test_outputs = torch.load('test_outputs.pt')
test_outputs

{'loss': tensor(0.1638, dtype=torch.float64),
 'preds': tensor([6, 7, 0,  ..., 7, 7, 7]),
 'labels': tensor([[1., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float64),
 'logits': tensor([[ 2.3418, -2.0781,  0.2194,  ..., -5.3945, -7.1797, -3.4180],
         [-1.6533, -3.4277, -6.8086,  ..., -6.8359, -5.2266, -5.6484],
         [ 1.0947, -3.7930, -6.1094,  ..., -6.3867, -6.6836, -5.5508],
         ...,
         [-2.8223, -3.7285, -4.6797,  ..., -7.9922, -5.6992, -6.7812],
         [-3.7148, -5.6328, -6.7188,  ..., -9.4062, -7.6445, -7.6016],
         [-2.5840, -2.2871, -4.6484,  ..., -7.8633, -4.7539, -6.4648]],
        dtype=torch.float16)}

In [67]:
def calculate_metrics(y_true, y_pred, y_prob):
    """
    Calculate and return performance metrics.
    """
    metrics = {
        "Balanced Accuracy": balanced_accuracy_score(y_true, y_pred),
        "F1 Score": f1_score(y_true, y_pred),
        "Precision": precision_score(y_true, y_pred),
        "Recall": recall_score(y_true, y_pred),
        "AUROC": roc_auc_score(y_true, y_prob),
        "Average Precision Score": average_precision_score(y_true, y_pred)
    }
    
    precision, recall, _ = precision_recall_curve(y_true, y_pred)
    metrics["AUC-PR"] = auc(recall, precision)
    
    return metrics

targets = [10]

for i in targets:
    labels = test_outputs['labels'][:, i]
    logits = torch.sigmoid(test_outputs['logits'][:, i])
    preds = (logits >= 0.5).int()

    print(calculate_metrics(labels, preds, logits))

{'Balanced Accuracy': 0.5, 'F1 Score': 0.0, 'Precision': 0.0, 'Recall': 0.0, 'AUROC': 0.8100258785715974, 'Average Precision Score': 0.001364147006900979, 'AUC-PR': 0.5006820735034505}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [68]:
preds.sum()

tensor(0)

In [69]:
labels.sum()

tensor(34., dtype=torch.float64)

In [9]:
preds

tensor([0, 0, 0,  ..., 0, 0, 0], dtype=torch.int32)