In [None]:
from typing import Tuple
import itertools
from pathlib import Path
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from monai.networks import nets, one_hot
from monai.metrics import compute_hausdorff_distance

import plotly.express as px
from plotly.colors import n_colors
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio

pio.templates.default = "simple_white"

import seaborn as sns

import skimage as skm
from scipy import interpolate

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision import transforms

import torchio as tio

from kedro.extras.datasets.pandas import CSVDataSet
from kedro.extras.datasets.pickle import PickleDataSet

In [None]:
import os, sys
sys.path.append(os.path.abspath('../src'))

from tagseg.data import ScdEvaluator, MnmEvaluator
from tagseg.data.dmd_dataset import DmdH5DataSet, DmdH5Evaluator
from tagseg.models.trainer import Trainer
from tagseg.models.segmenter import Net
from tagseg.metrics.dice import DiceMetric
from tagseg.pipelines.model_evaluation.nodes import tag_subjects
from tagseg.data.dmd_dataset import DmdDataSet

In [None]:
top_h_legend = dict(orientation='h', yanchor="bottom", y=1.1)

In [None]:
index = pd.read_csv('../data/07_model_output/index.csv')
index

In [None]:
pf = PickleDataSet(filepath=f'../data/07_model_output/E150/dmd_results.pt').load()
pf = pd.DataFrame(list(pf))

In [None]:
pf[pf.dice_lv > 0.].dice_lv.mean()

In [None]:
dfs = [] 

for _, row in index.iterrows():

    for split in ['train', 'test']:

        ext = '_train' if split == 'train' else ''

        df = PickleDataSet(filepath=f'../data/07_model_output/{row.model}/dmd_results{ext}.pt').load()
        df = pd.DataFrame(list(df))

        assert 'voxel_spacing' in df.columns

        df['architecture'] = row.architecture
        df['strategy'] = row.strategy
        df['split'] = split

        dfs.append(df)

In [None]:
df = pd.concat(dfs)
len(df)

In [None]:
results = df.pivot_table(index=['architecture', 'strategy'], values=['dice', 'hd95'], columns=['split'], aggfunc=['median', 'mean', 'std']) \
    .sort_index(level=[1, 2], ascending=[True, False], axis=1).reorder_levels([1, 2, 0], axis=1) \
    .sort_index(level=[0, 1], ascending=[False, False]) \

results

In [None]:
print(results.to_latex(
    float_format="%.3f", bold_rows=True, column_format='llrrrrrrrr', multicolumn_format='c', multirow=True,
    caption='Something retarded'
))

In [None]:
df = df[df.split == 'test'].reset_index()

In [None]:
df.patient_id = df.patient_id.astype(str)
df.head()

In [None]:
sorter = ['Scratch', 'Cine', 'Physics-driven', 'CycleGAN']
df.strategy = df.strategy.astype('category')
df.strategy = df.strategy.cat.set_categories(sorter)
df.sort_values('strategy', inplace=True)

In [None]:
fig = make_subplots(rows=2, cols=2, shared_yaxes=True, shared_xaxes=True, horizontal_spacing=0.02, vertical_spacing=0.10)

for m, metric in enumerate(['dice', 'hd95']):
    for i, (a, s, d) in enumerate(itertools.product(df.architecture.unique(), df.strategy.unique(), df.disease.unique())):

        fdf = df[(df.architecture == a) & (df.strategy == s) & (df.disease == d)]
        
        fig.add_trace(
            go.Violin(
                x=fdf.strategy, y=fdf[metric],
                legendgroup=d, scalegroup=d, name=d.upper(), side='negative' if d == 'control' else 'positive',
                line_color='#197278' if d == 'control' else '#C44536', showlegend=True if i in [0, 1] and m == 1 else False,
            ), row=m + 1, col=1 if a == 'nnUnet' else 2
        )
        
        if m == 1:
            fig.update_xaxes(title_text=a, row=2, col=1 if a == 'nnUnet' else 2)

fig.update_yaxes(title_text='DSC (↑)', range=[0, 1], dtick=0.1, row=1, col=1)
fig.update_yaxes(range=[0, 1], row=1, col=2)
fig.update_yaxes(title_text='HD-95 [mm] (↓)', range=[0, 25], row=2, col=1)
fig.update_yaxes(range=[0, 25], row=2, col=2)
fig.update_traces(meanline_visible=True, width=.9, points=False)
fig.update_layout(violingap=0, violinmode='overlay', legend=top_h_legend)
fig.update_layout(height=800 / 1.62, width=800)
fig.show()

In [None]:
fig.write_image("../../figures/disease-perf-violin.pdf")

In [None]:
df.melt(id_vars=['patient_id', 'architecture', 'strategy'], value_vars=['dice', 'hd95']).head()

In [None]:
colors = ['#219EBC', '#FB8500', '#023047', '#C44536']

fig = px.strip(
    df.melt(id_vars=['patient_id', 'architecture', 'strategy'], value_vars=['dice', 'hd95']), 
    x='patient_id', y='value', color='strategy', facet_col='architecture', facet_row='variable',
    labels=dict(patient_id='Patient ID', variable='Performance metric', strategy='Training strategy', architecture='Model architecture'),
    color_discrete_sequence=colors
)
fig.update_yaxes(matches=None)
fig.update_layout(legend=top_h_legend)

fig.update_yaxes(title_text='HD-95 [mm] (↓)', range=[0, 25], row=1, col=1)
fig.update_yaxes(range=[0, 25], row=1, col=2)
for a, s in itertools.product(df.architecture.unique(), df.strategy.unique()):
    avg = df[(df.architecture == a) & (df.strategy == s)].hd95.mean()
    fig.add_hline(
        y=avg, line_width=3, line_dash="dot", line_color=colors[list(df.strategy.unique()).index(s)], row=1, col=int(a == 'nnUnet')
    )

fig.update_yaxes(title_text='DSC (↑)', range=[0, 1], row=2, col=1)
fig.update_yaxes(range=[0, 1], row=2, col=2)
for a, s in itertools.product(df.architecture.unique(), df.strategy.unique()):
    avg = df[(df.architecture == a) & (df.strategy == s)].dice.mean()
    fig.add_hline(
        y=avg, line_width=3, line_dash="dot", line_color=colors[list(df.strategy.unique()).index(s)], row=2, col=int(a == 'nnUnet')
    )

fig.update_layout(height=800 / 1.62, width=800)
fig.show()

In [None]:
fig.write_image("../../figures/patient-performance-strip.pdf")

In [None]:
colors = n_colors('rgb(25, 114, 120)', 'rgb(196, 69, 54)', 25, colortype='rgb')

fig = make_subplots(rows=2, cols=4, shared_yaxes=True, shared_xaxes=True, horizontal_spacing=0.02, vertical_spacing=0.02)

for i, architecture in enumerate(df.architecture.unique()):

    fig.update_yaxes(title_text=f'{architecture}<br>Timeframe', row=i + 1, col=1)

    for j, strategy in enumerate(df.strategy.unique()):

        fig.update_xaxes(title_text=f'DSC (↑)<br>{strategy}', row=2, col=j + 1)

        data = np.array(list(map(lambda t: np.array(df[(df.timeframe == t) & (df.architecture == architecture) & (df.strategy == strategy)].dice), range(25))))

        for t, (data_line, color) in enumerate(zip(data, colors)):
            fig.add_trace(go.Violin(name=t + 1, x=data_line, line_color=color), row=i + 1, col=j + 1)
            fig.update_xaxes(range=[0.1, 1.2], tickvals=np.arange(0.2, 1.1, 0.1,), row=i + 1, col=j + 1)
            fig.update_yaxes(range=[-1, 28], tickvals=np.arange(0, 30, 5), ticktext=np.arange(0, 30, 5), row=i + 1, col=j + 1)
        
        fig.add_vline(
            x=data.mean(),
            annotation_text=f"     {data.mean():.3f}", annotation_position="top right", 
            annotation_font_color='rgb(40, 61, 59)',
            line_width=3, line_dash="dot", line_color='rgb(40, 61, 59)', row=i + 1, col=j + 1)

fig.update_traces(orientation='h', side='positive', width=3, points=False)
fig.update_layout(height=800 / 1.62, width=800, showlegend=False)
fig.show()

In [None]:
fig.write_image("../../figures/model-performance-violin.pdf")

In [None]:
fig, ax = plt.subplots(3, len(df.strategy.unique()), figsize=(12, 9))

padding = 50

for m, strategy in enumerate(df.strategy.unique()):

    subs = df[(df.architecture == 'nnUnet') & (df.strategy == strategy)].copy()

    for i, (title, quantile) in enumerate(zip(['Q1', 'Median', 'Q3'], subs.dice.quantile([.25, .5, .75]))):

        subs['diff'] = (subs.dice - quantile).abs()

        subject = subs.sort_values('diff', ascending=True).iloc[0]

        post_process = transforms.Resize(subject.raw_shape)

        image = post_process(subject.image.data)[0, 0].numpy()
        mask = subject['raw_mask'].data[0, 0].numpy()
        pred = subject.pred.data[0, 0].numpy()

        center = [list(map(lambda a: a.mean(), np.where(mask == 1))), list(map(lambda a: a.mean(), np.where(pred == 1)))]

        ax[i, m].imshow(image, cmap='gray')

        # Label only once
        for j, contour in enumerate(skm.measure.find_contours(mask, level=.5)):
            ax[i, m].plot(*contour[:, ::-1].T, c='b', label='Manual annotation' if j == 0 else None)
        for j, contour in enumerate(skm.measure.find_contours(pred, level=.5)):
            ax[i, m].plot(*contour[:, ::-1].T, c='r', label='Model prediction' if j == 0 else None)

        ax[i, m].get_xaxis().set_ticks([])
        ax[i, m].get_yaxis().set_ticks([])

        center_y, center_x = np.array(center).mean(axis=0)
        ax[i, m].set_xlim(center_x - padding, center_x + padding)
        ax[i, m].set_ylim(center_y + padding, center_y - padding)
        
        ax[i, 0].set_ylabel(title)

    ax[0, m].set_title(strategy)

ax[0, 0].legend(loc='upper left')

plt.tight_layout()
plt.savefig('../../figures/qualitative-nnUnet.pdf', bbox_inches='tight')
plt.show()

In [None]:
fig, ax = plt.subplots(3, len(df.strategy.unique()), figsize=(12, 9))

padding = 50

for m, strategy in enumerate(df.strategy.unique()):

    subs = df[(df.architecture == 'ResNetVAE') & (df.strategy == strategy)].copy()

    for i, (title, quantile) in enumerate(zip(['Q1', 'Median', 'Q3'], subs.dice.quantile([.25, .5, .75]))):

        subs['diff'] = (subs.dice - quantile).abs()

        subject = subs.sort_values('diff', ascending=True).iloc[0]

        post_process = transforms.Resize(subject.raw_shape)

        image = post_process(subject.image.data)[0, 0].numpy()
        mask = subject['raw_mask'].data[0, 0].numpy()
        pred = subject.pred.data[0, 0].numpy()

        center = [list(map(lambda a: a.mean(), np.where(mask == 1))), list(map(lambda a: a.mean(), np.where(pred == 1)))]

        ax[i, m].imshow(image, cmap='gray')

        # Label only once
        for j, contour in enumerate(skm.measure.find_contours(mask, level=.5)):
            ax[i, m].plot(*contour[:, ::-1].T, c='b', label='Manual annotation' if j == 0 else None)
        for j, contour in enumerate(skm.measure.find_contours(pred, level=.5)):
            ax[i, m].plot(*contour[:, ::-1].T, c='r', label='Model prediction' if j == 0 else None)

        ax[i, m].get_xaxis().set_ticks([])
        ax[i, m].get_yaxis().set_ticks([])

        center_y, center_x = np.array(center).mean(axis=0)
        ax[i, m].set_xlim(center_x - padding, center_x + padding)
        ax[i, m].set_ylim(center_y + padding, center_y - padding)
        
        ax[i, 0].set_ylabel(title)

    ax[0, m].set_title(strategy)

ax[0, 0].legend(loc='upper left')

plt.tight_layout()
plt.savefig('../../figures/qualitative-ResNetVAE.pdf', bbox_inches='tight')
plt.show()