In [1]:
from typing import Tuple
from pathlib import Path
from tqdm.notebook import tqdm

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, random_split, DataLoader

import aim

In [11]:
import os, sys
sys.path.append(os.path.abspath('..'))

from src.models.unet import UNet
from src.training.metrics import evaluate, dice_loss, dice
from src.training.train import train
from src.data.datasets import ACDCDataset

In [4]:
dataset = ACDCDataset(path='../../training/', verbose=1)

Loaded saved dataset from /worskpace/tagroi/checkpoints/acdc_dataset.pt


In [9]:
train_set, val_set, _ = random_split(dataset, [8, 4, 940], generator=torch.Generator().manual_seed(42))
loader_train = DataLoader(train_set, batch_size=4, shuffle=True)
loader_val = DataLoader(val_set, batch_size=4, shuffle=False)

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = UNet(n_channels=1, n_classes=4, bilinear=True).double()

if device.type == 'cuda':
    model = nn.DataParallel(model).to(device)
    model.n_classes = model.module.n_classes

In [14]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

In [15]:
dice_score = torch.zeros(4)

In [19]:
for epoch in range(40):
    
    acc_loss = 0.

    model.train()

    for inputs, targets in loader_train:
        # move to device
        # target is index of classes
        inputs, targets = inputs.to(device), targets.long().to(device)
        optimizer.zero_grad()
        
        outputs = model(inputs)
        loss = criterion(outputs, targets) + dice_loss(outputs, targets, True)
        
        dice_score += dice(outputs, targets)
        acc_loss += loss.item()

        loss.backward()
        optimizer.step()

    print(f'======= Epoch {epoch}')
    print(f'Accumulated loss {acc_loss}')

    train_perf = dice_score / len(loader_train)
    avg_dice = train_perf.mean()
    print(f'Training performance {train_perf}, {avg_dice}')

    val_perf = evaluate(model, loader_val, device)
    avg_val_dice = val_perf.mean()
    print(f'Validation performance {val_perf}, {avg_val_dice}')

Accumulated loss 0.8283676990545923
Training performance tensor([44.5099, 12.0919,  4.7984, 11.5418]), 18.235490798950195


                                                                                         

Validation performance tensor([0.9851, 0.0000, 0.0000, 0.0000]), 0.24627509713172913
Accumulated loss 1.032632250051536
Training performance tensor([45.5054, 12.3852,  5.3171, 12.1109]), 18.829662322998047


                                                                                         

Validation performance tensor([0.9568, 0.0658, 0.0000, 0.0000]), 0.2556353807449341
Accumulated loss 1.0901274357167152
Training performance tensor([46.5008, 12.6119,  5.8070, 12.6817]), 19.400360107421875


                                                                                         

Validation performance tensor([0.9851, 0.0000, 0.0000, 0.0000]), 0.24627509713172913
Accumulated loss 1.087571846049049
Training performance tensor([47.4928, 12.8928,  6.1683, 13.3587]), 19.97814178466797


                                                                                         

Validation performance tensor([0.9851, 0.0000, 0.0000, 0.0000]), 0.24627509713172913
Accumulated loss 1.0583362358629944
Training performance tensor([48.4848, 13.2248,  6.6707, 13.9041]), 20.57109260559082


                                                                                         

Validation performance tensor([0.9851, 0.0000, 0.0000, 0.0000]), 0.24627509713172913
Accumulated loss 1.1455142620910275
Training performance tensor([49.4775, 13.5555,  7.0586, 14.3917]), 21.120819091796875


                                                                                         

Validation performance tensor([0.9698, 0.0000, 0.0000, 0.0000]), 0.24243803322315216
Accumulated loss 1.0149418322802188
Training performance tensor([50.4698, 13.9595,  7.5345, 14.9657]), 21.732357025146484


                                                                                         

Validation performance tensor([0.9851, 0.0000, 0.0000, 0.0000]), 0.24627509713172913
Accumulated loss 0.8977391808722479
Training performance tensor([51.4641, 14.4997,  8.0243, 15.5736]), 22.390419006347656


                                                                                         

Validation performance tensor([0.9841, 0.0000, 0.0000, 0.0000]), 0.24601994454860687
Accumulated loss 0.7676395275313795
Training performance tensor([52.4573, 15.1806,  8.5943, 16.2119]), 23.111032485961914


                                                                                         

Validation performance tensor([0.9741, 0.0000, 0.0000, 0.0000]), 0.24352462589740753
Accumulated loss 0.7692211616290908
Training performance tensor([53.4518, 15.8329,  9.1497, 16.8746]), 23.827251434326172


                                                                                         

Validation performance tensor([0.9667, 0.0000, 0.0000, 0.0000]), 0.24168749153614044
Accumulated loss 0.8400401719165698
Training performance tensor([54.4468, 16.4048,  9.6122, 17.5421]), 24.501497268676758


                                                                                         

Validation performance tensor([0.9694, 0.0000, 0.0000, 0.0000]), 0.24235956370830536
Accumulated loss 0.8175725096366172
Training performance tensor([55.4415, 16.9442, 10.1787, 18.1922]), 25.18915367126465


                                                                                         

Validation performance tensor([0.9702, 0.0109, 0.0355, 0.0000]), 0.25415781140327454
Accumulated loss 0.8977398665321363
Training performance tensor([56.4356, 17.4718, 10.7613, 18.6913]), 25.8399658203125


                                                                                         

Validation performance tensor([0.9826, 0.0000, 0.0000, 0.0000]), 0.2456551194190979
Accumulated loss 1.0200799499578777
Training performance tensor([57.4212, 17.8620, 11.2775, 19.2960]), 26.464176177978516


                                                                                         

Validation performance tensor([0.9709, 0.0766, 0.0655, 0.1482]), 0.3152843117713928
Accumulated loss 0.7351758282744126
Training performance tensor([58.4139, 18.4856, 11.8638, 20.0083]), 27.192873001098633


                                                                                         

Validation performance tensor([0.9707, 0.0000, 0.0000, 0.0000]), 0.24266959726810455
Accumulated loss 0.9849214536034709
Training performance tensor([59.4050, 19.0154, 12.1786, 20.6410]), 27.80998992919922


                                                                                         

Validation performance tensor([9.7785e-01, 9.2251e-03, 0.0000e+00, 5.6227e-04]), 0.24691037833690643
Accumulated loss 0.8737969637295273
Training performance tensor([60.3971, 19.6042, 12.5996, 21.2904]), 28.472810745239258


                                                                                         

Validation performance tensor([0.9851, 0.0000, 0.0000, 0.0000]), 0.24627509713172913
Accumulated loss 0.8397681078475018
Training performance tensor([61.3902, 20.0625, 13.1459, 21.9956]), 29.1485595703125


                                                                                         

Validation performance tensor([0.9848, 0.0408, 0.0000, 0.0000]), 0.2563990354537964
Accumulated loss 0.8143875165502492
Training performance tensor([62.3830, 20.6866, 13.6300, 22.6498]), 29.83734130859375


                                                                                         

Validation performance tensor([0.9848, 0.0000, 0.0000, 0.0000]), 0.24621020257472992
Accumulated loss 0.9283637831343667
Training performance tensor([63.3755, 21.2576, 14.1406, 23.0878]), 30.465391159057617


                                                                                         

Validation performance tensor([0.9803, 0.0000, 0.0000, 0.0000]), 0.24508239328861237
Accumulated loss 0.7938802548243492
Training performance tensor([64.3700, 21.8199, 14.7156, 23.7199]), 31.156354904174805


                                                                                         

Validation performance tensor([0.9760, 0.0000, 0.0000, 0.0000]), 0.24399150907993317
Accumulated loss 0.7361454228123092
Training performance tensor([65.3647, 22.4337, 15.2538, 24.4361]), 31.872055053710938


                                                                                         

Validation performance tensor([0.9618, 0.1286, 0.0604, 0.1028]), 0.31341126561164856
Accumulated loss 0.7040246859217633
Training performance tensor([66.3591, 23.1392, 15.7567, 25.1426]), 32.599388122558594


                                                                                         

Validation performance tensor([0.9623, 0.0000, 0.0000, 0.0000]), 0.24058303236961365
Accumulated loss 0.5439285152600883
Training performance tensor([67.3550, 23.9073, 16.4358, 25.8982]), 33.399078369140625


                                                                                         

Validation performance tensor([0.9758, 0.0000, 0.0000, 0.0000]), 0.24395450949668884
Accumulated loss 0.5300298683996058
Training performance tensor([68.3510, 24.7005, 17.1528, 26.6159]), 34.20504379272461


                                                                                         

Validation performance tensor([0.9545, 0.0858, 0.0649, 0.0775]), 0.2956569194793701
Accumulated loss 0.5014237571100146
Training performance tensor([69.3475, 25.5205, 17.8069, 27.4020]), 35.019203186035156


                                                                                         

Validation performance tensor([0.9809, 0.0000, 0.0000, 0.0000]), 0.24521568417549133
Accumulated loss 0.46136526158211866
Training performance tensor([70.3446, 26.3587, 18.4848, 28.2026]), 35.84768295288086


                                                                                         

Validation performance tensor([0.9851, 0.0000, 0.0000, 0.0000]), 0.24627509713172913
Accumulated loss 0.4290432355345166
Training performance tensor([71.3415, 27.1906, 19.2219, 29.0127]), 36.69165802001953


                                                                                         

Validation performance tensor([0.9851, 0.0000, 0.0000, 0.0000]), 0.24627509713172913
Accumulated loss 0.4289609405988998
Training performance tensor([72.3382, 28.0382, 19.9821, 29.7749]), 37.53334045410156


                                                                                         

Validation performance tensor([0.9693, 0.0850, 0.0584, 0.1210]), 0.3084093928337097
Accumulated loss 0.38863115219559147
Training performance tensor([73.3353, 28.9231, 20.6977, 30.6181]), 38.39354705810547


                                                                                         

Validation performance tensor([0.9621, 0.0021, 0.0000, 0.0000]), 0.2410452663898468
Accumulated loss 0.39194855192567535
Training performance tensor([74.3327, 29.8092, 21.4156, 31.4400]), 39.249385833740234


                                                                                         

Validation performance tensor([0.9851, 0.0000, 0.0000, 0.0000]), 0.24627509713172913
Accumulated loss 0.36381415205801426
Training performance tensor([75.3297, 30.6921, 22.1700, 32.2841]), 40.118988037109375


                                                                                         

Validation performance tensor([0.9851, 0.0000, 0.0000, 0.0000]), 0.24627509713172913
Accumulated loss 0.27423658896085046
Training performance tensor([76.3274, 31.6150, 22.9739, 33.1938]), 41.027523040771484


                                                                                         

Validation performance tensor([0.9851, 0.0000, 0.0000, 0.0000]), 0.24627509713172913
Accumulated loss 0.31567123565975774
Training performance tensor([77.3248, 32.5015, 23.7438, 34.0971]), 41.916786193847656


                                                                                         

Validation performance tensor([0.9851, 0.0000, 0.0000, 0.0000]), 0.24627509713172913
Accumulated loss 0.28473805494503973
Training performance tensor([78.3221, 33.3996, 24.5405, 35.0164]), 42.819644927978516


                                                                                         

Validation performance tensor([0.9851, 0.0000, 0.0000, 0.0000]), 0.24627509713172913
Accumulated loss 0.25933002618555717
Training performance tensor([79.3198, 34.3185, 25.3525, 35.9387]), 43.73237609863281


                                                                                         

Validation performance tensor([0.9851, 0.0000, 0.0000, 0.0000]), 0.24627509713172913
Accumulated loss 0.2859986650074239
Training performance tensor([80.3173, 35.2179, 26.1549, 36.8359]), 44.631507873535156


                                                                                         

Validation performance tensor([0.9851, 0.0000, 0.0000, 0.0000]), 0.24627509713172913
Accumulated loss 0.24214123561873727
Training performance tensor([81.3151, 36.1400, 26.9825, 37.7624]), 45.55000305175781


                                                                                         

Validation performance tensor([0.9851, 0.0000, 0.0000, 0.0000]), 0.24627509713172913
Accumulated loss 0.23200964322905798
Training performance tensor([82.3129, 37.0566, 27.8264, 38.6968]), 46.47315216064453


                                                                                         

Validation performance tensor([0.9851, 0.0000, 0.0000, 0.0000]), 0.24627509713172913
Accumulated loss 0.22857779994756433
Training performance tensor([83.3106, 37.9857, 28.6558, 39.6358]), 47.396968841552734


                                                                                         

Validation performance tensor([0.9851, 0.0000, 0.0000, 0.0000]), 0.24627509713172913


