In [1]:
# header files
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import random
from random import shuffle
from PIL import Image

In [2]:
# ensure the experiment produces same result on each run
random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
class ImageFolder(torch.utils.data.Dataset):
  def __init__(self, root_images, root_gt, image_size=224, mode='train'):
    """
		Initializes image paths and preprocessing module.
		"""
    self.gt_paths = root_gt
    self.image_paths = list(map(lambda x: os.path.join(root_images, x), os.listdir(root_images)))
    self.image_size = image_size
    self.mode = mode
    self.rotation_list = [0, 90, 180, 270]
    self.augmentation_prob = 0.4
    
  def __getitem__(self, index):
    """
    Reads an image from a file and preprocesses it and returns.
    """
    image_path = self.image_paths[index]
    filename = image_path.split('_')[-1][:-len(".jpg")]
    gt_path = self.gt_paths + 'ISIC_' + filename + '_segmentation.png'
    
    image = Image.open(image_path)
    GT = Image.open(gt_path)
    aspect_ratio = image.size[1]/image.size[0]
    Transform = []
    ResizeRange = random.randint(300, 320)
    Transform.append(torchvision.transforms.Resize((int(ResizeRange*aspect_ratio), ResizeRange)))
    p_transform = random.random()
    
    if (self.mode == 'train') and p_transform <= self.augmentation_prob:
      RotationDegree = random.randint(0,3)
      RotationDegree = self.rotation_list[RotationDegree]
      if (RotationDegree == 90) or (RotationDegree == 270):
        aspect_ratio = 1/aspect_ratio
        
      Transform.append(torchvision.transforms.RandomRotation((RotationDegree, RotationDegree)))
      RotationRange = random.randint(-10, 10)
      Transform.append(torchvision.transforms.RandomRotation((RotationRange, RotationRange)))
      CropRange = random.randint(250, 270)
      Transform.append(torchvision.transforms.CenterCrop((int(CropRange*aspect_ratio), CropRange)))
      Transform = torchvision.transforms.Compose(Transform)
      
      image = Transform(image)
      GT = Transform(GT)
      ShiftRange_left = random.randint(0, 20)
      ShiftRange_upper = random.randint(0, 20)
      ShiftRange_right = image.size[0] - random.randint(0, 20)
      ShiftRange_lower = image.size[1] - random.randint(0, 20)
      image = image.crop(box=(ShiftRange_left,ShiftRange_upper,ShiftRange_right,ShiftRange_lower))
      GT = GT.crop(box=(ShiftRange_left,ShiftRange_upper,ShiftRange_right,ShiftRange_lower))
        
      Transform = torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, hue=0.02)
      image = Transform(image)
      Transform =[]
      
    Transform.append(torchvision.transforms.Resize((int(256*aspect_ratio)-int(256*aspect_ratio)%16,256)))
    Transform.append(torchvision.transforms.Resize((224, 224)))
    Transform.append(torchvision.transforms.ToTensor())
    Transform = torchvision.transforms.Compose(Transform)
    image = Transform(image)
    GT = Transform(GT)
    Norm_ = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    image = Norm_(image)
    return image, GT
    
  def __len__(self):
    """
    Returns the total number of font files.
    """
    return len(self.image_paths)



def get_loader(image_path, gt_path, batch_size, mode, num_workers=4, isshuffle=True):
  """
  Builds and returns Dataloader.
  """
  dataset = ImageFolder(root_images=image_path, root_gt=gt_path, mode=mode)
  data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=isshuffle, num_workers=4)
  return data_loader

In [None]:
train_loader = get_loader("/content/drive/My Drive/isic_training/", "/content/drive/My Drive/isic_training_gt/", 8, 'train', 8, True)
val_loader = get_loader("/content/drive/My Drive/isic_valid/", "/content/drive/My Drive/isic_valid_gt/", 8, "valid", 8, False)

In [None]:
# define loss for two-class problem
criterion = torch.nn.BCELoss()

In [3]:
class U_Net(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        """
        U-Net class.
        Arguments:
        ----------
        in_channels: int
            The number of input channels.
        out_channels: int
            The number of output channels.
        """
        super(U_Net, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.conv5 = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.up5 = nn.Sequential(
            nn.Upsample(scale_factor = 2),
            nn.Conv2d(1024, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace = True)
        )

        self.up_conv5 = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace = True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace = True),
        )

        self.up4 = nn.Sequential(
            nn.Upsample(scale_factor = 2),
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace = True)
        )

        self.up_conv4 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace = True),
        )

        self.up3 = nn.Sequential(
            nn.Upsample(scale_factor = 2),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace = True)
        )

        self.up_conv3 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace = True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace = True),
        )
        
        self.up2 = nn.Sequential(
            nn.Upsample(scale_factor = 2),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace = True)
        )

        self.up_conv2 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace = True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace = True),
        )

        self.final = nn.Sequential(
            nn.Upsample(scale_factor = 2),
            nn.Conv2d(64, out_channels, 1)   
        )

    def forward(self, x):
        """
        Computation of the U-Net.
        
        Arguments:
        ----------
        inputs: a 4-th order tensor of size 
            [batch_size, in_channels, height, width]
            Input to the U-Net.
        Returns:
        --------
        outputs: a 4-th order tensor of size
            [batch_size, out_channels, height, width]
            Output of the U-Net.  
        """

        # encoding path
        x1 = self.conv1(x)
        x2 = self.conv2(x1)        
        x3 = self.conv3(x2)
        x4 = self.conv4(x3)
        x5 = self.conv5(x4)

        # decoding + concat path
        d5 = self.up5(x5)
        d5 = torch.cat((x4, d5), dim=1)     
        d5 = self.up_conv5(d5)
        
        d4 = self.up4(d5)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.up_conv4(d4)

        d3 = self.up3(d4)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.up_conv3(d3)

        d2 = self.up2(d3)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.up_conv2(d2)

        d1 = self.final(d2)
        return d1

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = U_Net()
model.to(device)

In [None]:
# optimizer to be used
optimizer = torch.optim.Adam(model.parameters(), 0.001, [0.5, 0.999])

In [None]:
train_loss_list = []
train_accuracy_list = []
val_loss_list = []
val_accuracy_list = []

# training and val loop
for epoch in range(0, 250):

  # train
  model.train()
  train_loss = 0.0
  train_accuracy = 0.0
  correct = 0
  total = 0
  for step, (images, labels) in enumerate(train_loader):
    
    # if cuda
    batch_size = images.size(0)
    images = images.to(device)
    labels = labels.to(device)
    
    # get loss
    optimizer.zero_grad()
    outputs = torch.sigmoid(model(images))
    outputs_flat = outputs.view(outputs.size(0), -1)
    labels_flat = labels.view(labels.size(0), -1)

    loss = criterion(outputs_flat, labels_flat)
    loss.backward()
    optimizer.step()
    train_loss += loss.item()

    # convert outputs and labels to rank-1 tensor
    for j in range(0, batch_size):
      output = outputs[j]
      gt = labels[j]

      output = (output > 0.5)
      gt = (gt == torch.max(gt))
      correct += float(torch.sum(output == gt))
      total += float(output.numel())

  # update training_loss, training_accuracy and training_iou 
  train_loss = train_loss / float(len(train_loader))
  train_accuracy = float(correct) / float(total)
  train_loss_list.append(train_loss)
  train_accuracy_list.append(train_accuracy)

  
  # evaluation code
  model.eval()
  val_loss = 0.0
  val_accuracy = 0.0
  correct = 0
  total = 0
  for step, (images, labels) in enumerate(val_loader):
    with torch.no_grad():

      # if cuda
      batch_size = images.size(0)
      images = images.to(device)
      labels = labels.to(device)

      # get loss
      outputs = torch.sigmoid(model(images))
      outputs_flat = outputs.view(outputs.size(0), -1)
      labels_flat = labels.view(labels.size(0), -1)
      loss = criterion(outputs_flat, labels_flat)
      val_loss += loss.item()

      # convert outputs and labels to rank-1 tensor
      for j in range(0, batch_size):
        output = outputs[j]
        gt = labels[j]

        output = (output > 0.5)
        gt = (gt == torch.max(gt))
        correct += float(torch.sum(output == gt))
        total += float(output.numel())

  # update val_loss, val_accuracy and val_iou 
  val_loss = val_loss / float(len(val_loader))
  val_accuracy = float(correct) / float(total)
  val_loss_list.append(val_loss)
  val_accuracy_list.append(val_accuracy)


  print()
  print("Epoch: " + str(epoch))
  print("Training Loss: " + str(train_loss) + "    Validation Loss: " + str(val_loss))
  print("Training Accuracy: " + str(train_accuracy) + "    Validation Accuracy: " + str(val_accuracy))
  print()

In [None]:
import matplotlib.pyplot as plt
e = []
for index in range(0, 250):
  e.append(index)

In [None]:
plt.plot(e, train_loss_list)
plt.show()

In [None]:
plt.plot(e, val_loss_list)
plt.show()