# Run Simplex for time series RNN

### Imports

In [39]:
import os
import sys
from random import random
import numpy as np
import pandas as pd
import pickle as pkl
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score

sys.path.append("..")
from explainers.simplex import Simplex
from models.recurrent_neural_net import MortalityGRU
from experiments.time_series_prostate_cancer import (
    TimeSeriesProstateCancerDataset,
    load_time_series_prostate_cancer,
)


### Load the prostate cancer time series data from file

In [11]:
def load_data(random_seed=42, corpus_size=100, batch_size=50):

    # LOAD DATA from file
    (
        X,
        y,
        feature_names,
        max_time_points,
        rescale_dict,
    ) = load_time_series_prostate_cancer()

    # Get data into shape and produce corpus
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.25, random_state=random_seed, stratify=y
    )

    print(f"Train set contains {len(y_train)} records")
    print(f"Test set contains {len(y_test)} records")
    print(f"{sum(y_train == 1)} training records with a label of 1")
    print(f"{sum(y_train == 0)} training records with a label of 0")
    print(f"{sum(y_test == 1)} test records with a label of 1")
    print(f"{sum(y_test == 0)} test records with a label of 0")

    class_imbalance_weighting = sum(y_train == 0) / len(y_train)

    train_data = TimeSeriesProstateCancerDataset(X_train, y_train)
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

    test_data = TimeSeriesProstateCancerDataset(X_test, y_test)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    test_examples = enumerate(test_loader)
    batch_id_test, (test_inputs, test_targets) = next(test_examples)

    corpus_loader = DataLoader(train_data, batch_size=corpus_size, shuffle=False)
    corpus_examples = enumerate(corpus_loader)
    batch_id_corpus, (corpus_inputs, corpus_targets) = next(corpus_examples)

    input_baseline = torch.mean(torch.mean(corpus_inputs, 1), 0).expand(
        100, max_time_points, -1
    )  # Baseline tensor of the same shape as corpus_inputs

    return (
        train_loader,
        test_loader,
        corpus_inputs,
        corpus_targets,
        test_inputs,
        test_targets,
        max_time_points,
        feature_names,
        class_imbalance_weighting,
        input_baseline,
        rescale_dict,
    )

# LOAD data
batch_size = 50
corpus_size = 100

(
    train_loader,
    test_loader,
    corpus_inputs,
    corpus_targets,
    test_inputs,
    test_targets,
    max_time_points,
    feature_names,
    class_imbalance_weighting,
    input_baseline,
    rescale_dict,
) = load_data(random_seed=5, corpus_size=corpus_size, batch_size=batch_size)

Train set contains 438 records
Test set contains 147 records
75 training records with a label of 1
363 training records with a label of 0
25 test records with a label of 1
122 test records with a label of 0


### Get trained mortality model

Use the boolean variable `train_model` to either train a new model or load a pre-trained model from file.

In [19]:
def load_trained_model(model, trained_model_state_path):
    model.load_state_dict(torch.load(trained_model_state_path))
    # model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    model.eval()
    return model


# Get a trained model
save_path = os.path.abspath(
    f"../demonstrator/resources/trained_models/RNN/Time series Prostate Cancer/"
)
cv = 1
# Train model if required
train_model = False

if train_model:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Define parameters
    n_epoch_model = 600
    log_interval = 3
    weight_decay = 1e-6
    # test_size = 100
    # n_keep_list = [2, 5, 10, 50]
    # reg_factor_init = 0.01
    # reg_factor_final = 1.0
    # n_epoch_simplex = 10000

    # Create the model
    classifier = MortalityGRU(
        input_dim=len(feature_names),
        hidden_dim=5,
        output_dim=1,
        n_layers=1,
    )
    print(classifier)
    classifier.to(device)
    class_weights = torch.FloatTensor([class_imbalance_weighting]).to(device)
    criterion = nn.BCELoss(weight=class_weights)
    # weights
    optimizer = optim.Adam(classifier.parameters(), weight_decay=weight_decay)

    # Train the model
    print(100 * "-" + "\n" + "Now fitting the model. \n" + 100 * "-")
    train_losses = []
    train_counter = []
    test_losses = []

    def train(epoch):
        correct = 0
        correct_0 = 0
        correct_1 = 0
        num_targets_0 = 0
        num_targets_1 = 0

        classifier.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.to(device)
            # h = classifier.init_hidden(len(data)).data
            output = classifier(data)  # , h)
            # print(output.data.max())
            target = target.type(torch.LongTensor).unsqueeze(1).to(device)
            optimizer.zero_grad()
            # print(f"output: {output.data}")
            # print(f"target: {target.data}")
            loss = criterion(output, target.float())
            loss.backward()
            optimizer.step()

            train_roc_score = roc_auc_score(
                target.cpu().detach(), output.cpu().detach()
            )
            pred = output.round()
            num_targets_0 += len((target[target == 0]))
            num_targets_1 += len((target[target == 1]))
            correct_0 += float((pred[target == 0] == target[target == 0]).sum())
            correct_1 += float((pred[target == 1] == target[target == 1]).sum())
            correct += float((pred == target).sum())
            if batch_idx % log_interval == 0:
                # print(
                #     f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}"
                #     f" ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
                # )
                train_losses.append(loss.item())
                train_counter.append(
                    (batch_idx * 128) + ((epoch - 1) * len(train_loader.dataset))
                )
                # torch.save(
                #     classifier.state_dict(),
                #     os.path.join(save_path, f"model_cv{cv}.pth"),
                # )  # make sure to save the final step - watch for mod value
                # torch.save(
                #     optimizer.state_dict(),
                #     os.path.join(save_path, f"optimizer_cv{cv}.pth"),
                # )
        print(
            "TRAINING:\n"
            f"correct: {correct}/{len(train_loader.dataset)} ({100. * correct / len(train_loader.dataset):.0f}%)\n"
            f"correct 0s: {correct_0}/{num_targets_0} ({100. * correct_0 / num_targets_0:.0f}%)\n"
            f"correct 1s: {correct_1}/{num_targets_1} ({100. * correct_1 / num_targets_1:.0f}%)\n"
            # f"Training set: Avg. loss: {loss:.4f}, Accuracy: {correct}/{len(train_loader.dataset)}"
            f"Training set: Avg. loss: {loss:.4f}, ROC AUC score: {train_roc_score:.4f}"
        )

    def test():
        classifier.eval()
        test_loss = 0
        correct = 0
        correct_0 = 0
        correct_1 = 0
        num_targets_0 = 0
        num_targets_1 = 0
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(test_loader):
                data = data.to(device)
                # h = classifier.init_hidden(len(data)).data
                output = classifier(data)  # , h)
                target = target.type(torch.LongTensor).unsqueeze(1).to(device)
                test_loss += criterion(output, target.float()).item()
                pred = output.round()
                # print(f"target: {target[target == 0].size()}")
                # print(f"pred: {pred[target == 0].size()}")

                test_roc_score = roc_auc_score(target.cpu(), output.cpu())
                num_targets_0 += len((target[target == 0]))
                num_targets_1 += len((target[target == 1]))
                correct_0 += float((pred[target == 0] == target[target == 0]).sum())
                correct_1 += float((pred[target == 1] == target[target == 1]).sum())
                correct += float((pred == target).sum())

        test_loss /= len(test_loader.dataset)
        test_losses.append(test_loss)
        print(
            "TESTING:\n"
            f"correct: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f})%\n"
            f"correct 0s: {correct_0}/{num_targets_0} ({100. * correct_0 / num_targets_0:.0f})%\n"
            f"correct 1s: {correct_1}/{num_targets_1} ({100. * correct_1 / num_targets_1:.0f})%\n"
            # f"Test set: Avg. loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)}"
            f"Test set:     Avg. loss: {test_loss:.4f}, ROC AUC score: {test_roc_score:.4f}"
        )

    print("Pre-training test.")
    test()
    print("\n\n")
    for epoch in range(1, n_epoch_model + 1):
        print(f"\nepoch: {epoch}")
        train(epoch)
        test()
        torch.save(
            classifier.state_dict(),
            os.path.join(save_path, f"model_cv{cv}.pth"),
        )
        torch.save(
            optimizer.state_dict(),
            os.path.join(save_path, f"optimizer_cv{cv}.pth"),
        )

# Get a trained model
model = MortalityGRU(
    input_dim=len(feature_names),
    hidden_dim=5,
    output_dim=1,
    n_layers=1,
)  # Model should have the BlackBox interface
TRAINED_MODEL_STATE_PATH = os.path.join(save_path, f"model_cv{cv}.pth")
load_trained_model(model, TRAINED_MODEL_STATE_PATH)

# Compute corpus and test model predictions
corpus_predictions = model.forward(corpus_inputs).detach().round()
test_predictions = model.forward(test_inputs).detach().round()


### Define function for sorting examples to match the sorted output from jacobian decomposition

In [13]:
# sort order function for decomposition
def apply_sort_order(in_list, sort_order):
    if isinstance(in_list, list):
        return [in_list[idx] for idx in sort_order]
    if torch.is_tensor(in_list):
        return [in_list.numpy()[idx] for idx in sort_order]




### Fit SimplEx

In [14]:
# Fit SimplEx
# Compute the corpus and test latent representations
corpus_latents = model.latent_representation(corpus_inputs).detach()
test_latents = model.latent_representation(test_inputs).detach()
# Initialize SimplEX, fit it on test examples
simplex = Simplex(corpus_examples=corpus_inputs, corpus_latent_reps=corpus_latents)
simplex.fit(
    test_examples=test_inputs,
    n_epoch=50000,
    test_latent_reps=test_latents,
    reg_factor=0,
)


Weight Fitting Epoch: 10000/50000 ; Error: 0.0823 ; Regulator: 25.6 ; Reg Factor: 0
Weight Fitting Epoch: 20000/50000 ; Error: 0.00583 ; Regulator: 19.6 ; Reg Factor: 0
Weight Fitting Epoch: 30000/50000 ; Error: 0.00413 ; Regulator: 18.2 ; Reg Factor: 0
Weight Fitting Epoch: 40000/50000 ; Error: 0.00411 ; Regulator: 17 ; Reg Factor: 0
Weight Fitting Epoch: 50000/50000 ; Error: 0.00411 ; Regulator: 16.2 ; Reg Factor: 0


### Compute the SimplEx decomposition for test patient i

In [15]:
# Compute the Integrated Jacobian for a particular example
i = 1
simplex.jacobian_projection(test_id=i, model=model, input_baseline=input_baseline)
result, sort_order = simplex.decompose(i, return_id=True)

### Display Test Patient
Use the variable `test_time_steps_to_display` to change the number of time steps is the displayed output. A value of 50 or greater will ensure all available time points are displayed.

In [31]:
# set up dataframes
pd.set_option("display.max_columns", None)
test_time_steps_to_display = 3

# Test patient
test_patient_last_time_step_idx = (
    simplex.test_examples[i][
        ~np.all(simplex.test_examples[i].numpy() == 0, axis=1)
    ].shape[0]
    - 1
)
if (test_patient_last_time_step_idx + 1) - test_time_steps_to_display < 0:
    test_time_steps_to_display = (test_patient_last_time_step_idx + 1)

test_patient_df = pd.DataFrame(
    simplex.test_examples[i][
        test_patient_last_time_step_idx - (test_time_steps_to_display - 1) : test_patient_last_time_step_idx + 1, :
    ].numpy(),
    columns=feature_names,
    index=[f"(t_max) - {i}" if i != 0 else "(t_max)" for i in reversed(range(test_time_steps_to_display))],
)
display(test_patient_df.transpose())

Unnamed: 0,(t_max) - 2,(t_max) - 1,(t_max)
Exact age at diagnosis,0.107004,0.107004,0.107004
Number of negative biopsies before diagnosis,0.0,0.0,0.0
Number of MRI-visible lesions,0.0,0.0,0.0
Days since diagnosis.3,0.436568,0.436568,0.436568
Days Since Diagnosis,0.386841,0.422384,0.437593
Repeat PSA,0.071806,0.088701,0.101901
Repeat Biopsy Core Total,0.313726,0.313726,0.313726
Repeat Biopsy Core Positive,0.333333,0.333333,0.333333
Repeat MRI Volume,0.102222,0.102222,0.102222
Repeat MRI PSAd,0.190955,0.190955,0.190955


### Display Corpus decomposition and importances

Define styling functions for displaying dataframe.

In [None]:
def df_values_to_colors(df):
    """Gets color values based in values relative to all other values in df."""

    min_val = np.nanmin(df.values)
    max_val = np.nanmax(df.values)
    for col in df:
        # map values to colors in hex via
        # creating a hex Look up table table and apply the normalized data to it
        norm = mcolors.Normalize(
            vmin=min_val,
            vmax=max_val,
            # vmin=np.nanmin(df[col].values),
            # vmax=np.nanmax(df[col].values),
            clip=True,
        )
        lut = plt.cm.bwr(np.linspace(0.2, 0.75, 256))
        lut = np.apply_along_axis(mcolors.to_hex, 1, lut)
        a = (norm(df[col].values) * 255).astype(np.int16)
        df[col] = lut[a]
    return df


def highlight(x):
    return pd.DataFrame(importance_df_colors.values, index=x.index, columns=x.columns)

In [48]:
# Corpus of patients

# Variables
example_importance_threshold = 0.1
corpus_time_steps_to_display = 10


# Patient Feature values
last_time_step_idx = [
    result[j][1][~np.all(result[j][1].numpy() == 0, axis=1)].shape[0] - 1
    for j in range(len(result))
]


corpus_dfs = [
    pd.DataFrame(
        result[j][1][idx - (corpus_time_steps_to_display - 1) : idx + 1].numpy(),
        columns=feature_names,
    )
    for j, idx in zip(range(len(result)), last_time_step_idx)
]
for corpus_df in corpus_dfs:
    for col_name, rescale_value in rescale_dict.items():
        corpus_df[col_name] = corpus_df[col_name].apply(lambda x: x * rescale_value)
corpus_data = [
    {
        "feature_vals": corpus_dfs[i].transpose(),
        "Label": apply_sort_order(corpus_targets, sort_order)[i],
        "Prediction": apply_sort_order(corpus_predictions, sort_order)[i],
        "Example Importance": result[i][0],
    }
    for i in range(len(corpus_dfs))
]

# Patient importances
importance_dfs = [
    pd.DataFrame(
        result[j][2][idx - (corpus_time_steps_to_display - 1) : idx + 1].numpy(),
        columns=[f"{col}_fi" for col in feature_names],
    )
    for j, idx in zip(range(len(result)), last_time_step_idx)
]
importance_data = [
    {
        "importance_vals": importance_dfs[i].transpose(),
        "Label": apply_sort_order(corpus_targets, sort_order)[i],
        "Prediction": apply_sort_order(corpus_predictions, sort_order)[i],
        "Example Importance": result[i][0],
    }
    for i in range(len(corpus_dfs))
]

corpus_data = [
    example
    for example in corpus_data
    if example["Example Importance"] >= example_importance_threshold
]
importance_data = [
    example
    for example in importance_data
    if example["Example Importance"] >= example_importance_threshold
]

for example_i in range(len(corpus_data)):
    if (last_time_step_idx[example_i] + 1) - corpus_time_steps_to_display < 0:
        corpus_time_steps_to_display = (last_time_step_idx[example_i] + 1)
    importance_df_colors = df_values_to_colors(
        importance_data[example_i]["importance_vals"].copy()
    )
    importance_df_colors = importance_df_colors.applymap(
            lambda x: f"background-color: {x}"
        )
    display_corpus_df = (
        corpus_data[example_i]["feature_vals"].rename(
            columns={j: t_val for j, t_val in enumerate([f"(t_max) - {i}" if i != 0 else "(t_max)" for i in reversed(range(corpus_time_steps_to_display))])}
        ).style.apply(highlight, axis=None)
    )
    print(f"Corpus Example: {example_i}")
    print(f"Example Importance: {corpus_data[example_i]['Example Importance']}")
    display(display_corpus_df)
    display_importance_df = (
        importance_data[example_i]["importance_vals"].rename(
            columns={j: t_val for j, t_val in enumerate([f"(t_max) - {i}" if i != 0 else "(t_max)" for i in reversed(range(corpus_time_steps_to_display))])}
        )
    )




Corpus Example: 0
Example Importance: 0.8396381139755249


Unnamed: 0,(t_max) - 14,(t_max) - 13,(t_max) - 12,(t_max) - 11,(t_max) - 10,(t_max) - 9,(t_max) - 8,(t_max) - 7,(t_max) - 6,(t_max) - 5,(t_max) - 4,(t_max) - 3,(t_max) - 2,(t_max) - 1,(t_max)
Exact age at diagnosis,38.214377,38.214377,38.214377,38.214377,38.214377,38.214377,38.214377,38.214377,38.214377,38.214377,38.214377,38.214377,38.214377,38.214377,38.214377
Number of negative biopsies before diagnosis,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Number of MRI-visible lesions,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Days since diagnosis.3,4071.403757,4071.403757,4071.403757,4071.403757,4071.403757,4071.403757,4071.403757,4071.403757,4071.403757,4071.403757,4071.403757,4071.403757,4071.403757,4071.403757,4071.403757
Days Since Diagnosis,2778.000024,2878.999933,2948.000082,3058.000003,3148.99992,3257.999859,3358.999949,3458.999877,3565.000053,3577.99999,3634.000022,3822.999904,3947.000104,3998.000049,4074.999867
Repeat PSA,0.90095,0.870919,0.910961,0.730771,0.870919,0.810855,0.750792,0.780824,0.870919,0.870919,1.091151,1.021077,0.991045,0.991045,0.940993
Repeat Biopsy Core Total,16.470589,16.470589,16.470589,16.470589,16.470589,16.470589,16.470589,16.470589,16.470589,16.470589,16.470589,16.470589,16.470589,16.470589,16.470589
Repeat Biopsy Core Positive,6.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0
Repeat MRI Volume,56.882874,56.882874,56.882874,56.882874,56.882874,56.882874,56.882874,56.882874,56.882874,56.882874,56.882874,56.882874,56.882874,56.882874,56.882874
Repeat MRI PSAd,0.136743,0.136743,0.136743,0.136743,0.136743,0.136743,0.136743,0.136743,0.136743,0.136743,0.136743,0.136743,0.136743,0.136743,0.136743


### Save the model to file

In [None]:
# Save SimplEx model
explainer_path = f"../demonstrator/resources/trained_models/RNN/Time series Prostate Cancer/simplex.pkl"

with open(explainer_path, "wb") as f:
    print(f"Saving SimplEx decomposition in {explainer_path}.")

    pkl.dump(simplex, f)
