In [1]:
from typing import Tuple
import itertools
from pathlib import Path
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from monai.networks import nets, one_hot
from monai.metrics import compute_hausdorff_distance

import plotly.express as px
from plotly.colors import n_colors
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio

pio.templates.default = "simple_white"

import skimage as skm
from scipy import interpolate

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision import transforms

import torchio as tio

from kedro.extras.datasets.pandas import CSVDataSet
from kedro.extras.datasets.pickle import PickleDataSet

In [2]:
import os, sys
sys.path.append(os.path.abspath('../src'))

from tagseg.data import ScdEvaluator, MnmEvaluator
from tagseg.data.dmd_dataset import DmdH5DataSet, DmdH5Evaluator
from tagseg.models.trainer import Trainer
from tagseg.models.segmenter import Net
from tagseg.metrics.dice import DiceMetric
from tagseg.pipelines.model_evaluation.nodes import tag_subjects
from tagseg.data.dmd_dataset import DmdDataSet

In [3]:
top_h_legend = dict(orientation='h', yanchor="bottom", y=1.1)

In [4]:
index = pd.read_csv('../data/07_model_output/index_gamma.csv')
index

Unnamed: 0,model,gamma
0,model_dmd_v8,0.001
1,model_dmd_v7,0.005
2,model_dmd_v6,0.05
3,model_dmd_v5,0.1
4,model_dmd_v4,0.5
5,model_dmd_v3,1.0
6,model_dmd_v1,0.0


In [5]:
dfs = [] 

for _, row in index.iterrows():

    for split in ['train', 'test']:

        ext = '_train' if split == 'train' else ''

        df = PickleDataSet(filepath=f'../data/07_model_output/{row.model}/dmd_results{ext}.pt').load()
        df = pd.DataFrame(list(df))

        df['gamma'] = row.gamma
        df['split'] = split

        dfs.append(df)

In [6]:
df = pd.concat(dfs)
len(df)

8750

In [7]:
results = df.pivot_table(index=['gamma'], values=['dice', 'hd95'], columns=['split'], aggfunc=['mean', 'std']) \
    .sort_index(level=[1, 2, 0], ascending=[True, False, True], axis=1).reorder_levels([1, 2, 0], axis=1) \

results

Unnamed: 0_level_0,dice,dice,dice,dice,hd95,hd95,hd95,hd95
split,train,train,test,test,train,train,test,test
Unnamed: 0_level_2,mean,std,mean,std,mean,std,mean,std
gamma,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3
0.0,0.919866,0.020299,0.776869,0.097761,1.901702,0.563619,6.760752,8.413045
0.001,0.851197,0.019864,0.786467,0.042058,2.257392,0.648884,4.501546,5.037456
0.005,0.925525,0.01717,0.788016,0.103722,1.862229,0.560843,6.434096,8.461362
0.05,0.908185,0.019277,0.81353,0.063881,2.216321,0.598796,4.758275,1.716885
0.1,0.905907,0.018609,0.816334,0.083529,2.344565,0.670258,5.327254,4.446488
0.5,0.812901,0.019447,0.7426,0.048762,2.423071,0.84484,5.532859,9.391559
1.0,0.598068,0.034003,0.564278,0.042769,3.933431,3.023045,4.754245,1.894808


In [29]:
dices = np.array(results[('dice', 'test', 'mean')].sort_values(ascending=False))
hds = np.array(results[('hd95', 'test', 'mean')].sort_values(ascending=True))

In [32]:
results['dice_rank'] = results[('dice', 'test', 'mean')].apply(lambda d: np.where(np.isclose(dices, d))[0][0]) + 1
results['hd_rank'] = results[('hd95', 'test', 'mean')].apply(lambda hd: np.where(np.isclose(hds, hd))[0][0]) + 1 

results['rank'] = results['dice_rank'] + results['hd_rank']

In [33]:
results

Unnamed: 0_level_0,dice,dice,dice,dice,hd95,hd95,hd95,hd95,rank,dice_rank,hd_rank
split,train,train,test,test,train,train,test,test,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
Unnamed: 0_level_2,mean,std,mean,std,mean,std,mean,std,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2
gamma,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3
0.0,0.919866,0.020299,0.776869,0.097761,1.901702,0.563619,6.760752,8.413045,12,5,7
0.001,0.851197,0.019864,0.786467,0.042058,2.257392,0.648884,4.501546,5.037456,5,4,1
0.005,0.925525,0.01717,0.788016,0.103722,1.862229,0.560843,6.434096,8.461362,9,3,6
0.05,0.908185,0.019277,0.81353,0.063881,2.216321,0.598796,4.758275,1.716885,5,2,3
0.1,0.905907,0.018609,0.816334,0.083529,2.344565,0.670258,5.327254,4.446488,5,1,4
0.5,0.812901,0.019447,0.7426,0.048762,2.423071,0.84484,5.532859,9.391559,11,6,5
1.0,0.598068,0.034003,0.564278,0.042769,3.933431,3.023045,4.754245,1.894808,9,7,2


In [135]:
print(results.to_latex(
    float_format="%.3f", bold_rows=True, column_format='llrrrrrrrr', multicolumn_format='c', multirow=True,
    caption='Something retarded'
))

\begin{table}
\centering
\caption{Something retarded}
\begin{tabular}{llrrrrrrrr}
\toprule
{} & \multicolumn{4}{c}{dice} & \multicolumn{4}{c}{hd95} \\
\textbf{split} & \multicolumn{2}{c}{train} & \multicolumn{2}{c}{test} & \multicolumn{2}{c}{train} & \multicolumn{2}{c}{test} \\
{} &  mean &   std &  mean &   std &  mean &   std &  mean &   std \\
\textbf{gamma} &       &       &       &       &       &       &       &       \\
\midrule
\textbf{0.000} & 0.920 & 0.020 & 0.777 & 0.098 & 1.902 & 0.564 & 6.761 & 8.413 \\
\textbf{0.001} & 0.851 & 0.020 & 0.786 & 0.042 & 2.257 & 0.649 & 4.502 & 5.037 \\
\textbf{0.005} & 0.926 & 0.017 & 0.788 & 0.104 & 1.862 & 0.561 & 6.434 & 8.461 \\
\textbf{0.050} & 0.908 & 0.019 & 0.814 & 0.064 & 2.216 & 0.599 & 4.758 & 1.717 \\
\textbf{0.100} & 0.906 & 0.019 & 0.816 & 0.084 & 2.345 & 0.670 & 5.327 & 4.446 \\
\textbf{0.500} & 0.813 & 0.019 & 0.743 & 0.049 & 2.423 & 0.845 & 5.533 & 9.392 \\
\textbf{1.000} & 0.598 & 0.034 & 0.564 & 0.043 & 3.933 & 3.023 & 4.7


In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.



In [96]:
res = results.reset_index().melt(id_vars=[('gamma',      '',     '')])
res.columns = ['gamma', 'metric', 'split', 'statistic', 'value']
res = res.pivot(index=['gamma', 'metric', 'split'], columns=['statistic']).reset_index()
res.columns = ['gamma', 'metric', 'split', 'mean', 'std']
res = res.sort_values(by=['gamma', 'split'], ascending=[True, False])

In [100]:
fig = px.scatter(res, x='gamma', y='mean', facet_col='metric', color='split', error_y='std')

top_h_legend = dict(orientation='h', yanchor="bottom", y=1.1)
fig.update_layout(legend=top_h_legend)

fig.update_yaxes(matches=None)
fig.update_xaxes(type='log')

fig.show()


distutils Version classes are deprecated. Use packaging.version instead.



In [107]:
df[(df.split == 'test')][['gamma', 'dice']]

Unnamed: 0,gamma,dice
0,0.001,0.683831
1,0.001,0.706052
2,0.001,0.741874
3,0.001,0.759745
4,0.001,0.778860
...,...,...
245,0.000,0.793420
246,0.000,0.789687
247,0.000,0.761279
248,0.000,0.748563


In [136]:
metrics = res.metric.unique()
gammas = sorted(df.gamma.unique())

colors = n_colors('rgb(25, 114, 120)', 'rgb(40, 61, 59)', len(gammas), colortype='rgb')

fig = make_subplots(
    rows=1, cols=len(metrics), 
    shared_yaxes=True, shared_xaxes=False, 
    horizontal_spacing=0.02, vertical_spacing=0.0
)

for m, metric in enumerate(metrics):

    data = np.array(list(map(
        lambda g: np.array(df[(df.gamma == g) & (df.split == 'test')][metric]), gammas
    )))

    for g, (data_line, color) in enumerate(zip(data, colors)):
        fig.add_trace(go.Violin(name=gammas[g], x=data_line, line_color=color), row=1, col=m + 1)

    fig.add_vline(
        x=data.mean(),
        annotation_text=f"    {data.mean():.3f}", annotation_position="top right", 
        annotation_font_color='rgb(40, 61, 59)',
        line_width=3, line_dash="dot", line_color='rgb(40, 61, 59)', row=1, col=m + 1)

fig.update_xaxes(
    title_text='Dice coefficient (DSC)',
    range=[0.1, 1.1], tickvals=np.arange(0.2, 1.1, 0.1,), row=1, col=1)
fig.update_xaxes(
    title_text='Hausdorff distance (95%)',
    range=[0., 15.], tickvals=np.arange(0., 15., 2.5,), row=1, col=2)

fig.update_yaxes(
    title_text='Weight of Shape Distance Loss (Gamma)', row=1, col=1)

fig.update_layout(height=800 / 1.62, width=800, showlegend=False)
fig.update_traces(meanline_visible=True, orientation='h', side='positive', width=3, points=False)
fig.show()

In [137]:
fig.write_image("../../figures/sdl-perf-violin.pdf")