In [3]:
from typing import Tuple
from pathlib import Path
from tqdm.notebook import tqdm

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, random_split, DataLoader

In [4]:
import os, sys
sys.path.append(os.path.abspath('..'))

from src.models.unet import UNet
from src.training.metrics import evaluate, dice_loss, dice
from src.training.train import train
from src.data.datasets import ACDCDataset

In [10]:
dataset = ACDCDataset(path='../../training/', recompute=True, tagged=False, verbose=1)

Processing patient002...: 100%|██████████| 100/100 [00:41<00:00,  2.39it/s]


Saved dataset of 952 images to /worskpace/dev/tagroi/checkpoints/acdc_dataset_cine.pt


In [4]:
train_set, val_set, _ = random_split(dataset, [8, 4, 940], generator=torch.Generator().manual_seed(42))
loader_train = DataLoader(train_set, batch_size=4, shuffle=True)
loader_val = DataLoader(val_set, batch_size=4, shuffle=False)

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = UNet(n_channels=1, n_classes=4, bilinear=True).double()

if device.type == 'cuda':
    model = nn.DataParallel(model).to(device)
    model.n_classes = model.module.n_classes

In [6]:
learning_rate = 1e-1
weight_decay = 1e-3
momentum = .9
amp = True

In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=weight_decay,
                                momentum=momentum)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)

In [8]:
for epoch in range(4):
    
    dice_score = torch.zeros(4)
    acc_loss = 0.

    model.train()

    batch_pbar = tqdm(loader_train, total=len(loader_train), unit='batch', leave=False)
    for inputs, targets in loader_train:

        batch_pbar.set_description(f'Acummulated loss: {acc_loss:.4f}')
        # move to device
        # target is index of classes
        inputs, targets = inputs.double().to(device), targets.long().to(device)
        
        with torch.cuda.amp.autocast(enabled=amp):
            outputs = model(inputs)
            loss = criterion(outputs, targets) + \
                dice_loss(F.softmax(outputs, dim=1), targets, exclude_bg=True)

            optimizer.zero_grad(set_to_none=True)
            grad_scaler.scale(loss).backward()
            grad_scaler.step(optimizer)
            grad_scaler.update()

        dice_score += dice(F.softmax(outputs, dim=1), targets)
        acc_loss += loss.item()

    # Tracking training performance
    train_perf = dice_score / len(loader_train)
    avg_dice = train_perf.mean()

    status = f'Epoch {epoch:03} \t Loss {acc_loss:.4f} \t Dice {avg_dice:.4f}'
    
    # Tracking validation performance
    val_perf = evaluate(model, loader_val, device)
    avg_val_dice = val_perf.mean()
    scheduler.step(avg_val_dice)

    status += f'\t Val. Dice {avg_val_dice:.4f}'

    print(status)

  0%|          | 0/2 [00:00<?, ?batch/s]



Epoch 000 	 Loss 6.9501 	 Dice 0.1680	 Val. Dice 0.2463


  0%|          | 0/2 [00:00<?, ?batch/s]


[A
[A
[A

Epoch 001 	 Loss 9.4449 	 Dice 0.2440	 Val. Dice 0.2463


  0%|          | 0/2 [00:00<?, ?batch/s]



[A[A

[A[A

[A[A

Epoch 002 	 Loss 3.5048 	 Dice 0.2440	 Val. Dice 0.2463


  0%|          | 0/2 [00:00<?, ?batch/s]




[A[A[A


[A[A[A


[A[A[A

Epoch 003 	 Loss 4.7997 	 Dice 0.2440	 Val. Dice 0.2463
