In [None]:
from typing import Tuple
from pathlib import Path
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from skimage.draw import polygon, polygon2mask
from skimage import measure

from monai.networks import nets, one_hot
from monai.metrics import compute_hausdorff_distance

import plotly.express as px

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision import transforms

import torchio as tio

from kedro.extras.datasets.pandas import CSVDataSet
from kedro.extras.datasets.pickle import PickleDataSet

In [None]:
import os, sys
sys.path.append(os.path.abspath('../src'))

from tagseg.data import ScdEvaluator, MnmEvaluator, TagSegDataSet
from tagseg.data.dmd_dataset import DmdH5DataSet
from tagseg.models.trainer import Trainer
from tagseg.models.segmenter import Net
from tagseg.metrics.dice import DiceMetric
from tagseg.pipelines.model_evaluation.nodes import tag_subjects
from tagseg.data.dmd_dataset import DmdDataSet

##### DMD H5

In [None]:
dmd = DmdH5DataSet(
    filepath='../data/03_primary/dmd_wtv.pt'
)._load_except(filepath_raw='../data/01_raw/dmd_alex/train')

In [None]:
mask = dmd[0][1][0]
pred = dmd[14][1][0]

In [None]:
nmask = measure.label(mask + 1)
npred = measure.label(pred + 1)

In [None]:
dice_metric = DiceMetric(include_background=False)

In [None]:
nmask_myo = one_hot(torch.tensor(nmask == 2).reshape(1, 1, 256, 256), num_classes=2)
nmask_lv = one_hot(torch.tensor(nmask == 3).reshape(1, 1, 256, 256), num_classes=2)

npred_myo = one_hot(torch.tensor(npred == 2).reshape(1, 1, 256, 256), num_classes=2)
npred_lv = one_hot(torch.tensor(npred == 3).reshape(1, 1, 256, 256), num_classes=2)

In [None]:
dices = dice_metric(y_pred=npred_myo, y=nmask_myo),  dice_metric(y_pred=npred_lv, y=nmask_lv)
hd95 = compute_hausdorff_distance(
    one_hot(npred_myo.argmax(dim=1, keepdim=True), num_classes=2), nmask_myo,
    include_background=False, percentile=95), \
        compute_hausdorff_distance(
    one_hot(npred_lv.argmax(dim=1, keepdim=True), num_classes=2), nmask_lv,
    include_background=False, percentile=95)

In [None]:
dices

In [None]:
hd95

In [None]:
plt.imshow(npred)
plt.colorbar()

In [None]:
plt.imshow(nmask)

### Rest

In [None]:
model = Net(
    load_model='../data/06_models/model_cine_v6_tag_v1_dmd_v1.pt',
    model_type='SegResNetVAE'
)

In [None]:
original_path = '/home/loecher/dmd_org2/GROUP_1/CONTROL/CHOC/16-000297-201/base/tag_cine_SA__Base_fl2d9_grid_10/im_s4983_0006.dcm'

In [None]:
import pydicom as dicom

ds = dicom.dcmread(original_path)

In [None]:
pixel_spacing = 1.4285714626312

In [None]:
import h5py

img_hf = h5py.File('../data/01_raw/dmd_alex/test/1_C_C_201_base.h5', 'r')
roi_hf = h5py.File('../data/01_raw/dmd_alex/test/1_C_C_201_base_roi.h5', 'r')

t = 5

imt = np.array(img_hf.get('imt')).swapaxes(0, 2)
image = imt[t]
image = image / image.max()
image = TagSegDataSet._preprocess_image(0.456, 0.224)(image).unsqueeze(0)

pts_inner = np.array(list(map(lambda i: np.array(roi_hf[roi_hf.get('pts_interp_inner')[i][0]]),
                                range(roi_hf.get('pts_interp_inner').shape[0]))))
pts_outer = np.array(list(map(lambda i: np.array(roi_hf[roi_hf.get('pts_interp_outer')[i][0]]),
                                range(roi_hf.get('pts_interp_inner').shape[0]))))

inner = polygon2mask(imt.shape[1:],
                        np.array(polygon(pts_inner[t, :, 1], pts_inner[t, :, 0])).T)
outer = polygon2mask(imt.shape[1:],
                        np.array(polygon(pts_outer[t, :, 1], pts_outer[t, :, 0])).T)

label = outer ^ inner
label = label.astype(np.float64)
# label = TagSegDataSet._preprocess_label()(label).unsqueeze(0)

In [None]:
imt.shape

In [None]:
back = transforms.Resize((224, 180))

In [None]:
# image, label = dmd_test[0]

output = model.forward(image).sigmoid()

y_pred = one_hot(back(output.argmax(dim=1, keepdim=True)), num_classes=2)
y = one_hot(torch.tensor(label[None, None, ...]), num_classes=2)

In [None]:
compute_hausdorff_distance(y_pred, y,include_background=False, percentile=100) * 1.4285714626312

In [None]:
compute_hausdorff_distance(y_pred, y,include_background=False, percentile=95)

In [None]:
import medpy

medpy.metric.binary.hd(
    back(output.argmax(dim=1, keepdim=True))[0, 0].numpy(),
    label,
    voxelspacing=1.4285714626312
)

In [None]:
dice = []
dice_metric = DiceMetric(include_background=False)

for image, label in tqdm(dmd_test):
    
    output = model.forward(image.unsqueeze(0))

    with torch.no_grad():
        y_pred = output.sigmoid()
        y = one_hot(label.unsqueeze(0), num_classes=2)

        dice.append(dice_metric(y_pred=y_pred, y=y))

dice = np.array(dice)

In [None]:
dice.reshape((6, 25)).mean(axis=1)

In [None]:
output = model.forward(dmd_test[131][0].unsqueeze(0)).detach().sigmoid().argmax(dim=1)[0]

plt.imshow(dmd_test[131][0][0], cmap='gray')
plt.imshow(np.ma.masked_where(output == 0, output), cmap='Reds', alpha=0.4)

In [None]:
plt.imshow(dmd_test[131][0][0], cmap='gray')
plt.imshow(np.ma.masked_where(dmd_test[131][1][0] == 0, dmd_test[130][1][0]), cmap='Reds', alpha=0.4)

In [None]:
plt.plot(range(150), dice)

##### Case-by-case example

In [None]:
mnm_data = MnmEvaluator('../data/03_primary/mnm_test_tagged.pt').load()

In [None]:
model = Net(
    load_model='../data/06_models/model_cine_v6_tag_v1.pt',
    model_type='SegResNetVAE'
)

subject = mnm_data[12]

batch = subject['image'][tio.DATA], subject['mask'][tio.DATA]
batch = Trainer.tensor_tuple_to('cpu', batch)
image, label = batch

image = 0.18047594 * (image - image.mean()) / image.std() + 0.72535978

preds = model.forward(image)

y_pred = preds.sigmoid()
y = one_hot(label, num_classes=2)

In [None]:
DiceMetric(include_background=False)(y_pred, y)

In [None]:
compute_hausdorff_distance(one_hot(y_pred.argmax(dim=1, keepdim=True), num_classes=2), y, percentile=95, include_background=False)

In [None]:
image.shape

In [None]:
preds.sigmoid().shape

In [None]:
pred = y_pred.detach().cpu().numpy()

fig, ax = plt.subplots(2, 2)

ax[0, 0].imshow(pred[0, 0])
ax[0, 1].imshow(pred[0, 1])
ax[1, 0].imshow(y[0, 0])
ax[1, 1].imshow(y[0, 1])

In [None]:
plt.imshow(preds.sigmoid().argmax(dim=1)[0].detach())
plt.colorbar()

In [None]:
plt.imshow(image[0, 0], cmap='gray')
plt.imshow(np.ma.masked_where(label[0, 0] == 0, label[0, 0]), cmap='jet', alpha=0.4)

prediction = y_pred.argmax(dim=1).unsqueeze(0).detach().cpu()[0, 0]
plt.imshow(np.ma.masked_where(prediction == 0, prediction), cmap='viridis', alpha=0.4)

##### Checking results

In [None]:
dsets = ['mnm', 'scd']
results = {ds: PickleDataSet(f'../data/07_model_output/{ds}_results_unet.pt').load() for ds in dsets}
dfs = {ds: pd.read_csv(f'../data/08_reporting/model_cine_v7_tag_v3/{ds}_results.csv', index_col=0) for ds in dsets}

In [None]:
dicesets = []

for ds in dsets:
    subds = dfs[ds][['dice']].copy()
    subds['dataset'] = ds
    dicesets.append(subds)

diceset = pd.concat(dicesets)

In [None]:
diceset[diceset.dice > 0].groupby('dataset').mean()

In [None]:
px.histogram(diceset[diceset.dice > 0], x='dice', color='dataset', marginal='rug', barmode='overlay', nbins=20, histnorm='probability')

In [None]:
i = 30

print(results['mnm'][i].dice)
results['mnm'][i].plot(figsize=(20, 20))

In [None]:
px.bar(dfs['mnm'].groupby('VendorName').mean().reset_index(), x='VendorName', y='dice')

In [None]:
px.bar(dfs['mnm'].groupby('Centre').mean().reset_index(), x='Centre', y='dice')

In [None]:
px.bar(dfs['mnm'].groupby('Pathology').mean().reset_index(), x='Pathology', y='dice')

In [None]:
px.scatter(dfs['mnm'][dfs['mnm'].dice > 0], x=['Age', 'Height', 'Weight'], facet_col='Sex', y='dice')

##### Checking 100 images from dataset

In [None]:
# M, N = 20, 5
# fig, ax = plt.subplots(M, N, figsize=(20, 100))

# for i in range(M * N):
#     m, n = i % M, i // M
#     ax[m, n].imshow(tagged_subjects[i].image.data[0][0].cpu(), cmap='gray')
    
#     mask = tagged_subjects[i].mask.data[0][0].cpu()
#     mask = np.ma.masked_where(mask == 0, mask)
#     ax[m, n].imshow(mask, cmap='Reds', alpha=0.8)
    
#     ax[m, n].axis('off')