In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
import sys
import re
import os
import torch.optim as optim
import time
import nibabel as nib
import matplotlib.pylab as plt
import math
from torch.utils.data import DataLoader
from tqdm import tqdm
from scipy import ndimage
from datetime import datetime
from glob import glob


In [2]:
base_dir = "./"
raw_dataset_dir = "dataset/"
transformed_dataset_dir_path = "dataset/affine_transformed/"

In [3]:
is_colab = True
if is_colab:
    base_dir = "/content/drive/MyDrive/Colab Notebooks/"
    if not os.path.isdir(base_dir):
        from google.colab import drive
        drive.mount('/content/drive')

raw_dataset_dir = os.path.join(base_dir, raw_dataset_dir)
transformed_dataset_dir_path = os.path.join(base_dir, transformed_dataset_dir_path)

if os.path.isdir(raw_dataset_dir) and os.path.isdir(transformed_dataset_dir_path):
    print("dataset folder exists, OK")
else:
    raise Exception("check path for dataset:{} \n path for transformed dataset: {}"
                    .format(raw_dataset_dir, transformed_dataset_dir_path))




dataset folder exists, OK


In [4]:
sys.path.append(base_dir)
from classes.dataset_utils.toTorchDataset import ProcessedKit23TorchDataset
from classes.models.unet3d import UNet3D
from classes.epoch_results import EpochResult

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
training_data = ProcessedKit23TorchDataset(train_data=True, test_size=0.25, dataset_dir =transformed_dataset_dir_path)
test_data = ProcessedKit23TorchDataset(train_data=False, test_size=0.25, dataset_dir =transformed_dataset_dir_path)
print("size of training data:{}    size of testing dat:{}".format(len(training_data), len(test_data)))

size of training data:366    size of testing dat:123


In [7]:
model = UNet3D(1, 4).to(device)

In [8]:
criterion = nn.CrossEntropyLoss(ignore_index=-1)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

In [9]:
continue_from_checkpoint = True
epoch_res = EpochResult()
epoch_start = 0
if continue_from_checkpoint:
    print("Unet3D - loading from trained weight")
    # this continues from certain training points
    checkpoint_ref_filepath = "training_checkpoints/Model_UNET_epoch39.pth.tar"
    checkpoint_file = os.path.join(base_dir, checkpoint_ref_filepath)
    checkpoint = torch.load(checkpoint_file)
    model.load_state_dict(checkpoint['state_dict'])
    # load additional customised info from checkpoint
    optimizer.load_state_dict(checkpoint['optimizer'])
    ep_list = checkpoint['epoch_list']
    loss_list = checkpoint['loss_list']
    lr_list = checkpoint['lr_list']
    epoch_res = EpochResult(_epoch_list =ep_list, _loss_list=loss_list, _lr_list=lr_list)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
    epoch_start = epoch_res.epoch_list[-1] + 1
else:
    print("Unet3D - initialise weight")


Unet3D - loading from trained weight


In [10]:
batch_size = 4
total_batches = len(training_data) / batch_size
num_epochs = 50
model_unet_save_path = os.path.join(base_dir,"training_checkpoints/Model_UNET_epoch{}.pth.tar")
train_loader = DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)

In [11]:
train_time_start = time.time()
batches_per_epoch = len(train_loader)
model.train()
for epoch in range(epoch_start, num_epochs):
    current_lr = scheduler.get_last_lr()
    for batch_idx, batch in enumerate(train_loader):
        images, masks = batch
        images, masks = images.to(device), masks.to(device)
        masks = masks.long().squeeze(1)

        optimizer.zero_grad()
        outputs = model(images.float()).float()
        loss = criterion(outputs, masks)
        running_loss = loss.item()
        loss.backward()
        optimizer.step()

        total_processed_batches = (epoch - epoch_start) * batches_per_epoch + 1 + batch_idx
        avg_batch_time = (time.time() - train_time_start) / total_processed_batches
        if batch_idx % 4 == 0:
            print("Epoch:{}/{} batch:{}/{}   Loss:{:.4f}  avg batch time:{:1f}".format(epoch, num_epochs, batch_idx, total_batches,running_loss, avg_batch_time))
    scheduler.step()
    epoch_res.append_result(epoch, running_loss, current_lr)
    model_checkpoint_path = model_unet_save_path.format(epoch)
    torch.save({'epoch_list': epoch_res.epoch_list, 'loss_list': epoch_res.loss_list,
                'lr_list': epoch_res.lr_list, 'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()},model_checkpoint_path, _use_new_zipfile_serialization=True)

print('Finished Training')


Epoch:40/50 batch:0/91.5   Loss:0.0144  avg batch time:16.387671
Epoch:40/50 batch:4/91.5   Loss:0.0283  avg batch time:5.401740
Epoch:40/50 batch:8/91.5   Loss:0.0180  avg batch time:4.396298


KeyboardInterrupt: ignored