In [2]:
import math
import numpy as np
import torch
import pkbar
from unet3d.config import *
from tqdm import tqdm
from torch.nn import CrossEntropyLoss
from torch.nn.functional import one_hot
from torch.optim import Adam
#from torch.utils.tensorboard import SummaryWriter
from unet3d.unet3d import UNet3D
from patchify import patchify
import nrrd

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
#writer = SummaryWriter("runs")
device = torch.device('cpu')
model = UNet3D(in_channels=IN_CHANNELS , num_classes= NUM_CLASSES).to(device)
criterion = CrossEntropyLoss(weight=torch.Tensor(np.array(CE_WEIGHTS)/np.array(CE_WEIGHTS).sum()))
optimizer = Adam(params=model.parameters())

min_valid_loss = math.inf

In [4]:
# Load some data
scan, _ = nrrd.read(DATASET_PATH + '/SAIAD 1/scan.nrrd')
segm, _ = nrrd.read(DATASET_PATH + '/SAIAD 1/segm.nrrd')

scan_patches = patchify(scan, PATCH_SIZE, step=PATCH_SIZE).reshape(-1,64,64,64)
scan_patches = torch.tensor(scan_patches).double()
scan_patches = torch.unsqueeze(scan_patches,1) # add channel dimension
segm_patches = patchify(segm, PATCH_SIZE, step=PATCH_SIZE).reshape(-1,64,64,64)
segm_patches = one_hot(torch.tensor(segm_patches).to(torch.int64), num_classes=NUM_CLASSES).permute(0,4,1,2,3).double()


print([scan_patches.shape, scan_patches.dtype])
print([segm_patches.shape, segm_patches.dtype])


[torch.Size([128, 1, 64, 64, 64]), torch.float64]
[torch.Size([128, 5, 64, 64, 64]), torch.float64]


In [5]:
# Training
model = model.double() # Convert all params to double

for epoch in range(TRAINING_EPOCH):
    # progress bar
    kbar = pkbar.Kbar(target=120, epoch=epoch, num_epochs=TRAINING_EPOCH, width=8, always_stateful=False)

    train_loss = 0.0
    model.train()
    for i in range(50,80):
        image, ground_truth = scan_patches[i:i+TRAIN_BATCH_SIZE], segm_patches[i:i+TRAIN_BATCH_SIZE]
        optimizer.zero_grad()
        target = model(image)
        loss = criterion(target, ground_truth)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        kbar.update(i, values=[("loss", train_loss)])
    
    valid_loss = 0.0
    model.eval()
    for i in range(120,128):
        image, ground_truth = scan_patches[i:i+TRAIN_BATCH_SIZE], segm_patches[i:i+TRAIN_BATCH_SIZE]
        
        target = model(image)
        loss = criterion(target,ground_truth)
        valid_loss = loss.item()
        
    #writer.add_scalar("Loss/Train", train_loss / 120, epoch)
    #writer.add_scalar("Loss/Validation", valid_loss / 8, epoch)
    
    print(f'Epoch {epoch+1} \t\t Training Loss: {train_loss / 8} \t\t Validation Loss: {valid_loss / 8}')
    
    if min_valid_loss > valid_loss:
        print(f'Validation Loss Decreased({min_valid_loss:.6f}--->{valid_loss:.6f}) \t Saving The Model')
        min_valid_loss = valid_loss
        # Saving State Dict
        torch.save(model.state_dict(), f'checkpoints/epoch{epoch}_valLoss{min_valid_loss}.pth')


Epoch: 1/100
 79/120 [====>...] - ETA: 5:47 - loss: 0.1730

NameError: name 'writer' is not defined

In [None]:
pred = model(scan_patches[58:59])

In [None]:
pred.size()

In [None]:
sys.path.insert(1, '../Helpers')
from Visualization import ImageSliceViewer3D
pred_index = np.array(torch.argmax(pred[0], dim=0))
ImageSliceViewer3D(pred_index, np.array(scan_patches[58:59][0,0]))

In [None]:
np.max(pred)

In [None]:
ImageSliceViewer3D(np.array(scan_patches[58:59][0,0]))