In [2]:
import numpy as np

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim

import PIL
from PIL import Image
import matplotlib.pyplot as plt

import torchvision.transforms as transforms
import torchvision.models as models
import torchvision.utils as tv
import torchvision.datasets as dset
import torch.utils.data as data

In [4]:
# utility blocks for building a UNet model

# Convolutional Block used in downscaling 
def conv_block(inp_dim, out_dim):
    model = nn.Sequential(
                    nn.Conv2d(inp_dim, out_dim, kernel_size=3, stride=1, padding=1),
                    nn.BatchNorm2d(out_dim),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1,padding=1),
                    nn.BatchNorm2d(out_dim)
            )
    return model

# Deconvolutional Block used in upscaling 
def trans_conv_block(inp_dim, out_dim):
    model = nn.Sequential(
                    nn.ConvTranspose2d(inp_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
                    nn.BatchNorm2d(out_dim), 
                    nn.LeakyReLU(0.2, inplace=True)
            )
    return model

# maxpool
def maxpool():
    pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
    return pool

![](../imgs/unet.png)

In [None]:
# Now putting the complete architecture all together


class UNet(nn.Module):
    def __init__(self, inp_dim, out_dim, num_filter):
        super(UNet, self).__init__()
        
        self.inp_dim = inp_dim
        self.out_dim = out_dim
        self.num_filter = num_filter
        
        # downscaling
        self.down1 = conv_block(self.inp_dim, self.num_filter)
        self.pool1 = maxpool()
        self.down2 = conv_block(self.num_filter, self.num_filter*2)
        self.pool2 = maxpool()
        self.down3 = conv_block(self.num_filter*2, self.num_filter*4)
        self.pool3 = maxpool()
        self.down4 = conv_block(self.num_filter*4, self.num_filter*8)
        self.pool4 = maxpool()
        
        self.bridge = conv_block(self.num_filter*8, self.num_filter*16)
        
        # upscaling
        self.trans1 = trans_conv_block(self.num_filter*16, self.num_filter*8)
        self.up1 = conv_block(self.num_filter*16, self.num_filter*8)
        self.trans2 = trans_conv_block(self.num_filter*8, self.num_filter*4)
        self.up2 = conv_block(self.num_filter*8, self.num_filter*4)
        self.tarns3 = trans_conv_block(self.num_filter*4, self.num_filter*2)
        self.up3 = conv_block(self.num_filter*4, self.num_filter*2)
        self.tarns4 = trans_conv_block(self.num_filter*2, self.num_filter*1)
        self.up4 = conv_block(self.num_filter*2, self.num_filter*1)
        
        self.out = nn.Sequential(
                        nn.Conv2d(self.num_filter,self.out_dim,3,1,1),
                        nn.Tanh(),
                        )
        
    def forward(self, input):
        down1 = self.down1(input)
        pool1 = self.pool1(self.down1)
        down2 = self.down2(pool1)
        pool2 = self.pool2(down2)
        down3 = self.down3(pool2)
        pool3 = self.pool3(down3)
        down4 = self.down4(pool3)
        pool4 = self.pool4(down4)
        
        bridge = self.bridge(pool4)
        
        trans1 = self.trans1(bridge)
        concat1 = torch.cat([tarns1, down4], dim=1)
        up1 = self.up1(concat1)
        trans2 = self.trans2(up1)
        concat2 = torch.cat([tarns2, down3], dim=1)
        up2 = self.up2(concat2)
        trans3 = self.trans3(up2)
        concat3 = torch.cat([tarns3, down2], dim=1)
        up3 = self.up3(concat3)
        trans4 = self.trans4(up3)
        concat4 = torch.cat([tarns4, down1], dim=1)
        up4 = self.up4(concat4)
        
        out = self.out(up4)
        return out


In [None]:
model = UNet(3, 3, 64)
model = model.cuda()
model


In [None]:
batch_size = 1
img_size = 200
lr = 0.0002
epoch = 100
# preparing dataset from folders containing images and masks
image_data = dset.ImageFolder(root="drive/SemanticDataset/train/", transform = transforms.Compose([
                                            transforms.Scale(size=img_size),
                                            transforms.CenterCrop(size=(img_size,img_size*2)),
                                            transforms.ToTensor(),
                                            ]))

label_data = dset.ImageFolder(root="drive/SemanticDataset/label/", transform = transforms.Compose([
                                            transforms.Scale(size=img_size),
                                            transforms.CenterCrop(size=(img_size,img_size*2)),
                                            transforms.ToTensor(),
                                            ]))

In [None]:
# defining dataloaders
image_batch = data.DataLoader(image_data, batch_size=batch_size, shuffle=False, num_workers=2)
label_batch = data.DataLoader(label_data, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
# defining loss functions
loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),lr=lr)
loss= []

# training model
for i in range(epoch):
    for _, (image, label) in enumerate(zip(image_batch, label_batch)):
        optimizer.zero_grad()

        x = Variable(image, requires_grad=True).cuda()
        y = Variable(label).cuda()

        out = model.forward(x)
        loss = loss_func(out, y)

        loss.backward()
        optimizer.step()

        if _ % 100 == 0:
          print("Epoch: "+i+"| Loss: " , loss)
