# **LOCAL ATTRIBUTE SEGMENTATION IN DERMOSCOPIC IMAGES**

## Introduction

Skin cancer is one of the most common types of cancer in the world. Early detection, by visual analysis of skin lesions, is the critical stage in order to prevent it and to successfully treat it. Dermoscopy is the technique of visually analyzing superficial skin lesions for medical evaluation, studying the color and shape of the mole, as well as the presence and typicality of certain local structures, such as: **milia-like cysts, negative/inverted networks, normal/pigmented networks, streaks and globules**. 



Due to the increasing capabilities of Machine Learning algorithms, medical diagnostic can nowadays be performed by convolutional neural networks (CNN), thanks to their state-of-the-art performance in pattern recognition tasks.


![alt txt](https://github.com/CesarCaramazana/DermoscopicSegmentation/blob/main/images/structures.PNG?raw=true)

Images sources: https://dermoscopedia.org


The objective of this notebook is the design and implementation of a semantic segmentation algorithm based on a Fully Convolutional Network (FCN) that identifies and classifies at pixel level the local structures (cysts, networks, streaks and globules) present in a pigmented skin lesion. 

## Model proposal: U-net (Inception v3) + ASPP

From the beginning, **U-Net** was considered as the starting point for our proposal. The original arquitecture was modified in order to fit the requirements of the ISIC dataset (output number of classes and input resolution), and Resnet101 was used as backbone (pre-trained in ImageNet). Additionally, the last volume of the encoder was incorporated an **Atrous Pyramid Pooling block (ASPP)**, from Deeplab v3, with dilation rates 2, 3 and 4, which slighyly improved the results. 
The last implementation carried out was the replacement of Resnet101 by **Inception v3**, which parallelizes convolutions and significantly reduces the number of parameters.

The following figure represents the arquitecture, which is implemented in the "Arquitecture" subsection below:

![alt txt](https://github.com/CesarCaramazana/DermoscopicSegmentation/blob/main/images/unet_inception.PNG?raw=true)

### Import libraries

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import argparse
import collections
import re

from tqdm import tqdm

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
from torch.autograd import Function
from torch.autograd import Variable
import torch.cuda.amp as amp

import statistics as stats
import random
import scipy.io
from PIL import Image
import cv2 as cv


from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve
from sklearn.metrics import auc
from sklearn.metrics import confusion_matrix


### Mount Google Drive

I have stored the dataset in my personal Drive account, but the ISIC DB 2017 is available here: https://challenge.isic-archive.com/data

It should be downloaded and unzipped in a folder named "db_isic", which containes subfolders with the data and text files (.txt) with the paths.

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import os
os.chdir('/content/drive/My Drive/') #Mount Drive in order to read the neccesary files (such as the dataset).

In [None]:
#We haven't used Tensorboard in this project due to an error that occured when we first tried to launch it with Colab.
#The problem, now solved, was that Tensorboard doesn't work if Third-party cookies are blocked!

%load_ext tensorboard

import tensorflow as tf
import datetime, os
from torch.utils.tensorboard import SummaryWriter

log_dir = "runs/outputs"
writer = SummaryWriter(log_dir=log_dir)

In [None]:
#Use GPU if available. (Runtime -> Change runtime type -> GPU)
use_gpu = torch.cuda.is_available()
device = torch.device("cuda:0" if use_gpu else "cpu")

### Auxiliary Functions

We define some auxiliary functions that are needed during various stages of the pipeline. 

In [None]:
def one_hot(mask):
  """
    Ground truth masks are not one-hot encoded in the original ISIC DB.
      input: [1, h, w] mask with indices [0, 6].
      output: [6, h, w] binary masks.
  """
  
  new = np.zeros((6,mask.shape[0],mask.shape[1]))
  new[0, (mask==1)] = 1 #Others
  new[1, (mask==2)] = 1 #Cysts  
  new[2, (mask==3)] = 1 #Neg network 
  new[3, (mask==4)] = 1 #Pigm network
  new[4, (mask==5)] = 1 #Streaks
  new[5, (mask==6)] = 1 #Globules
  
  new = new.transpose(1,2,0)
  return new

In [None]:
def lesion(mask):
  """
    This function generates a binary mask that separates background pixels (healthy skin) and the lesion.
    The mask is used to eliminate the maximum amount of background pixels when reading an image from the dataset.
  
      input: mask(h,w), indices[0,6]: Bg, Others, Cysts...
      output: mask'(h,w), indices[0,1]: Bg, Fg 
  """

  new = np.zeros(mask.shape)
  new[(mask==0)] = 0 
  new[(mask==1)] = 1 
  new[(mask==2)] = 1  
  new[(mask==3)] = 1  
  new[(mask==4)] = 1 
  new[(mask==5)] = 1 
  new[(mask==6)] = 1 

  return new

In [None]:
def get_segment_crop(img,tol=0, mask=None):
  """
  This function crops and image given a background-foreground binary mask. 
    inputs: image [c, h, w]
             mask [1, h, w]
    output: image [c, h', w'], where h'w' < hw        
  """
  if mask is None:
    mask = img > tol
  
  return img[np.ix_(mask.any(1), mask.any(0))]

In [None]:
def palette(mask):
  """
    This function defines a color palette for the plots, instead of using the default one.
  """

  red = np.zeros(mask.shape) #Red channel
  green = np.zeros(mask.shape) #Green channel
  blue = np.zeros(mask.shape) #Blue channel

  #Background (dark blue)
  red[np.where(mask==0)] = 0.08
  green[np.where(mask==0)] = 0.08
  blue[np.where(mask==0)] = 0.17

  #Foreground (bright grey)
  red[np.where(mask==1)] = 0.88
  green[np.where(mask==1)] = 0.97
  blue[np.where(mask==1)] = 0.98

  rgb = np.stack((red,green,blue)) #3, h, w
  rgb = rgb.transpose(1,2,0) #h, w, 3

  return rgb

In [None]:
def plot_images(y_score, labels, foreground):
  """
    This function plots the images of interest (prediction masks and ground truth label)
  """

  dictionary = ['Others', 'Cysts', 'Negative network', 'Pigment network', 'Streaks', 'Globules']
  y_score = torch.nn.functional.softmax(y_score, dim=0)
  
  x, y = torch.max(y_score, dim=0)
  y = one_hot(y.squeeze())
  y = y*foreground
  y = postprocessing(y)  
  
  plt.figure(figsize=(12, 6)) #Prediction
  for i in range(6):   
    plt.subplot(1,6,i+1)
    plt.title(dictionary[i])
    plt.imshow(palette(y[i])) #Show using color palette
  plt.show()
  
  
  plt.figure(figsize=(12, 6)) #Ground truth
  for i in range(6):
    plt.subplot(1,6,i+1)
    plt.title(dictionary[i])
    plt.imshow(palette(labels[i])) #Show using color palette
  plt.show() 



### Preprocessing

The preprocessing of the images consists in three operational blocks: 

1. Elimination of background pixels. 

2. Resize + Random low resolution cropping (128x128)*

3. Data augmentation (random flips and color jitter). 

![pre](https://github.com/CesarCaramazana/DermoscopicSegmentation/blob/main/images/preproc_pipeline.png?raw=True)

**This is important due to the memory limitations of Google Colab. If run in a local GPU, the resolution should be increased without compromising the batch size.*

In [None]:
#ImageNet mean/std 
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

#ISIC DB mean/std 
mean = [0.1769, 0.1479, 0.1367]
std = [0.0352, 0.0378, 0.0417]

#These operations are applied to the training set. 
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize([256,], interpolation=transforms.InterpolationMode.NEAREST),
     transforms.RandomCrop(size=(128)),
     #transforms.RandomRotation(90),
     transforms.RandomVerticalFlip(p=0.5), 
     transforms.RandomHorizontalFlip(p=0.5)
    ])

#These operations are applied to the validation and test set to evaluate the model.
transform_test = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize([256,], interpolation=transforms.InterpolationMode.NEAREST),
     transforms.RandomCrop(size=256)
    ])

#Data augmentation in the RGB channels
data_aug = transforms.Compose(
    [transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05),
     transforms.ToTensor(),
     transforms.Normalize(mean=mean, std=std),
     transforms.ToPILImage()
     
    ])

### ISIC Dataset class and Dataloaders

The dataset class is defined to fit the modified version of the ISIC DB, so it may not be directly applied to the vanilla version if downloaded from the ISIC website. The main difference is that the indices of the images/labels are stored in text files that are not included in the original database. 

ISIC DB is composed of 2750 high resolution dermoscopic images of pigmented skin lesions. Some examples of these are shown below:

![isic](https://github.com/CesarCaramazana/DermoscopicSegmentation/blob/main/images/isic.jpg?raw=True)

Regarding batch size: we use batch size=24 because of consistency among Colab sessions. The memory limitation doesn't allow a higher batch size, which is encouraged when working with Batch Normalization layers. 

In [None]:
class db_isic_Dataset(torch.utils.data.Dataset):
  def __init__(self, transform=None, data_aug=None, set='train'):

    self.transform = transform
    self.data_aug = data_aug
    self.subset = set
    self.dataDir = ''
    self.imgRoot = ''
    self.gtRoot = ''
    
    #In the folder where the dataset was stored, I defined .txt files with the indices of the images.
    #These indices have the format: "ISIC_0014349", without the file extension.
    
    if(self.subset == 'train'):
      self.imgRoot = '/content/drive/My Drive/db_isic/ISIC-2017_Training_Data/'
      self.gtRoot = '/content/drive/My Drive/db_isic/ISIC-2017_Training_Part2_GroundTruth/gtann/'
      self.dataDir = '/content/drive/My Drive/db_isic/train.txt'
    
    if(self.subset == 'val'):
      self.imgRoot = '/content/drive/My Drive/db_isic/ISIC-2017_Validation_Data/'
      self.gtRoot = '/content/drive/My Drive/db_isic/ISIC-2017_Validation_Part2_GroundTruth/gtann/'
      self.dataDir = '/content/drive/My Drive/db_isic/val.txt'

    if(self.subset == 'test'):
      self.imgRoot = '/content/drive/My Drive/db_isic/ISIC-2017_Test_v2_Data/'
      self.gtRoot = '/content/drive/My Drive/db_isic/ISIC-2017_Test_v2_Part2_GroundTruth/gtann/'
      self.dataDir = '/content/drive/My Drive/db_isic/test.txt'


    self.data = open(self.dataDir)
    self.paths = self.data.readlines()

   
  def __getitem__(self, index):
    imagePath = self.imgRoot + self.paths[index][:-1] + '.jpg'
    gtPath =  self.gtRoot + self.paths[index][:-1] + '.mat'

    image_t = Image.open(imagePath)
    mask = scipy.io.loadmat(gtPath)['smgt']

    foreground = lesion(mask) #Get binary foreground mask

    mask = get_segment_crop(mask, mask=foreground) #Apply foreground mask to input image and label
    image_t = get_segment_crop(np.asarray(image_t), mask=foreground) 
    foreground = get_segment_crop(foreground, mask=foreground)
    
    if self.data_aug: #Data augmentation
      image_t = self.data_aug(image_t)
    
    mask_t = one_hot(mask) #One-hot encoding   
        
    seed = np.random.randint(2147483647) # make a seed with numpy generator 
    
    if self.transform:
      random.seed(seed) #Apply the same random operations to mask and image
      torch.manual_seed(seed)
      image_t = self.transform(image_t)
      
      random.seed(seed)
      torch.manual_seed(seed)
      mask_t = self.transform(mask_t)

      random.seed(seed)
      torch.manual_seed(seed)
      foreground = self.transform(foreground)


    return image_t, mask_t, foreground

  def __len__(self):
    return len(self.paths)

In [None]:
#Dataloaders
trainset = db_isic_Dataset(transform=transform, data_aug=data_aug, set='train')
trainloader = torch.utils.data.DataLoader(trainset, batch_size=24, shuffle=True, num_workers=0) #Batch=32 or higher may generate memory usage problems due to the memory limit of Colab.

validationset = db_isic_Dataset(transform=transform_test, set='val')
validationloader = torch.utils.data.DataLoader(validationset, batch_size=2, shuffle=True)

testset = db_isic_Dataset(transform=transform_test, set='test')
testloader = torch.utils.data.DataLoader(testset, batch_size=2, shuffle=True)

data_loaders = {"train": trainloader, "val": validationloader}

### Arquitecture

In [None]:
class unet_inception(torch.nn.Module):
  def __init__(self):
    super().__init__()
    out_channels = 6

    #Downloads the pre-trained weights of an Inception v3 model
    self.inception = models.inception_v3(pretrained=True)
    
    #ENCODER
    self.l1 = nn.Sequential(*list(self.inception.children())[0:3])
    self.l2 = nn.Sequential(*list(self.inception.children())[3:6])
    self.l3 = nn.Sequential(*list(self.inception.children())[6:10])

    #ASPP layers
    self.a1 = nn.Sequential( #Dilation 2
        nn.Conv2d(in_channels= 192, out_channels=256, kernel_size=3, stride=1, dilation=2, padding=2),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU(),
        nn.Conv2d(in_channels= 256, out_channels=256, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU(),
        nn.Conv2d(in_channels= 256, out_channels=256, kernel_size=3, stride=2, padding=0),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU(),
    )
    self.a2 = nn.Sequential( #Dilation 2
        nn.Conv2d(in_channels= 192, out_channels=256, kernel_size=3, stride=1, dilation=3, padding=3),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU(),
        nn.Conv2d(in_channels= 256, out_channels=256, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU(),
        nn.Conv2d(in_channels= 256, out_channels=256, kernel_size=3, stride=2, padding=0),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU(),

    )
    self.a3 = nn.Sequential( #Dilation 2
        nn.Conv2d(in_channels= 192, out_channels=256, kernel_size=3, stride=1, dilation=4, padding=4),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU(),
        nn.Conv2d(in_channels= 256, out_channels=256, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU(),
        nn.Conv2d(in_channels= 256, out_channels=256, kernel_size=3, stride=2, padding=0),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU(),

    )

    #DECODER
    self.l5 = nn.Sequential(
        nn.ConvTranspose2d(in_channels=1056, out_channels=800, kernel_size=3, stride=2, padding=0, output_padding=1),
        nn.BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU(),
        nn.Conv2d(in_channels= 800, out_channels=512, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU(),
    )
    self.l6 = nn.Sequential(
        nn.ConvTranspose2d(in_channels=704, out_channels=704, kernel_size=7, stride=2, padding=0, output_padding=0),
        nn.BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU(),
        nn.Dropout2d(p=0.5),
        nn.Conv2d(in_channels= 704, out_channels=256, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU(),
        
    )
    self.l7 = nn.Sequential(
        nn.ConvTranspose2d(in_channels=320, out_channels=256, kernel_size=7, stride=2, padding=0, output_padding=1),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU(),
        nn.Conv2d(in_channels= 256, out_channels=256, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU(),
        nn.Conv2d(in_channels=256, out_channels=out_channels, kernel_size=1, stride=1, padding=0),
        nn.BatchNorm2d(out_channels), 
    )


  def forward(self, input):

    printShapes = 0 #For debugging purposes
    
    #Encoder--
    l1 = self.l1(input)
    if(printShapes): print("L1 shape: ", l1.shape)
    l2 = self.l2(l1)
    if(printShapes): print("L2 shape: ", l2.shape)
    l3 = self.l3(l2)
    if(printShapes): print("L3 shape: ", l3.shape)

    #Apply Atrous Spatial Pyramid Pooling
    a1 = self.a1(l2)
    a2 = self.a2(l2)
    a3 = self.a3(l2)

    if(printShapes): print("A1 shape: ", a1.shape)
    if(printShapes): print("A2 shape: ", a2.shape)
    if(printShapes): print("A3 shape: ", a3.shape)


    a = torch.cat((a1, a2, a3), dim=1) #Generate a single volume.
    if(printShapes): print("ASPP: ", a.shape)
    
    x = torch.cat((a, l3), dim=1) #Add layer 3
    if(printShapes): print("AL3: ", x.shape)


    #Decoder--
    x = self.l5(x)
    if(printShapes): print("L5 shape: ", x.shape)
    x = torch.cat((x, l2), dim=1) #Skip connection 1
    x = self.l6(x)
    if(printShapes): print("L6 shape: ", x.shape)
    x = torch.cat((x, l1), dim=1) #Skip connection 2
    x = self.l7(x)
    if(printShapes): print("L7 output: ", x.shape)

    return x

In [None]:
#Debug. Testing the shapes of the feature maps given an input resolution

model = unet_inception()

hw = 256
input  = torch.rand((2,3,hw,hw)) #Generate random 3-channel tensor (input image)

out = model(input)
print(out.shape) #(2, 6, hw, hw)


### Loss function: (Focal) Cross Entropy with class weights

The design of the loss function was approached with the class imbalance problem in mind. We use cross entropy with weights, and apply the Focal Loss operation (https://ieeexplore.ieee.org/document/8237586), although either one of the other should be enough to downweight easy examples.


The coefficients are calculated as the inverse number of samples and normalized. 

In [None]:
class Focal_CE_weights(nn.Module):
  """
    This loss function is a variation of the cross entropy loss function, which adds class weights (calculated before hand) and
    the Focal Loss modification. 
  """

  def __init__(self, **kwargs):
    super(Focal_CE_weights, self).__init__()
    self.kwargs = kwargs       

  def forward(self, inputs, targets, smooth=1, gamma=2):
    """
      inputs: raw output probabilities*. *nn.CrossEntropyLoss already incorporates LogSoftmax()
      targets: one-hot encoded Ground Truth mask.
          
      gamma: focal loss downweighting parameter.
    """      
       
    #FOCAL CROSS ENTROPY--- 
    _, targets = torch.max(targets, dim=1) #[b, 7, h, w] -> [b, 1, h, w] Undo one-hot encoding
   
    class_weights = torch.tensor([4.7e-3, 0.225, 0.264, 0.02, 0.348, 0.136]) #Calculated as the inverse number of samples and normalized
    class_weights = class_weights.to(device) #Weights to GPU

    BCE_loss = nn.CrossEntropyLoss(reduction='none', weight=class_weights)(inputs, targets)
        
    #Focal loss: easily classified pixels have a lesser contribution
    pt = torch.exp(-BCE_loss)
    F_loss = (1-pt)**gamma * BCE_loss
      
    return torch.mean(F_loss) 

### Postprocessing

The output masks are postprocessed with morphological operations (erosion + dilation) in order to reduce the amount of False Positives generated by the weighting of the loss function. Square structuring elements are used

In [None]:
def postprocessing(y):
  """
    This function applies morphological operations (erosion + dilation) in batch binary masks. 
    The structuring elements are squared. The size for each class was obtained by trial and error, maximizing the
    Jaccard score of the validation set during training. 
      input: output predicted masks [6, h, w]
      output: processed masks [6, h, w]
  """
  
  kernel_size = [1, 21, 21, 15, 5, 15] #[others, cysts, n.net, p.net, str, glob]
  for c in range(6):
    kernel = np.ones((kernel_size[c],kernel_size[c]), np.uint8)
    y[c] = cv.erode(y[c], kernel, iterations=1) #Erosion
    y[c] = cv.dilate(y[c], kernel, iterations=1) #Dilation

  return y

### Evaluation metrics: Area under the curve and Jaccard Score

In order to evaluate the quality of an image segmentation model the following metrics are used: area under the (ROC) curve, Jaccard score (IoU) and Confusion Matrix.

In [None]:
def get_auc_values(y_true, y_score, auc_list, auc_list_epoch):
  """
  Updates the values of the lists "auc_list" and "auc_list_epoch", which store the AUC values per class.
    inputs:
      y_true (tensor) [batch, classes, h, w]: ground truth one hot label. 
      y_score (tensor) [batch, classes, h, w]: output softmax predictions.     
  """
  y_true = y_true.detach()
  y_score = y_score.detach()

  batch, classes, h, w = y_true.shape

  for batch in range(batch):
    for c in range(classes):
      y_true_v = y_true[batch, c].view(-1, h*w).squeeze() #From [batch, c, h, w] to [h*w] (required for roc_curve())
      y_score_v = y_score[batch, c].view(-1, h*w).squeeze()

      if len(np.unique(y_true_v)) == 2: #Check if the class is present in the image.
        auc_c = roc_auc_score(y_true_v, y_score_v)
        auc_list[c].append(auc_c)
        auc_list_epoch[c].append(auc_c)

In [None]:
def get_iou(y_pred, y_true, foreground, iou_list):
  """
  Updates the values of the list "iou_list", which stores the jaccard scores per class.
    inputs:
      y_pred (tensor) [b, 6, h, w],  output probabilities.
      y_true (tensor) [b, 6, h, w], Ground Truth labels.
      foreground [b, h, w], binary foreground masks.
      iou_list [num_classes*[]], list that contains the lists with the IoU values
  """
  batch, classes, h, w = y_true.shape
  smooth = 1

  _, y_pred = torch.max(y_pred, dim=1) #Classify pixel in class with highest score

  y_pred = y_pred.numpy()
  y_true = y_true.numpy()
  foreground = foreground.numpy()


  for batch in range(batch):
    y = one_hot(y_pred[batch]) #One-hot output pixel classification (for IoU computation)
    y = y * foreground[batch] #Eliminate background (healthy skin) from calculation
    y = postprocessing(y) #Postprocess masks
    
    for c in range(classes):     
      intersection = (y[c] * y_true[batch, c]).sum() 
   
      a = (y[c] * y[c]).sum()  #Area of predicted mask
      b = (y_true[batch, c] * y_true[batch, c]).sum() #Area of ground truth mask   
      iou = intersection / (a+b-intersection + smooth) #IoU

      if((a+b) != 0): #Only consider if A or B are not empty
        iou_list[c].append(iou)


### Model and parameters

We initialize the weights of the ASPP and the decoder layers with Xavier (https://arxiv.org/abs/1704.08863). 
In order to compensate for the small batch size, we modify the momentum parameter of Batch Normalization layers. 

Then, establish the learning rate for each layer, the optimizer (Adam), the LR scheduler and the regularization (L2). 

In [None]:
def weights_init(m):
    """
      Defines how weights are initialized in Conv2d and ConvTranspose2d layers. 
      In this case, using Xavier gaussian distribution.
    """
    #Xavier weight initialization
    if isinstance(m, torch.nn.Conv2d):
      torch.nn.init.xavier_normal_(m.weight,1.0)
      if m.bias is not None:
          nn.init.constant_(m.bias.data, 0)
    if isinstance(m, torch.nn.ConvTranspose2d):
      torch.nn.init.xavier_normal_(m.weight, 1.0)
      if m.bias is not None:
          nn.init.constant_(m.bias.data, 0)

In [None]:
def bn_(m):
    """
      Reduce "momentum" parameter in BatchNorm2d layers which supposedly works better for small batch size.
    """
    classname = m.__class__.__name__
    if classname.find('BatchNorm2d') != -1:
        m.momentum = 0.008

In [None]:
def bn_eval(m):
    """
      Due to an error found during the development of the project, this function needed to be defined 
      to solve how BatchNorm2d layers work during evaluation. 
    """
    classname = m.__class__.__name__
    if classname.find('BatchNorm2d') != -1:
        m.track_running_stats = False
        m.train() #m.eval() doesn't work properly. The reason is still unknown.

In [None]:
#ARQUITECTURE
model = unet_inception()

#We define the parameters of each layer separately to control the Learning Rate and weight initialization.
params_e3 = [p for p in model.l3.parameters() if p.requires_grad] #L3 Inception
params_a1 = [p for p in model.a1.parameters() if p.requires_grad] #ASPP
params_a2 = [p for p in model.a2.parameters() if p.requires_grad]
params_a3 = [p for p in model.a3.parameters() if p.requires_grad]
params_d5 = [p for p in model.l5.parameters() if p.requires_grad] #Decoder
params_d6 = [p for p in model.l6.parameters() if p.requires_grad]
params_d7 = [p for p in model.l7.parameters() if p.requires_grad]

#Xavier init. DO NOT APPLY TO PRE-TRAINED LAYERS
model.a1.apply(weights_init)
model.a2.apply(weights_init)
model.a3.apply(weights_init)
model.l5.apply(weights_init)
model.l6.apply(weights_init)
model.l7.apply(weights_init)

#BatchNorm2d momentum
model.apply(bn_)

model.to(device) #To GPU or CPU

#Optimizer: Adam
optimizer = optim.Adam([{'params': params_e3, 'lr': 1e-5}, #Fine-tuning, layer 3
                        {'params': params_a1, 'lr': 1e-3},
                        {'params': params_a2, 'lr': 1e-3},
                        {'params': params_a3, 'lr': 1e-3},
                        {'params': params_d5, 'lr': 1e-3},
                        {'params': params_d6, 'lr': 1e-3},                         
                        {'params': params_d7, 'lr': 1e-3}
                        ], lr=1e-3, weight_decay=1e-6) #L2 regularization
#Loss function
criterion = Focal_CE_weights()

#LR scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.95)

#Number of epochs
num_epochs = 50

## Training Loop

The training loop follows a standard structure. To control the progress of the training, we calculate the AUC values for each class, which are stored in two variables: auc_i (with all the values), auc_i_epoch (with the values of the current epoch) to facilitate the evaluation and for academic purposes.

In [None]:
freq = 15 #Plot images every freq iterations

checkpoint_path = "/content/drive/My Drive/Checkpoints/model.tar" #Load and save model each epoch in a .tar file.

saveCheckpoint = True #Saves the model / optimizer each epoch
loadCheckpoint = False #Loads the model / optimizer before the loop
#---------------------------------------------------------------------------------
#We store the loss values in lists. Ideally we would want to use Tensorboard for a "cleaner" visualization of the plots. This are mostly for debugging purposes.
train_loss = []
validation_loss = []
train_loss_epoch = []
val_loss_epoch = []

auc_others, auc_cysts, auc_negNet, auc_pigNet, auc_streaks, auc_globules = [], [], [], [], [], []
auc_others_epoch, auc_cysts_epoch, auc_negNet_epoch, auc_pigNet_epoch, auc_streaks_epoch, auc_globules_epoch = [], [], [], [], [], []

auc_list = [auc_others, auc_cysts, auc_negNet, auc_pigNet, auc_streaks, auc_globules]
auc_list_epoch = [auc_others_epoch, auc_cysts_epoch, auc_negNet_epoch, auc_pigNet_epoch, auc_streaks_epoch, auc_globules_epoch]


if(loadCheckpoint):
  print("Loading checkpoint... | Path: ", checkpoint_path)
  checkpoint = torch.load(checkpoint_path, map_location=device)
  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 


for epoch in range(num_epochs):
    optimizer.zero_grad()
    if(epoch > 0):       
      print("__Summary of epoch ", epoch-1)
      print("|--------------------------------|")
      print("| Avg Train Loss       | %.4f  |" % stats.mean(train_loss_epoch))
      print("| Avg Val   Loss       | %.4f  |" % stats.mean(val_loss_epoch))
      print("|----------------------|---------|")
      print("| Avg AUC Other  Val   | %.4f  |" % stats.mean(auc_others_epoch))
      print("| Avg AUC M-Cysts  Val | %.4f  |" % stats.mean(auc_cysts_epoch))
      print("| Avg AUC Neg net  Val | %.4f  |" % stats.mean(auc_negNet_epoch))
      print("| Avg AUC Pig net  Val | %.4f  |" % stats.mean(auc_pigNet_epoch))
      print("| Avg AUC Streaks  Val | %.4f  |" % stats.mean(auc_streaks_epoch))
      print("| Avg AUC Globules Val | %.4f  |" % stats.mean(auc_globules_epoch))
      print("|----------------------|---------|")

      #Tensorboard--
      writer.add_scalar("Loss/train", stats.mean(train_loss_epoch), epoch)
      writer.add_scalar("Loss/validation", stats.mean(val_loss_epoch), epoch)
      writer.add_scalar("AUC/train/others", stats.mean(auc_others_epoch), epoch)
      writer.add_scalar("AUC/train/cysts", stats.mean(auc_cysts_epoch), epoch)
      writer.add_scalar("AUC/train/negNet", stats.mean(auc_negNet_epoch), epoch)
      writer.add_scalar("AUC/train/pigNet", stats.mean(auc_pigNet_epoch), epoch)
      writer.add_scalar("AUC/train/streaks", stats.mean(auc_streaks_epoch), epoch)
      writer.add_scalar("AUC/train/globules", stats.mean(auc_globules_epoch), epoch)
      #--

      train_loss_epoch.clear()
      val_loss_epoch.clear()
      auc_others_epoch.clear()
      auc_cysts_epoch.clear()
      auc_negNet_epoch.clear()
      auc_pigNet_epoch.clear()
      auc_streaks_epoch.clear()
      auc_globules_epoch.clear()          
   
      if(saveCheckpoint): #Save model if epoch > 0
          print("Saving checkpoint... | epoch: ", epoch, "| Path: ", checkpoint_path)
          torch.save({'model_state_dict': model.state_dict(),  
                      'optimizer_state_dict': optimizer.state_dict(),
                      'scheduler_state': scheduler.state_dict()},
                      checkpoint_path)

    
   #Training loop--------------------------------------------------------------- 
    
    # Each epoch has a training and a validation phase
    for phase in ['train', 'val']:

        if phase == 'train':
            model.train()  # Set model to training mode
            print("Epoch: ", epoch, "| Training phase |")
        else:
            model.eval()  # Set model to evaluate mode
            model.apply(bn_eval)
            print("Epoch: ", epoch, "| Validation phase |")

        for i, data in enumerate(data_loaders[phase],0):
          inputs, labels, foreground = data  
          labels = labels.long()
          foreground = foreground.long()
          foreground = foreground.squeeze(1)  
                  
          inputs, labels, foreground = inputs.to(device), labels.to(device), foreground.to(device) 

          #Predict output probabilities
          y_score = model(inputs)

          #Calculate Loss
          loss = criterion(y_score, labels)
          
          if phase == 'train':            
            optimizer.zero_grad()
            loss.backward() #Backpropagation
            optimizer.step()  
            scheduler.step() #Sch step

            train_loss.append(loss.item()) #Save loss value every iteration*. *For debugging purposes.
            train_loss_epoch.append(loss.item())

          if phase == 'val':
            validation_loss.append(loss.item())
            val_loss_epoch.append(loss.item())
            get_auc_values(labels.cpu(), y_score.cpu(), auc_list, auc_list_epoch)

          
          if(i%freq==0):
              #Show input image
              plt.figure(figsize=(9,9))
              plt.title('Input')
              plt.imshow(inputs.cpu().data[0].squeeze().permute(1,2,0))                  
              plt.show()

              #Show predicted and Ground truth masks
              plot_images(y_score[0].detach().cpu(), labels[0].cpu(), foreground[0].cpu().numpy())

              #Plot loss curves*. In Tensorboard we save only 1 value per epoch. This variables contain the loss values for each iteration. 
              plt.figure(figsize=(12,7))
              plt.subplot(121)
              plt.title("Training Loss")
              plt.plot(train_loss, 'r')
              plt.subplot(122)
              plt.title("Validation Loss")
              plt.plot(validation_loss, 'b')
              plt.show()
 
    

## Evaluation

This block evaluates the model saved in a .tar file (checkpoint_path) in the test set. 

In [None]:
checkpoint_path = "/content/drive/My Drive/Checkpoints/model.tar"

print("Loading checkpoint... | Path: ", checkpoint_path)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

model.eval()
model.apply(bn_eval) #Set BatchNorm2d to evaluation mode

#Area under the curve
auc_others, auc_cysts, auc_negNet, auc_pigNet, auc_streaks, auc_globules = [], [], [], [], [], []
auc_others_epoch, auc_cysts_epoch, auc_negNet_epoch, auc_pigNet_epoch, auc_streaks_epoch, auc_globules_epoch = [], [], [], [], [], []
auc_list = [auc_others, auc_cysts, auc_negNet, auc_pigNet, auc_streaks, auc_globules]
auc_list_epoch = [auc_others_epoch, auc_cysts_epoch, auc_negNet_epoch, auc_pigNet_epoch, auc_streaks_epoch, auc_globules_epoch]

#Jaccard Score
iou_others, iou_cysts, iou_negNet, iou_pigNet, iou_streaks, iou_globules = [], [], [], [], [], []
iou_list = [iou_others, iou_cysts, iou_negNet, iou_pigNet, iou_streaks, iou_globules]


cm_epoch = np.zeros((6,6)) #Confusion Matrix


for i, data in enumerate(testloader,0):
  inputs, labels, foreground = data  
  labels = labels.long()
  foreground = foreground.long()
  foreground = foreground.squeeze(1)    

  inputs, labels, foreground = inputs.to(device), labels.to(device), foreground.to(device) #GPU / CPU  


  y_score = model(inputs)
  y_score = torch.nn.functional.softmax(y_score, dim=0) #Raw output -> Softmax()
  get_iou(y_score.cpu(), labels.cpu(), foreground.cpu(), iou_list)
  
  x, pred = torch.max(y_score, dim=1) #Classify in highest score class
  x2, y = torch.max(labels, dim=1) #Undo one-hot

  
  prediction = pred.view(-1)  #prediction
  yf = y.view(-1) #label
  
  cm = confusion_matrix(y_true=yf.cpu(), y_pred=prediction.cpu()) #Calculate confusion matrix for the batch
  cm_epoch = cm_epoch + cm  #Summatory of the whole test set

  #Show input image
  plt.figure(figsize=(9,9))              
  plt.title('Input')
  plt.imshow(inputs.cpu().data[0].squeeze().permute(1,2,0))     
  plt.show()
    
  #Show prediction and Ground truth masks  
  plot_images(y_score[0].detach().cpu(), labels[0].cpu(), foreground[0].cpu().numpy())

  #Calculate AUC per class
  get_auc_values(labels.cpu(), y_score.cpu(), auc_list, auc_list_epoch)
 
print("|------------------|---------|")
print("| Avg AUC Other    | %.4f  |" % stats.mean(auc_others_epoch))
print("| Avg AUC Cysts    | %.4f  |" % stats.mean(auc_cysts_epoch))
print("| Avg AUC Neg net  | %.4f  |" % stats.mean(auc_negNet_epoch))
print("| Avg AUC Pig net  | %.4f  |" % stats.mean(auc_pigNet_epoch))
print("| Avg AUC Streaks  | %.4f  |" % stats.mean(auc_streaks_epoch))
print("| Avg AUC Globules | %.4f  |" % stats.mean(auc_globules_epoch))
print("|__________________|_________|")
print("| Avg IOU Other    | %.4f  |" % stats.mean(iou_others))
print("| Avg IOU Cysts    | %.4f  |" % stats.mean(iou_cysts))
print("| Avg IOU Neg net  | %.4f  |" % stats.mean(iou_negNet))
print("| Avg IOU Pig net  | %.4f  |" % stats.mean(iou_pigNet))
print("| Avg IOU Streaks  | %.4f  |" % stats.mean(iou_streaks))
print("| Avg IOU Globules | %.4f  |" % stats.mean(iou_globules))
print("|__________________|_________|")




In [None]:
#Save values in Tensorboard
writer.add_scalar("AUC/test/others", stats.mean(auc_others_epoch), 1)
writer.add_scalar("AUC/test/cysts", stats.mean(auc_cysts_epoch), 1)
writer.add_scalar("AUC/test/negNet", stats.mean(auc_negNet_epoch), 1)
writer.add_scalar("AUC/test/pigNet", stats.mean(auc_pigNet_epoch), 1)
writer.add_scalar("AUC/test/streaks", stats.mean(auc_streaks_epoch), 1)
writer.add_scalar("AUC/test/globules", stats.mean(auc_globules_epoch), 1)

writer.add_scalar("IOU/test/others", stats.mean(iou_others), 1)
writer.add_scalar("IOU/test/cysts", stats.mean(iou_cysts), 1)
writer.add_scalar("IOU/test/negNet", stats.mean(iou_negNet), 1)
writer.add_scalar("IOU/test/pigNet", stats.mean(iou_pigNet), 1)
writer.add_scalar("IOU/test/streaks", stats.mean(iou_streaks), 1)
writer.add_scalar("IOU/test/globules", stats.mean(iou_globules), 1)

writer.flush()

In [None]:
#Print confusion matrix*. *Rows represent the label and columns represent the prediction.
print(cm_epoch)

## Launch Tensorboard

Although the code wasn't developed using Tensorboard, it can be launched by executing the following cell:

*There might be some errors since this implementation hasn't been properly debugged.

In [None]:
%tensorboard --logdir=runs

In [None]:
writer.close()