##### Imports

In [None]:
from operator import itemgetter

import numpy as np
import matplotlib.pyplot as plt

import aim

from skimage.draw import disk

from medpy.metric.binary import dc

from monai.metrics import DiceMetric, HausdorffDistanceMetric, compute_meandice
from monai.transforms import AsDiscrete, EnsureType, Compose
from monai.data import decollate_batch
from monai.losses import DiceLoss
from monai.networks import one_hot

import plotly.express as px

import kornia.augmentation as K

import torch
from torch import nn
from torch.utils.data.dataloader import default_collate
from torch.utils.data import DataLoader, TensorDataset

from kedro.extras.datasets.pickle import PickleDataSet

In [None]:
import os, sys
sys.path.append(os.path.abspath('../src'))

from tagseg.models.segmenter import Net
from tagseg.models.trainer import Trainer
from tagseg.metrics.shape import ShapeDistLoss
from tagseg.pipelines.data_splitting.nodes import split_data

##### Fetch data

In [None]:
dataset = PickleDataSet(filepath='../data/05_model_input/model_input.pt').load()

In [None]:
data_params = dict(
    train_val_split=.5,
    batch_size=8
)

ds = TensorDataset()
ds.tensors = dataset[:128]

loaders = split_data(ds, data_params)

loader = loaders['loader_val']

In [None]:
list(map(lambda kv: len(kv[1]), loaders.items()))

In [None]:
img, lab = next(iter(loaders['loader_val']))
img.shape, lab.shape

In [None]:
model = Net(learning_rate=.01, weight_decay=.001)

In [None]:
trainer = Trainer(
    epochs=10,
    device=torch.device('cuda:0'),
    logger=aim.Run(experiment='Debugging'),
    amp=True
)

In [None]:
trainer.fit(model, loaders['loader_train'], loaders['loader_val'])

In [None]:
image, label = dataset[150]

In [None]:
output = model.forward(image.unsqueeze(0))

##### Look at input data

In [None]:
ims, las = next(iter(loader))

bs = 8
rows = 2
fig, ax = plt.subplots(rows, int(bs / rows), figsize=(10, 10))

for i in range(bs):
    m, n = i % rows, i // rows

    ax[m, n].imshow(ims[i, 0].numpy(), cmap='gray'), ax[m, n].axis('off') 
    ax[m, n].imshow(las[i, 0].numpy(), cmap='Reds', alpha=0.5), ax[m, n].axis('off')

##### Look at predictions from loader

In [None]:
proba: float = 0.2

train_aug = K.AugmentationSequential(
    K.RandomHorizontalFlip(p=proba),
    K.RandomVerticalFlip(p=proba),
    K.RandomElasticTransform(p=proba),
    K.RandomGaussianNoise(p=proba),
    K.RandomSharpness(p=proba),
    K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 0.1), p=proba),
    data_keys=["input", "mask"],
)

ims, las = train_aug(ims, las)

with torch.cuda.amp.autocast(enabled=True):
    output = model.forward(ims)

pred = output.detach().numpy()

In [None]:
fig, ax = plt.subplots(8, 4, figsize=(10, 20))

for i in range(8):
    ax[i, 0].imshow(las[i, 0]), ax[i, 0].axis('off')
    ax[i, 1].imshow(pred[i, 0]), ax[i, 1].axis('off')
    ax[i, 2].imshow(pred[i, 1]), ax[i, 2].axis('off')
    ax[i, 3].imshow(output.argmax(dim=1).detach().numpy()[i]), ax[i, 3].axis('off')

##### Test loss

In [None]:
model.loss_fn(output, las)

In [None]:
dc = DiceLoss(include_background=False, to_onehot_y=True, softmax=True)
si = ShapeDistLoss(include_background=False, to_onehot_y=True, smooth_k=.2)
ce = nn.CrossEntropyLoss()

In [None]:
output.shape, las.shape

In [None]:
dcm = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)

In [None]:
y_pred = one_hot(output.sigmoid().argmax(dim=1).unsqueeze(1), num_classes=2)
y = one_hot(las, num_classes=2)

[ dcm(o.unsqueeze(0), l.unsqueeze(0)).item() for o, l in zip(y_pred, y)]

In [None]:
compute_meandice(y_pred, y, include_background=False)

In [None]:
dcm(y_pred, y)

In [None]:
a = dcm.aggregate().item()
print(a)
dcm.reset()
print(a)

In [None]:
[dc(o.unsqueeze(0), l.unsqueeze(0)).item() for o, l in zip(output, las)]

In [None]:
[ce(o.unsqueeze(0), l.long()).item() for o, l in zip(output, las)]

In [None]:
[si(o.unsqueeze(0), l.unsqueeze(0)).item() for o, l in zip(output, las)]