# Heatmaps
First prototypes <br>
date: Nov 2, 2021

In [None]:
%config Completer.use_jedi = False
%matplotlib inline

In [None]:
import numpy as np
import pandas as pd
import cv2
import copy
import time
import ast
import json
import tensorflow as tf
from sklearn.metrics import confusion_matrix
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.figure_factory as ff

In [None]:
import sys
import os

sys.path.append(os.path.abspath('../data_utils'))
sys.path.append(os.path.abspath('../models'))

In [None]:
from model_zoo.utils import check_gpu
from model_zoo.losses.dice import DiceLoss, DiceCoefficient
from data_utils.DataContainer import DataContainer
from data_utils.TestSet import TestSet

In [None]:
import logging
logging.basicConfig(level=logging.INFO)

# TestSet

In [None]:
MODELS_SIMPLE1 = ["XNet_T2_relu", "XNet_T2_leaky", "XNet_T2_selu"]
MODELS_SIMPLE2 = ["XNet_T1_relu", "XNet_T1_leaky", "XNet_T1_selu"]
MODELS_SIMPLE = [*MODELS_SIMPLE2, *MODELS_SIMPLE1]
MODELS_CG = ["CG_XNet_T1_relu", "CG_XNet_T2_relu"]
MODELS_DA = ["SegmS2T_GAN1_relu", "SegmS2T_GAN2_relu", "SegmS2T_GAN5_relu"]
MODELS_GAN = ["GAN_1+XNet_T1_relu", "GAN_2+XNet_T1_relu", "GAN_5+XNet_T1_relu"]
MODELS = [*MODELS_SIMPLE, *MODELS_CG, *MODELS_GAN, *MODELS_DA]
MODELS_BASELINE = [*MODELS_SIMPLE, *MODELS_CG]
MODELS_DA = [*MODELS_DA, *MODELS_GAN]

In [None]:
testset = TestSet("/tf/workdir/data/VS_segm/VS_registered/test_processed/", load=True)

In [None]:
df_total = testset.df_total
intermediate = df_total.to_json()
df_total = pd.read_json(intermediate)

In [None]:
df_total

# First overview heatmap

In [None]:
slice_type="only_tumor"
slider_values = [0,1.0]
metric = "DSC"
patient_id = 200

In [None]:
models = testset.all_models
models = models[:3]
models.append("Baseline")
models

In [None]:
base_colorscale = {"DSC": [*px.colors.sequential.Plasma],
                  "ASSD": [*px.colors.sequential.Plasma_r],
                  "ACC": [*px.colors.sequential.Plasma],
                  "TPR": [*px.colors.sequential.Plasma],
                  "TNR": [*px.colors.sequential.Plasma]
                  }
lookup_color = base_colorscale[metric]

steps = (slider_values[1]-slider_values[0])/9
colorscale = []
if slider_values[0] != 0:
    colorscale.append([0, lookup_color[0]])
for idx, x in enumerate(np.arange(slider_values[0],slider_values[1],steps)):
    colorscale.append([x/1, lookup_color[idx]])
colorscale.append([slider_values[1]/1, lookup_color[-1]])
if slider_values[1] != 1:
    colorscale.append([1, lookup_color[-1]])

In [None]:
df_total = pd.read_json(intermediate)
lookup = {"DSC": f"dice_{slice_type}",
          "ASSD": f"assd_{slice_type}",
          "ACC": f"acc_{slice_type}",
          "TPR": f"tpr_{slice_type}",
          "TNR": f"tnr_{slice_type}"}
df_metric = pd.DataFrame(df_total.iloc[0][lookup[metric]])
df_metric.head()

In [None]:
if "All" not in models:
    models_selected = ["id"] + models
    print(models_selected)
    if "Baseline" in models:
        models_selected += [m for m in MODELS_BASELINE]
        models_selected.remove("Baseline")
        print(models_selected)
    if "DA" in models:
        models_selected += [m for m in MODELS_DA]
        models_selected.remove("DA")
        print(models_selected)
    seen = set()
    models_selected = [x for x in models_selected if not (x in seen or seen.add(x))]
    print(models_selected)
    df_metric = df_metric[models_selected]
models_selected = list(df_metric.columns)[1:]

In [None]:
selected_ids = ['204','205','206']
idx_selected2 = [int(idx) for idx, row in df_metric.iterrows() if row["id"] in selected_ids] + [len(df_metric)-1]
idx_selected = [int(idx) for idx, row in df_metric.iterrows() if row["id"] not in selected_ids][:-1]# + [len(df_metric)-1]
df_metric.iloc[idx_selected, 1:] = np.NaN

In [None]:
df_metric.head()

In [None]:
x = list(df_metric.columns[1:])
y = list(df_metric["id"].values[:-1])
z = [list(df_metric.iloc[idx][1:].values) for idx in range(len(df_metric)-1)]

hovertext = list()
for yi, yy in enumerate(y):
    hovertext.append(list())
    for xi, xx in enumerate(x):
        hovertext[-1].append(
            'Model: {}<br />ID: {}<br />{}: {}'.format(xx, yy, metric, np.round(z[yi][xi], decimals=5)))

In [None]:
df_metric2 = df_metric
df_metric2 = df_metric2.fillna(value=1)
df_metric2.iloc[idx_selected2, 1:] = np.NaN
colorscale_nan = px.colors.colorbrewer.Greys[0:2]#px.colors.colorbrewer.Greys_r
x2 = list(df_metric2.columns[1:])
y2 = list(df_metric2["id"].values[:-1])
z2 = [list(df_metric2.iloc[idx][1:].values) for idx in range(len(df_metric2)-1)]

In [None]:
fig = make_subplots(rows=2, cols=1, 
                    row_heights=[0.1, 0.9], vertical_spacing=0.05, shared_xaxes=True )

trace = ff.create_annotated_heatmap(x=list(df_metric.columns)[1:],
                                    y=["mean"],
                                    z=[list(df_metric.iloc[-1][1:].values)],
                                    hoverinfo='skip',
                                    coloraxis="coloraxis",
                                    annotation_text=[
                                        [np.round(x, 3) for x in list(df_metric.iloc[-1][1:].values)]])
fig.add_trace(trace.data[0],
              1,1)
fig.layout.update(trace.layout)
fig.add_trace(go.Heatmap(
                x=x,
                y=y,
                z=z,
                hoverongaps=False,
                hoverinfo='text',
                text=hovertext,
                coloraxis = "coloraxis"), 2,1)
fig.add_trace(go.Heatmap(
                x=x2,
                y=y2,
                z=z2,
                hoverongaps=False,
                hoverinfo='skip',
                colorscale=colorscale_nan,
showscale=False),2,1)

fig.update_layout(xaxis2 = {'showticklabels': False},
                 xaxis1 = {'side': 'top', 'showticklabels': True},
                 yaxis2 = {'title': 'Patient ID'})
fig.update_layout(coloraxis = {'colorscale': colorscale, 
                               'colorbar': dict(title=metric, tickvals=np.arange(0,1,0.1), tickmode="array")})
fig.update_layout(margin=dict(l=5,
                                r=5,
                                b=5,
                                t=150,
                                pad=4),
                  title="Something")                  
fig.show()

# First detail heatmap

In [None]:
slice_type = "only_tumor"
metric = "ASSD"
models = ["DA"]

In [None]:
def select_model_detail_list(models):
    models_selected = None
    if "All" not in models:
        models_selected = ["slice", "VS_class_gt"] + models
        if "Baseline" in models:
            models_selected += [m for m in MODELS_BASELINE]
            models_selected.remove("Baseline")
        if "DA" in models:
            models_selected += [m for m in MODELS_DA]
            models_selected.remove("DA")
        seen = set()
        models_selected = [x for x in models_selected if not (x in seen or seen.add(x))]
    return models_selected

In [None]:
def get_colorscale_tickvals(metric, slider_values, slider_max):
    # define colorscale and tickvals
    lookup_color = reversed([*px.colors.sequential.Plasma]) if metric == "ASSD" else [*px.colors.sequential.Plasma]
    steps = (slider_values[1] - slider_values[0]) / 9
    colorscale = []
    if slider_values[0] != 0:
        colorscale.append([0, lookup_color[0]])
    for idx, x in enumerate(np.arange(slider_values[0], slider_values[1], steps)):
        colorscale.append([x / slider_max, lookup_color[idx]])
    colorscale.append([slider_values[1] / slider_max, lookup_color[-1]])
    if slider_values[1] != 1:
        colorscale.append([1, lookup_color[-1]])
    tickvals = np.arange(0, slider_max, 20) if metric == "ASSD" else np.arange(0, 1, 0.1)
    return colorscale, tickvals

In [None]:
df = pd.read_json(f"/tf/workdir/data/VS_segm/VS_registered/test_processed/vs_gk_{patient_id}/evaluation.json")

if metric in ["DSC", "ASSD"]:
    lookup = {"DSC": "dice", "ASSD": "assd"}
    cols = [c for c in df.columns if lookup[metric] in c]
    df_metric = df[["slice", "VS_class_gt"]+cols]
    df_metric.rename(columns={k:k.split("-")[-1] for k in cols}, inplace=True)
    models_selected = select_model_detail_list(models)
    if models_selected is not None:
        df_metric = df_metric[models_selected]
    if slice_type == "only_tumor":
        df_metric = df_metric[df_metric["VS_class_gt"] == 1]
    df_metric.drop(columns=["VS_class_gt"], inplace=True)
    df_metric = df_metric.append({"slice": metric, **dict(df_metric.mean()[1:])}, ignore_index=True)

elif metric in ["ACC", "TPR", "TNR"]:
    cols = [c for c in df.columns if "class_pred-" in c]
    df_metric = df[["slice", "VS_class_gt"]+cols]
    df_metric.rename(columns={k:k.split("-")[-1] for k in cols}, inplace=True)
    models_selected = select_model_detail_list(models)
    if models_selected is not None:
        df_metric = df_metric[models_selected]
    if slice_type == "only_tumor":
            df_metric = df_metric[df_metric["VS_class_gt"] == 1]
    lookup = {"ACC": TestSet().calculate_accuracy,
         "TPR": TestSet().calculate_tpr,
         "TNR": TestSet().calculate_tnr}
    model_cols = list(df_metric.columns)[2:]
    values = [TestSet().calculate_accuracy(confusion_matrix(df_metric["VS_class_gt"].values, x[1].values, labels=[0,1]).ravel()) for x in df_metric[model_cols].items()]
    df_metric.drop(columns=["VS_class_gt"], inplace=True)
    df_metric = df_metric.append({"slice": metric, **{k:v for k,v in zip(model_cols, values)}}, ignore_index=True)


In [None]:
df_metric

In [None]:
info = df_metric.to_json()

df_metric = pd.read_json(info)
metric = df_metric.iloc[-1]["slice"]
df_metric

In [None]:
selected_ids = ['30','31','34', '35']
idx_selected2 = [int(idx) for idx, row in df_metric.iterrows() if str(row["slice"]) in selected_ids] + [len(df_metric)-1]
df_metric = df_metric.iloc[idx_selected2, :]

In [None]:
idx_selected2

In [None]:
df_metric

In [None]:
# define colorscale and tickvals
lookup_color = list(reversed([*px.colors.sequential.Plasma])) if metric == "ASSD" else [*px.colors.sequential.Plasma]
steps = (slider_values[1] - slider_values[0]) / 9
colorscale = []
if slider_values[0] != 0:
    colorscale.append([0, lookup_color[0]])
for idx, x in enumerate(np.arange(slider_values[0], slider_values[1], steps)):
    colorscale.append([x / 1, lookup_color[idx]])
colorscale.append([slider_values[1] / 1, lookup_color[-1]])
if slider_values[1] != 1:
    colorscale.append([1, lookup_color[-1]])
tickvals = np.arange(0, 362, 20) if metric == "ASSD" else np.arange(0, 1, 0.1)

In [None]:
# create figure
fig = make_subplots(rows=2, cols=1,
                    row_heights=[0.1, 0.9], vertical_spacing=0.05, shared_xaxes=True)
# create annotated heatmap with total values
round_dec = 2 if len(df_metric.columns) >= 8 else 3
trace = ff.create_annotated_heatmap(x=list(df_metric.columns)[1:],
                                    y=["mean"],
                                    z=[list(df_metric.iloc[-1][1:].values)],
                                    hoverinfo='skip',
                                    coloraxis="coloraxis",
                                    annotation_text=[
                                        [np.round(x, round_dec) for x in list(df_metric.iloc[-1][1:].values)]])
fig.add_trace(trace.data[0],
              1, 1)
fig.layout.update(trace.layout)

# prepare x,y,z for heatmap
x = list(df_metric.columns)[1:]
y = [str(x) for x in list(df_metric["slice"].values[:-1])]
z = [list(df_metric.iloc[idx][1:].values) for idx in range(len(df_metric) - 1)]
# create hovertext
hovertext = list()
for yi, yy in enumerate(y):
    hovertext.append(list())
    for xi, xx in enumerate(x):
        hovertext[-1].append(
            'Model: {}<br />Slice: {}<br />{}: {}'.format(xx, yy, metric, np.round(z[yi][xi], decimals=5)))

In [None]:
# heatmap for patient data
fig.add_trace(go.Heatmap(x=x,
                         y=y,
                         z=z,
                         hoverongaps=True,
                         hoverinfo='text',
                         text=hovertext,
                         coloraxis="coloraxis"), 2, 1);

# update layout
fig.update_layout(xaxis2={'showticklabels': False},
                  xaxis1={'side': 'top', 'showticklabels': True},
                  yaxis2={'title': 'Patient ID'},
                  coloraxis={'colorscale': colorscale,
                             'colorbar': dict(title=metric, tickvals=tickvals, tickmode="array")},
                  margin=dict(l=5,
                              r=5,
                              b=5,
                              t=5,
                              pad=4)
                  )