# Interactive Slider for GradCAM Heatmap

Plots an interactive slider for GradCAM heatmaps for all models for a given patient.

### Import Libraries and Modules

In [None]:
%matplotlib inline

import os
import h5py
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from sklearn import metrics
from sklearn.metrics import confusion_matrix, roc_curve, auc

import tensorflow as tf
from tensorflow import keras

print("TF  Version",tf.__version__)

In [None]:
# check and set path before loading modules
print(os.getcwd())
INPUT_DIR = "/tf/notebooks/schnemau/xAI_stroke_3d/"
OUTPUT_DIR = "/tf/notebooks/bule/explainable_AI/"
if os.getcwd() != OUTPUT_DIR:
    os.chdir(OUTPUT_DIR)
    
import functions_model_definition as md
import functions_read_data as rdat
import functions_slider as sl

## Load Data and Set Up Model

In [None]:
# Define the path + output path:
DATA_DIR = INPUT_DIR + "data/"

version = "10Fold_CIB" # one of:
# 10Fold_sigmoid_V0, 10Fold_sigmoid_V1, 10Fold_sigmoid_V2, 10Fold_sigmoid_V2f, 10Fold_sigmoid_V3
# 10Fold_softmax_V0, 10Fold_softmax_V1, andrea
# 10Fold_CIB, 10Fold_CIBLSX
save_plot = False

# Define Model Version
model_version = 1

# define weighting
hm_mode = "wgt" 

# define heatmap type
hm_type = "gc"
pred_hm_only = False

# define paths
WEIGHT_DIR, DATA_OUTPUT_DIR, PIC_OUTPUT_DIR, pic_save_name = rdat.dir_setup(
    OUTPUT_DIR, version, model_version, weight_mode = hm_mode,
    hm_type = hm_type, ending = "_predcl" if pred_hm_only else "_bothcl")

WEIGHT_DIR = INPUT_DIR + "/weights/" + version + "/"

In [None]:
## load images and ids
(X_in, pat, id_tab, all_results, num_models) = rdat.version_setup(
    DATA_DIR = DATA_DIR, version = version, model_version = model_version)

## load patient data
PAT_CSV_DIR = "/tf/notebooks/schnemau/xAI_stroke_3d/data/baseline_data_zurich_prepared0.csv" 
pat_dat = pd.read_csv(PAT_CSV_DIR, sep=";")
pat_dat

## Define Model

In [None]:
if "LSX" in version:
    (X_train, X_valid, X_test),(X_tab_train, X_tab_valid, X_tab_test), (y_train, y_valid, y_test) = rdat.split_data_tabular(
        id_tab, X_in, 1)
    input_dim_tab = X_tab_train.shape[1]
else:
    input_dim_tab = None

# define model
(input_dim, output_dim, LOSS, layer_connection, last_activation) = md.model_setup(version)

model_3d = md.model_init(
    version = version, 
    output_dim = output_dim,
    LOSS = LOSS,
    layer_connection = layer_connection,
    last_activation = last_activation,
    C = 2,
    learning_rate = 5*1e-5,
    batch_size = 6,
    input_dim = input_dim,
    input_dim_tab = input_dim_tab
)

# define if print should be enabled
check_print = True


In [None]:
# Define Model Name
generate_model_name = md.set_generate_model_name(
    model_version = model_version, 
    layer_connection = layer_connection, 
    last_activation = last_activation, 
    path = WEIGHT_DIR)  

In [None]:
import seaborn as sns
all_results.y_pred_unc_w
sns.stripplot(data=all_results.y_pred_std_w)


# GradCam Slider

In [None]:
vis_layers = [i.name for i in model_3d.layers[1:-6]]
vis_layers = [vis_layer for vis_layer in vis_layers if vis_layer.startswith("conv")]

sl.gradcam_interactive_plot(
    471, # patient id
    vis_layers=vis_layers[-1],
    cnn=model_3d, all_results=all_results, pat=pat, X_in=X_in,
    generate_model_name=generate_model_name, num_models=num_models,
    pat_dat=pat_dat,
    pred_hm_only=True) # if False also negative heatmap is shown
