# Change these flags to train a specific model

In [None]:
TRAIN_RESNET = False
TRAIN_UNODE = True
TRAIN_UNET = False

def get_title():
    if TRAIN_UNODE: return 'U-NODE'
    elif TRAIN_RESNET: return 'RESNET'
    elif TRAIN_UNET: return 'UNET'

---

In [None]:
import os
import glob
import random

import torch
import torch.utils.data

import PIL
import skimage.measure
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm, tqdm_notebook

%matplotlib inline

from models import ConvODEUNet, ConvResUNet, ODEBlock, Unet
from dataloader import GLaSDataLoader
from train_utils import plot_losses

from IPython.display import clear_output

# Commands to download the dataset

In [None]:
if not os.path.exists('Warwick QU Dataset (Released 2016_07_08)'):
    !wget https://warwick.ac.uk/fac/sci/dcs/research/tia/glascontest/download/warwick_qu_dataset_released_2016_07_08.zip
    !unzip warwick_qu_dataset_released_2016_07_08.zip     

## Define datasets

In [None]:
torch.manual_seed(0)

val_set_idx = torch.LongTensor(10).random_(0, 85)
train_set_idx = torch.arange(0, 85)

overlapping = (train_set_idx[..., None] == val_set_idx).any(-1)
train_set_idx = torch.masked_select(train_set_idx, ~overlapping)

In [None]:
trainset = GLaSDataLoader((352, 512), dataset_repeat=1, images=train_set_idx)
valset = GLaSDataLoader((352, 512), dataset_repeat=1, images=val_set_idx, validation=True)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True, num_workers=10)
valloader = torch.utils.data.DataLoader(valset, batch_size=1, shuffle=False, num_workers=10)

# Plotting train data

In [None]:
fig, ax = plt.subplots(nrows=5, ncols=6, figsize=(24, 15))

for y in range(5):
    for x in range(3):
        sample = trainset[y]
        ax[y, x * 2].imshow(sample[0].numpy().transpose(1,2,0))
        ax[y, x * 2 + 1].imshow(sample[1][0])
        ax[y, x * 2].axis('off')
        ax[y, x * 2 + 1].axis('off')

plt.show();

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(24, 15))

sample = trainset[0]
ax[1].imshow(sample[1][0].numpy())
ax[2].imshow(sample[1].sum(dim=0))
ax[0].imshow(sample[0].numpy().transpose(1,2,0))

# Plotting validation data

In [None]:
fig, ax = plt.subplots(nrows=5, ncols=6, figsize=(24, 15))

for y in range(5):
    for x in range(3):
        sample = valset[y]
        ax[y, x * 2].imshow(sample[0].numpy().transpose(1,2,0))
        ax[y, x * 2 + 1].imshow(sample[1][1])
        ax[y, x * 2].axis('off')
        ax[y, x * 2 + 1].axis('off')

plt.show(); 

# Define network

In [None]:
device = torch.device('cuda')

if TRAIN_UNODE:
    net = ConvODEUNet(num_filters=16, output_dim=2, time_dependent=True, 
                      non_linearity='lrelu', adjoint=True, tol=1e-3)
    net.to(device)

In [None]:
if TRAIN_RESNET:
    net = ConvResUNet(num_filters=16, output_dim=2, non_linearity='lrelu')
    net.to(device)

In [None]:
if TRAIN_UNET:
    net = Unet(depth=5, num_filters=64, output_dim=2).cuda()
    net.to(device)

---

In [None]:
for m in net.modules():
    if isinstance(m, torch.nn.Conv2d):
        torch.nn.init.kaiming_normal_(m.weight)
        torch.nn.init.constant_(m.bias, 0)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
count_parameters(net)

# Train model

In [None]:
criterion = torch.nn.BCEWithLogitsLoss()
val_criterion = torch.nn.BCEWithLogitsLoss()

if TRAIN_UNET:
    cross_entropy = torch.nn.BCEWithLogitsLoss()

    def criterion(conf, labels):
        out_shape = conf.shape[2:4]
        label_shape = labels.shape[2:4]

        w = (label_shape[1] - out_shape[1]) // 2
        h = (label_shape[1] - out_shape[1]) // 2
        dh, dw = out_shape[0:2]

        conf_loss_ce = cross_entropy(conf, labels[:, :, h:h+dh, w:w+dw])

        return conf_loss_ce

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)

In [None]:
torch.backends.cudnn.benchmark = True

In [None]:
losses = []
val_losses = []
nfe = [[],[],[],[],[],[],[],[],[]] if TRAIN_UNODE else None

In [None]:
accumulate_batch = 8  # mini-batch size by gradient accumulation
accumulated = 0

if TRAIN_RESNET: filename = 'best_border_resnet_model.pt'
elif TRAIN_UNODE: filename = 'best_border_unode_model.pt'
elif TRAIN_UNET: filename = 'best_border_unet_model.pt'

def run(lr=1e-3, epochs=100):
    accumulated = 0
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    for epoch in range(epochs):
        
        # training loop with gradient accumulation
        running_loss = 0.0
        optimizer.zero_grad()
        for data in tqdm(trainloader):
            inputs, labels = data[0].cuda(), data[1].cuda()
            outputs = net(inputs)
            loss = criterion(outputs, labels) / accumulate_batch
            loss.backward()
            accumulated += 1
            if accumulated == accumulate_batch:
                optimizer.step()
                optimizer.zero_grad()
                accumulated = 0

            running_loss += loss.item() * accumulate_batch

        losses.append(running_loss / len(trainloader))
        
        # validation loop
        with torch.no_grad():
            running_loss = 0.0
            for data in valloader:
                inputs, labels = data[0].cuda(), data[1].cuda()
                outputs = net(inputs)
                loss = criterion(outputs, labels)
                running_loss += loss.item()

            val_losses.append(running_loss / len(valloader))
            if np.argmin(val_losses) == len(val_losses) - 1 and loss < 0.4:
                torch.save(net, filename)
                
            clear_output(wait=True)
            plot_losses(inputs, outputs, losses, val_losses, get_title(), nfe, net=net)

In [None]:
if TRAIN_UNODE or TRAIN_RESNET: lr = 1e-3 
else: lr = 1e-4

run(lr, 600 - len(losses))

# Calculate results

In [None]:
# load best model
net = torch.load(filename)

In [None]:
with torch.no_grad():
    running_loss = 0.0
    for data in tqdm(valloader):
        inputs, labels = data[0].cuda(), data[1].cuda()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        running_loss += loss.item()

    print("Check validation loss:", running_loss / len(valloader))

# Visualize results on validation set

In [None]:
from inference_utils import inference_image, postprocess

In [None]:
fig, ax = plt.subplots(nrows=5, ncols=3, figsize=(4*3,3*5))

ax[0, 0].set_title('Image')
ax[0, 1].set_title('Ground-truth')
ax[0, 2].set_title(get_title())

for col in range(3):
    for row in range(5):
        index = val_set_idx[row]
        image = PIL.Image.open(f'Warwick QU Dataset (Released 2016_07_08)/train_{index}.bmp')
        gt = PIL.Image.open(f'Warwick QU Dataset (Released 2016_07_08)/train_{index}_anno.bmp')
        
        with torch.no_grad():
            result, input_image = inference_image(net, image, shouldpad=TRAIN_UNET)
            result = postprocess(result, gt)
        if col == 0:
            ax[row, col].imshow(image)
        elif col == 1:
            ax[row, col].imshow(np.array(gt) > 0)
        else:
            ax[row, col].imshow(image)
            ax[row, col].imshow(result, alpha=0.5)
                
        ax[row, col].set_axis_off()

plt.show(); 

# Calculate metrics on test set

In [None]:
from metrics import ObjectDice, ObjectHausdorff, F1score

In [None]:
TEST_RESNET = False
TEST_UNODE = True
TEST_UNET = False

In [None]:
if TEST_UNODE: net = torch.load('best_border_unode_model.pt')
elif TEST_RESNET: net = torch.load('best_border_resnet_model.pt')
elif TEST_UNET: net = torch.load('best_border_unet_model.pt')

In [None]:
# USE THIS TO LOAD WEIGHTS FROM PAPER
#

# if TEST_UNODE: 
#     net = ConvODEUNet(num_filters=16, output_dim=2, time_dependent=True, 
#                       non_linearity='lrelu', adjoint=True, tol=1e-3)
#     net = net.cuda()
#     state_dict = torch.load('best_border_unode_paper.pt')
#     net.load_state_dict(state_dict)
    
# if TEST_RESNET:
#     net = ConvResUNet(num_filters=16, output_dim=2, non_linearity='lrelu')
#     net = net.cuda()
#     state_dict = torch.load('best_border_resnet_paper.pt')
#     net.load_state_dict(state_dict)

# if TEST_UNET:
#     net = Unet(depth=5, num_filters=64, output_dim=2).cuda()
#     net = net.cuda()
#     state_dict = torch.load('best_border_unet_paper.pt')
#     net.load_state_dict(state_dict)

In [None]:
dice, hausdorff, f1, dice_full = 0, 0, 0, 0

if TEST_UNODE: folder = 'results_unode'
elif TEST_UNET: folder = 'results_unet'
elif TEST_RESNET: folder = 'results_resnet'
    
images = []
for index in np.arange(1, 81):
    if index < 61: images.append(f'testA_{index}_anno.bmp')
    else: images.append(f'testB_{index - 60}_anno.bmp')
        
for i, fname in tqdm_notebook(enumerate(images), total=80):
    gt = PIL.Image.open(f'Warwick QU Dataset (Released 2016_07_08)/' + fname)
    image = PIL.Image.open(f'Warwick QU Dataset (Released 2016_07_08)/' + fname.replace('_anno', ''))
    
    result, resized = inference_image(net, image, shouldpad=TEST_UNET)
    result = postprocess(result, gt)

    gt = skimage.measure.label(np.array(gt))
    
    f1_img = F1score(result, gt)
    hausdorff_img = ObjectHausdorff(result, gt)
    dice_img = ObjectDice(result, gt)
    
    f1 += f1_img
    hausdorff += hausdorff_img
    dice += dice_img

    if i == 59:        
        diceA = dice 
        hausdorffA = hausdorff 
        f1A = f1

    print(i, f1_img, hausdorff_img, dice_img)

diceB = dice - diceA
hausdorffB = hausdorff - hausdorffA
f1B = f1 - f1A

print('ObjectDice:', dice / 80, 'A', diceA / 60, 'B', diceB / 20)
print('Hausdorff:', hausdorff / 80, 'A', hausdorffA / 60, 'B', hausdorffB / 20)
print('F1:', f1 / 80, 'A', f1A / 60, 'B', f1B / 20)