In [1]:
# header files needed
import numpy as np
import torch
import torch.nn as nn
import torchvision
import random
from PIL import Image
import glob

In [2]:
# ensure the experiment produces same result on each run
random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)

In [4]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# transforms
input_transform = torchvision.transforms.Compose([
    torchvision.transforms.CenterCrop(256),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

target_transform = torchvision.transforms.Compose([
    torchvision.transforms.CenterCrop(256),
    torchvision.transforms.ToTensor()
])

In [None]:
# dataset
train_dataset = torchvision.datasets.VOCSegmentation("/content/drive/My Drive/", year='2012', image_set='train', download=False, transform=input_transform, target_transform=target_transform)
val_dataset = torchvision.datasets.VOCSegmentation("/content/drive/My Drive/", year='2012', image_set='val', download=False, transform=input_transform, target_transform=target_transform)

In [None]:
# data loader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=8)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=True, num_workers=8)

In [None]:
# loss
weight = torch.ones(22)
weight[0] = 0
loss = torch.nn.NLLLoss2d(weight=weight)

In [None]:
# model
class FCN(torch.nn.Module):

  # init function
  def __init__(self, num_classes=21):
    super(FCN, self).__init__()

    # vgg-16 backbone for encoder part
    self.encoder = torch.nn.Sequential(
        
        # block 1
        torch.nn.Conv2d(3, 64, kernel_size=3, padding=100),
        torch.nn.BatchNorm2d(64),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(64, 64, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(64),
        torch.nn.ReLU(inplace=True),
        torch.nn.MaxPool2d(kernel_size=2, stride=2),         # 227 x 227 x 64

        # block 2
        torch.nn.Conv2d(64, 128, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(128),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(128, 128, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(128),
        torch.nn.ReLU(inplace=True),
        torch.nn.MaxPool2d(kernel_size=2, stride=2),         # 113 x 113 x 128

        # block 3
        torch.nn.Conv2d(128, 256, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(256),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(256, 256, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(256),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(256, 256, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(256),
        torch.nn.ReLU(inplace=True),
        torch.nn.MaxPool2d(kernel_size=2, stride=2),         # 56 x 56 x 256

        # block 4
        torch.nn.Conv2d(256, 512, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(512, 512, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(512, 512, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True),
        torch.nn.MaxPool2d(kernel_size=2, stride=2),         # 28 x 28 x 512

        # block 5
        torch.nn.Conv2d(512, 512, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(512, 512, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(512, 512, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True),
        torch.nn.MaxPool2d(kernel_size=2, stride=2),          # 14 x 14 x 512

        # fc
        torch.nn.Conv2d(512, 4096, kernel_size=7),
        torch.nn.BatchNorm2d(4096),
        torch.nn.ReLU(inplace=True),
        torch.nn.Dropout2d(),
        torch.nn.Conv2d(4096, 4096, kernel_size=1),
        torch.nn.BatchNorm2d(4096),
        torch.nn.Dropout2d()                                  # 8 x 8 x 4096
    )