In [None]:
import os

import pandas as pd
import torch as ch
import numpy as np
from numba.np.arrayobj import np_concatenate
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from sklearn.metrics import matthews_corrcoef, average_precision_score, roc_auc_score, accuracy_score
from sklearn.preprocessing import StandardScaler
import pytorch_lightning as pl
from prediction.outcome_prediction.Transformer.utils.utils import DictLogger
from prediction.outcome_prediction.Transformer.architecture import OPSUMTransformer, OPSUM_encoder_decoder
from prediction.outcome_prediction.Transformer.lightning_wrapper import LitEncoderRegressionModel, \
    LitEncoderDecoderModel

In [None]:
# data_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_08062024_083500/early_neurological_deterioration_train_data_splits/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth'
data_path = '/Users/jk1/Downloads/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth'
model_path = '/Users/jk1/Downloads/checkpoints_short_opsum_transformer_20240822_051920_cv_1/short_opsum_dec_transformer_epoch=02_val_cos_sim=0.9611.ckpt'
model_hyperparams_path = '/Users/jk1/Downloads/checkpoints_short_opsum_transformer_20240822_051920_cv_1/best_enc_dec_df.csv'

In [None]:
use_gpu = False
n_time_steps = 72
eval_n_time_steps_before_event = 20
target_timeseries_length = 1

In [None]:
model_config = pd.read_csv(model_hyperparams_path)
model_config = model_config.to_dict(orient='records')[0]

In [None]:
model_config

In [None]:
splits = ch.load(os.path.join(data_path))
full_X_train, full_X_val, y_train, y_val = splits[model_config['best_cv_fold']]

In [None]:
full_X_train[0, 0, :, -2]

In [None]:
# find index of max_NIHSS
max_NIHSS_idx = np.where(full_X_train[0, 0, :, -2] == 'max_NIHSS')[0][0]
max_NIHSS_idx

In [None]:
# prepare input data
X_train = full_X_train[:, :, :, -1].astype('float32')
X_val = full_X_val[:, :, :, -1].astype('float32')

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train.reshape(-1, X_train.shape[-1])).reshape(X_train.shape)
X_val = scaler.transform(X_val.reshape(-1, X_train.shape[-1])).reshape(X_val.shape)

In [None]:
accelerator = 'gpu' if use_gpu else 'cpu'

ff_factor = 2
ff_dim = ff_factor * model_config['model_dim']
pos_encode_factor = 1

input_dim = X_val.shape[-1]

logger = DictLogger(0)
trainer = pl.Trainer(accelerator=accelerator, devices=1, max_epochs=1000,
                     gradient_clip_val=model_config['grad_clip_value'], logger=logger)

model_architecture = OPSUM_encoder_decoder(
            input_dim=input_dim,
            num_layers=int(model_config['num_layers']),
            num_decoder_layers=int(model_config['num_decoder_layers']),
            model_dim=int(model_config['model_dim']),
            dropout=int(model_config['dropout']),
            ff_dim=int(ff_dim),
            num_heads=int(model_config['num_head']),
            pos_encode_factor=pos_encode_factor,
            n_tokens=1,
            max_dim=5000,
            layer_norm_eps=1e-05)

trained_model = LitEncoderDecoderModel.load_from_checkpoint(checkpoint_path=model_path, model=model_architecture,
                                              lr=model_config['lr'],
                                              wd=model_config['weight_decay'],
                                              train_noise=model_config['train_noise'],
                                            lr_warmup_steps=model_config['n_lr_warm_up_steps'])

In [None]:
ts = 0
modified_time_steps = ts+1

In [None]:
X_val_with_first_n_ts = X_val[:, 0:modified_time_steps, :]
y_placeholder = ch.zeros((X_val_with_first_n_ts.shape[0], 1))
if use_gpu:
    val_dataset = TensorDataset(ch.from_numpy(X_val_with_first_n_ts).cuda(), y_placeholder.cuda())
else:
    val_dataset = TensorDataset(ch.from_numpy(X_val_with_first_n_ts), y_placeholder)

val_loader = DataLoader(val_dataset, batch_size=1024)
if ts == 0:
    y_pred = np.array(trainer.predict(trained_model, val_loader)[0])
else:
    y_pred = np.array(trainer.predict(trained_model, val_loader)[0][:, -1])

In [None]:
def predict_n_next_steps(input_data, n_steps, model, trainer):
    """
    Predict the next n_steps using the model and trainer.
    :param input_data: The input data to predict on.
    :param n_steps: The number of steps to predict.
    :param model: The model to use for prediction.
    :param trainer: The trainer to use for prediction.
    :return: The predictions for the next n_steps.
    """

    predictions = []
    for i in tqdm(range(n_steps)):
        # first predictions only relies on past date
        if i == 0:
            input_np = input_data
        else:
            # append last prediction to input
            input_np = np.concatenate([input_np, np.expand_dims(predictions[-1], axis=1)], axis=1)

        if use_gpu:
            input_dataset = TensorDataset(ch.from_numpy(input_np).cuda(), y_placeholder.cuda())
        else:
            input_dataset = TensorDataset(ch.from_numpy(input_np), y_placeholder)

        input_loader = DataLoader(input_dataset, batch_size=1024)


        y_pred = np.array(trainer.predict(model, input_loader)[0][:, -1])

        # append prediction to list
        predictions.append(y_pred)

    predictions_np = np.concatenate([np.expand_dims(predictions[i], axis=1) for i in range(len(predictions))], axis=1)
    return predictions_np

In [None]:
# predict recursively for eval_n_time_steps_before_event time steps
X_val_with_first_n_ts = X_val[:, 0:modified_time_steps, :]
y_placeholder = ch.zeros((X_val_with_first_n_ts.shape[0], 1))

predictions_np = predict_n_next_steps(X_val_with_first_n_ts, eval_n_time_steps_before_event, trained_model, trainer)


In [None]:
predictions_np.shape

In [None]:
var_idx = max_NIHSS_idx

In [None]:
# for every subject, plot the actual values of the variable over time
import matplotlib.pyplot as plt
import seaborn as sns

for i in range(X_val.shape[0]):

    plt.plot(X_val[i, :, var_idx])
    # predictions start at time step modified_time_steps
    plt.plot(np.arange(modified_time_steps, modified_time_steps + eval_n_time_steps_before_event),
             predictions_np[i, :, var_idx], 'ro-')
    plt.title(f'Subject {i}, Variable {var_idx}')

    plt.xlabel('Time step')
    plt.ylabel('Value')
    plt.ylim(-2.5, 2.5)

    plt.legend(['Actual', 'Predicted'])
    plt.show()
