In [None]:
from google.colab import drive
drive.mount('/content/drive')

**Set-up**

In [None]:
!pip install torchgeometry

import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
import random
from sklearn import model_selection
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchgeometry.losses import dice_loss
import torchvision
import torchvision.transforms.functional as TF

In [None]:
# The original paper set aside images which came from the following labs 
# as their test data. We follow their example.
testing_sites = ['OL', 'LL', 'C8', 'BH', 'AR', 'A7', 'A1']

# The dataset is loaded from google drive
rgbs_dir = '/content/drive/MyDrive/0_Public-data-Amgad2019_0.25MPP/rgbs_colorNormalized'
masks_dir = '/content/drive/MyDrive/0_Public-data-Amgad2019_0.25MPP/masks'

# Splitting training and testing data
training_names = []
test_names = []

# The rgbs and masks have the exact same filenames, so we do not need to treat
# them separately
for filename in os.listdir(rgbs_dir):
    # The lab location exists in the filename at the fifth and sixth characters.
    if filename[5:7] in testing_sites:
      test_names += [filename]
    else:
      training_names += [filename]

val_size = 22
train_size = 60

train_names, val_names = model_selection.train_test_split(training_names,
                                                          train_size = train_size,
                                                          test_size = val_size,
                                                          random_state = 7)

**Model**

In [None]:
model = torchvision.models.segmentation.deeplabv3_resnet50(num_classes = 5, 
                                                     pretrained_backbone = True)
device = torch.device("cuda:0")
model.to(device)

# There are two children in the DeepLab model, the encoder and the decoder. For
# finetuning, we freeze the encoder layers
for i, child in enumerate(model.children()):
  if i == 0:
    for param in child.parameters():
      param.requires_grad = False

**Auxiliary Code**

In [None]:
# The masks in the dataset are annotated with labels from 1 to 21. The original
# paper specifies how to separate these classes into the five classes used for 
# our purposes.

def change_class_vals(mask):
   
  for i in range(1, 22):

    if i == 1 or i == 19 or i == 20:
      mask[mask == i] = 0
    elif i == 2:
      mask[mask == i] = 1
    elif i == 3 or i == 10 or i == 11 or i == 14:
      mask[mask == i] = 2
    elif i == 4:
      mask[mask == i] = 3
    else:
      mask[mask == i] = 4

  return mask


In [None]:
# We ran this code once to determine the mean and std of each channel
# for our training data set

'''
class CalcForNorm(torch.utils.data.Dataset):

    def __init__(self, filenames):

        self.filenames = filenames
    
    def __len__(self):

        return len(self.filenames) 
  
    def __getitem__(self, index):

        rgb_name = os.path.join(rgbs_dir, self.filenames[index])
        rgb = TF.to_tensor(Image.open(rgb_name))

        return rgb
    
dataset = CalcForNorm(training_names)
data_loader = torch.utils.data.DataLoader(dataset, 
                                          batch_size=1, 
                                          num_workers=1, 
                                          shuffle=False)

channel_sums = [0] * 3
channel_squared_sums = [0] * 3
total_pixels = 0

for rgb in data_loader:
  rgb = rgb.squeeze(0)
  total_pixels += rgb.shape[1] * rgb.shape[2]

  for i in range(3):
    channel_sums[i] += torch.sum(rgb[i,:,:])
    channel_squared_sums[i] += torch.sum(rgb[i,:,:] ** 2)

means = [0] * 3
stds = [0] * 3

for i in range(3):
   means[i] = channel_sums[i] / total_pixels
   stds[i] = (channel_squared_sums[i] / total_pixels - means[i]**2) ** 0.5

print(means, stds)
'''

In [None]:
# We ran this code once to determine the class weights for our cross
# entropy loss function. Classes are weighted according to a formula in the 
# original paper. Less common classes receive more weight.

'''
weights = [0] * 5
# N is the total amount of pixels in the dataset
N = 0

# Find the total number of pixels of each class in the dataset
for filename in train_names:
  mask = cv2.imread(os.path.join(masks_dir, filename))[:,:,0]
  mask = change_class_vals(mask)
  N += np.size(mask)
  weights[0] += np.size(mask[mask == 0])
  weights[1] += np.size(mask[mask == 1])
  weights[2] += np.size(mask[mask == 2])
  weights[3] += np.size(mask[mask == 3])
  weights[4] += np.size(mask[mask == 4])

for i in range(len(weights)):
  weights[i] = 1 - weights[i] / N

print(weights)
'''

**Loss and Optimizer**

In [None]:
weights = [0.4905573355962124, 0.695194226695798, 0.9165659564679096, 0.9391179214116163, 0.9585645598284636]

criterion = nn.CrossEntropyLoss(torch.FloatTensor(weights).to(device))
#criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
#                                                      threshold = 0.01, 
#                                                      patience = 3)

**Data**

In [None]:
# These are the channel means and stds of our dataset,
# used for normalization when training from scratch
# means = [0.7395, 0.5253, 0.7026]
# stds = [0.1912, 0.2321, 0.1716]

# These are the channel means and stds of the dataset used to pretrain our models,
# used for normalization when fine-tuning
means = [0.485, 0.456, 0.406]
stds = [0.229, 0.224, 0.225]

class MakeDataset(Dataset):

  def __init__(self, filenames, training):

        self.training = training
        self.filenames = filenames

  def __len__(self):
      
      return len(self.filenames)

  # Basic data augmentation for training data
  def transform_train(self, rgb, mask):

        # Random crop
        i, j, h, w = torchvision.transforms.RandomCrop.get_params(
            rgb, output_size = (768, 768))
        rgb = TF.crop(rgb, i, j, h, w)
        mask = TF.crop(mask, i, j, h, w)

        # Random horizontal flipping
        if random.random() > 0.5:
            rgb = TF.hflip(rgb)
            mask = TF.hflip(mask)

        # Random vertical flipping
        if random.random() > 0.5:
            rgb = TF.vflip(rgb)
            mask = TF.vflip(mask)

        # Transform to tensor and normalize the rgb
        rgb = TF.to_tensor(rgb)
        rgb = TF.normalize(rgb, means, stds)

        mask = TF.to_tensor(mask)

        return rgb, mask
    
  def __getitem__(self, index):
      
        rgb_name = os.path.join(rgbs_dir, self.filenames[index])
        mask_name = os.path.join(masks_dir, self.filenames[index])
        
        rgb = Image.open(rgb_name)
        mask = Image.open(mask_name).convert('I;16')
        
        if self.training :

          rgb, mask = self.transform_train(rgb, mask)
          mask = change_class_vals(mask.squeeze(0)).long()

          return {'rgb': rgb, 'mask': mask}

        else :

          rgb = TF.to_tensor(TF.center_crop(rgb, output_size = (768, 768)))
          rgb = TF.normalize(rgb, means, stds)

          mask = TF.to_tensor(TF.center_crop(mask, output_size = (768, 768)))
          mask = change_class_vals(mask.squeeze(0)).long()

          return {'rgb': rgb, 'mask': mask}

train_data = MakeDataset(train_names, training = True)
val_data = MakeDataset(val_names, training = False)

batch_size = 12

train_loader = DataLoader(dataset = train_data, batch_size = batch_size, 
                          shuffle = True, num_workers = 2)
val_loader = DataLoader(dataset = val_data, batch_size = batch_size, 
                        shuffle = True, num_workers = 2)

**Training**

In [None]:
def pixel_acc(pred, truth):
  pred = torch.argmax(F.softmax(pred, dim=1), dim = 1).squeeze(1)
  acc = 0

  for i in range(len(pred)):
    acc += (pred[i] == truth[i]).sum() / torch.numel(pred[i])

  return acc

In [None]:
epochs = 40

train_running_loss_history = []
validation_running_loss_history = []

for epoch in range(epochs):

  # We track the loss, the dice loss, and the pixelwise 
  # accuracy for each epoch
  train_loss_running = 0.0
  val_loss_running = 0.0
  train_dice_running = 0.0
  val_dice_running = 0.0
  train_pixel_running = 0.0
  val_pixel_running = 0.0

  model.train()
 
  for i, batch in enumerate(train_loader):
    x_train = batch['rgb'].to(device)
    y_train = batch['mask'].to(device)

    optimizer.zero_grad()
    y_pred = model(x_train)['out']

    # loss = 0.3 * dice_loss(y_pred, y_train) + 0.7 * criterion(y_pred, y_train)
    loss = criterion(y_pred, y_train)
    loss.backward()
    optimizer.step()
    
    train_loss_running += loss.item() * len(y_pred)
    train_dice_running += dice_loss(y_pred, y_train) * len(y_pred)
    train_pixel_running += pixel_acc(y_pred, y_train)
     
  with torch.no_grad():
      
    model.eval()
      
    for ith_batch, sample_batched in enumerate(val_loader):
        x_val = sample_batched['rgb'].to('cuda')
        y_val = sample_batched['mask'].to('cuda')
          
        y_pred = model(x_val)['out']

        # val_loss = 0.3 * dice_loss(y_pred, y_val) + 0.7 * criterion(y_pred, y_val)
        val_loss = criterion(y_pred, y_val)
        val_loss_running += val_loss.item() * len(y_pred)
        val_dice_running += dice_loss(y_pred, y_val) * len(y_pred)
        val_pixel_running += pixel_acc(y_pred, y_val)

    print("================================================================================")
    print("Epoch {} completed".format(epoch + 1))
      
    train_loss = train_loss_running / train_size
    val_loss = val_loss_running / val_size
    train_dice_loss = train_dice_running / train_size
    val_dice_loss = val_dice_running / val_size
    train_pixel_acc = train_pixel_running / train_size
    val_pixel_acc = val_pixel_running / val_size
      
    print("Training loss: {}".format(train_loss))
    print("Training DICE loss: {}".format(train_dice_loss))
    print("Training Pixelwise Accuracy: {}".format(train_pixel_acc))
    print("Validation loss: {}".format(val_loss))
    print("Validation DICE loss: {}".format(val_dice_loss))
    print("Validation Pixelwise Accuracy: {}".format(val_pixel_acc))
    print("================================================================================")
    train_running_loss_history.append(train_loss)
    validation_running_loss_history.append(val_loss)
  
  #scheduler.step(train_loss)

  torch.cuda.empty_cache()

**Evaluation**

In [None]:
def pixel_class_acc(pred, truth, n):
  pred = torch.argmax(F.softmax(pred, dim=1), dim = 1).squeeze(1)
  acc = 0
  num_cases = len(pred)

  for i in range(len(pred)):
    class_size = torch.numel(truth[i][truth[i] == n])
    # If there are no pixels of the given class in a mask, we need to discount
    # that mask when we compute pixel accuracy 
    if class_size > 0 :
      matches = pred[i][pred[i] == truth[i]]
      acc += (matches == n).sum() / class_size
    else :
      num_cases -= 1

  return (acc, num_cases)

In [None]:
test_size = 69

test_data = MakeDataset(test_names, training = False)
test_loader = DataLoader(dataset = test_data, batch_size = batch_size, 
                         shuffle = True, num_workers = 2)




with torch.no_grad():
      
  model.eval()
      
  test_dice_running = 0.0
  test_pixel_running = 0.0
  class_pixel_running = [0] * 5
  num_cases_per_class = [0] * 5

  for ith_batch, sample_batched in enumerate(test_loader):
      x_test = sample_batched['rgb'].to('cuda')
      y_test = sample_batched['mask'].to('cuda')
          
      y_pred = model(x_test)['out']

      test_dice_running += dice_loss(y_pred, y_test) * len(y_pred)
      test_pixel_running += pixel_acc(y_pred, y_test)
      for i in range(5) :
        acc, num_cases = pixel_class_acc(y_pred, y_test, i)
        class_pixel_running[i] += acc
        num_cases_per_class[i] += num_cases

print("DICE Score: {}".format(1 - (test_dice_running / test_size)))
print("Pixelwise Accuracy: {}".format(test_pixel_running / test_size))
for i in range(5) :
  print("Pixelwise Accuracy for Class {}: {}".format(i, class_pixel_running[i] / 
                                                     num_cases_per_class[i]))