In [None]:
from typing import Tuple
import itertools
from pathlib import Path
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from monai.networks import nets, one_hot
from monai.metrics import compute_hausdorff_distance

import plotly.express as px
from plotly.colors import n_colors
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio

pio.templates.default = "simple_white"

import skimage as skm
from scipy import interpolate

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision import transforms

import torchio as tio

from kedro.extras.datasets.pandas import CSVDataSet
from kedro.extras.datasets.pickle import PickleDataSet

In [None]:
import os, sys
sys.path.append(os.path.abspath('../src'))

from tagseg.data import ScdEvaluator, MnmEvaluator
from tagseg.data.dmd_dataset import DmdH5DataSet, DmdH5Evaluator
from tagseg.models.trainer import Trainer
from tagseg.models.segmenter import Net
from tagseg.metrics.dice import DiceMetric
from tagseg.pipelines.model_evaluation.nodes import tag_subjects
from tagseg.data.dmd_dataset import DmdDataSet

In [None]:
top_h_legend = dict(orientation='h', yanchor="bottom", y=1.1)

In [None]:
index = pd.read_csv('../data/07_model_output/index_gamma.csv')
index

In [None]:
dfs = [] 

for _, row in index.iterrows():

    for split in ['train', 'test']:

        ext = '_train' if split == 'train' else ''

        df = PickleDataSet(filepath=f'../data/07_model_output/{row.model}/dmd_results{ext}.pt').load()
        df = pd.DataFrame(list(df))

        df['gamma'] = row.gamma
        df['split'] = split

        dfs.append(df)

In [None]:
df = pd.concat(dfs)
len(df)

In [None]:
results = df.pivot_table(index=['gamma'], values=['dice', 'hd95'], columns=['split'], aggfunc=['mean', 'median', 'std']) \
    .sort_index(level=[1, 2], ascending=[True, False], axis=1).reorder_levels([1, 2, 0], axis=1) \

results

In [None]:
dices = np.array(results[('dice', 'test', 'mean')].sort_values(ascending=False))
hds = np.array(results[('hd95', 'test', 'mean')].sort_values(ascending=True))

In [None]:
results['dice_rank'] = results[('dice', 'test', 'mean')].apply(lambda d: np.where(np.isclose(dices, d))[0][0]) + 1
results['hd_rank'] = results[('hd95', 'test', 'mean')].apply(lambda hd: np.where(np.isclose(hds, hd))[0][0]) + 1 

results['rank'] = results['dice_rank'] + results['hd_rank']

In [None]:
results

In [None]:
print(results.to_latex(
    float_format="%.3f", bold_rows=True, column_format='llrrrrrrrr', multicolumn_format='c', multirow=True,
    caption='Something retarded'
))

In [None]:
res = results.reset_index().melt(id_vars=[('gamma',      '',     '')])
res.columns = ['gamma', 'metric', 'split', 'statistic', 'value']
res = res.pivot(index=['gamma', 'metric', 'split'], columns=['statistic']).reset_index()

In [None]:
res.columns = ['gamma', 'metric', 'split', 'bad', 'mean', 'median', 'std']
res = res.sort_values(by=['gamma', 'split'], ascending=[True, False])

In [None]:
fig = px.scatter(res, x='gamma', y='mean', facet_col='metric', color='split', error_y='std')

top_h_legend = dict(orientation='h', yanchor="bottom", y=1.1)
fig.update_layout(legend=top_h_legend)

fig.update_yaxes(matches=None)
fig.update_xaxes(type='log')

fig.show()

In [None]:
df[(df.split == 'test')][['gamma', 'dice']]

In [None]:
df = df[(df.split == 'test')]

In [None]:
df[df.gamma.isin([0.05, 0.1])][['hd95', 'timeframe', 'disease', 'patient_id', 'slice', 'gamma']].groupby('gamma').median()

In [None]:
df[df.gamma == 0.05].sort_values('hd95').hd95

In [None]:
df[df.gamma == 0.1].sort_values('hd95').hd95

In [None]:
subject_1.hd95, subject_2.hd95

In [None]:
from skimage import measure

subject_1 = df[df.gamma == 0.05].iloc[184]
subject_2 = df[df.gamma == 0.1].iloc[101]

plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
raw_shape = subject_1.raw_mask.data[0, 0].shape
postprocess = transforms.Resize(raw_shape, interpolation=transforms.InterpolationMode.NEAREST)
plt.imshow(postprocess(subject_1.image.data)[0, 0], cmap='gray')
mask = subject_1.raw_mask.data[0, 0].numpy()
pred = subject_1.pred.data[0, 0].numpy()
for i, (contour_m, contour_p) in enumerate(zip(measure.find_contours(mask), measure.find_contours(pred))):
    plt.plot(*contour_m[:, ::-1].T, c='b')
    plt.plot(*contour_p[:, ::-1].T, c='r')
plt.xlim(70, 140)
plt.ylim(150, 75)

plt.subplot(1, 2, 2)
raw_shape = subject_2.raw_mask.data[0, 0].shape
postprocess = transforms.Resize(raw_shape, interpolation=transforms.InterpolationMode.NEAREST)
plt.imshow(postprocess(subject_2.image.data)[0, 0], cmap='gray')
mask = subject_2.raw_mask.data[0, 0].numpy()
pred = subject_2.pred.data[0, 0].numpy()
for i, (contour_m, contour_p) in enumerate(zip(measure.find_contours(mask), measure.find_contours(pred))):
    plt.plot(*contour_m[:, ::-1].T, c='b')
    plt.plot(*contour_p[:, ::-1].T, c='r')
plt.xlim(40, 100)
plt.ylim(140, 70)

In [None]:
metrics = ['dice', 'hd95']
gammas = sorted(df.gamma.unique())

colors = n_colors('rgb(25, 114, 120)', 'rgb(40, 61, 59)', len(gammas), colortype='rgb')

fig = make_subplots(
    rows=1, cols=len(metrics), 
    shared_yaxes=True, shared_xaxes=False, 
    horizontal_spacing=0.02, vertical_spacing=0.0
)

for m, metric in enumerate(metrics):

    data = np.array(list(map(
        lambda g: np.array(df[(df.gamma == g) & (df.split == 'test')][metric]), gammas
    )))

    for g, (data_line, color) in enumerate(zip(data, colors)):
        fig.add_trace(go.Violin(name=gammas[g], x=data_line, line_color=color), row=1, col=m + 1)

    fig.add_vline(
        x=data.mean(),
        annotation_text=f"    {data.mean():.3f}", annotation_position="top right", 
        annotation_font_color='rgb(40, 61, 59)',
        line_width=3, line_dash="dot", line_color='rgb(40, 61, 59)', row=1, col=m + 1)

fig.update_xaxes(
    title_text=r'DSC (↑)',
    range=[0.1, 1.1], tickvals=np.arange(0.2, 1.1, 0.1,), row=1, col=1)
fig.update_xaxes(
    title_text=r'HD-95 [mm] (↓)',
    range=[0., 15.], tickvals=np.arange(0., 15., 2.5,), row=1, col=2)

fig.update_yaxes(
    title_text='Weight of Shape Distance Loss (Gamma)', row=1, col=1)

fig.update_layout(height=800 / 1.62, width=800, showlegend=False)
fig.update_traces(
    meanline_visible=False, 
    box_visible=True, orientation='h', side='positive', width=3, points=False)
fig.show()

In [None]:
fig.write_image("../../figures/sdl-perf-violin.pdf")

In [None]:
gammas = [0.0, 0.05, 1.0]

fig, ax = plt.subplots(3, len(gammas), figsize=(12, 9))

padding = 50

for m, gamma in enumerate(gammas):

    subs = df[(df.gamma == gamma)].copy()

    for i, (title, quantile) in enumerate(zip(['Q1', 'Median', 'Q3'], subs.dice.quantile([.25, .5, .75]))):

        subs['diff'] = (subs.dice - quantile).abs()

        subject = subs.sort_values('diff', ascending=True).iloc[0]

        post_process = transforms.Resize(subject.raw_shape)

        image = post_process(subject.image.data)[0, 0].numpy()
        mask = subject['raw_mask'].data[0, 0].numpy()
        pred = subject.pred.data[0, 0].numpy()

        center = [list(map(lambda a: a.mean(), np.where(mask == 1))), list(map(lambda a: a.mean(), np.where(pred == 1)))]

        ax[i, m].imshow(image, cmap='gray')

        # Label only once
        for j, contour in enumerate(skm.measure.find_contours(mask, level=.5)):
            ax[i, m].plot(*contour[:, ::-1].T, c='b', label='Manual annotation' if j == 0 else None)
        for j, contour in enumerate(skm.measure.find_contours(pred, level=.5)):
            ax[i, m].plot(*contour[:, ::-1].T, c='r', label='Model prediction' if j == 0 else None)

        ax[i, m].get_xaxis().set_ticks([])
        ax[i, m].get_yaxis().set_ticks([])

        center_y, center_x = np.array(center).mean(axis=0)
        ax[i, m].set_xlim(center_x - padding, center_x + padding)
        ax[i, m].set_ylim(center_y + padding, center_y - padding)
        
        ax[i, 0].set_ylabel(title)

    ax[0, m].set_title(gamma)

ax[0, 0].legend(loc='upper left')

plt.tight_layout()
plt.savefig('../../figures/qualitative-ShapeLoss.pdf', bbox_inches='tight')
plt.show()