# Interactive Slider for GradCAM Heatmap

Plots an interactive slider for GradCAM heatmaps for all models for a given patient.

### Import Libraries and Modules

In [None]:
%matplotlib inline

import os
import h5py
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn import metrics
from sklearn.metrics import confusion_matrix, roc_curve, auc

import tensorflow as tf
from tensorflow import keras

print("TF  Version",tf.__version__)

In [None]:
# check and set path before loading modules
print(os.getcwd())
INPUT_DIR = "/tf/notebooks/schnemau/xAI_stroke_3d/"
OUTPUT_DIR = "/tf/notebooks/bule/explainable_AI/"
if os.getcwd() != OUTPUT_DIR:
    os.chdir(OUTPUT_DIR)
    
import functions_model_definition as md
import functions_read_data as rdat
import functions_slider as sl

## Load Data and Set Up Model

In [None]:
## ToDo: 
## - hm_type should always be "gc" in this notebook
## - pred_hm_only, norm_hm and hm_mode should only be adjusted in last chunk
## - pic_save_name is not needed in this notebook: implement dictionary for paths

# Define Version
version = "CIB" # one of:
# 10Fold_sigmoid_V0, 10Fold_sigmoid_V1, 10Fold_sigmoid_V2, 10Fold_sigmoid_V2f, 10Fold_sigmoid_V3
# 10Fold_softmax_V0, 10Fold_softmax_V1, andrea
# 10Fold_CIB, 10Fold_CIBLSX

# Define Model Version
model_version = 2

# define weighting
hm_mode = "wgt" 

# define heatmap type
hm_type = "gc"
norm_hm = False # (gradcam is normalized over all heatmaps)
pred_hm_only = True

# Select naming convention (for CIBLSX model_version >= 3 should be False)
comp_mode = False # if True: use old naming convention

# define paths
DATA_DIR, WEIGHT_DIR, DATA_OUTPUT_DIR, PIC_OUTPUT_DIR, pic_save_name = rdat.dir_setup(
    INPUT_DIR, OUTPUT_DIR, version, model_version, 
    weight_mode = hm_mode, hm_type = hm_type, pred_hm = pred_hm_only, hm_norm = norm_hm,
    compatibility_mode=comp_mode)

In [None]:
## load images and ids
(X_in, pat_ids, id_tab, all_results_tab, pat_orig_tab, pat_norm_tab, num_models) = rdat.version_setup(
    DATA_DIR = DATA_DIR, version = version, model_version = model_version,
    compatibility_mode=comp_mode)

## Define Model

In [None]:
# define model
(input_dim_img, output_dim, LOSS, layer_connection, last_activation) = md.model_setup(version)

model_3d = md.model_init(
    version = version, 
    output_dim = output_dim,
    LOSS = LOSS,
    layer_connection = layer_connection,
    last_activation = last_activation,
    C = 2,
    learning_rate = 5*1e-5,
    batch_size = 6,
    input_dim = input_dim_img,
    input_dim_tab = pat_norm_tab.drop(columns=["p_id"]).shape[1] if "LSX" in version else None,
)

In [None]:
# Define Model Name
generate_model_name = md.set_generate_model_name(
    model_version = model_version, 
    layer_connection = layer_connection, 
    last_activation = last_activation, 
    path = WEIGHT_DIR,
    compatability_mode=comp_mode)  

# GradCam Slider

In [None]:
sl.gradcam_interactive_plot(
    297, # patient id
    vis_layers=md.get_last_conv_layer(model_3d),
    cnn=model_3d, all_results=all_results_tab, pat=pat_ids, X_in=X_in,
    generate_model_name=generate_model_name, num_models=num_models,
    pat_dat=pat_orig_tab,
    model_mode = "weighted", # mean or weighted
    normalize_hm = False, # norm_hm
    pat_norm_table=pat_norm_tab,
    pred_hm_only=True) # if False, also negative heatmap is shown
