In [1]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive


Imports

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import argparse
import cv2
import glob
import re
import cv2
from sys import stdout
from itertools import islice
from tqdm import tqdm
#from model import UNet
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.transforms import ToPILImage, ToTensor, Normalize, Compose
from torch.utils.data import Dataset
from scipy.ndimage import zoom
import matplotlib.pyplot as plt
import random
import copy
import math
import sys
from google.colab import files

Model

In [0]:

class UNet(nn.Module):

  def __init__(self, transfer_learning=False):
    super().__init__()

    self.initial_conv = down_block(3, 64, False) 
    self.down1 = down_block(64, 128, True) 
    self.down2 = down_block(128, 256, True) 
    self.down3 = down_block(256, 512, True) 
    self.down4 = down_block(512, 1024, True) 

    self.up1 = up_block(512)
    self.up2 = up_block(256)
    self.up3 = up_block(128)
    self.up4 = up_block(64)
    self.out = nn.Conv2d(64, 2, 1)

  def forward(self, x):
    self.x1 = self.initial_conv(x)
    self.x2 = self.down1(self.x1)
    self.x3 = self.down2(self.x2)
    self.x4 = self.down3(self.x3)
    self.x5 = self.down4(self.x4)

    x = self.up1(self.x5, self.x4)
    x = self.up2(x, self.x3)
    x = self.up3(x, self.x2)
    x = self.up4(x, self.x1)

    return F.softmax(self.out(x), dim=1)

class down_block(nn.Module):

  def __init__(self, in_channels, out_channels, downsample):
    super().__init__()
    self.double_conv = nn.Sequential(
      nn.Conv2d(in_channels, out_channels, 3, padding=1),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_channels, out_channels, 3, padding=1),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=True)
    )
    self.downsample = downsample
    self.max_pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)


  def enable(self):
    for name,param in self.named_parameters():
      param.requires_grad = True
      param.requires_grad = True

  def forward(self, x):
    if self.downsample == True:
      x = self.max_pool(x)
    x = self.double_conv(x)
    return x

class up_block(nn.Module):

  def __init__(self, out_channels):
    super().__init__()

    self.up_sampling = nn.ConvTranspose2d(out_channels*2, out_channels, 2, stride=2)

    self.double_conv = nn.Sequential(
      nn.Conv2d(out_channels*2, out_channels, 3, padding=1),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_channels, out_channels, 3, padding=1),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=True)
    )
  def enable(self):
    for name,param in self.named_parameters():
      param.requires_grad = True
      param.requires_grad = True

  def forward(self, x, prev_x):
    x = self.up_sampling(x)
    x = torch.cat((x, prev_x), dim=1)
    x = self.double_conv(x)
    return x

def weights_init(m):
  if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):    
    m.weight.data.normal_(0, math.sqrt(2/(m.in_channels*9)))
    m.bias.data.normal_(0, math.sqrt(2/(m.in_channels*9)))


Data Augmentation


In [0]:
class VerticalFlip:
  def __init__(self, probability=0.3):
    self.probability = probability
  def __call__(self, images):
    img = images[0]
    mask = images[1]
    if random.random() < self.probability:
      img = cv2.flip(img, 0)
      mask = cv2.flip(mask, 0).reshape(128,128,1)
    return img, mask

class HorizontalFlip:
  def __init__(self, probability=0.3):
    self.probability = probability
  def __call__(self, images):
    img = images[0]
    mask = images[1]
    if random.random() < self.probability:
      img = cv2.flip(img, 1)
      mask = cv2.flip(mask, 1).reshape(128,128,1)
    return img, mask
  
class Rotate:
  def __init__(self, probability=0.15):
    self.probability = probability
  def __call__(self, images):
    img = images[0]
    mask = images[1]
    random_num = random.random()
    if random_num < self.probability/3:
      
      M = cv2.getRotationMatrix2D((img.shape[1]/2,img.shape[0]/2), 90, 1.)
      img = cv2.warpAffine(img, M, (img.shape[1],img.shape[0]))
     
      mask[:,:,0] = np.rot90(mask[:,:,0], k=1, axes=(1,0))
      
    elif random_num < (2*self.probability)/3:
      
      M = cv2.getRotationMatrix2D((img.shape[1]/2,img.shape[0]/2), 180, 1.)
      img = cv2.warpAffine(img, M, (img.shape[1],img.shape[0]))
     
      mask[:,:,0] = np.rot90(mask[:,:,0], k=2, axes=(1,0))
      
    elif random_num < self.probability:
      
      M = cv2.getRotationMatrix2D((img.shape[1]/2,img.shape[0]/2), 180, 1.)
      img = cv2.warpAffine(img, M, (img.shape[1],img.shape[0]))
     
      mask[:,:,0] = np.rot90(mask[:,:,0], k=3, axes=(1,0))
       
    return img, mask

class Zoom:
  def __init__(self, probability=0.1):
    self.probability = probability
  def __call__(self, images):
    img = images[0]
    mask = images[1]
    if random.random() < self.probability:
      zoom_magnitude = 1 + self.probability
      img = zoom(img, (zoom_magnitude,zoom_magnitude,1))
      mask = zoom(img, (zoom_magnitude,zoom_magnitude, 1))
      w, h = img.shape[1], img.shape[0]
      a, b = int(w/2 - (64)), int(h/2 + (64))
      img = img[a:b,a:b,:]
      mask = mask[a:b,a:b,0]      
      
    return img, mask

train_transform = Compose([
  VerticalFlip(),
  HorizontalFlip(),
  Rotate(),
  Zoom()
])

Training and Testing


In [0]:

## Alter these as required for testing. 
## You can access the CSC420_A3 drive using the shareable link
## https://drive.google.com/drive/folders/1mt-_YWULaaFSe8WqWVfiiwpc4Yldmewg?usp=sharing
train_inputs_path = "drive/My Drive/CSC420_A3/cat_data/Train/input/*.jpg"
train_masks_path = "drive/My Drive/CSC420_A3/cat_data/Train/mask/*.jpg"

transfer_train_inputs_path = "drive/My Drive/CSC420_A3/extra_data/input/*.jpg"
transfer_train_masks_path = "drive/My Drive/CSC420_A3/extra_data/mask/*.png"

test_inputs_path = "drive/My Drive/CSC420_A3/cat_data/Test/input/*.jpg"
test_masks_path = "drive/My Drive/CSC420_A3/cat_data/Test/mask/*.jpg"

def dice(expected, x):
  num_classes = 2
  dims = (0,2) 
  s=1e-7
  one_hot_expected = torch.eye(num_classes)[expected.squeeze(1)]
  one_hot_expected = one_hot_expected.permute(0, 3, 1, 2).float()
  
  numerator = 2. * (torch.sum(x * one_hot_expected, (0,2,3)))
  denominator = torch.sum(x + one_hot_expected, (0,2,3)) + s  
  dice_loss = (numerator / denominator).mean()
  return (dice_loss)

def dice_loss(expected, x):
  return (1 - dice(expected, x))

def digits_in_string(text):
	convert = lambda text: int(text) if text.isdigit() else text.lower() 
	return [ convert(c) for c in re.split(r'(\d+)', text) ]

def read_data(inputs, masks, augment_data = False, transfer_learning=False):

  img_filenames = glob.glob(inputs)
  img_filenames.sort(key=digits_in_string)
  mask_filenames = glob.glob(masks)
  mask_filenames.sort(key=digits_in_string)
  
  train_data = []
  images = []
  masks = []
  # Read and resize and reshape the images and their corresponding masks
  # Append each image and mask to train_data
  for i in range(len(img_filenames)):
    img = cv2.imread(img_filenames[i])
    img = (cv2.resize(img,(128,128)))/255
    mask = cv2.imread(mask_filenames[i], cv2.IMREAD_GRAYSCALE)
    mask = ((cv2.resize(mask,(128,128)))).reshape(128,128,1)
    # The masks for use in transfer learning are already normalized
    if not transfer_learning:
      mask = mask/255
    mask = np.where(mask > 0.5, 1, 0)
    images.append(copy.deepcopy(img))
    masks.append(copy.deepcopy(mask))
    transform = transforms.ToTensor()
    train_data.append([transform(img).float(), transform(mask).long()])
  
  # If the data is augmented, it is doubled in size and half the images
  # undergo random transformations.
  if augment_data: 
    images, masks = images, masks
    for i in range(len(images)):
            
      img, mask = images[i], masks[i]      
      img, mask = train_transform([img, mask])
      transform = transforms.ToTensor()
      img, mask = transform(img), transform(mask)
      train_data.append([img.float(), mask.long()])
    
  return train_data
    
def train(train_data, model=None, num_epochs=20, batch_size=1, lr=0.02, transfer_learning=False):
  
  if model == None:
    model = UNet()
    model.train()
    # Initialize weights
    model.apply(weights_init)
  elif model != None and transfer_learning:
    model = UNet()
    model.train()
    # Freeze all layers
    for name,param in model.named_parameters():
      param.requires_grad = False
    # Unfreeze the last layers
    model.up4.enable()
    model.up3.enable()
    model.up2.enable()
    model.up1.enable()
    # Reset the weights of the last layers
    model.out.apply(weights_init)
    model.up4.apply(weights_init)
    model.up3.apply(weights_init)
    model.up2.apply(weights_init)
    model.out.requires_grad = True
    model.out.weight.requires_grad = True
    model.out.bias.requires_grad = True
    
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.99)

  if torch.cuda.is_available():
    model = model.cuda()

  train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)
  for epoch in range(num_epochs):
    print("Epoch: ", epoch+1)
    running_loss = 0
    for i, data in enumerate(train_loader):  
      
      optimizer.zero_grad()
      x, real_mask = data[0], data[1][:,0,:,:]
      x.requires_grad = True      
      if torch.cuda.is_available() and model.cuda():
        x = x.cuda()
        real_mask = real_mask.cuda()

      output = model(x.float())

      #### Calculate loss
      loss = dice_loss(real_mask, output)
      # loss = F.cross_entropy(output.float(), real_mask.long())
      loss.backward()
      optimizer.step()
      
      running_loss += loss.item()
      
    print(f"Training loss: {running_loss/len(train_loader)}")
    
  return model

def test(test_data, model):
  model.eval()
  test_loader = DataLoader(test_data, shuffle=False, batch_size=1)
  running_loss = 0
  accuracy = 0
  stop = 0
  for i, data in enumerate(test_loader):
    
    x, y_expected = data[0], data[1][0,0,:,:]
    
    if torch.cuda.is_available() and model.cuda():
        x = x.cuda()
        y_expected = y_expected.cuda()
    
    output = model(x.float())

    # Get final cat segmentation image
    _, predicted = torch.max(output.data, 1, keepdim=True)

    # Sorensen-dice coefficient for only the final product, rather than
    # the two segment dice score computed during training
    intersection = torch.sum(predicted[0,0,:,:]*y_expected[:,:])
    cardinality = torch.sum(predicted[0,0,:,:] + y_expected[:,:])
    score = (2. * intersection) / (cardinality + 1e-7)

    accuracy += score.item()
  print(f"\n\nTest Accuracy: {accuracy/len(test_loader)}")


def visualize_segmentation(model, test_data, samples=10):
  model.eval()
  test_loader = DataLoader(test_data, shuffle=True, batch_size=1)
  running_loss = 0
  accuracy = 0
  stop = 0
  for i, data in enumerate(test_loader):
    if i == samples:
      break
    x, y_expected = data[0], data[1][:,0,:,:]
    if torch.cuda.is_available() and model.cuda():
      x = x.cuda()
      y_expected = y_expected.cuda()
    output = model(x.float())
    _, predicted = torch.max(output.data, 1, keepdim=True)
    predicted = predicted[0,0,:,:].detach().cpu().numpy()

    # Find the contours of the mask
    border = cv2.copyMakeBorder(predicted, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0 )
    _, contours, _ = cv2.findContours(np.uint8(border), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    # Sometimes there are other small contours in the image. Make sure its the cat contour. 
    longest_contour = 0
    if len(contours) > 1:
      for i in range(1, len(contours)):
        if contours[i].shape[0] > contours[longest_contour].shape[0]:
          longest_contour = i
      contours = [contours[longest_contour]]
    input_img = data[0][0,:,:,:].permute(1,2,0).detach().cpu().numpy()
    blank = np.zeros(predicted.shape)
    img = cv2.drawContours(blank.copy(), contours, -1, (255, 255, 255), 2)
    img = np.stack((img,)*3, axis=-1)
    segmented = input_img + img

    ## Download the image.
    ## Uncomment this if you would like to try downloading segmentation images.
    # cv2.imwrite("seg" + str(i) + ".png", segmented*255)
    # files.download("seg" + str(i) + ".png")

  




Main:


In [0]:

torch.manual_seed(1234)
if torch.cuda.is_available():
  torch.set_default_tensor_type(torch.cuda.FloatTensor)
  torch.cuda.manual_seed_all(1234)
  print('Running on GPU: {}.'.format(torch.cuda.get_device_name()))
else:
  print('Running on CPU.')

# Set True to train the model for 1.1.
# If augment_data is set to True it will train it using augmented data for 1.2 
train_model = False

# Set True to load the model from 1.1.
# If augment_data is set to True it will load the model trained with augmented data from 1.2 
load_model = False

# Train the transfer model which is to be transfered on other data
train_transfer_model = False

# Take the transfer model and use transfer learning to train the model for cat segmentation
transfer_learning_training = False

# Load the model which was trained with transfer learning
load_model_using_transfer_learning = True 

test_model = True

# Whether to train/load with augmented data for 1.2
augment_data = False

# Download images of visualized segmentation
# Need to uncomment the command to download
visualize = False

#### 1.1/1.2
if train_model:
  if not augment_data:
    train_data = read_data(train_inputs_path, train_masks_path, augment_data=False)
    model = train(train_data, num_epochs=50, lr=0.01, batch_size=8)
    # torch.save(model.state_dict(), 'drive/My Drive/model_weights_11.pth')
  else:
    train_data = read_data(train_inputs_path, train_masks_path, augment_data=True)
    model = train(train_data, num_epochs=50, lr=0.01, batch_size=8)
    # torch.save(model.state_dict(), 'drive/My Drive/model_weights_augmented_data_12.pth')


if load_model:
  model = UNet()
  if not augment_data:
    model.load_state_dict(torch.load('drive/My Drive/model_weights_11.pth'))
  else:
    model.load_state_dict(torch.load('drive/My Drive/model_weights_augmented_data_12.pth'))


#### 3.3 TRANSFER LEARNING
# train the model on a larger, similar data set
if train_transfer_model:
  train_data = read_data(transfer_train_inputs_path, transfer_train_masks_path, augment_data=False, transfer_learning=True)
  model = train(train_data, lr=0.1, batch_size=16, num_epochs=30)
  torch.save(model.state_dict(), 'drive/My Drive/transfer_model.pth')

# load the transfer model and train it on the cat data
if transfer_learning_training:
  train_data = read_data(train_inputs_path, train_masks_path, augment_data=False)
  model = UNet()
  model.load_state_dict(torch.load('drive/My Drive/transfer_model.pth'))
  model = train(train_data, num_epochs=40, batch_size=4, model=model, transfer_learning=True, lr=0.01)
  # torch.save(model.state_dict(), 'drive/My Drive/model_with_transfer_learning_weights.pth')

# load the transfer learning model for testing
if load_model_using_transfer_learning:
  model = UNet()
  model.load_state_dict(torch.load('drive/My Drive/model_with_transfer_learning_weights.pth'))

if test_model:
  test_data = read_data(test_inputs_path, test_masks_path)
  test(test_data, model)

if visualize:
  test_data = read_data(test_inputs_path, test_masks_path)
  visualize_segmentation(model, test_data)
