In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data.sampler import SubsetRandomSampler
import time
import copy
from random import shuffle
#import tqdm.notebook as tqdm
import sklearn
from sklearn.metrics import accuracy_score, cohen_kappa_score
from sklearn.metrics import classification_report
from PIL import Image
import cv2
import os
import shutil
from datetime import datetime 
import sys, os 
from glob import glob 
import imageio

<h3> Unet Architecture </h3>

<img src="https://production-media.paperswithcode.com/methods/Screen_Shot_2020-07-07_at_9.08.00_PM_rpNArED.png" height=500, width =1000></img>

In [4]:
def DoubleConv2D(in_channels, out_channels):

    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=0),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=0),
        nn.ReLU()
    )

a = torch.randn(1, 1, 572, 572)

b = DoubleConv2D(1, 64)

c = b(a) 
#Size must be (568, 568, 64)
print(c.size())



torch.Size([1, 64, 568, 568])


In [16]:
def crop(input_tensor, target_tensor):
    target_size = target_tensor.size()[2]
    input_size = input_tensor.size()[2] 
    required = input_size - target_size 
    delta = required//2 

    return input_tensor[:, :, delta:input_size-delta, delta:input_size-delta]  

a = torch.rand(1, 512, 64, 64)
b = torch.rand(1, 512, 56, 56)

c = crop(a, b) 
print(c.size())

torch.Size([1, 512, 56, 56])


In [35]:
class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.double_conv2d_1 = DoubleConv2D(1, 64)
        self.double_conv2d_2 = DoubleConv2D(64, 128)
        self.double_conv2d_3 = DoubleConv2D(128, 256)
        self.double_conv2d_4 = DoubleConv2D(256, 512)
        self.double_conv2d_5 = DoubleConv2D(512, 1024)

        self.up_trans_1 = nn.ConvTranspose2d(1024, 512, 2, stride=2, padding=0)

        self.up_conv_1 = DoubleConv2D(1024, 512)

        self.up_trans_2 = nn.ConvTranspose2d(512, 256, 2, stride=2, padding=0)

        self.up_conv_2 = DoubleConv2D(512, 256)

        self.up_trans_3 = nn.ConvTranspose2d(256, 128, 2, stride=2, padding=0)

        self.up_conv_3 = DoubleConv2D(256, 128)

        self.up_trans_4 = nn.ConvTranspose2d(128, 64, 2, stride=2, padding=0)

        self.up_conv_4 = DoubleConv2D(128, 64)

        self.output = nn.Conv2d(64, 2, 1)

    def forward(self, X):
        #Contracting Path
        x1 = self.double_conv2d_1(X) 
        x2 = self.max_pool(x1)
        x3 = self.double_conv2d_2(x2) 
        x4 = self.max_pool(x3) 
        x5 = self.double_conv2d_3(x4) 
        x6 = self.max_pool(x5)
        x7 = self.double_conv2d_4(x6) 
        x8 = self.max_pool(x7)
        x9 = self.double_conv2d_5(x8) 
        x10 = self.up_trans_1(x9)

        x4_1 = crop(x7, x10)

        x11 = self.up_conv_1(torch.cat((x10, x4_1), 1))
        x12 = self.up_trans_2(x11) 
        x3_1 = crop(x5, x12) 
        x13 = self.up_conv_2(torch.cat((x12, x3_1), 1))


        x14 = self.up_trans_3(x13) 
        x2_1 = crop(x3, x14) 
        x15 = self.up_conv_3(torch.cat((x14, x2_1), 1))

        x15 = self.up_trans_4(x15) 
        x1_1 = crop(x1, x15) 
        x16 = self.up_conv_4(torch.cat((x15, x1_1), 1))

        output = self.output(x16)
        return output
    
    
a = torch.rand(1, 1, 572, 572)
b = Unet()
c = b(a) 
#Size must be (568, 568, 64)
print(a.size())
print(c.size())

torch.Size([1, 1, 572, 572])
torch.Size([1, 2, 388, 388])
