In [None]:
import json
import os
from prediction.utils.scoring import precision, recall, matthews
import numpy as np
import pickle
from torch.utils.data import DataLoader
from prediction.outcome_prediction.Transformer.utils.utils import prepare_dataset, DictLogger
import torch as ch
from torch import optim
import pytorch_lightning as pl
import torchmetrics
from torchmetrics import AUROC

In [None]:
features_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_features_01012023_233050.csv'
labels_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_outcomes_01012023_233050.csv'

In [None]:
model_path = '/Users/jk1/Downloads/checkpoints_opsum_transformer_20230328_004215_cv_2/opsum_transformer_epoch=02_val_auroc=0.9227.ckpt'
model_config_path = '/Users/jk1/Downloads/hyperopt_selected_transformer_20230328_004215.json'

In [None]:
outcome = '3M mRS 0-2'

In [None]:
# load model config from json
model_config = json.load(open(model_config_path, 'r'))
model_config['outcome'] = outcome
model_config['test_size'] = 0.2
model_config['seed'] = 42
model_config['n_splits'] = 5
model_config

In [None]:
from prediction.outcome_prediction.data_loading.data_loader import load_data

pids, train_data, test_data, train_splits, test_features_lookup_table = load_data(features_path, labels_path, outcome, model_config['test_size'], model_config['n_splits'], model_config['seed'])


In [None]:
X_train, y_train = train_data
X_test, y_test = test_data

# Prepare train dataset
train_dataset, _ = prepare_dataset((X_train, X_test, y_train, y_test),
                                              balanced=model_config['balanced'],
                                              rescale=True,
                                              use_gpu=False)

train_loader = DataLoader(train_dataset, batch_size=X_train.shape[0], shuffle=True, drop_last=True)


In [None]:
batch = next(iter(train_loader))
train_sample, _ = batch

In [None]:
background = train_sample[:100]

In [None]:
from torchmetrics import Accuracy


class LitModel(pl.LightningModule):
    def __init__(self, model, lr, wd, train_noise):
        super().__init__()
        self.model = model
        self.lr = lr
        self.wd = wd
        self.train_noise = train_noise
        self.criterion = ch.nn.BCEWithLogitsLoss()
        self.train_accuracy = Accuracy(task='binary')
        self.train_accuracy_epoch = Accuracy(task='binary')
        self.val_accuracy_epoch = Accuracy(task='binary')
        self.train_auroc = AUROC(task="binary")
        self.val_auroc = AUROC(task="binary")

    def training_step(self, batch, batch_idx, mode='train'):
        x, y = batch
        if self.train_noise != 0:
            x = x + ch.randn_like(x) * self.train_noise
        predictions = self.model(x).squeeze().ravel()
        y = y.unsqueeze(1).repeat(1, x.shape[1]).ravel()
        loss = self.criterion(predictions, y.float()).ravel()
        self.train_accuracy(predictions.ravel(), y.ravel())
        self.train_accuracy_epoch(predictions.ravel(), y.ravel())
        # self.train_auroc(ch.sigmoid(predictions.ravel()), y.ravel())
        # self.log("train_auroc", self.train_auroc, on_step=True, on_epoch=False, prog_bar=True)
        # self.log("train_loss", loss, on_step=True, on_epoch=False, prog_bar=True)
        # self.log("train_loss_epoch", loss, on_step=False, on_epoch=True, prog_bar=True)
        # self.log("train_acc_epoch", self.train_accuracy_epoch, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self,batch, batch_idx, mode='train'):
        x, y = batch
        predictions = self.model(x).squeeze().ravel()
        y = y.unsqueeze(1).repeat(1, x.shape[1]).ravel()
        loss = self.criterion(predictions, y.float()).ravel()
        self.val_auroc(ch.sigmoid(predictions.ravel()), y.ravel())
        # self.val_accuracy_epoch(predictions.ravel(), y.ravel())
        self.log("val_auroc", self.val_auroc, on_step=False, on_epoch=True, prog_bar=True)
        # self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        # self.log("val_accuracy", self.val_accuracy_epoch, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def predict_step(self, batch, batch_idx):
        x, y = batch
        predictions = self.model(x).squeeze()
        return predictions

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.wd)
        # optimizer = optim.SGD(self.parameters(), lr=self.lr, weight_decay=self.wd)

        return [optimizer], [optim.lr_scheduler.ExponentialLR(optimizer, 0.99)]


In [None]:
from torch.utils.data import TensorDataset
from sklearn.preprocessing import StandardScaler


def prepare_dataset(X_train, X_val, y_train, y_val, balanced=False):
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train.reshape(-1, 84)).reshape(X_train.shape)
    if balanced:
        X_train_neg = X_train[y_train == 0]
        X_train_pos = X_train[np.random.choice(np.where(y_train==1)[0], X_train_neg.shape[0])]
        X_train = np.concatenate([X_train_neg, X_train_pos])
        y_train = np.concatenate([np.zeros(X_train_neg.shape[0]), np.ones(X_train_pos.shape[0])])
    X_val = scaler.transform(X_val.reshape(-1, 84)).reshape(X_val.shape)
    # train_dataset = TensorDataset(ch.from_numpy(X_train).cuda(), ch.from_numpy(y_train.astype(np.int32)).cuda())
    # val_dataset = TensorDataset(ch.from_numpy(X_val).cuda(), ch.from_numpy(y_val.astype(np.int32)).cuda())
    train_dataset = TensorDataset(ch.from_numpy(X_train), ch.from_numpy(y_train.astype(np.int32)))
    val_dataset = TensorDataset(ch.from_numpy(X_val), ch.from_numpy(y_val.astype(np.int32)))
    return train_dataset, val_dataset

In [None]:
from prediction.outcome_prediction.Transformer.architecture import OPSUMTransformer

ff_factor = 2
ff_dim = ff_factor * model_config['model_dim']
pos_encode_factor = 1

model = OPSUMTransformer(
            input_dim=84,
            num_layers=int(model_config['num_layers']),
            model_dim=int(model_config['model_dim']),
            dropout=int(model_config['dropout']),
            ff_dim=int(ff_dim),
            num_heads=int(model_config['num_head']),
            num_classes=1,
            max_dim=500,
            pos_encode_factor=pos_encode_factor
        )
module = LitModel(model, model_config['lr'], model_config['weight_decay'], model_config['train_noise'])

load model

In [None]:
saved_model = LitModel.load_from_checkpoint(checkpoint_path=model_path, model=model, lr=model_config['lr'], wd=model_config['weight_decay'], train_noise=model_config['train_noise'])

predict with model

In [None]:
saved_model.eval()
with ch.no_grad():
    y_hat = saved_model.predict_step(ch.from_numpy(test_X_np))


In [None]:
model_fold_train_X, _, model_fold_train_y, _ = splits[int(model_config['best_cv_fold'])]

In [None]:
train_dataset, test_dataset = prepare_dataset(model_fold_train_X, test_X_np, model_fold_train_y, test_y_np, balanced=False)

In [None]:
from prediction.outcome_prediction.Transformer.utils import DictLogger
from torch.utils.data import DataLoader

logger = DictLogger(0)

test_loader = DataLoader(test_dataset, batch_size=1024)
trainer = pl.Trainer(accelerator='cpu', devices=1, max_epochs=1000, gradient_clip_val=model_config['grad_clip_value'], logger=logger)
predictions = trainer.predict(saved_model, test_loader)

In [None]:
y_hat = predictions[0]

In [None]:
y_hat = ch.sigmoid(y_hat)
y_hat = y_hat[:, -1]

In [None]:
# compute auc
from sklearn.metrics import roc_auc_score

roc_auc_score(test_y_np, y_hat.numpy())

In [None]:
from sklearn.metrics import roc_auc_score

model
model_y_pred_train = np.where(model_y_train > 0.5, 1, 0).astype('float32')
model_acc_train = accuracy_score(y_train, model_y_pred_train)
model_precision_train = precision(y_train, model_y_pred_train.astype(float)).numpy()
model_sn_train = recall(y_train, model_y_pred_train).numpy()
model_auc_train = roc_auc_score(y_train, model_y_train)
model_mcc_train = matthews_corrcoef(y_train, model_y_pred_train)
model_sp_train = specificity(y_train, model_y_pred_train).numpy()

In [None]:
import numpy as np

y_hat_std = np.std(y_hat.numpy(), axis=1)

In [None]:
np.median(y_hat_std)

In [None]:
np.min(y_hat_std), np.max(y_hat_std)

In [None]:
y_hat_sigm = ch.sigmoid(y_hat)

In [None]:
 # Use the training data for deep explainer => can use fewer instances
explainer = shap.DeepExplainer(saved_model.model, ch.from_numpy(train_X_np))

In [None]:
# explain the testing instances (can use fewer instances)
# explaining each prediction requires 2 * background dataset size runs
shap_values = explainer.shap_values(ch.from_numpy(test_X_np[0:1]))