In [1]:
# header files needed
import numpy as np
import torch
import torch.nn as nn
import torchvision

In [2]:
# ensure the experiment produces same result on each run
np.random.seed(1234)
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# transforms
input_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((256, 256)),
    torchvision.transforms.CenterCrop((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

target_transform = torchvision.transforms.Compose([                         
    torchvision.transforms.Resize((256, 256)),
    torchvision.transforms.CenterCrop((224, 224)),
    torchvision.transforms.ToTensor()
])

In [None]:
# dataset
train_dataset = torchvision.datasets.VOCSegmentation("/content/drive/My Drive/", year='2012', image_set='train', download=False, transform=input_transform, target_transform=target_transform)
val_dataset = torchvision.datasets.VOCSegmentation("/content/drive/My Drive/", year='2012', image_set='val', download=False, transform=input_transform, target_transform=target_transform)

In [None]:
# data loader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=8)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=16, shuffle=True, num_workers=8)

In [None]:
# loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)

In [None]:
# hyper-parameters and hard-coded values
num_classes = 21
num_epochs = 100
lr = 1e-5
wd = 5e-5

In [None]:
# model
class FCN8(torch.nn.Module):

  # init function
  def __init__(self, pretrained_net, num_classes=num_classes):
    super(FCN8, self).__init__()

    # encoder 1, encoder 2 and encoder 3
    self.encoder_1 = torch.nn.Sequential(*list(pretrained_net.features.children())[:-20])
    self.encoder_2 = torch.nn.Sequential(*list(pretrained_net.features.children())[-20:-10])
    self.encoder_3 = torch.nn.Sequential(*list(pretrained_net.features.children())[-10:])

    self.encoder_classifier = torch.nn.Sequential(
        torch.nn.Conv2d(512, 4096, kernel_size=7, padding=3),
        torch.nn.ReLU(inplace=True),
        torch.nn.Dropout(),
        torch.nn.Conv2d(4096, 4096, kernel_size=1),
        torch.nn.ReLU(inplace=True),
        torch.nn.Dropout()
    )

    # decoder 1, decoder 2 and decoder 3
    self.decoder_1 = torch.nn.Sequential(
        torch.nn.ConvTranspose2d(4096, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True)
    )

    self.decoder_2 = torch.nn.Sequential(
        torch.nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
        torch.nn.BatchNorm2d(256),
        torch.nn.ReLU(inplace=True)
    )

    self.decoder_3 = torch.nn.Sequential(
        torch.nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
        torch.nn.BatchNorm2d(128),
        torch.nn.ReLU(inplace=True),
        torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
        torch.nn.BatchNorm2d(64),
        torch.nn.ReLU(inplace=True),
        torch.nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
        torch.nn.BatchNorm2d(32),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(32, num_classes, kernel_size=1)
    )

  # forward function
  def forward(self, x):
    # apply encoder
    enc_output_1 = self.encoder_1(x)
    enc_output_2 = self.encoder_2(enc_output_1)
    output = self.encoder_3(enc_output_2)
    output = self.encoder_classifier(output)

    # apply decoder
    dec_output_1 = self.decoder_1(output)
    dec_output_2 = dec_output_1 + enc_output_2
    dec_output_2 = self.decoder_2(dec_output_2)
    output = dec_output_2 + enc_output_1
    output = self.decoder_3(output)

    # return the predicted label image
    return output

In [None]:
model = torchvision.models.vgg16_bn(pretrained=True)
model = FCN8(model, num_classes)
model.to(device)

In [None]:
# create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

In [None]:
train_loss_list = []
train_accuracy_list = []
train_iou_list = []
val_loss_list = []
val_accuracy_list = []
val_iou_list = []
best_metric = -1
best_metric_epoch = -1


# training and val loop
for epoch in range(0, num_epochs):

  # train
  model.train()
  train_loss = 0.0
  train_accuracy = 0.0
  train_iou = 0.0
  correct = 0.0
  total = 0.0
  area_intersection = 0.0
  area_union = 0.0
  for step, (images, labels) in enumerate(train_loader):
    
    # if cuda
    images = images.to(device)
    labels = labels.type(torch.LongTensor)
    labels = labels.reshape(labels.shape[0], labels.shape[2], labels.shape[3])
    labels = labels.to(device)
    
    # get loss
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    train_loss += loss.item()

    # convert outputs and labels to rank-1 tensor
    outputs = outputs.argmax(1).flatten().cpu()
    labels = labels.flatten().cpu()
    intersection = outputs * (outputs == labels).long()

    # update training_accuracy
    total += ((labels == labels) * (labels > 0)).sum()
    correct += ((outputs == labels) * (labels > 0)).sum()

    # update training_iou
    area_intersection_array = torch.histc(intersection.float(), bins=num_classes-1, max=num_classes-1, min=1)
    area_labels_array = torch.histc(labels.float(), bins=num_classes-1, max=num_classes-1, min=1)
    area_outputs_array = torch.histc(outputs.float(), bins=num_classes-1, max=num_classes-1, min=1)
    area_union_array = area_outputs_array + area_labels_array - area_intersection_array
    area_intersection += float(torch.sum(area_intersection_array))
    area_union += float(torch.sum(area_union_array))

  # update training_loss, training_accuracy and training_iou 
  train_loss = train_loss / float(len(train_loader))
  train_accuracy = float(correct) / float(total)
  train_iou = float(area_intersection) / float(area_union)
  train_loss_list.append(train_loss)
  train_accuracy_list.append(train_accuracy)
  train_iou_list.append(train_iou)
  
  # evaluation code
  model.eval()
  val_loss = 0.0
  val_accuracy = 0.0
  val_iou = 0.0
  correct = 0.0
  total = 0.0
  area_intersection = 0.0
  area_union = 0.0
  for step, (images, labels) in enumerate(val_loader):
    with torch.no_grad():

      # if cuda
      images = images.to(device)
      labels = labels.type(torch.LongTensor)
      labels = labels.reshape(labels.shape[0], labels.shape[2], labels.shape[3])
      labels = labels.to(device)

      # get loss
      outputs = model(images)
      loss = criterion(outputs, labels)
      val_loss += loss.item()

      # convert outputs and labels to rank-1 tensor
      outputs = outputs.argmax(1).flatten().cpu()
      labels = labels.flatten().cpu()
      intersection = outputs * (outputs == labels).long()

      # update val_accuracy
      total += ((labels == labels) * (labels > 0)).sum()
      correct += ((outputs == labels) * (labels > 0)).sum()

      # update val_iou
      area_intersection_array = torch.histc(intersection.float(), bins=num_classes-1, max=num_classes-1, min=1)
      area_labels_array = torch.histc(labels.float(), bins=num_classes-1, max=num_classes-1, min=1)
      area_outputs_array = torch.histc(outputs.float(), bins=num_classes-1, max=num_classes-1, min=1)
      area_union_array = area_outputs_array + area_labels_array - area_intersection_array
      area_intersection += float(torch.sum(area_intersection_array))
      area_union += float(torch.sum(area_union_array))

  # update val_loss, val_accuracy and val_iou 
  val_loss = val_loss / float(len(val_loader))
  val_accuracy = float(correct) / float(total)
  val_iou = float(area_intersection) / float(area_union)
  val_loss_list.append(val_loss)
  val_accuracy_list.append(val_accuracy)
  val_iou_list.append(val_iou)

  
  # early stopping
  if(best_metric < float(val_accuracy) and epoch >= 30):
    best_metric = float(val_accuracy)
    best_metric_epoch = epoch
    torch.save(model.state_dict(), "best_model.pth")

  print()
  print("Epoch: " + str(epoch))
  print("Training Loss: " + str(train_loss) + "    Validation Loss: " + str(val_loss))
  print("Training Accuracy: " + str(train_accuracy) + "    Validation Accuracy: " + str(val_accuracy))
  print("Training mIoU: " + str(train_iou) + "    Validation mIoU: " + str(val_iou))
  print()

In [None]:
import matplotlib.pyplot as plt

In [None]:
e = []
for index in range(0, 100):
  e.append(index)

In [None]:
plt.plot(e, train_iou_list)

In [None]:
plt.plot(e, val_iou_list)