In [None]:
import os

import pandas as pd
import torch as ch
import numpy as np
from numba.np.arrayobj import np_concatenate
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from sklearn.metrics import matthews_corrcoef, average_precision_score, roc_auc_score, accuracy_score
from sklearn.preprocessing import StandardScaler
import pytorch_lightning as pl
from prediction.outcome_prediction.Transformer.utils.utils import DictLogger
from prediction.outcome_prediction.Transformer.architecture import OPSUMTransformer, OPSUM_encoder_decoder
from prediction.outcome_prediction.Transformer.lightning_wrapper import LitEncoderRegressionModel, \
    LitEncoderDecoderModel

In [None]:
# data_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_08062024_083500/early_neurological_deterioration_train_data_splits/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth'
data_path = '/Users/jk1/Downloads/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth'
model_path = '/Users/jk1/Downloads/checkpoints_short_opsum_transformer_20240822_051920_cv_1/short_opsum_dec_transformer_epoch=02_val_cos_sim=0.9611.ckpt'
model_hyperparams_path = '/Users/jk1/Downloads/checkpoints_short_opsum_transformer_20240822_051920_cv_1/best_enc_dec_df.csv'
normalisation_data_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_08062024_083500/logs_08062024_083500/normalisation_parameters.csv'
outcome_data_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_08062024_083500/preprocessed_outcomes_short_term_08062024_083500.csv'

In [None]:
use_gpu = False
n_time_steps = 72
eval_n_time_steps_before_event = 6

In [None]:
model_config = pd.read_csv(model_hyperparams_path)
model_config = model_config.to_dict(orient='records')[0]

In [None]:
model_config

In [None]:
splits = ch.load(os.path.join(data_path))
full_X_train, full_X_val, y_train, y_val = splits[model_config['best_cv_fold']]


In [None]:
normalisation_parameters_df = pd.read_csv(normalisation_data_path)
outcome_df = pd.read_csv(outcome_data_path)

In [None]:
full_X_train[0, 0, :, -2]

In [None]:
# find index of max_NIHSS
max_NIHSS_idx = np.where(full_X_train[0, 0, :, -2] == 'max_NIHSS')[0][0]
# find index of min_NIHSS
min_NIHSS_idx = np.where(full_X_train[0, 0, :, -2] == 'min_NIHSS')[0][0]

# find heart_rate
heart_rate_idx = np.where(full_X_train[0, 0, :, -2] == 'max_heart_rate')[0][0]
# find systolic_blood_pressure
systolic_blood_pressure_idx = np.where(full_X_train[0, 0, :, -2] == 'median_systolic_blood_pressure')[0][0]

max_NIHSS_idx, min_NIHSS_idx

In [None]:
# prepare input data
X_train = full_X_train[:, :, :, -1].astype('float32')
X_val = full_X_val[:, :, :, -1].astype('float32')

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train.reshape(-1, X_train.shape[-1])).reshape(X_train.shape)
X_val = scaler.transform(X_val.reshape(-1, X_train.shape[-1])).reshape(X_val.shape)

In [None]:
accelerator = 'gpu' if use_gpu else 'cpu'

ff_factor = 2
ff_dim = ff_factor * model_config['model_dim']
pos_encode_factor = 1

input_dim = X_val.shape[-1]

logger = DictLogger(0)
trainer = pl.Trainer(accelerator=accelerator, devices=1, max_epochs=1000,
                     gradient_clip_val=model_config['grad_clip_value'], logger=logger)

model_architecture = OPSUM_encoder_decoder(
            input_dim=input_dim,
            num_layers=int(model_config['num_layers']),
            num_decoder_layers=int(model_config['num_decoder_layers']),
            model_dim=int(model_config['model_dim']),
            dropout=int(model_config['dropout']),
            ff_dim=int(ff_dim),
            num_heads=int(model_config['num_head']),
            pos_encode_factor=pos_encode_factor,
            n_tokens=1,
            max_dim=5000,
            layer_norm_eps=1e-05)

trained_model = LitEncoderDecoderModel.load_from_checkpoint(checkpoint_path=model_path, model=model_architecture,
                                              lr=model_config['lr'],
                                              wd=model_config['weight_decay'],
                                              train_noise=model_config['train_noise'],
                                              loss_function='mse',
                                            lr_warmup_steps=model_config['n_lr_warm_up_steps'])

In [None]:
ts = 20
modified_time_steps = ts+1

In [None]:
# X_val_with_first_n_ts = X_val[:, 0:modified_time_steps, :]
# y_placeholder = ch.zeros((X_val_with_first_n_ts.shape[0], 1))
# if use_gpu:
#     val_dataset = TensorDataset(ch.from_numpy(X_val_with_first_n_ts).cuda(), y_placeholder.cuda())
# else:
#     val_dataset = TensorDataset(ch.from_numpy(X_val_with_first_n_ts), y_placeholder)

# val_loader = DataLoader(val_dataset, batch_size=1024)
# if ts == 0:
#     y_pred = np.array(trainer.predict(trained_model, val_loader)[0])
# else:
#     y_pred = np.array(trainer.predict(trained_model, val_loader)[0][:, -1])

In [None]:
def predict_n_next_steps(input_data, n_steps, model, trainer):
    """
    Predict the next n_steps using the model and trainer.
    :param input_data: The input data to predict on.
    :param n_steps: The number of steps to predict.
    :param model: The model to use for prediction.
    :param trainer: The trainer to use for prediction.
    :return: The predictions for the next n_steps.
    """

    predictions = []
    for i in tqdm(range(n_steps)):
        # first predictions only relies on past date
        if i == 0:
            input_np = input_data
        else:
            # append last prediction to input
            input_np = np.concatenate([input_np, np.expand_dims(predictions[-1], axis=1)], axis=1)

        if use_gpu:
            input_dataset = TensorDataset(ch.from_numpy(input_np).cuda(), y_placeholder.cuda())
        else:
            input_dataset = TensorDataset(ch.from_numpy(input_np), y_placeholder)

        input_loader = DataLoader(input_dataset, batch_size=1024)


        y_pred = np.array(trainer.predict(model, input_loader)[0][:, -1])

        # append prediction to list
        predictions.append(y_pred)

    predictions_np = np.concatenate([np.expand_dims(predictions[i], axis=1) for i in range(len(predictions))], axis=1)
    return predictions_np

In [None]:
# predict recursively for eval_n_time_steps_before_event time steps
X_val_with_first_n_ts = X_val[:, 0:modified_time_steps, :]
y_placeholder = ch.zeros((X_val_with_first_n_ts.shape[0], 1))

predictions_np = predict_n_next_steps(X_val_with_first_n_ts, eval_n_time_steps_before_event, trained_model, trainer)


In [None]:
predictions_np.shape

In [None]:
var_idx = max_NIHSS_idx

In [None]:
# for every subject, plot the actual values of the variable over time
import matplotlib.pyplot as plt
import seaborn as sns

for i in range(X_val.shape[0]):
# for i in range(5):

    plt.plot(X_val[i, :, var_idx])
    # predictions start at time step modified_time_steps
    plt.plot(np.arange(modified_time_steps, modified_time_steps + eval_n_time_steps_before_event),
             predictions_np[i, :, var_idx], 'ro-')
    plt.title(f'Subject {i}, Variable {var_idx}')

    plt.xlabel('Time step')
    plt.ylabel('Value')
    plt.ylim(-2.5, 2.5)

    # add a light grey vertical arrow pointing at at modified_time_steps
    plt.axvline(x=modified_time_steps, color='lightgrey', linestyle='--')

    plt.legend(['Actual', 'Predicted'])
    plt.show()


get delta NIHSS from current timestep to last predicted timestep

In [None]:
def reverse_normalisation(data, variable_name, normalisation_parameters_df):
    """
    Reverse normalisation of the data.
    :param data: The data to reverse normalise.
    :param variable_name: The name of the variable to reverse normalise.
    :return: The reverse normalised data.
    """
    # Get the original mean and std from the normalisation parameters
    # Reverse normalisation
    std = normalisation_parameters_df[normalisation_parameters_df.variable == variable_name].original_std.iloc[0]
    mean = normalisation_parameters_df[normalisation_parameters_df.variable == variable_name].original_mean.iloc[0]
    data = (data * std) + mean
    return data

In [None]:
norm_min_NIHSS_up_to_current_timestep = np.min(full_X_val[:, 0:ts+1, min_NIHSS_idx, -1], axis=1)
reverse_scaled_predictions = scaler.inverse_transform(predictions_np.reshape(-1, X_train.shape[-1])).reshape(predictions_np.shape)
norm_max_NIHSS_at_last_prediction_timestep = reverse_scaled_predictions[:, -1, max_NIHSS_idx]
norm_delta_NIHSS_at_last_predicted_ts = norm_max_NIHSS_at_last_prediction_timestep - norm_min_NIHSS_up_to_current_timestep

max_NIHSS_at_last_prediction_timestep = reverse_normalisation(norm_max_NIHSS_at_last_prediction_timestep, 'max_NIHSS', normalisation_parameters_df)
min_NIHSS_up_to_current_timestep = reverse_normalisation(norm_min_NIHSS_up_to_current_timestep, 'min_NIHSS', normalisation_parameters_df)
delta_NIHSS_at_last_predicted_ts = max_NIHSS_at_last_prediction_timestep - min_NIHSS_up_to_current_timestep

compute metrics at a single timestep

In [None]:
val_patient_cids = full_X_val[:, 0, 0, 0]

In [None]:
evaluated_ts = ts + eval_n_time_steps_before_event

outcome_at_evaluated_ts_df = outcome_df[outcome_df['relative_sample_date_hourly_cat'] == evaluated_ts]
# outcome_at_evaluated_ts_df = outcome_df[(outcome_df['relative_sample_date_hourly_cat'] > ts) & (outcome_df['relative_sample_date_hourly_cat'] <= evaluated_ts)]
# gt at ts is 0/1 if the patient is in the outcome group at the evaluated time step
y_true_at_evaluated_ts = np.isin(val_patient_cids, outcome_at_evaluated_ts_df['case_admission_id'].values).astype(np.int32)

y_pred = norm_delta_NIHSS_at_last_predicted_ts
y_pred_binary = delta_NIHSS_at_last_predicted_ts >= 4

# if len(np.unique(y_true)) == 1:
#     roc_scores.append(np.nan)
#     auprc_scores.append(np.nan)
#     mcc_scores.append(np.nan)
# else:
#     roc_scores.append(roc_auc_score(y_true, y_pred))
#     auprc_scores.append(average_precision_score(y_true, y_pred))
#     mcc_scores.append(matthews_corrcoef(y_true, y_pred_binary))
print(f'ROC AUC: {roc_auc_score(y_true_at_evaluated_ts, y_pred):.4f}')
print(f'AUPRC: {average_precision_score(y_true_at_evaluated_ts, y_pred):.4f}')
print(f'MCC: {matthews_corrcoef(y_true_at_evaluated_ts, y_pred_binary):.4f}')
print(f'Accuracy: {accuracy_score(y_true_at_evaluated_ts, y_pred_binary):.4f}')

# number FP / number TP / number TN / number FN
from sklearn.metrics import confusion_matrix
tn, fp, fn, tp = confusion_matrix(y_true_at_evaluated_ts, y_pred_binary).ravel()
print(f'FP: {fp}, TP: {tp}, TN: {tn}, FN: {fn}')


loop through all timesteps

In [None]:
# compute predictions
pred_over_ts = []
# for ts in tqdm(range(n_time_steps)):
for ts in tqdm(range(10)):
    modified_time_steps = ts+1
    # predict recursively for eval_n_time_steps_before_event time steps
    X_val_with_first_n_ts = X_val[:, 0:modified_time_steps, :]
    y_placeholder = ch.zeros((X_val_with_first_n_ts.shape[0], 1))

    predictions_np = predict_n_next_steps(X_val_with_first_n_ts, eval_n_time_steps_before_event, trained_model, trainer)
    pred_over_ts.append(predictions_np)

pred_over_ts_np = np.squeeze(pred_over_ts)
    


In [None]:
full_X_val.shape, X_val.shape

In [None]:
roc_scores = []
auprc_scores = []
mcc_scores = []
accuracy_scores = []
# count number of positive samples for each time step
n_pos_samples = []
timesteps = []

overall_prediction_df = pd.DataFrame(columns=['timestep', 'prediction', 'true_label'])

# for ts in range(n_time_steps):
for ts in range(10):
    evaluated_ts = ts + eval_n_time_steps_before_event

    # GT at evaluated time step
    outcome_at_evaluated_ts_df = outcome_df[outcome_df['relative_sample_date_hourly_cat'] == evaluated_ts]
    # gt at ts is 0/1 if the patient is in the outcome group at the evaluated time step
    y_true_at_evaluated_ts = np.isin(val_patient_cids, outcome_at_evaluated_ts_df['case_admission_id'].values).astype(np.int32)

    # prediction at evaluated time step
    predictions_at_ts_np = pred_over_ts_np[ts]

    # prediction at evaluated time step
    # norm_min_NIHSS_up_to_current_timestep = np.min(X_val[:, 0:ts+1, min_NIHSS_idx], axis=1)
    # norm_max_NIHSS_at_last_prediction_timestep = pred_over_ts_np[ts, :, -1, max_NIHSS_idx]
    norm_min_NIHSS_up_to_current_timestep = np.min(full_X_val[:, 0:ts+1, min_NIHSS_idx, -1], axis=1)
    reverse_scaled_predictions = scaler.inverse_transform(predictions_at_ts_np.reshape(-1, X_train.shape[-1])).reshape(predictions_at_ts_np.shape)
    norm_max_NIHSS_at_last_prediction_timestep = reverse_scaled_predictions[:, -1, max_NIHSS_idx]
    norm_delta_NIHSS_at_last_predicted_ts = norm_max_NIHSS_at_last_prediction_timestep - norm_min_NIHSS_up_to_current_timestep

    max_NIHSS_at_last_prediction_timestep = reverse_normalisation(norm_max_NIHSS_at_last_prediction_timestep, 'max_NIHSS', normalisation_parameters_df)
    min_NIHSS_up_to_current_timestep = reverse_normalisation(norm_min_NIHSS_up_to_current_timestep, 'min_NIHSS', normalisation_parameters_df)
    delta_NIHSS_at_last_predicted_ts = max_NIHSS_at_last_prediction_timestep - min_NIHSS_up_to_current_timestep

    # y_pred = norm_delta_NIHSS_at_last_predicted_ts
    y_pred = delta_NIHSS_at_last_predicted_ts
    y_pred_binary = delta_NIHSS_at_last_predicted_ts >= 4

    timestep_df = pd.DataFrame({'timestep': [ts] * len(y_true_at_evaluated_ts),
                                        'prediction': y_pred,
                                        'true_label': y_true_at_evaluated_ts})
    overall_prediction_df = pd.concat([overall_prediction_df, timestep_df])

    timesteps.append(ts)
    n_pos_samples.append(np.sum(y_true_at_evaluated_ts))
    accuracy_scores.append(accuracy_score(y_true_at_evaluated_ts, y_pred_binary))

    if len(np.unique(y_true_at_evaluated_ts)) == 1:
        roc_scores.append(np.nan)
        auprc_scores.append(np.nan)
        mcc_scores.append(np.nan)
    else:
        roc_scores.append(roc_auc_score(y_true_at_evaluated_ts, y_pred))
        auprc_scores.append(average_precision_score(y_true_at_evaluated_ts, y_pred))
        mcc_scores.append(matthews_corrcoef(y_true_at_evaluated_ts, y_pred_binary))
    

In [None]:
overall_prediction_df

In [None]:
# Ensure true_label is binary
overall_prediction_df['true_label'] = overall_prediction_df['true_label'].astype(int)
# Ensure prediction is a continuous value between 0 and 1
overall_prediction_df['prediction'] = overall_prediction_df['prediction'].astype(float)


# compute overall metrics
overall_results_df = pd.DataFrame({'overall_roc': roc_auc_score(overall_prediction_df.true_label,
                                                                    overall_prediction_df.prediction),
                                'overall_auprc': average_precision_score(overall_prediction_df.true_label,
                                                                        overall_prediction_df.prediction),
                                'overall_mcc': matthews_corrcoef(overall_prediction_df.true_label,
                                                                    overall_prediction_df.prediction >= 4),
                                'overall_accuracy': accuracy_score(overall_prediction_df.true_label,
                                                                    overall_prediction_df.prediction >= 4),
                                'n_pos_samples': np.sum(overall_prediction_df.true_label),
                                'n_samples': len(overall_prediction_df),
                                'cv_fold': model_config['best_cv_fold']
                            }, index=[0])

In [None]:
overall_results_df

In [None]:
median_results_df = pd.DataFrame({'median_roc': np.nanmedian(roc_scores),
                                        'median_auprc': np.nanmedian(auprc_scores),
                                        'median_mcc': np.nanmedian(mcc_scores),
                                       'median_accuracy': np.nanmedian(accuracy_scores),
                                       'n_pos_samples': np.nanmedian(n_pos_samples),
                                   }, index=[0])

median_results_df