In [None]:
import json
import shap
from torch.utils.data import DataLoader

from prediction.outcome_prediction.Transformer.architecture import OPSUMTransformer
from prediction.outcome_prediction.Transformer.lightning_wrapper import LitModel
from prediction.outcome_prediction.Transformer.utils.utils import prepare_dataset


In [None]:
features_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_features_01012023_233050.csv'
labels_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_outcomes_01012023_233050.csv'

In [None]:
model_config_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_mrs02/transformer_20230402_184459_test_set_evaluation/hyperopt_selected_transformer_20230402_184459.json'
model_weights_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_mrs02/transformer_20230402_184459_test_set_evaluation/trained_models/checkpoints_opsum_transformer_20230402_184459_cv_2/opsum_transformer_epoch=14_val_auroc=0.9222.ckpt'

In [None]:
outcome = '3M mRS 0-2'
test_size = 0.2
n_splits = 5
seed = 42
n_samples_background = 1

## Load model

In [None]:
# load model config
model_config = json.load(open(model_config_path, 'r'))

In [None]:
ff_factor = 2
ff_dim = ff_factor * model_config['model_dim']
pos_encode_factor = 1

model_architecture = OPSUMTransformer(
    input_dim=84,
    num_layers=int(model_config['num_layers']),
    model_dim=int(model_config['model_dim']),
    dropout=int(model_config['dropout']),
    ff_dim=int(ff_dim),
    num_heads=int(model_config['num_head']),
    num_classes=1,
    max_dim=500,
    pos_encode_factor=pos_encode_factor
)

trained_model = LitModel.load_from_checkpoint(checkpoint_path=model_weights_path, model=model_architecture,
                                              lr=model_config['lr'],
                                              wd=model_config['weight_decay'],
                                              train_noise=model_config['train_noise'])


## Load data

In [None]:
from prediction.outcome_prediction.data_loading.data_loader import load_data

pids, train_data, test_data, train_splits, test_features_lookup_table = load_data(features_path, labels_path, outcome, test_size, n_splits, seed)
fold_X_train, _, fold_y_train, _ = train_splits[int(model_config['best_cv_fold'])]

In [None]:
X_test, y_test = test_data
X_train, y_train = fold_X_train, fold_y_train

## Prepare data

In [None]:
ts = 0

In [None]:
modified_time_steps = ts + 1

X_test_with_first_n_ts = X_test[:, 0:modified_time_steps, :]
X_train_with_first_n_ts = X_train[:, 0:modified_time_steps, :]

train_dataset, test_dataset = prepare_dataset((X_train_with_first_n_ts, X_test_with_first_n_ts, y_train, y_test),
                                              balanced=model_config['balanced'],
                                              rescale=True,
                                              use_gpu=False)

In [None]:
# Prepare background dataset (use all training data in batch size)
train_loader = DataLoader(train_dataset, batch_size=X_train.shape[0], shuffle=True, drop_last=True)

batch = next(iter(train_loader))
train_sample, _ = batch
background = train_sample[:n_samples_background]

test_loader = DataLoader(test_dataset, batch_size=1024)
batch = next(iter(test_loader))
test_samples, _ = batch

In [None]:
# Initialize DeepExplainer
explainer = shap.DeepExplainer(trained_model.model.to(background.device), background)

In [None]:
explainer.shap_values(test_samples)