In [0]:
from os import path
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
if path.exists('/opt/bin/nvidia-smi'):
  !pip install http://download.pytorch.org/whl/cu80/torch-0.4.0-cp36-cp36m-linux_x86_64.whl torchvision
  !pip install dotted pyfastnoisesimd tqdm Pillow==4.0.0 PIL image
  !wget -nc https://warwick.ac.uk/fac/sci/dcs/research/tia/glascontest/download/warwick_qu_dataset_released_2016_07_08.zip -O warick.zip
  !unzip -q -o warick.zip
  !mv 'Warwick QU Dataset (Released 2016_07_08)' warick_data
else:
  print('Select GPU backend')

In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms.functional as F_img
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
from dotted.collection import DottedDict
from pyfastnoisesimd import generate
from PIL import Image
from glob import glob
import numpy as np
import pickle
import random
from random import uniform, randint
import tqdm

def conv(in_c, out_c):
  return nn.Sequential(
    nn.Conv2d(in_c, out_c, 3, padding=1, bias=False),
    nn.ELU(inplace=True),
    nn.BatchNorm2d(out_c),
    nn.Conv2d(out_c, out_c, 3, padding=1, bias=False),
    nn.ELU(inplace=True),
    nn.BatchNorm2d(out_c),
  )

class UNet512(nn.Module):
  def __init__(self):
    super(UNet512, self).__init__()
    self.down1 = conv(  3,  16) # (  3, 512, 512) --> ( 16, 512, 512)
    self.down2 = conv( 16,  32) # ( 16, 256, 256) --> ( 32, 256, 256)
    self.down3 = conv( 32,  64) # ( 32, 128, 128) --> ( 64, 128, 128)
    self.down4 = conv( 64, 128) # ( 64,  64,  64) --> (128,  64,  64)
    self.down5 = conv(128, 256) # (128,  32,  32) --> (256,  32,  32)
    self.down6 = conv(256, 512) # (256,  16,  16) --> (512,  16,  16)
    self.up1   = conv(768, 256) # (768,  32,  32) --> (256,  32,  32)
    self.up2   = conv(384, 128) # (384,  64,  64) --> (128,  64,  64)
    self.up3   = conv(192,  64) # (192, 128, 128) --> ( 64, 128, 128)
    self.up4   = conv( 96,  32) # ( 32, 256, 256) --> ( 16, 256, 256)
    self.up5   = conv( 48,  16) # ( 32, 512, 512) --> ( 16, 512, 512)
    self.tail  = nn.Conv2d(16, 1, 1)
    self.downpool = nn.MaxPool2d(kernel_size=2)
    self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

  def forward(self, x):
    x_down_512 = self.down1(x)
    x_down_256 = self.down2(self.downpool(x_down_512))
    x_down_128 = self.down3(self.downpool(x_down_256))
    x_down_64  = self.down4(self.downpool(x_down_128))
    x_down_32  = self.down5(self.downpool(x_down_64))
    x_down_16  = self.down6(self.downpool(x_down_32))
    x_up = self.up1(torch.cat([self.upsample(x_down_16), x_down_32], dim=1))
    x_up = self.up2(torch.cat([self.upsample(x_up), x_down_64], dim=1))
    x_up = self.up3(torch.cat([self.upsample(x_up), x_down_128], dim=1))
    x_up = self.up4(torch.cat([self.upsample(x_up), x_down_256], dim=1))
    x_up = self.up5(torch.cat([self.upsample(x_up), x_down_512], dim=1))
    return self.tail(x_up)
  
class UNet256(nn.Module):
  def __init__(self):
    super(UNet256, self).__init__()
    self.down1 = conv(  3,  16) # (  3, 256, 256) --> ( 16, 256, 256)
    self.down2 = conv( 16,  32) # ( 16, 128, 128) --> ( 32, 128, 128)
    self.down3 = conv( 32,  64) # ( 32,  64,  64) --> ( 64,  64,  64)
    self.down4 = conv( 64, 128) # ( 64,  64,  64) --> (128,  32,  32)
    self.down5 = conv(128, 256) # (128,  16,  16) --> (256,  16,  16)
    self.down6 = conv(256, 512) # (256,   8,   8) --> (512,   8,   8)
    self.up1   = conv(768, 256) # (768,  16,  16) --> (256,  16,  16)
    self.up2   = conv(384, 128) # (384,  32,  32) --> (128,  32,  32)
    self.up3   = conv(192,  64) # (192,  64,  64) --> ( 64,  64,  64)
    self.up4   = conv( 96,  32) # ( 32, 128, 128) --> ( 16, 128, 128)
    self.up5   = conv( 48,  16) # ( 32, 256, 256) --> ( 16, 256, 256)
    self.tail  = nn.Conv2d(16, 1, 1)
    self.downpool = nn.MaxPool2d(kernel_size=2)
    self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

  def forward(self, x):
    x_down_512 = self.down1(x)
    x_down_256 = self.down2(self.downpool(x_down_512))
    x_down_128 = self.down3(self.downpool(x_down_256))
    x_down_64  = self.down4(self.downpool(x_down_128))
    x_down_32  = self.down5(self.downpool(x_down_64))
    x_down_16  = self.down6(self.downpool(x_down_32))
    x_up = self.up1(torch.cat([self.upsample(x_down_16), x_down_32], dim=1))
    x_up = self.up2(torch.cat([self.upsample(x_up), x_down_64], dim=1))
    x_up = self.up3(torch.cat([self.upsample(x_up), x_down_128], dim=1))
    x_up = self.up4(torch.cat([self.upsample(x_up), x_down_256], dim=1))
    x_up = self.up5(torch.cat([self.upsample(x_up), x_down_512], dim=1))
    return self.tail(x_up)

class WarickData(Dataset):
  def __init__(self, img_glob, img_size=256):
    files = []
    for file in glob(img_glob):
      path_split = file.split('/')
      post_split = path_split[-1].split('_')
      if len(post_split) == 3:
        img_file = f'{path_split[0]}/{post_split[0]}_{post_split[1]}.bmp'
        mask_file = f'{path_split[0]}/{post_split[0]}_{post_split[1]}_anno.bmp'
        files.append((img_file, mask_file))
    self.files = files
    self.img_mean, self.img_std = [200.248, 131.253, 199.778], [41.787, 62.667, 32.977]
    self.mask_mean, self.mask_std = 2.512, 4.168
    self.img_size = img_size

    
  def __len__(self):
    return len(self.files)
  
  def __getitem__(self, idx):
    img = Image.open(self.files[idx][0]) 
    img_mask = F_img.to_grayscale(Image.open(self.files[idx][1]))
    
    if random.random() > .5:
      img = F_img.hflip(img)
      img_mask = F_img.hflip(img_mask)
    if random.random() > .5:
      img = F_img.vflip(img)
      img_mask = F_img.vflip(img_mask)
    
    img = F_img.resize(img, [self.img_size, self.img_size])
    img_mask = F_img.resize(img_mask, [self.img_size, self.img_size])
    return F_img.normalize(F_img.to_tensor(img), self.img_mean, self.img_std), F_img.to_tensor(img_mask)

In [0]:
def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)

def plot_metrics(metric_logs, ylim=None):
  for metric_log in metric_logs:
    plt.plot(metric_log['epoch'], metric_log['metric'], label=metric_log['label'])
  plt.xlabel('Epochs')
  plt.ylabel('Metric')
  if ylim is not None: plt.ylim(*ylim)
  plt.legend()
  plt.show()

def dice(pred, targs):
    pred = (pred>0).float()
    return 2. * (pred*targs).sum() / (pred+targs).sum()
  
def dice_loss(input, target):
    smooth = 1.

    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()

    return 1.0 - (((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth)))
  
def train(model, epochs=1):
  for e in tqdm.trange(epochs, desc='epochs'):
    metric = 0
    samples_seen = 0
    model.net.train()
    for img, mask in model.loader:
      model.img_cuda.copy_(img)
      del img
      model.mask_cuda.copy_(mask)
      del mask
      model.optimizer.zero_grad()
      prediction = model.net(model.img_cuda)
      loss = dice_loss(F.sigmoid(prediction), model.mask_cuda) * 100
      metric += loss.item()
      samples_seen += model.batch_size
      loss.backward()
      model.optimizer.step()
      model.scheduler.step()
    model.train_metric_log['epoch'].append(model.epochs_trained)
    model.train_metric_log['metric'].append(metric / samples_seen)
    
    if model.epochs_trained % model.eval_test == 0:
      metric = 0
      samples_seen = 0
      model.net.eval()
      with torch.no_grad():
        for img, mask in model.loader_test:
          model.img_cuda.copy_(img)
          del img
          model.mask_cuda.copy_(mask)
          del mask
          prediction = model.net(model.img_cuda)
          loss = dice_loss(F.sigmoid(prediction), model.mask_cuda) * 100
          metric += loss.item()
          samples_seen += model.batch_size
      model.test_metric_log['epoch'].append(model.epochs_trained)
      model.test_metric_log['metric'].append(metric / samples_seen)
    model.epochs_trained += 1

def generate_noise():
  return generate(size=[1, 256, 256], noiseType='Perlin',
                  freq=uniform(.001, .05), seed=randint(0, 100000))[0]
    
def interpolate(a, b, f):
  return (a * (1.0 - f)) + (b * f)
    
def train_mixup(model, epochs=1):
  for e in tqdm.trange(epochs, desc='epochs'):
    metric = 0
    samples_seen = 0
    model.net.train()
    for (img_a, mask_a), (img_b, mask_b) in zip(model.loader, model.loader_other):
      mixup_lerp = random.random()
      model.img_cuda.copy_(interpolate(img_a, img_b, mixup_lerp))
      del img_a; del img_b
      model.mask_cuda.copy_(interpolate(mask_a, mask_b, mixup_lerp))
      del mask_a; del mask_b
      model.optimizer.zero_grad()
      prediction = model.net(model.img_cuda)
      loss = dice_loss(F.sigmoid(prediction), model.mask_cuda) * 100
      metric += loss.item()
      samples_seen += model.batch_size
      loss.backward()
      model.optimizer.step()
      model.scheduler.step()
    model.train_metric_log['epoch'].append(model.epochs_trained)
    model.train_metric_log['metric'].append(metric / samples_seen)
    
    if model.epochs_trained % model.eval_test == 0:
      metric = 0
      samples_seen = 0
      model.net.eval()
      with torch.no_grad():
        for img, mask in model.loader_test:
          model.img_cuda.copy_(img)
          del img
          model.mask_cuda.copy_(mask)
          del mask
          prediction = model.net(model.img_cuda)
          loss = dice_loss(F.sigmoid(prediction), model.mask_cuda) * 100
          metric += loss.item()
          samples_seen += model.batch_size
      model.test_metric_log['epoch'].append(model.epochs_trained)
      model.test_metric_log['metric'].append(metric / samples_seen)
    model.epochs_trained += 1

def train_noisy_mixup(model, epochs=1):
  for e in tqdm.trange(epochs, desc='epochs'):
    metric = 0
    samples_seen = 0
    model.net.train()
    for (img_a, mask_a), (img_b, mask_b) in zip(model.loader, model.loader_other):
      mixup_lerp = torch.tensor(generate_noise())
      model.img_cuda.copy_(interpolate(img_a, img_b, mixup_lerp))
      del img_a; del img_b
      model.mask_cuda.copy_(interpolate(mask_a, mask_b, mixup_lerp))
      del mask_a; del mask_b
      model.optimizer.zero_grad()
      prediction = model.net(model.img_cuda)
      loss = dice_loss(F.sigmoid(prediction), model.mask_cuda) * 100
      metric += loss.item()
      samples_seen += model.batch_size
      loss.backward()
      model.optimizer.step()
      model.scheduler.step()
    model.train_metric_log['epoch'].append(model.epochs_trained)
    model.train_metric_log['metric'].append(metric / samples_seen)
    
    if model.epochs_trained % model.eval_test == 0:
      metric = 0
      samples_seen = 0
      model.net.eval()
      with torch.no_grad():
        for img, mask in model.loader_test:
          model.img_cuda.copy_(img)
          del img
          model.mask_cuda.copy_(mask)
          del mask
          prediction = model.net(model.img_cuda)
          loss = dice_loss(F.sigmoid(prediction), model.mask_cuda) * 100
          metric += loss.item()
          samples_seen += model.batch_size
      model.test_metric_log['epoch'].append(model.epochs_trained)
      model.test_metric_log['metric'].append(metric / samples_seen)
    model.epochs_trained += 1
    
def pickle_history(model, file):
  pickle.dump({
    'train_metric_log': model['train_metric_log'],
    'test_metric_log': model['test_metric_log'],
  }, open(file, 'wb'))

In [0]:
train_data, test_data = WarickData('warick_data/train*anno.bmp'), WarickData('warick_data/test*anno.bmp')

model = DottedDict()
model['seed'] = 53
set_seed(model.seed)
model['batch_size'] = 21
model['net'] = UNet256().cuda()
model['optimizer'] = optim.Adam(model.net.parameters(), lr=0.05)
model['scheduler'] = optim.lr_scheduler.CosineAnnealingLR(model.optimizer, len(train_data) * 1000 / model.batch_size, .00005)
model['loader'] = DataLoader(train_data, model.batch_size, True, drop_last=True, pin_memory=True)
model['loader_other'] = DataLoader(train_data, model.batch_size, True, drop_last=True, pin_memory=True)
model['loader_test'] = DataLoader(test_data, model.batch_size, False, drop_last=True, pin_memory=True)
model['img_cuda'] = torch.empty([model.batch_size, 3, 256, 256]).cuda()
model['mask_cuda'] = torch.empty([model.batch_size, 1, 256, 256]).cuda()
model['epochs_trained'] = 0
model['eval_test'] = 1
model['train_metric_log'] = {'label':'Train Loss', 'epoch':[], 'metric':[]}
model['test_metric_log'] = {'label':'Test Loss', 'epoch':[], 'metric':[]}

In [0]:
train(model, 200)

In [0]:
with torch.no_grad():
  model.net.eval()
  batch = next(iter(model.loader))

  predicted_masks = model.net(batch[0].cuda()).cpu()
  img = (np.moveaxis(batch[0][0].numpy(), 0, 2) * 
         np.array([41.787, 62.667, 32.977]) + 
         np.array([200.248, 131.253, 199.778]))
  mask, mask_prediction = batch[1][0].numpy(), predicted_masks[0].numpy()

plt.imshow(img)
plt.imshow(mask[0], cmap='inferno', alpha=.5)
plt.show()
plt.imshow(img)
plt.imshow(mask_prediction[0], cmap='inferno', alpha=.5)
plt.show()

In [0]:
plot_metrics([model.train_metric_log, model.test_metric_log], (0, 3))

In [0]:
from google.colab import files

hist_file = f'standard_seed53_res256_epoch{model.epochs_trained}.hist'

pickle_history(model.to_python(), hist_file)
files.download(hist_file)

In [0]:
model_file = f'standard_seed53_res256_epoch{model.epochs_trained}.model'

torch.save(model.to_python(), model_file)
files.download(model_file)