In [1]:
# header files
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import random
from random import shuffle
from PIL import Image

In [2]:
# ensure the experiment produces same result on each run
random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
class ImageFolder(torch.utils.data.Dataset):
  def __init__(self, root_images, root_gt, image_size=224, mode='train'):
    """
		Initializes image paths and preprocessing module.
		"""
    self.gt_paths = root_gt
    self.image_paths = list(map(lambda x: os.path.join(root_images, x), os.listdir(root_images)))
    self.image_size = image_size
    self.mode = mode
    self.rotation_list = [0, 90, 180, 270]
    self.augmentation_prob = 0.4
    
  def __getitem__(self, index):
    """
    Reads an image from a file and preprocesses it and returns.
    """
    image_path = self.image_paths[index]
    filename = image_path.split('_')[-1][:-len(".jpg")]
    gt_path = self.gt_paths + 'ISIC_' + filename + '_segmentation.png'
    
    image = Image.open(image_path)
    GT = Image.open(GT_path)
    aspect_ratio = image.size[1]/image.size[0]
    Transform = []
    ResizeRange = random.randint(300, 320)
    Transform.append(T.Resize((int(ResizeRange*aspect_ratio), ResizeRange)))
    p_transform = random.random()
    
    if (self.mode == 'train') and p_transform <= self.augmentation_prob:
      RotationDegree = random.randint(0,3)
      RotationDegree = self.rotation_list[RotationDegree]
      if (RotationDegree == 90) or (RotationDegree == 270):
        aspect_ratio = 1/aspect_ratio
        
      Transform.append(torchvision.transforms.RandomRotation((RotationDegree, RotationDegree)))
      RotationRange = random.randint(-10, 10)
      Transform.append(torchvision.transforms.RandomRotation((RotationRange, RotationRange)))
      CropRange = random.randint(250, 270)
      Transform.append(torchvision.transforms.CenterCrop((int(CropRange*aspect_ratio), CropRange)))
      Transform = torchvision.transforms.Compose(Transform)
      
      image = Transform(image)
      GT = Transform(GT)
      ShiftRange_left = random.randint(0, 20)
      ShiftRange_upper = random.randint(0, 20)
      ShiftRange_right = image.size[0] - random.randint(0, 20)
      ShiftRange_lower = image.size[1] - random.randint(0, 20)
      image = image.crop(box=(ShiftRange_left,ShiftRange_upper,ShiftRange_right,ShiftRange_lower))
      GT = GT.crop(box=(ShiftRange_left,ShiftRange_upper,ShiftRange_right,ShiftRange_lower))
      
      if random.random() < 0.5:
        image = F.hflip(image)
        GT = F.hflip(GT)
        
      if random.random() < 0.5:
        image = F.vflip(image)
        GT = F.vflip(GT)
        
      Transform = torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, hue=0.02)
      image = Transform(image)
      Transform =[]
      
    Transform.append(torchvision.transforms.Resize((int(256*aspect_ratio)-int(256*aspect_ratio)%16,256)))
    Transform.append(torchvision.transforms.Resize((224, 224)))
    Transform.append(torchvision.transforms.ToTensor())
    Transform = torchvision.transforms.Compose(Transform)
    image = Transform(image)
    GT = Transform(GT)
    Norm_ = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    image = Norm_(image)
    return image, GT
    
  def __len__(self):
    """
    Returns the total number of font files.
    """
    return len(self.image_paths)



def get_loader(image_path, gt_path, batch_size, mode, num_workers=4, isshuffle=True):
  """
  Builds and returns Dataloader.
  """
  dataset = ImageFolder(root_images=image_path, root_gt=gt_path, mode=mode)
  data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=isshuffle, num_workers=4)
  return data_loader

In [None]:
train_loader = get_loader("/content/drive/My Drive/isic_training/", "/content/drive/My Drive/isic_training_gt/", 4, 'train', 4, True)
val_loader = get_loader("/content/drive/My Drive/isic_valid/", "/content/drive/My Drive/isic_valid_gt/", 4, "valid", 4, False)

In [None]:
# define loss for two-class problem
criterion = torch.nn.BCELoss()

In [None]:
class SegNet(torch.nn.Module):

  def __init__(self, pretrained_net, num_classes=1):
    super(SegNet, self).__init__()

    # encoder 1, encoder 2, encoder 3, encoder 4, encoder 5
    self.encoder_block_1 = torch.nn.Sequential(*list(model.features.children())[:-38])
    self.encoder_block_2 = torch.nn.Sequential(*list(model.features.children())[-37:-31])
    self.encoder_block_3 = torch.nn.Sequential(*list(model.features.children())[-30:-21])
    self.encoder_block_4 = torch.nn.Sequential(*list(model.features.children())[-20:-11])
    self.encoder_block_5 = torch.nn.Sequential(*list(model.features.children())[-10:-1])

    # max-pool layer with return_indices as true
    self.max_pool_layer = torch.nn.Sequential(
        torch.nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
    )

    # decoder block 1
    self.decoder_block_1 = torch.nn.Sequential(
        torch.nn.ConvTranspose2d(512, 512, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True),
        torch.nn.ConvTranspose2d(512, 512, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True),
        torch.nn.ConvTranspose2d(512, 512, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True)
    )

    # decoder block 2
    self.decoder_block_2 = torch.nn.Sequential(
        torch.nn.ConvTranspose2d(512, 512, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True),
        torch.nn.ConvTranspose2d(512, 512, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True),
        torch.nn.ConvTranspose2d(512, 256, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(256),
        torch.nn.ReLU(inplace=True)
    )

    # decoder block 3
    self.decoder_block_3 = torch.nn.Sequential(
        torch.nn.ConvTranspose2d(256, 256, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(256),
        torch.nn.ReLU(inplace=True),
        torch.nn.ConvTranspose2d(256, 256, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(256),
        torch.nn.ReLU(inplace=True),
        torch.nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(128),
        torch.nn.ReLU(inplace=True)
    )

    # decoder block 4
    self.decoder_block_4 = torch.nn.Sequential(
        torch.nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(128),
        torch.nn.ReLU(inplace=True),
        torch.nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(64),
        torch.nn.ReLU(inplace=True)
    )

    # decoder block 5
    self.decoder_block_5 = torch.nn.Sequential(
        torch.nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(64),
        torch.nn.ReLU(inplace=True),
        torch.nn.ConvTranspose2d(64, num_classes, kernel_size=3, padding=1)
    )

  def forward(self, x):
    
    # apply encoder block 1
    enc_1 = self.encoder_block_1(x)
    enc_1, m_1 = self.max_pool_layer(enc_1)

    # apply encoder block 2
    enc_2 = self.encoder_block_2(enc_1)
    enc_2, m_2 = self.max_pool_layer(enc_2)

    # apply encoder block 3
    enc_3 = self.encoder_block_3(enc_2)
    enc_3, m_3 = self.max_pool_layer(enc_3)

    # apply encoder block 4
    enc_4 = self.encoder_block_4(enc_3)
    enc_4, m_4 = self.max_pool_layer(enc_4)

    # apply encoder block 5
    enc_5 = self.encoder_block_5(enc_4)
    enc_5, m_5 = self.max_pool_layer(enc_5)

    # apply decoder block 1
    dec_1 = self.decoder_block_1(torch.nn.functional.max_unpool2d(enc_5, m_5, kernel_size=2, stride=2, output_size=enc_4.size()))

    # apply decoder block 2
    dec_2 = self.decoder_block_2(torch.nn.functional.max_unpool2d(enc_4, m_4, kernel_size=2, stride=2, output_size=enc_3.size()))

    # apply decoder block 3
    dec_3 = self.decoder_block_3(torch.nn.functional.max_unpool2d(enc_3, m_3, kernel_size=2, stride=2, output_size=enc_2.size()))

    # apply decoder block 4
    dec_4 = self.decoder_block_4(torch.nn.functional.max_unpool2d(enc_2, m_2, kernel_size=2, stride=2, output_size=enc_1.size()))

    output = self.decoder_block_5(torch.nn.functional.max_unpool2d(dec_4, m_1, kernel_size=2, stride=2, output_size=x.size()))
    return output

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torchvision.models.vgg16_bn(pretrained=True)
model = SegNet(model, 1)
model.to(device)

In [None]:
# optimizer to be used
optimizer = torch.optim.Adam(model.parameters(), 0.001, [0.5, 0.999])

In [None]:
train_loss_list = []
train_accuracy_list = []
train_iou_list = []
val_loss_list = []
val_accuracy_list = []
val_iou_list = []

# training and val loop
for epoch in range(0, 250):

  # train
  model.train()
  train_loss = 0.0
  train_accuracy = 0.0
  train_iou = 0.0
  for step, (images, labels) in enumerate(train_loader):
    
    # if cuda
    images = images.to(device)
    labels = labels.type(torch.LongTensor)
    labels = labels.reshape(labels.shape[0], labels.shape[2], labels.shape[3])
    labels = labels.to(device)
    
    # get loss
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    train_loss += loss.item()

    # convert outputs and labels to rank-1 tensor
    outputs = outputs.argmax(1).flatten().cpu()
    labels = labels.flatten().cpu()
    intersection = outputs * (outputs == labels).long()

    # update training_accuracy
    total = ((labels == labels) * (labels > 0)).sum()
    correct = ((outputs == labels) * (labels > 0)).sum()
    train_accuracy += np.round(float(float(correct) / float(total)), 6)

    # update training_iou
    area_intersection = torch.histc(intersection.float(), bins=num_classes-1, max=num_classes-1, min=1)
    area_labels = torch.histc(labels.float(), bins=num_classes-1, max=num_classes-1, min=1)
    area_outputs = torch.histc(outputs.float(), bins=num_classes-1, max=num_classes-1, min=1)
    area_union = area_outputs + area_labels - area_intersection
    train_iou += np.round(float(float(torch.sum(area_intersection)) / float(torch.sum(area_union))), 6)

  # update training_loss, training_accuracy and training_iou 
  train_loss = train_loss / float(len(train_loader))
  train_accuracy = train_accuracy / float(len(train_loader))
  train_iou = train_iou / float(len(train_loader))
  train_loss_list.append(train_loss)
  train_accuracy_list.append(train_accuracy)
  train_iou_list.append(train_iou)

  
  # evaluation code
  model.eval()
  val_loss = 0.0
  val_accuracy = 0.0
  val_iou = 0.0
  for step, (images, labels) in enumerate(val_loader):
    with torch.no_grad():

      # if cuda
      images = images.to(device)
      labels = labels.type(torch.LongTensor)
      labels = labels.reshape(labels.shape[0], labels.shape[2], labels.shape[3])
      labels = labels.to(device)

      # get loss
      outputs = model(images)
      loss = criterion(outputs, labels)
      val_loss += loss.item()

      # convert outputs and labels to rank-1 tensor
      outputs = outputs.argmax(1).flatten().cpu()
      labels = labels.flatten().cpu()
      intersection = outputs * (outputs == labels).long()

      # update val_accuracy
      total = ((labels == labels) * (labels > 0)).sum()
      correct = ((outputs == labels) * (labels > 0)).sum()
      val_accuracy += np.round(float(float(correct) / float(total)), 4)

      # update val_iou
      area_intersection = torch.histc(intersection.float(), bins=num_classes-1, max=num_classes-1, min=1)
      area_labels = torch.histc(labels.float(), bins=num_classes-1, max=num_classes-1, min=1)
      area_outputs = torch.histc(outputs.float(), bins=num_classes-1, max=num_classes-1, min=1)
      area_union = area_outputs + area_labels - area_intersection
      val_iou += np.round(float(float(torch.sum(area_intersection)) / float(torch.sum(area_union))), 6)

  # update val_loss, val_accuracy and val_iou 
  val_loss = val_loss / float(len(val_loader))
  val_accuracy = val_accuracy / float(len(val_loader))
  val_iou = val_iou / float(len(val_loader))
  val_loss_list.append(val_loss)
  val_accuracy_list.append(val_accuracy)
  val_iou_list.append(val_iou)


  print()
  print("Epoch: " + str(epoch))
  print("Training Loss: " + str(train_loss) + "    Validation Loss: " + str(val_loss))
  print("Training Accuracy: " + str(train_accuracy) + "    Validation Accuracy: " + str(val_accuracy))
  print("Training mIoU: " + str(train_iou) + "    Validation mIoU: " + str(val_iou))
  print()

In [None]:
import matplotlib.pyplot as plt
e = []
for index in range(0, 250):
  e.append(index)

In [None]:
plt.plot(e, train_loss_list)
plt.show()

In [None]:
plt.plot(e, val_loss_list)
plt.show()