In [None]:
import argparse
import json

import shap
import os

from prediction.outcome_prediction.LSTM.testing.shap_helper_functions import check_shap_version_compatibility
from prediction.utils.scoring import precision, recall, matthews
import numpy as np
import pickle
from tqdm import tqdm
from prediction.outcome_prediction.data_loading.data_formatting import format_to_2d_table_with_time
import torch as ch
import pytorch_lightning as pl
import torchmetrics
from torchmetrics import AUROC

In [None]:
# Shap values require very specific versions
check_shap_version_compatibility()

In [None]:
features_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_features_01012023_233050.csv'
labels_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_outcomes_01012023_233050.csv'

In [None]:
model_config_path = '/Users/jk1/Downloads/params_opsum_transformer_20230306_221654.json'
model_path = '/Users/jk1/Downloads/opsum_transformer_20230306_221654_epoch=15_val_auroc=0.9012.ckpt'

In [None]:
# load model config from json
model_config = json.load(open(model_config_path, 'r'))
model_config

In [None]:
# load the dataset
X, y = format_to_2d_table_with_time(feature_df_path=features_path, outcome_df_path=labels_path,
                                    outcome=model_config['outcome'])

n_time_steps = X.relative_sample_date_hourly_cat.max() + 1
n_channels = X.sample_label.unique().shape[0]

from sklearn.model_selection import train_test_split
from prediction.outcome_prediction.data_loading.data_formatting import features_to_numpy, \
    link_patient_id_to_outcome, numpy_to_lookup_table

# Reduce every patient to a single outcome (to avoid duplicates)
all_pids_with_outcome = link_patient_id_to_outcome(y, model_config['outcome'])
pid_train, pid_test, y_pid_train, y_pid_test = train_test_split(all_pids_with_outcome.patient_id.tolist(),
                                                                all_pids_with_outcome.outcome.tolist(),
                                                                stratify=all_pids_with_outcome.outcome.tolist(),
                                                                test_size=model_config['test_size'],
                                                                random_state=model_config['seed'])

test_X_df = X[X.patient_id.isin(pid_test)]
test_y_df = y[y.patient_id.isin(pid_test)]
train_X_df = X[X.patient_id.isin(pid_train)]
train_y_df = y[y.patient_id.isin(pid_train)]

train_X_np = features_to_numpy(train_X_df,
                               ['case_admission_id', 'relative_sample_date_hourly_cat', 'sample_label', 'value'])
test_X_np = features_to_numpy(test_X_df,
                              ['case_admission_id', 'relative_sample_date_hourly_cat', 'sample_label', 'value'])
train_y_np = np.array([train_y_df[train_y_df.case_admission_id == cid].outcome.values[0] for cid in
                       train_X_np[:, 0, 0, 0]]).astype('float32')
test_y_np = np.array([test_y_df[test_y_df.case_admission_id == cid].outcome.values[0] for cid in
                      test_X_np[:, 0, 0, 0]]).astype('float32')

# create look-up table for case_admission_ids, sample_labels and relative_sample_date_hourly_cat
test_features_lookup_table = numpy_to_lookup_table(test_X_np)
train_features_lookup_table = numpy_to_lookup_table(train_X_np)

# Remove the case_admission_id, sample_label, and time_step_label columns from the data
test_X_np = test_X_np[:, :, :, -1].astype('float32')
train_X_np = train_X_np[:, :, :, -1].astype('float32')

In [None]:
# define the LightningModule
class LitModel(pl.LightningModule):
    def __init__(self, model, lr, wd, train_noise):
        super().__init__()
        self.model = model
        self.lr = lr
        self.wd = wd
        self.train_noise = train_noise
        self.criterion = ch.nn.BCEWithLogitsLoss()
        self.train_auroc = AUROC(task="binary")
        self.val_auroc = AUROC(task="binary")

    def training_step(self, batch, batch_idx, mode='train'):
        x, y = batch
        if self.train_noise != 0:
            x = x + ch.randn_like(x) * self.train_noise
        predictions = self.model(x).squeeze().ravel()
        y = y.unsqueeze(1).repeat(1, x.shape[1]).ravel()
        loss = self.criterion(predictions, y.float()).ravel()
        self.train_auroc(ch.sigmoid(predictions.ravel()), y.ravel())
        self.log("train_auroc", self.train_auroc, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx, mode='train'):
        x, y = batch
        predictions = self.model(x).squeeze().ravel()
        y = y.unsqueeze(1).repeat(1, x.shape[1]).ravel()
        loss = self.criterion(predictions, y.float()).ravel()
        self.val_auroc(ch.sigmoid(predictions.ravel()), y.ravel())
        self.log("val_auroc", self.val_auroc, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def predict_step(self, x):
        predictions = self.model(x).squeeze().ravel()
        return predictions

    def configure_optimizers(self):
        optimizer = ch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.wd)
        return optimizer

In [None]:
from prediction.outcome_prediction.Transformer.architecture import OPSUMTransformer

ff_dim = model_config['ff_factor'] * model_config['model_dim']


model = OPSUMTransformer(
            input_dim=84,
            num_layers=model_config['num_layers'],
            model_dim=model_config['model_dim'],
            dropout=model_config['dropout'],
            ff_dim=ff_dim,
            num_heads=model_config['num_heads'],
            num_classes=1,
            max_dim=model_config['max_dim'],
            pos_encode_factor=model_config['pos_encode_factor'],
        )

module = LitModel(model, model_config['lr'], model_config['weight_decay'], model_config['train_noise'])

load model

In [None]:
checkpoint = ch.load(model_path, map_location=ch.device('cpu'))

In [None]:
saved_model = LitModel.load_from_checkpoint(checkpoint_path=model_path, model=model, lr=model_config['lr'], wd=model_config['weight_decay'], train_noise=model_config['train_noise'])

predict with model

In [None]:
saved_model.eval()
with ch.no_grad():
    # y_hat = saved_model.predict_step(ch.from_numpy(test_X_np))

In [None]:
import numpy as np

y_hat_std = np.std(y_hat.numpy(), axis=1)

In [None]:
np.median(y_hat_std)

In [None]:
np.min(y_hat_std), np.max(y_hat_std)

In [None]:
y_hat_sigm = ch.sigmoid(y_hat)

In [None]:
 # Use the training data for deep explainer => can use fewer instances
explainer = shap.DeepExplainer(saved_model.model, ch.from_numpy(train_X_np))

In [None]:
# explain the testing instances (can use fewer instances)
# explaining each prediction requires 2 * background dataset size runs
shap_values = explainer.shap_values(ch.from_numpy(test_X_np[0:1]))