Workflow

1. Split the dataset into a training, validation, and test set by randomly moving some percentage of the provided data into a validation and test directory—split the data evenly between the two, but make the training set the largest. Make sure that the training, validation, and test images do no overlap.
2. Use the DataLoader class to create the loading mechanism for the training and validation data using the Dataset class built in Step 2.
3. Build a training loop using MSE as the loss function. Determine an optimizer (pick between SGD or Adam).
4. Instantiate a model and train the network with the created routine.

In [1]:
import os
import numpy as np
from nifti_dataset import NiftiDataset, RandomCrop3D, ToTensor
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import transforms

import matplotlib.pyplot as plt
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
input_dir = "/home/nbaranov/projects/04_cv/MedicalImageAnalysis/data/small_data/small/"
f_size = (14,8)

t1_dir = os.path.join(input_dir, 't1')
t2_dir = os.path.join(input_dir, 't2')

In [3]:
valid_split = 0.1
batch_size = 16
n_jobs = 8
n_epochs = 50

In [4]:
tfms = transforms.Compose([RandomCrop3D((32, 32, 90)), ToTensor()])

# set up training and validation data loader for nifti images
dataset = NiftiDataset(t1_dir, t2_dir, tfms, preload=False)  # set preload=False if you have limited CPU memory
num_train = len(dataset)
indices = list(range(num_train))
val_size = int(valid_split * num_train)

In [5]:
valid_idx = np.random.choice(indices, size=val_size, replace=False)
train_idx = list(set(indices) - set(valid_idx))
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = DataLoader(dataset, sampler=train_sampler, batch_size=batch_size,
                          num_workers=n_jobs, pin_memory=True)
valid_loader = DataLoader(dataset, sampler=valid_sampler, batch_size=batch_size,
                          num_workers=n_jobs, pin_memory=True)

In [6]:
assert torch.cuda.is_available()
device = torch.device('cuda:0')
torch.backends.cudnn.benchmark = True

In [7]:
from model import SimpleEncDec

model = SimpleEncDec((batch_size, 1, 32, 32, 90))
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), weight_decay=1e-6)
criterion = torch.nn.SmoothL1Loss()  #nn.MSELoss()

In [2]:
train_losses, valid_losses = [], []
n_batches = len(train_loader)

for t in range(1, n_epochs + 1):
    t_losses = []
    model.train()
    for i, (src, tgt) in enumerate(train_loader):
        src, tgt = src.float().to(device), tgt.float().to(device)

        out = model(src)
        loss = criterion(out, tgt)
        t_losses.append(loss.item())

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    train_losses.append(t_losses)

    # validation
    v_losses = []
    model.eval()
    with torch.no_grad():
        for src, tgt in valid_loader:
            src, tgt = src.float().to(device), tgt.float().to(device)

            out = model(src)
            loss = criterion(out, tgt)
            v_losses.append(loss.item())
        valid_losses.append(v_losses)

    if not np.all(np.isfinite(t_losses)):
        raise RuntimeError('NaN or Inf in training loss, cannot recover. Exiting.')
    log = f'Epoch: {t} - Training Loss: {round(np.mean(t_losses), 5)}, ' \
          f'Validation Loss: {round(np.mean(v_losses), 5)}'
    print(log)

In [1]:
torch.save(model.state_dict(), 'results/trained.pth')