<a href="https://colab.research.google.com/github/AriPathak/ResUnet/blob/main/UNET.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF

In [None]:
class Rescale(object):
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image = sample
        h, w = image.size
        if h > w:
            new_h, new_w = self.output_size * h / w, self.output_size
        else:
            new_h, new_w = self.output_size, self.output_size * w / h

        new_h, new_w = int(new_h), int(new_w)

        img = TF.resize(image, (new_w, new_h))

        return img

In [None]:
num_classes = 3 # as specified above, there are 3 classes for each pixel
img_dim = 128
trainset = torchvision.datasets.OxfordIIITPet(root="./", download=True, target_types='segmentation',
                                              transform=transforms.Compose([
                                                            Rescale(img_dim),
                                                            transforms.CenterCrop(img_dim),
                                                            transforms.ToTensor()]),
                                              target_transform=transforms.Compose([
                                                            Rescale(img_dim),
                                                            transforms.CenterCrop(img_dim),
                                                            transforms.Lambda(lambda x: torch.from_numpy(np.array(x) - 1).long())]))
testset = torchvision.datasets.OxfordIIITPet(root="./", download=True, target_types='segmentation', split='test',
                                             transform=transforms.Compose([
                                                            Rescale(img_dim),
                                                            transforms.CenterCrop(img_dim),
                                                            transforms.ToTensor()]),
                                             target_transform=transforms.Compose([
                                                            Rescale(img_dim),
                                                            transforms.CenterCrop(img_dim),
                                                            transforms.Lambda(lambda x: torch.from_numpy(np.array(x) - 1).long())]))

Downloading https://thor.robots.ox.ac.uk/datasets/pets/images.tar.gz to oxford-iiit-pet/images.tar.gz


100%|██████████| 791918971/791918971 [00:38<00:00, 20610736.67it/s]


Extracting oxford-iiit-pet/images.tar.gz to oxford-iiit-pet
Downloading https://thor.robots.ox.ac.uk/datasets/pets/annotations.tar.gz to oxford-iiit-pet/annotations.tar.gz


100%|██████████| 19173078/19173078 [00:01<00:00, 9919986.26it/s] 


Extracting oxford-iiit-pet/annotations.tar.gz to oxford-iiit-pet


In [None]:
from torchvision.models import resnext50_32x4d

class ConvRelu(nn.Module):
    def __init__(self, in_channels, out_channels, kernel, padding):
        super().__init__()

        self.convrelu = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.convrelu(x)
        return x

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv1 = ConvRelu(in_channels, in_channels // 4, 1, 0)

        self.deconv = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, kernel_size=4,
                                          stride=2, padding=1, output_padding=0)

        self.conv2 = ConvRelu(in_channels // 4, out_channels, 1, 0)

    def forward(self, x):
        x = self.conv1(x)
        x = self.deconv(x)
        x = self.conv2(x)

        return x

class ResUnetDecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv1 = ConvRelu(in_channels, in_channels, 3, 1)

        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3,
                                          stride=2, padding=1, output_padding=1)

        self.conv2 = ConvRelu(out_channels, out_channels, 1, 0)

    def forward(self, x):
        x = self.conv1(x)
        x = self.deconv(x)
        x = self.conv2(x)

        return x

In [None]:
class ResBlock(nn.Module):
  def __init__(self, in_channels, out_channels, stride=1):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride = 1, padding=1)
    self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride, padding=1)
    self.norm1 = nn.BatchNorm2d(out_channels)
    self.norm2 = nn.BatchNorm2d(out_channels)
    self.shortcut = nn.Sequential()
    if stride != 1 or in_channels != out_channels:
      self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, stride),
                                        nn.BatchNorm2d(out_channels))
  def forward(self, x):
    res = x
    x = F.relu(self.norm1(self.conv1(x)))
    x = self.norm2(self.conv2(x))
    res = F.relu(self.shortcut(res))
    x += res
    return F.relu(x)


class ResUNet(nn.Module):
  def __init__(self, num_classes, num_channels):
    super().__init__()
    self.e0 = nn.Conv2d(num_channels, 32, 3, padding=1)
    self.bn1 = nn.BatchNorm2d(32)
    self.e1 = ResBlock(32, 64, stride=1) #why is stride 1 here ari?
    self.e2 = ResBlock(64, 128, stride=2)
    self.e3 = ResBlock(128, 256, stride=2)
    self.e4 = ResBlock(256, 512, stride=2)


    self.decoder4 = ResUnetDecoderBlock(512, 256)
    self.decoder3 = ResUnetDecoderBlock(256, 128)
    self.decoder2 = ResUnetDecoderBlock(128, 64)

    self.last_conv0 = nn.Conv2d(64, num_classes, (1,1), padding=0)

  def forward(self, x):
    res0 = F.relu(self.bn1(self.e0(x)))
    res1 = self.e1(res0)
    res2 = self.e2(res1)
    res3 = self.e3(res2)
    res4 = self.e4(res3)

    x = self.decoder4(res4) + res3
    x = self.decoder3(x) + res2
    x = self.decoder2(x) + res1
    out = self.last_conv0(x)
    return F.softmax(out, dim=1)



In [None]:
class Unet(nn.Module):
  def __init__(self, n_classes):
    super().__init__()

    self.base_model = resnext50_32x4d(pretrained=True)
    self.base_layers = list(self.base_model.children())
    filters = [4*64, 4*128, 4*256, 4*512]

    self.encoder0 = nn.Sequential(*self.base_layers[:3])
    self.encoder1 = nn.Sequential(*self.base_layers[4])
    self.encoder2 = nn.Sequential(*self.base_layers[5])
    self.encoder3 = nn.Sequential(*self.base_layers[6])
    self.encoder4 = nn.Sequential(*self.base_layers[7])

    self.decoder4 = DecoderBlock(filters[3], filters[2])
    self.decoder3 = DecoderBlock(filters[2], filters[1])
    self.decoder2 = DecoderBlock(filters[1], filters[0])
    self.decoder1 = DecoderBlock(filters[0], filters[0])

    self.last_conv1 = nn.Conv2d(256, n_classes, 1, padding=0)

  def forward(self, x):
    x = self.encoder0(x)
    e1 = self.encoder1(x)
    e2 = self.encoder2(e1)
    e3 = self.encoder3(e2)
    e4 = self.encoder4(e3)

    d4 = self.decoder4(e4) + e3
    d3 = self.decoder3(d4) + e2
    d2 = self.decoder2(d3) + e1
    d1 = self.decoder1(d2)

    out = self.last_conv1(d1)
    return F.softmax(out, dim=1)






In [None]:
class mIoULoss(nn.Module):
    def __init__(self, weight=None, size_average=True, n_classes=2):
        super(mIoULoss, self).__init__()
        self.classes = n_classes

    def to_one_hot(self, tensor):
        n,h,w = tensor.size()
        one_hot = torch.zeros(n,self.classes,h,w).to(tensor.device).scatter_(1,tensor.view(n,1,h,w),1)
        return one_hot

    def forward(self, inputs, target):
        # inputs => N x Classes x H x W
        # target_oneHot => N x Classes x H x W

        N = inputs.size()[0]

        # predicted probabilities for each pixel along channel
        #inputs = F.softmax(inputs,dim=1)

        # Numerator Product
        target_oneHot = self.to_one_hot(target)
        inter = inputs * target_oneHot
        ## Sum over all pixels N x C x H x W => N x C
        inter = inter.view(N,self.classes,-1).sum(2)

        #Denominator
        union= inputs + target_oneHot - (inputs*target_oneHot)
        ## Sum over all pixels N x C x H x W => N x C
        union = union.view(N,self.classes,-1).sum(2)

        loss = inter/union

        ## Return average loss over classes and batch
        return 1-loss.mean()

In [None]:
device = 'cuda:0'

In [None]:
import time
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader

def training_loop(model, num_epochs, batch_size, trainset, criterion, optimizer, scheduler):
    train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
    train_loss = []
    model.cuda()
    for epoch in range(1, num_epochs + 1):
        batches = tqdm(train_loader, initial=1)
        model.train()
        running_loss = 0
        for images, labels in batches:
            images = images.to(device)
            labels = labels.to(device).squeeze()
            predictions = model(images)
            optimizer.zero_grad()
            loss = criterion(predictions, labels)
            #print(loss.item())
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        if scheduler:
            scheduler.step()
        #epoch_iou = mean_iou(trainset, model)
        epoch_iou = 0
        train_loss.append(running_loss / len(train_loader))
        print("Epoch: {0:02d} | Training Loss: {1:.5f} | IOU: {2:.5f}".format(epoch, train_loss[-1], epoch_iou))
    return train_loss


In [None]:
import torch.optim as optim

learning_rate = 1e-3
n_epochs = 20
batch_size = 16
u_net = ResUNet(3, 3)

criterion = mIoULoss(n_classes=3)
optimizer = optim.Adam(u_net.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
training_loop(u_net, n_epochs, batch_size, trainset, criterion, optimizer, scheduler)

In [None]:
torch.save(u_net.state_dict(), "UNET_RESBACKBONE_OxfordIII.pth")

In [None]:
import matplotlib.gridspec as gs


num_samples = 3
img_grid = gs.GridSpec(num_samples, 3, width_ratios=[1, 1, 1], wspace=0.2, hspace=0.2)
fig = plt.figure(figsize=(10, 30))

for idx, i in enumerate(np.random.choice(len(trainset), num_samples)):
    image, label = trainset[i]
    image = image.to(device)
    pred = u_net(image.unsqueeze(dim=0)).squeeze()
    print(pred.shape)
    pred = torch.argmax(pred, dim=0).cpu() #dim 0 is the output class channels/3 rgb maps
    image = image.cpu()
    plt.subplot(img_grid[idx, 0]).imshow(image.permute(1, 2, 0))
    plt.subplot(img_grid[idx, 1]).imshow(label.squeeze())
    plt.subplot(img_grid[idx, 2]).imshow(pred.squeeze())