# Change these flags to train a specific model

In [None]:
TRAIN_RESNET = False
TRAIN_UNODE = True
TRAIN_UNET = False

---

In [None]:
import os
import glob
import random

import torch
import torchvision
import torch.utils.data
import torch.nn.functional as F

import cv2
import PIL
import scipy.ndimage
import skimage.measure
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm, tqdm_notebook

%matplotlib inline

from models import ConvODEUNet, ConvResUNet, ODEBlock, Unet

from IPython.display import clear_output

# Commands to download the dataset

In [None]:
if not os.path.exists('Warwick QU Dataset (Released 2016_07_08)'):
    !wget https://warwick.ac.uk/fac/sci/dcs/research/tia/glascontest/download/warwick_qu_dataset_released_2016_07_08.zip
    !unzip warwick_qu_dataset_released_2016_07_08.zip     

## Define datasets

In [None]:
cv2.setNumThreads(0)

In [None]:
from augmentations import ElasticTransformations, RandomRotationWithMask

In [None]:
class GLaSDataLoader(object):
    def __init__(self, patch_size, dataset_repeat=1, images=np.arange(0, 70), validation=False):
        self.image_fname = 'Warwick QU Dataset (Released 2016_07_08)/train_' 
        self.images = images
        
        self.patch_size = patch_size
        self.repeat = dataset_repeat
        self.validation = validation
        
        self.image_mask_transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToPILImage(),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.RandomVerticalFlip(),
            RandomRotationWithMask(45, resample=False, expand=False, center=None),
            ElasticTransformations(2000, 60),
            torchvision.transforms.ToTensor()
        ])
        self.image_transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToPILImage(),
            torchvision.transforms.ColorJitter(brightness=0.3, contrast=0.2, saturation=0.1, hue=0.1),
            torchvision.transforms.ToTensor()
        ])

    def __getitem__(self, index):
        # index to image index
        index_img = index // self.repeat
        index_img = self.images[index_img]
        index_str = str(index_img.item() + 1)
        
        image = self.image_fname + index_str + '.bmp'
        mask = self.image_fname + index_str + '_anno.bmp'
        
        image = PIL.Image.open(image)
        ratio = (775 / 512)
        new_size = (int(round(image.size[0] / ratio)), 
                    int(round(image.size[1] / ratio)))
        
        image = image.resize(new_size)
        
        mask = PIL.Image.open(mask)
        mask = mask.resize(new_size)
        
        image = np.array(image)
        mask = np.array(mask)
        
        if not self.validation:
            pad_h = max(self.patch_size[0] - image.shape[0], 128) 
            pad_w = max(self.patch_size[1] - image.shape[1], 128) # pad to image size
        else: 
            pad_h = max((self.patch_size[0] - image.shape[0]) // 2 + 1, 0)
            pad_w = max((self.patch_size[1] - image.shape[1]) // 2 + 1, 0)
            
        # pad to image size
        padded_image = np.pad(image, ((pad_h, pad_h), (pad_w, pad_w), (0, 0)), mode='reflect')
        mask = np.pad(mask, ((pad_h, pad_h), (pad_w, pad_w)), mode='reflect')

        if not self.validation:
            loc_y = random.randint(0, padded_image.shape[0] - self.patch_size[0])        
            loc_x = random.randint(0, padded_image.shape[1] - self.patch_size[1])  
        else:
            loc_y, loc_x = 0, 0
            
        patch = torch.from_numpy(padded_image.transpose(2, 0, 1)).float() / 255
        n_glands = mask.max()
        label = torch.from_numpy(mask).float() / n_glands

        if not self.validation:            
            patch_label_concat = torch.cat((patch, label[None, :, :].float()))
            patch_label_concat = self.image_mask_transforms(patch_label_concat)
            patch, label = patch_label_concat[0:3], np.round(patch_label_concat[3] * n_glands)
            patch = self.image_transforms(patch)
        else:
            label *= n_glands
            
        boundaries = torch.zeros(label.shape)
        for i in np.unique(mask):
            if i == 0: continue
            gland_mask = (label == i).float()
            binarized_mask_border = scipy.ndimage.morphology.binary_erosion(gland_mask, 
                                                                            structure=np.ones((13, 13)), 
                                                                            border_value=1)
            
            binarized_mask_border = torch.from_numpy(binarized_mask_border.astype(np.float32))
            boundaries[label == i] = binarized_mask_border[label == i]
        
        label = (label > 0).float()
        label = torch.stack((boundaries, label))
            
        patch = patch[:, loc_y:loc_y+self.patch_size[0], loc_x:loc_x+self.patch_size[1]]   
        label = label[:, loc_y:loc_y+self.patch_size[0], loc_x:loc_x+self.patch_size[1]]
        
        return patch, label.float()

    def __len__(self):
        return len(self.images) * self.repeat

In [None]:
torch.manual_seed(0)

val_set_idx = torch.LongTensor(10).random_(0, 85)
train_set_idx = torch.arange(0, 85)

overlapping = (train_set_idx[..., None] == val_set_idx).any(-1)
train_set_idx = torch.masked_select(train_set_idx, 1-overlapping)

Windows users: you may need to put the dataloader inside a different python file when using multiprocessing.

In [None]:
trainset = GLaSDataLoader((352, 512), dataset_repeat=1, images=train_set_idx)
valset = GLaSDataLoader((352, 512), dataset_repeat=1, images=val_set_idx, validation=True)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True, num_workers=10)
valloader = torch.utils.data.DataLoader(valset, batch_size=1, shuffle=False, num_workers=10)

# Plotting train data

In [None]:
fig, ax = plt.subplots(nrows=5, ncols=6, figsize=(24, 15))

for y in range(5):
    for x in range(3):
        sample = trainset[y]
        ax[y, x * 2].imshow(sample[0].numpy().transpose(1,2,0))
        ax[y, x * 2 + 1].imshow(sample[1][0])
        ax[y, x * 2].axis('off')
        ax[y, x * 2 + 1].axis('off')

plt.show();

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(24, 15))

sample = trainset[0]
ax[1].imshow(sample[1][0].numpy())
ax[2].imshow(sample[1].sum(dim=0))
ax[0].imshow(sample[0].numpy().transpose(1,2,0))

# Plotting validation data

In [None]:
fig, ax = plt.subplots(nrows=5, ncols=6, figsize=(24, 15))

for y in range(5):
    for x in range(3):
        sample = valset[y]
        ax[y, x * 2].imshow(sample[0].numpy().transpose(1,2,0))
        ax[y, x * 2 + 1].imshow(sample[1][1])
        ax[y, x * 2].axis('off')
        ax[y, x * 2 + 1].axis('off')

plt.show(); 

# Define network

In [None]:
device = torch.device('cuda')

if TRAIN_UNODE:
    net = ConvODEUNet(num_filters=16, output_dim=2, time_dependent=True, 
                      non_linearity='lrelu', adjoint=True, tol=1e-3)
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)

In [None]:
if TRAIN_RESNET:
    net = ConvResUNet(num_filters=16, output_dim=2, non_linearity='lrelu')
    net.to(device)

In [None]:
if TRAIN_UNET:
    net = Unet(depth=5, in_ch=3, out_ch=64, n_classes=2).cuda()
    net.to(device)

---

In [None]:
for m in net.modules():
    if isinstance(m, torch.nn.Conv2d):
        torch.nn.init.kaiming_normal_(m.weight)
        torch.nn.init.constant_(m.bias, 0)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
count_parameters(net)

# Train model

In [None]:
criterion = torch.nn.BCEWithLogitsLoss()
val_criterion = torch.nn.BCEWithLogitsLoss()

if TRAIN_UNET:
    cross_entropy = torch.nn.BCEWithLogitsLoss()

    def criterion(conf, labels):
        out_shape = conf.shape[2:4]
        label_shape = labels.shape[2:4]

        w = (label_shape[1] - out_shape[1]) // 2  # net.crop_left_top
        h = (label_shape[1] - out_shape[1]) // 2
        dh, dw = out_shape[0:2]

        conf_loss_ce = cross_entropy(conf, labels[:, :, h:h+dh, w:w+dw])

        return conf_loss_ce

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)

In [None]:
torch.backends.cudnn.benchmark = True

In [None]:
losses = []
val_losses = []
nfe = [[],[],[],[],[],[],[],[],[]]

In [None]:
accumulate_batch = 8  # mini-batch size by gradient accumulation
accumulated = 0

if TRAIN_RESNET: filename = 'best_border_resnet_model.pt'
elif TRAIN_UNODE: filename = 'best_border_unode_model.pt'
elif TRAIN_UNET: filename = 'best_border_unet_model.pt'

def run(lr=1e-3, epochs=100):
    accumulated = 0
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    for epoch in range(epochs):
        
        # training loop with gradient accumulation
        running_loss = 0.0
        optimizer.zero_grad()
        for data in tqdm(trainloader):
            inputs, labels = data[0].cuda(), data[1].cuda()
            outputs = net(inputs)
            loss = criterion(outputs, labels) / accumulate_batch
            loss.backward()
            accumulated += 1
            if accumulated == accumulate_batch:
                optimizer.step()
                optimizer.zero_grad()
                accumulated = 0

            running_loss += loss.item() * accumulate_batch

        losses.append(running_loss / len(trainloader))
        
        # validation loop
        with torch.no_grad():
            running_loss = 0.0
            for data in valloader:
                inputs, labels = data[0].cuda(), data[1].cuda()
                outputs = net(inputs)
                loss = criterion(outputs, labels)
                running_loss += loss.item()

            val_losses.append(running_loss / len(valloader))
            if np.argmin(val_losses) == len(val_losses) - 1 and loss < 0.4:
                torch.save(net, filename)
                
            plot_losses(inputs, outputs)
                
def plot_losses(inputs, outputs):
    # plot statistics
    if TRAIN_UNODE:
        nfe[0].append(net.odeblock_down1.odefunc.nfe)
        nfe[1].append(net.odeblock_down2.odefunc.nfe)
        nfe[2].append(net.odeblock_down3.odefunc.nfe)
        nfe[3].append(net.odeblock_down4.odefunc.nfe)
        nfe[4].append(net.odeblock_embedding.odefunc.nfe)
        nfe[5].append(net.odeblock_up1.odefunc.nfe)
        nfe[6].append(net.odeblock_up2.odefunc.nfe)
        nfe[7].append(net.odeblock_up3.odefunc.nfe)
        nfe[8].append(net.odeblock_up4.odefunc.nfe)

    clear_output(wait=True)

    if TRAIN_UNODE: cols = 4
    else: cols = 3
    fig, ax = plt.subplots(nrows=1, ncols=cols, figsize=(15,5))

    if TRAIN_UNODE: fig.suptitle('U-NODE', fontsize=16)
    elif TRAIN_RESNET: fig.suptitle('RESNET', fontsize=16)
    elif TRAIN_UNET: fig.suptitle('UNET', fontsize=16)

    ax[0].plot(np.arange(len(losses)), losses, label="loss")
    ax[0].plot(np.arange(len(val_losses)), val_losses, label="val_loss")

    if TRAIN_UNODE:
        ax[3].plot(np.arange(len(nfe[0])), nfe[0], label="down1")
        ax[3].plot(np.arange(len(nfe[0])), nfe[1], label="down2")
        ax[3].plot(np.arange(len(nfe[0])), nfe[2], label="down3")
        ax[3].plot(np.arange(len(nfe[0])), nfe[3], label="down4")
        ax[3].plot(np.arange(len(nfe[0])), nfe[4], label="embed")
        ax[3].plot(np.arange(len(nfe[0])), nfe[5], label="up1")
        ax[3].plot(np.arange(len(nfe[0])), nfe[6], label="up2")
        ax[3].plot(np.arange(len(nfe[0])), nfe[7], label="up3")
        ax[3].plot(np.arange(len(nfe[0])), nfe[8], label="up4")
        ax[3].legend() 


    outputs = torch.argmax(torch.softmax(outputs, dim=1), dim=1)[0]
    outputs = outputs.detach().cpu()
    outputs = outputs.numpy()

    ax[0].legend() 
    ax[1].imshow(outputs)
    ax[2].imshow(inputs.detach().cpu()[0].numpy().transpose(1,2,0))

    plt.show();

In [None]:
if TRAIN_UNODE or TRAIN_RESNET: lr = 1e-3 
else: lr = 1e-4

run(lr, 600 - len(losses))

# Calculate results

In [None]:
# load best model
net = torch.load(filename)

In [None]:
with torch.no_grad():
    running_loss = 0.0
    for data in tqdm(valloader):
        inputs, labels = data[0].cuda(), data[1].cuda()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        running_loss += loss.item()

    print("Check validation loss:", running_loss / len(valloader))

# Visualize results on validation set

In [None]:
from inference_utils import inference_image, postprocess

In [None]:
fig, ax = plt.subplots(nrows=5, ncols=3, figsize=(4*3,3*5))

ax[0, 0].set_title('Image')
ax[0, 1].set_title('Ground-truth')
ax[0, 2].set_title('Trained network')

for col in range(3):
    for row in range(5):
        index = val_set_idx[row]
        image = PIL.Image.open(f'Warwick QU Dataset (Released 2016_07_08)/train_{index}.bmp')
        gt = PIL.Image.open(f'Warwick QU Dataset (Released 2016_07_08)/train_{index}_anno.bmp')
        
        with torch.no_grad():
            result, input_image = inference_image(net, image, shouldpad=TRAIN_UNET)
            result = postprocess(result, gt)
        if col == 0:
            ax[row, col].imshow(image)
        elif col == 1:
            ax[row, col].imshow(np.array(gt) > 0)
        else:
            ax[row, col].imshow(image)
            ax[row, col].imshow(result, alpha=0.5)
                
        ax[row, col].set_axis_off()

plt.show(); 

# Calculate metrics on test set

In [None]:
from metrics import ObjectDice, ObjectHausdorff, F1score

In [None]:
TEST_RESNET = False
TEST_UNODE = True
TEST_UNET = False

In [None]:
if TEST_UNODE: net = torch.load('best_border_unode_model.pt')
elif TEST_RESNET: net = torch.load('best_border_resnet_model.pt')
elif TEST_UNET: net = torch.load('best_border_unet_model.pt')

In [None]:
dice, hausdorff, f1, dice_full = 0, 0, 0, 0

if TEST_UNODE: folder = 'results_unode'
elif TEST_UNET: folder = 'results_unet'
elif TEST_RESNET: folder = 'results_resnet'
    
images = []
for index in np.arange(1, 81):
    if index < 61: images.append(f'testA_{index}_anno.bmp')
    else: images.append(f'testB_{index - 60}_anno.bmp')
        
for i, fname in tqdm_notebook(enumerate(images), total=80):
    gt = PIL.Image.open(f'Warwick QU Dataset (Released 2016_07_08)/' + fname)
    image = PIL.Image.open(f'Warwick QU Dataset (Released 2016_07_08)/' + fname.replace('_anno', ''))
    result, resized = inference_image(net, image, shouldpad=TEST_UNET)
    result = postprocess(result, gt)

    gt = skimage.measure.label(np.array(gt))
    
    f1_img = F1score(result, gt)
    hausdorff_img = ObjectHausdorff(result, gt)
    dice_img = ObjectDice(result, gt)
    
    f1 += f1_img
    hausdorff += hausdorff_img
    dice += dice_img
    
    result = np.array(result) > 0
    gt = np.array(gt) > 0
    intersection = np.logical_and(result, gt)
    
    if i == 59:        
        diceA = dice 
        hausdorffA = hausdorff 
        f1A = f1

    
    print(i, f1_img, hausdorff_img, dice_img)

diceB = dice - diceA
hausdorffB = hausdorff - hausdorffA
f1B = f1 - f1A

print('ObjectDice:', dice / 80, 'A', diceA / 60, 'B', diceB / 20)
print('Hausdorff:', hausdorff / 80, 'A', hausdorffA / 60, 'B', hausdorffB / 20)
print('F1:', f1 / 80, 'A', f1A / 60, 'B', f1B / 20)