In [None]:
import os

import pandas as pd
import torch as ch
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from sklearn.metrics import matthews_corrcoef, average_precision_score, roc_auc_score, accuracy_score
from sklearn.preprocessing import StandardScaler
import pytorch_lightning as pl
from prediction.outcome_prediction.Transformer.utils.utils import DictLogger
from prediction.outcome_prediction.Transformer.architecture import OPSUMTransformer
from prediction.outcome_prediction.Transformer.lightning_wrapper import LitEncoderRegressionModel

In [None]:
data_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_08062024_083500/early_neurological_deterioration_train_data_splits/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth'
# data_path = '/Users/jk1/Downloads/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth'
model_path = '/Users/jk1/temp/opsum_end/training/hyperopt/tte_gridsearch/best_05_03_2025/checkpoints_short_opsum_transformer_tte_20250109_235614_cv_1/short_opsum_transformer_tte_epoch=07_val_mae=3.7831.ckpt'
model_hyperparams_path = '/Users/jk1/temp/opsum_end/training/hyperopt/tte_gridsearch/best_05_03_2025/checkpoints_short_opsum_transformer_tte_20250109_235614_cv_1/tte_end_transformer_best_hyperparameters.csv'
predictions_path = '/Users/jk1/Downloads/tte_validation_evaluation_results_6h/predictions.pt'

In [None]:
use_gpu = False
n_time_steps = 72
eval_n_time_steps_before_event = 6

In [None]:
model_config = pd.read_csv(model_hyperparams_path)
model_config = model_config.to_dict(orient='records')[0]

In [None]:
model_config

In [None]:
splits = ch.load(os.path.join(data_path))

In [None]:
full_X_train, full_X_val, y_train, y_val = splits[model_config['best_cv_fold']]

In [None]:
# prepare input data
X_train = full_X_train[:, :, :, -1].astype('float32')
X_val = full_X_val[:, :, :, -1].astype('float32')

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train.reshape(-1, X_train.shape[-1])).reshape(X_train.shape)
X_val = scaler.transform(X_val.reshape(-1, X_train.shape[-1])).reshape(X_val.shape)

Load model

In [None]:
accelerator = 'gpu' if use_gpu else 'cpu'

ff_factor = 2
ff_dim = ff_factor * model_config['model_dim']
pos_encode_factor = 1

input_dim = X_val.shape[-1]

logger = DictLogger(0)
trainer = pl.Trainer(accelerator=accelerator, devices=1, max_epochs=1000,
                     gradient_clip_val=model_config['grad_clip_value'], logger=logger)


model_architecture = OPSUMTransformer(
            input_dim=input_dim,
            num_layers=int(model_config['num_layers']),
            model_dim=int(model_config['model_dim']),
            dropout=int(model_config['dropout']),
            ff_dim=int(ff_dim),
            num_heads=int(model_config['num_head']),
            num_classes=1,
            max_dim=500,
            pos_encode_factor=pos_encode_factor
        )

trained_model = LitEncoderRegressionModel.load_from_checkpoint(checkpoint_path=model_path, model=model_architecture,
                                              lr=model_config['lr'],
                                              wd=model_config['weight_decay'],
                                              train_noise=model_config['train_noise'])

In [None]:
if predictions_path == '':
    pred_over_ts = []
    for ts in tqdm(range(n_time_steps)):
        modified_time_steps = ts + 1
    
        X_val_with_first_n_ts = X_val[:, 0:modified_time_steps, :]
        y_placeholder = ch.zeros((X_val_with_first_n_ts.shape[0], 1))
        if use_gpu:
            val_dataset = TensorDataset(ch.from_numpy(X_val_with_first_n_ts).cuda(), y_placeholder.cuda())
        else:
            val_dataset = TensorDataset(ch.from_numpy(X_val_with_first_n_ts), y_placeholder)
    
        val_loader = DataLoader(val_dataset, batch_size=1024)
        if ts == 0:
            y_pred = np.array(trainer.predict(trained_model, val_loader)[0])
        else:
            y_pred = np.array(trainer.predict(trained_model, val_loader)[0][:, -1])
    
        pred_over_ts.append(np.squeeze(y_pred))
        pred_over_ts_np = np.squeeze(pred_over_ts).T
    
else:
    predictions_data = ch.load(predictions_path)
    pred_over_ts_np = np.squeeze(predictions_data).T

        

In [None]:
pred_over_ts_np.shape

In [None]:
X_val.shape

In [None]:
# construct classification y 
y_val_list = []
for cid in full_X_val[:, 0, 0, 0]:
    if cid not in y_val.case_admission_id.values:
        cid_y = np.zeros(n_time_steps)
    else:
        cid_event_ts = y_val[y_val.case_admission_id == cid].relative_sample_date_hourly_cat.values
        if cid_event_ts < (eval_n_time_steps_before_event + 1):
            # if the event occurs before a detection window, ignore the patient
            cid_y = np.array([])
        else:
            # let y be 0s until 6 hours before the event then stop the series
            cid_y = np.zeros(int(cid_event_ts) - eval_n_time_steps_before_event - 1)
            cid_y = np.append(cid_y, 1)
    
    y_val_list.append(cid_y)

In [None]:
# compute roc scores for each time step
roc_scores = []
auprc_scores = []
mcc_scores = []

for ts in range(n_time_steps):
    pts_idx = [i for i, y in enumerate(y_val_list) if len(y) > ts]
    y_true = np.array([y[ts] for y in y_val_list if len(y) > ts])
    y_pred = pred_over_ts_np[pts_idx, ts]
    y_pred_bin = y_pred < eval_n_time_steps_before_event
    
    if len(np.unique(y_true)) == 1:
        roc_scores.append(np.nan)
        auprc_scores.append(np.nan)
        mcc_scores.append(np.nan)

    else:
        roc_scores.append(roc_auc_score(y_true, y_pred_bin))
        auprc_scores.append(average_precision_score(y_true, y_pred_bin))
        mcc_scores.append(matthews_corrcoef(y_true, np.where(y_pred_bin > 0.5, 1, 0)))


In [None]:
roc_scores

In [None]:
pred_over_ts_np.shape[0]

In [None]:
# for every patient plot the prediction over time
import matplotlib.pyplot as plt
import seaborn as sns

for i in range(pred_over_ts_np.shape[0]):
    cid = full_X_val[i, 0, 0, 0]
    ax = sns.scatterplot(x=range(1, pred_over_ts_np.shape[1] + 1), y=pred_over_ts_np[i, :], hue=pred_over_ts_np[i, :] < 6)
    ax.set_ylim(0, 80)
    ax.set_title(f'Prediction over time for patient {cid}')
    
    if cid in y_val.case_admission_id.values:
        ax.axvline(x=y_val[y_val.case_admission_id == cid].relative_sample_date_hourly_cat.values, color='red')
    
    plt.show()

In [None]:
np.log(0+1)