In [1]:
# header files needed
import numpy as np
import torch
import torch.nn as nn
import torchvision
import random
from PIL import Image
import glob

In [2]:
# ensure the experiment produces same result on each run
random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# transforms
input_transform = torchvision.transforms.Compose([
    torchvision.transforms.CenterCrop(256),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

target_transform = torchvision.transforms.Compose([
    torchvision.transforms.CenterCrop(256),
    torchvision.transforms.ToTensor()
])

In [None]:
# dataset
train_dataset = torchvision.datasets.VOCSegmentation("/content/drive/My Drive/", year='2012', image_set='train', download=False, transform=input_transform, target_transform=target_transform)
val_dataset = torchvision.datasets.VOCSegmentation("/content/drive/My Drive/", year='2012', image_set='val', download=False, transform=input_transform, target_transform=target_transform)

In [None]:
# data loader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=8)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=2, shuffle=True, num_workers=8)

In [None]:
# loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weight = torch.ones(22)
weight[0] = 0
criterion = torch.nn.NLLLoss2d(weight=weight.to(device))

In [3]:
class UNet(torch.nn.Module):

  # init function
  def __init__(self, num_classes=22):
    super(UNet, self).__init__()

    # encoder block 1                                         # 256 x 256 x 3
    self.encoder_block_1 = torch.nn.Sequential(
        torch.nn.Conv2d(3, 64, kernel_size=3),                
        torch.nn.BatchNorm2d(64),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(64, 64, kernel_size=3),
        torch.nn.BatchNorm2d(64),
        torch.nn.ReLU(inplace=True),
        torch.nn.MaxPool2d(kernel_size=2, stride=2)
    )                                                         # 126 x 126 x 64

    # encoder block 2
    self.encoder_block_2 = torch.nn.Sequential(
        torch.nn.Conv2d(64, 128, kernel_size=3),
        torch.nn.BatchNorm2d(128),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(128, 128, kernel_size=3),
        torch.nn.BatchNorm2d(128),
        torch.nn.ReLU(inplace=True),
        torch.nn.MaxPool2d(kernel_size=2, stride=2)
    )                                                         # 61 x 61 x 128

    # encoder block 3
    self.encoder_block_3 = torch.nn.Sequential(
        torch.nn.Conv2d(128, 256, kernel_size=3),
        torch.nn.BatchNorm2d(256),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(256, 256, kernel_size=3),
        torch.nn.BatchNorm2d(256),
        torch.nn.ReLU(inplace=True),
        torch.nn.MaxPool2d(kernel_size=2, stride=2)
    )                                                         # 29 x 29 x 256

    # encoder block 4
    self.encoder_block_4 = torch.nn.Sequential(
        torch.nn.Conv2d(256, 512, kernel_size=3),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(512, 512, kernel_size=3),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True),
        torch.nn.MaxPool2d(kernel_size=2, stride=2)
    )                                                         # 13 x 13 x 512

    # center block
    self.center = torch.nn.Sequential(
        torch.nn.Conv2d(512, 1024, kernel_size=3),
        torch.nn.BatchNorm2d(1024),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(1024, 1024, kernel_size=3),
        torch.nn.BatchNorm2d(1024),
        torch.nn.ReLU(inplace=True),
        torch.nn.Dropout(),
        torch.nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True)
    )

    # decoder block 1
    self.decoder_block_1 = torch.nn.Sequential(
        torch.nn.Conv2d(1024, 512, kernel_size=3),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(512, 512, kernel_size=3),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True),
        torch.nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
        torch.nn.BatchNorm2d(256),
        torch.nn.ReLU(inplace=True)
    )

    # decoder block 2
    self.decoder_block_2 = torch.nn.Sequential(
        torch.nn.Conv2d(512, 256, kernel_size=3),
        torch.nn.BatchNorm2d(256),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(256, 256, kernel_size=3),
        torch.nn.BatchNorm2d(256),
        torch.nn.ReLU(inplace=True),
        torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
        torch.nn.BatchNorm2d(128),
        torch.nn.ReLU(inplace=True)
    )

    # decoder block 3
    self.decoder_block_3 = torch.nn.Sequential(
        torch.nn.Conv2d(256, 128, kernel_size=3),
        torch.nn.BatchNorm2d(128),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(128, 128, kernel_size=3),
        torch.nn.BatchNorm2d(128),
        torch.nn.ReLU(inplace=True),
        torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
        torch.nn.BatchNorm2d(64),
        torch.nn.ReLU(inplace=True)
    )

    # decoder block 4
    self.decoder_block_4 = torch.nn.Sequential(
        torch.nn.Conv2d(128, 64, kernel_size=3),
        torch.nn.BatchNorm2d(64),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(64, 64, kernel_size=3),
        torch.nn.BatchNorm2d(64),
        torch.nn.ReLU(inplace=True)
    )

    # final block
    self.final_block = torch.nn.Sequential(
        torch.nn.Conv2d(64, num_classes, kernel_size=1)
    )


  def forward(self, x):

    # apply encoder block 1
    enc_1 = self.encoder_block_1(x)

    # apply encoder block 2
    enc_2 = self.encoder_block_2(enc_1)

    # apply encoder block 3
    enc_3 = self.encoder_block_3(enc_2)

    # apply encoder block 4
    enc_4 = self.encoder_block_4(enc_3)

    # apply center block
    cen = self.center(enc_4)

    # apply decoder block 1
    dec_1 = self.decoder_block_1(torch.cat([cen, torch.nn.functional.upsample_bilinear(enc_4, cen.size()[2:])], 1))

    # apply decoder block 2
    dec_2 = self.decoder_block_2(torch.cat([dec_1, torch.nn.functional.upsample_bilinear(enc_3, dec_1.size()[2:])], 1))

    # apply decoder block 3
    dec_3 = self.decoder_block_3(torch.cat([dec_2, torch.nn.functional.upsample_bilinear(enc_2, dec_2.size()[2:])], 1))

    # apply decoder block 4
    dec_4 = self.decoder_block_4(torch.cat([dec_3, torch.nn.functional.upsample_bilinear(enc_1, dec_3.size()[2:])], 1))

    # apply final block
    final = self.final_block(dec_4)

    # upsample to image size
    output = torch.nn.functional.upsample_bilinear(final, [256, 256])
    output = torch.nn.functional.log_softmax(output)
    return output

In [None]:
model = UNet()
model.to(device)

In [None]:
# create optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9, weight_decay=5e-4)

In [None]:
loss_train = []
loss_valid = []

# training and val loop
for epoch in range(0, 40):

  # train
  model.train()
  training_loss = 0.0
  for step, (images, labels) in enumerate(train_loader):
    
    # if cuda
    images = images.to(device)
    labels = labels.to(device)
 
    images = torch.autograd.Variable(images)
    labels = torch.autograd.Variable(labels)
    labels = labels.long()

    # get predicted outputs
    outputs = model(images)

    # update parameters
    optimizer.zero_grad()
    loss = criterion(outputs, labels[:, 0])
    training_loss = training_loss + loss.item()
    loss.backward()
    optimizer.step()
  training_loss = training_loss / float(len(train_loader))
  loss_train.append(training_loss)

  model.eval()
  valid_loss = 0.0
  for step, (images, labels) in enumerate(val_loader):
    with torch.no_grad():
      images = images.to(device)
      labels = labels.to(device)

      images = torch.autograd.Variable(images)
      labels = torch.autograd.Variable(labels)
      labels = labels.long()

      # get predictions
      outputs = model(images)

      # get loss
      loss = criterion(outputs, labels[:, 0])
    valid_loss = valid_loss + loss.item()
  valid_loss = valid_loss / float(len(val_loader))
  loss_valid.append(valid_loss)

  print()
  print("Epoch" + str(epoch) + ":")
  print("Training Loss: " + str(training_loss) + "    Validation Loss: " + str(valid_loss))
  print()

In [None]:
import matplotlib.pyplot as plt
e = []
for index in range(0, 40):
  e.append(index)

In [None]:
plt.plot(e, loss_train)
plt.show()

In [None]:
plt.plot(e, loss_valid)
plt.show()