In [1]:
# header files needed
import numpy as np
import torch
import torch.nn as nn
import torchvision
import random
from PIL import Image
import glob

In [2]:
# ensure the experiment produces same result on each run
random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# transforms
input_transform = torchvision.transforms.Compose([
    torchvision.transforms.CenterCrop(256),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

target_transform = torchvision.transforms.Compose([
    torchvision.transforms.CenterCrop(256),
    torchvision.transforms.ToTensor()
])

In [None]:
# dataset
train_dataset = torchvision.datasets.VOCSegmentation("/content/drive/My Drive/", year='2012', image_set='train', download=False, transform=input_transform, target_transform=target_transform)
val_dataset = torchvision.datasets.VOCSegmentation("/content/drive/My Drive/", year='2012', image_set='val', download=False, transform=input_transform, target_transform=target_transform)

In [None]:
# data loader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=8)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=True, num_workers=8)

In [None]:
# loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weight = torch.ones(22)
weight[0] = 0
criterion = torch.nn.NLLLoss2d(weight=weight.to(device))

In [None]:
class UNet(torch.nn.Module):

  # init function
  def __init__(self, num_classes=22):
    super(UNet, self).__init__()

    # encoder block 1
    self.encoder_block_1 = torch.nn.Sequential(
        torch.nn.Conv2d(3, 64, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(64),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(64, 64, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(64),
        torch.nn.ReLU(inplace=True),
        torch.nn.MaxPool2d(kernel_size=2, stride=2)
    )

    # encoder block 2
    self.encoder_block_2 = torch.nn.Sequential(
        torch.nn.Conv2d(64, 128, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(128),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(128, 128, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(128),
        torch.nn.ReLU(inplace=True),
        torch.nn.MaxPool2d(kernel_size=2, stride=2)
    )

    # encoder block 3
    self.encoder_block_3 = torch.nn.Sequential(
        torch.nn.Conv2d(128, 256, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(256),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(256, 256, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(256),
        torch.nn.ReLU(inplace=True),
        torch.nn.MaxPool2d(kernel_size=2, stride=2)
    )

    # encoder block 4
    self.encoder_block_4 = torch.nn.Sequential(
        torch.nn.Conv2d(256, 512, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(512, 512, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True),
        torch.nn.MaxPool2d(kernel_size=2, stride=2)
    )

    # center block
    self.center = torch.nn.Sequential(
        torch.nn.Conv2d(512, 1024, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(1024),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(1024),
        torch.nn.ReLU(inplace=True),
        torch.nn.Dropout(),
        torch.nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True)
    )

    # decoder block 1
    self.decoder_block_1 = torch.nn.Sequential(
        torch.nn.Conv2d(1024, 512, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(512, 512, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True),
        torch.nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
        torch.nn.BatchNorm2d(256),
        torch.nn.ReLU(inplace=True)
    )

    # decoder block 2
    self.decoder_block_2 = torch.nn.Sequential(
        torch.nn.Conv2d(512, 256, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(256),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(256, 256, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(256),
        torch.nn.ReLU(inplace=True),
        torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
        torch.nn.BatchNorm2d(128),
        torch.nn.ReLU(inplace=True)
    )

    # decoder block 3
    self.decoder_block_3 = torch.nn.Sequential(
        torch.nn.Conv2d(256, 128, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(128),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(128, 128, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(128),
        torch.nn.ReLU(inplace=True),
        torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
        torch.nn.BatchNorm2d(64),
        torch.nn.ReLU(inplace=True)
    )

    # decoder block 4
    self.decoder_block_4 = torch.nn.Sequential(
        torch.nn.Conv2d(128, 64, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(64),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(64, 64, kernel_size=3, padding=1),
        torch.nn.BatchNorm2d(64),
        torch.nn.ReLU(inplace=True)
    )

    # final block
    self.final_block = torch.nn.Sequential(
        torch.nn.Conv2d(64, num_classes, kernel_size=1)
    )


  def forward(self, x):
    return x