In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
import sys
import re
import os
import torch.optim as optim
import time
import nibabel as nib
import matplotlib.pylab as plt
import math
from torch.utils.data import DataLoader
from tqdm import tqdm
from scipy import ndimage
from datetime import datetime
from glob import glob


# 3D UNET
- the model of 3D UNET is provided at <mark> from classes.models.unet3d import UNet3D </mark>
They are few important parameters that are essential to extract features better.
- The UNET model uses 3D convolution. It has 4 layers in the model.
- Default kernel size for Double convolution is 3 or (3x3x3)
- Number of features channels: Increase the number of channels for features enable prediction of classes.
- channel selector 0: (4, 8, 16, 32, 64) has failed to extract any class but background.
- channel selector 1: (8, 16, 32, 64, 128) can obtain segmentation for kidney well. However, features for cyst and tumor could not be successfully predicted.


In [2]:
base_dir = "./"
raw_dataset_dir = "dataset/"
transformed_dataset_dir_path = "dataset/affine_transformed/"

In [3]:
is_colab = True
if is_colab:
    base_dir = "/content/drive/MyDrive/Colab Notebooks/"
    if not os.path.isdir(base_dir):
        from google.colab import drive
        drive.mount('/content/drive')

raw_dataset_dir = os.path.join(base_dir, raw_dataset_dir)
transformed_dataset_dir_path = os.path.join(base_dir, transformed_dataset_dir_path)

if os.path.isdir(raw_dataset_dir) and os.path.isdir(transformed_dataset_dir_path):
    print("dataset folder exists, OK")
else:
    raise Exception("check path for dataset:{} \n path for transformed dataset: {}"
                    .format(raw_dataset_dir, transformed_dataset_dir_path))




dataset folder exists, OK


In [4]:
sys.path.append(base_dir)
from classes.dataset_utils.toTorchDataset import ProcessedKit23TorchDataset
from classes.models.unet3d import UNet3D
from classes.epoch_results import EpochResult

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
training_data = ProcessedKit23TorchDataset(train_data=True, test_size=0.25, dataset_dir =transformed_dataset_dir_path)
test_data = ProcessedKit23TorchDataset(train_data=False, test_size=0.25, dataset_dir =transformed_dataset_dir_path)
print("size of training data:{}    size of testing dat:{}".format(len(training_data), len(test_data)))

size of training data:366    size of testing dat:123


## Reduce Training Cases and Test Cases
- Following is used to reduce number of Training and Test casess

In [7]:
is_simplified = True
# to demo, only 10 test cases are tested.
if is_simplified:
    training_data.case_dirs = training_data.case_dirs[:100]
    training_data.case_names = training_data.case_names[:100]
    test_data.case_dirs = test_data.case_dirs[:10]
    test_data.case_names = test_data.case_names[:10]

In [8]:
channel_selection = 2
ks = 5
model = UNet3D(1, 4, channel_selection=channel_selection, double_conv_kernel_size=ks).to(device)
model._initialize_weights()

## Optimizer or Gradient Descent Model
- Enable choose of ADAM or SGD
- Adjust learning rate decay manually. Higher gamma if there are high number of test data. For 100 cases, gamma 0.95 is used.

In [9]:
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-3)
is_ADAM = True
if is_ADAM:
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
criterion = nn.CrossEntropyLoss(ignore_index=-1)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

In [10]:
continue_from_checkpoint = False
epoch_res = EpochResult()
epoch_start = 0
if continue_from_checkpoint:
    print("Unet3D - loading from trained weight")
    checkpoint_ref_filepath = None
    # this continues from certain training points
    if is_ADAM:
        checkpoint_ref_filepath = "training_checkpoints/Model_UNET_epoch40.pth.tar"
    else:
        checkpoint_ref_filepath = "training_checkpoints/Model_UNET_SGD_epoch40.pth.tar"
    checkpoint_file = os.path.join(base_dir, checkpoint_ref_filepath)
    checkpoint = torch.load(checkpoint_file)
    model.load_state_dict(checkpoint['state_dict'])
    # load additional customised info from checkpoint
    optimizer.load_state_dict(checkpoint['optimizer'])
    ep_list = checkpoint['epoch_list']
    loss_list = checkpoint['loss_list']
    lr_list = checkpoint['lr_list']
    epoch_res = EpochResult(_epoch_list =ep_list, _loss_list=loss_list, _lr_list=lr_list)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
    epoch_start = epoch_res.epoch_list[-1] + 1
else:
    print("Unet3D - was initialised with weight")


Unet3D - was initialised with weight


## Training params
Batch size used  
- channel selector 0: batch size 6
- channel selector 1: batch size 3
- channel selector 2: batch size 1

In [11]:
batch_size = 1
total_batches = math.ceil(len(training_data) / batch_size)
num_epochs = 100
model_unet_save_path = os.path.join(base_dir,"training_checkpoints/Model_UNET_ch{}_ks{}_epoch{}.pth.tar")
if not is_ADAM:
    model_unet_save_path = os.path.join(base_dir,"training_checkpoints/Model_UNET_ch{}_ks{}_SGD_epoch{}.pth.tar")
train_loader = DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)

## Training Loop
- Cross validation during training is commented out. This is because training is extremely costly and the team has already used Colab GPU Tesla T4 for the task.
- Please note that compute unit is not free in Colab.

In [12]:
train_time_start = time.time()
batches_per_epoch = len(train_loader)

for epoch in range(epoch_start, num_epochs):
    model.train()
    current_lr = scheduler.get_last_lr()[0]
    for batch_idx, batch in enumerate(train_loader):
        images, masks = batch
        images, masks = images.to(device), masks.to(device)
        masks = masks.long().squeeze(1)
        optimizer.zero_grad()
        outputs = model(images.float())
        loss = criterion(outputs, masks)
        running_loss = loss.item()
        loss.backward()
        optimizer.step()

        total_processed_batches = (epoch - epoch_start) * batches_per_epoch + 1 + batch_idx
        avg_batch_time = (time.time() - train_time_start) / total_processed_batches
        if batch_idx % 5 == 0:
            print("Epoch:{}/{} batch:{}/{}   Loss:{:.4f}  avg batch time:{:.1f} LR={:.6f}".format(epoch, num_epochs, batch_idx, total_batches,running_loss, avg_batch_time, current_lr))
    scheduler.step()
    epoch_res.append_result(epoch, running_loss, current_lr)
    model_checkpoint_path = model_unet_save_path.format(channel_selection, ks, epoch)
    torch.save({'epoch_list': epoch_res.epoch_list, 'loss_list': epoch_res.loss_list,
                'lr_list': epoch_res.lr_list, 'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()},model_checkpoint_path, _use_new_zipfile_serialization=True)
    # Validation after each epoch
    # model.eval()
    # total_loss = 0.0
    # with torch.no_grad():
    #     for batch in test_loader:
    #         images, masks = batch
    #         images, masks = images.to(device), masks.to(device)
    #         masks = masks.long().squeeze(1)

    #         optimizer.zero_grad()
    #         outputs = model(images.float())
    #         loss = criterion(outputs, masks)
    #         total_loss += loss.item()

    # average_loss = total_loss / len(test_loader)
    # print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {average_loss:.4f}")

print('Finished Training')


Epoch:0/100 batch:0/100   Loss:1.9070  avg batch time:5.3 LR=0.001000
Epoch:0/100 batch:5/100   Loss:1.1907  avg batch time:3.0 LR=0.001000
Epoch:0/100 batch:10/100   Loss:0.7994  avg batch time:2.7 LR=0.001000
Epoch:0/100 batch:15/100   Loss:0.6093  avg batch time:2.6 LR=0.001000
Epoch:0/100 batch:20/100   Loss:0.6072  avg batch time:2.6 LR=0.001000
Epoch:0/100 batch:25/100   Loss:0.4412  avg batch time:2.6 LR=0.001000
Epoch:0/100 batch:30/100   Loss:0.4188  avg batch time:2.6 LR=0.001000
Epoch:0/100 batch:35/100   Loss:0.3485  avg batch time:2.6 LR=0.001000
Epoch:0/100 batch:40/100   Loss:0.3406  avg batch time:2.6 LR=0.001000
Epoch:0/100 batch:45/100   Loss:0.3147  avg batch time:2.5 LR=0.001000
Epoch:0/100 batch:50/100   Loss:0.3212  avg batch time:2.5 LR=0.001000
Epoch:0/100 batch:55/100   Loss:0.2538  avg batch time:2.5 LR=0.001000
Epoch:0/100 batch:60/100   Loss:0.2424  avg batch time:2.5 LR=0.001000
Epoch:0/100 batch:65/100   Loss:0.2334  avg batch time:2.5 LR=0.001000
Epoch:0/

KeyboardInterrupt: ignored