In [None]:
from typing import Tuple
from pathlib import Path
from tqdm.notebook import tqdm

import numpy as np
import matplotlib.pyplot as plt

import aim
import PIL

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, random_split, DataLoader

In [None]:
import os, sys
sys.path.append(os.path.abspath('..'))

from src.models.unet import UNet
from src.training.train import train
from src.data.datasets import ACDCDataset
from src.training.metrics import dice_score, DiceLoss, evaluate

In [None]:
dataset = ACDCDataset(path='../../training/', tagged=True, verbose=1)

In [None]:
train_set, val_set = random_split(dataset, [704, 248], generator=torch.Generator().manual_seed(42))
loader_train = DataLoader(train_set, batch_size=32, shuffle=True)
loader_val = DataLoader(val_set, batch_size=8, shuffle=False)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = UNet(n_channels=1, n_classes=4, bilinear=True).double()
# Load old saved version of the model
saved_model = torch.load('../checkpoints/model/model_cine_tag_v1_sd.pt')
# Extract UNet if saved model is parallelized
# if isinstance(saved_model, nn.DataParallel):
    # saved_model = saved_model.module
model.load_state_dict(saved_model)

# if device.type == 'cuda':
    # model = nn.DataParallel(model)
    # model.n_classes = model.module.n_classes

In [None]:
saved_model.keys())

In [None]:
model = model.to(device)

In [None]:
images, targets = next(iter(loader_val))
images, targets = images.double().to(device), targets.long().to(device)
# predict the mask
outputs = model(images)

In [None]:
plt.imshow(targets.detach().cpu().numpy()[5])

In [None]:
PIL.Image.fromarray(targets.detach().cpu().numpy()[5].astype('uint8') * round(256 / 4))

In [None]:
aim.Image(targets.detach().cpu().numpy()[5].astype('uint8') * round(256 / 4))

In [None]:
targets.detach().cpu().numpy()[0]

In [None]:
evaluate(model, loader_train, device)

In [None]:
images, targets = next(iter(loader_train))
images, targets = images.double().to(device), targets.long().to(device)

In [None]:
# model.eval()
outputs = model(images)

In [None]:
prediction = F.softmax(outputs, dim=1).argmax(dim=1)

In [None]:
# without model.eval()
dice_score(outputs, targets)

In [None]:
# with model.eval()
dice_score(outputs, targets)

In [None]:
# no grad
dice_score(outputs, targets)

In [None]:
fig, ax = plt.subplots(16, 3, figsize=(10, 40))

for i in range(images.shape[0]):

    ax[i, 0].imshow(images[i, 0].detach().cpu().numpy()), ax[i, 0].axis('off')
    ax[i, 1].imshow(targets[i].detach().cpu().numpy()), ax[i, 1].axis('off')
    ax[i, 2].imshow(prediction[i].detach().cpu().numpy()), ax[i, 2].axis('off')


In [None]:
evaluate(model, loader_train, device)

In [None]:
criterion = nn.CrossEntropyLoss()
dice_criterion = DiceLoss()
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=momentum)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)

In [None]:
for epoch in range(4):
    
    dice = torch.zeros(4)
    acc_loss = 0.

    model.train()

    batch_pbar = tqdm(loader_train, total=len(loader_train), unit='batch', leave=False)
    for inputs, targets in loader_train:

        batch_pbar.set_description(f'Acummulated loss: {acc_loss:.4f}')
        # move to device
        # target is index of classes
        inputs, targets = inputs.double().to(device), targets.long().to(device)
        
        with torch.cuda.amp.autocast(enabled=amp):
            outputs = model(inputs)
            loss = criterion(outputs, targets) + \
                dice_criterion(outputs, targets, exclude_bg=True)

            optimizer.zero_grad(set_to_none=True)
            grad_scaler.scale(loss).backward()
            grad_scaler.step(optimizer)
            grad_scaler.update()

        dice += dice_score(outputs, targets)
        acc_loss += loss.item()

    # Tracking training performance
    train_perf = dice / len(loader_train)
    avg_dice = train_perf.mean()

    status = f'Epoch {epoch:03} \t Loss {acc_loss:.4f} \t Dice {avg_dice:.4f}'
    
    # Tracking validation performance
    val_perf = evaluate(model, loader_val, device)
    avg_val_dice = val_perf.mean()
    scheduler.step(avg_val_dice)

    status += f'\t Val. Dice {avg_val_dice:.4f}'

    print(status)

In [None]:
image, mask = next(iter(loader_train))
output = model(image.double().to(device))

In [None]:
nn.CrossEntropyLoss()(output, mask.long().to(device))

In [None]:
DiceLoss(exclude_bg=True)(output, mask.long().to(device))

In [None]:
dice_score(output, mask.long().to(device))

In [None]:
input_soft = F.softmax(output, dim=1)

In [None]:
from kornia.utils.one_hot import one_hot

target_one_hot = one_hot(mask.long().to(device), 4, device, output.dtype)

In [None]:
dims = (2, 3)
intersection = torch.sum(input_soft * target_one_hot, dims)
cardinality = torch.sum(input_soft + target_one_hot, dims)

dice_score = 2.0 * intersection / (cardinality + 1e-8)

In [None]:
per_class = dice_score.mean(dim=0)
per_class

In [None]:
torch.mean(-per_class[1:] + 1.)