In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import OxfordIIITPet
from torchvision.transforms import ToTensor, Resize
import torch.nn as N
import matplotlib.pyplot as plt

In [2]:
torch.manual_seed(42)

batch_size = 32
epochs = 5
learning_rate = 1e-3

filters = [32, 64, 128, 256]
drop_rate= 0.15

img_depth = 3

In [3]:
transforms = [ToTensor(), Resize((256,256))]

In [4]:
training_data = OxfordIIITPet('../content', 'trainval', 'segmentation', transforms= transforms, download= True)
test_data = OxfordIIITPet('../content', 'test', 'segmentation', transforms= transforms, download= False)

Downloading https://thor.robots.ox.ac.uk/pets/images.tar.gz to ../content/oxford-iiit-pet/images.tar.gz


100%|██████████| 791918971/791918971 [00:06<00:00, 119972984.99it/s]


Extracting ../content/oxford-iiit-pet/images.tar.gz to ../content/oxford-iiit-pet
Downloading https://thor.robots.ox.ac.uk/pets/annotations.tar.gz to ../content/oxford-iiit-pet/annotations.tar.gz


100%|██████████| 19173078/19173078 [00:00<00:00, 104804867.07it/s]


Extracting ../content/oxford-iiit-pet/annotations.tar.gz to ../content/oxford-iiit-pet


In [5]:
train_dataloader = DataLoader(training_data, batch_size= batch_size, shuffle= True, drop_last= True)
test_dataloader = DataLoader(test_data, batch_size= batch_size, shuffle= True, drop_last= True)

In [6]:
class DownwardBlock(N.Module):
    def __init__(self, in_channels, out_channels, device, stride= 1, kernel_size= 3, padding= 'same'):
        super(DownwardBlock, self).__init__()

        self.enc_block = N.Sequential(
            N.Conv2d(in_channels, out_channels, kernel_size, stride, padding, device= device, bias= False),
            N.BatchNorm2d(out_channels, device= device),
            N.ReLU(inplace= True),

            N.Conv2d(out_channels, out_channels, kernel_size, stride, padding, device= device, bias= False),
            N.BatchNorm2d(out_channels, device= device),
            N.ReLU(inplace= True)
        )

    def forward(self, input):
        output = self.enc_block(input)
        return output

In [7]:
class DownwardHalf(N.Module):
    def __init__(self, filters, device):
        super(DownwardHalf, self).__init__()
        self.down_blocks = N.ModuleList()
        self.mpool = N.MaxPool2d(kernel_size= 2, stride= 2)

        for i in range(len(filters)):
            if i == 0:
                self.down_blocks.append(DownwardBlock(3, filters[i], device= device))
            else:
                self.down_blocks.append(DownwardBlock(filters[i-1], filters[i], device= device))

    def forward(self, input):
        skip_conn = []
        for i in range(len(self.down_blocks)):
            input = self.down_blocks[i](input)
            skip_conn.append(input)
            input = self.mpool(input)

        return input, skip_conn

In [8]:
class UpwardBlock(N.Module):
    def __init__(self, in_channels, out_channels, device, stride= 1, kernel_size= 3, padding='same'):
        super(UpwardBlock, self).__init__()

        self.tconv = N.ConvTranspose2d(in_channels, in_channels, kernel_size= 2, stride= 2, device= device)
        self.dec_block = N.Sequential(
            N.Conv2d(in_channels, out_channels, kernel_size, stride, padding, device= device, bias= False),
            N.BatchNorm2d(out_channels, device= device),
            N.ReLU(inplace= True),

            N.Conv2d(out_channels, out_channels, kernel_size, stride, padding, device= device, bias= False),
            N.BatchNorm2d(out_channels, device= device),
            N.ReLU(inplace= True)
        )

    def forward(self, input, skip_conn):
        tconv_output = self.tconv(input)
        concat_output = torch.cat([tconv_output, skip_conn], dim= 1)
        output = self.dec_block(concat_output)

        return output

In [9]:
class UpwardHalf(N.Module):
    def __init__(self, filters, device):
        super(UpwardHalf, self).__init__()
        self.up_blocks = N.ModuleList()
        n = len(filters)

        for i in range(n):
          self.up_blocks.append(UpwardBlock(filters[n-i-1] * 2, filters[n-i-1], device))

    def forward(self, input, skip_conn):
        n = len(self.up_blocks)
        for i in range(n):
            input = self.up_blocks[i](input, skip_conn[n-i-1])

        return input

In [10]:
class UNet(N.Module):
    def __init__(self, filters, device):
        super(UNet, self).__init__()
        self.down_half = DownwardHalf(filters, device)
        self.up_half = UpwardHalf(filters, device)
        self.final_conv = N.Conv2d(filters[-1], 1, 3, 1, 'same', device= device)
        self.bridge_block = N.Sequential(
            N.Conv2d(filters[-1], filters[-1] * 2, 3, 1, 'same', device= device, bias= False),
            N.BatchNorm2d(filters[-1] * 2, device= device),
            N.ReLU(inplace= True),

            N.Conv2d(filters[-1] * 2, filters[-1] * 2, 3, 1, 'same', device= device, bias= False),
            N.BatchNorm2d(filters[-1] * 2, device= device),
            N.ReLU(inplace= True)
        )

    def forward(self, input):
        down_half_output, skip_conn = self.down_half(input)
        bridge_block_output = self.bridge_block(down_half_output)
        up_half_output = self.up_half(bridge_block_output, skip_conn)
        output = self.final_conv(up_half_output)

        return output