<a href="https://colab.research.google.com/github/Ankan1998/paper-implementation/blob/main/U_Net_from_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## U-Net Architecture paper implementation from Scratch

In [2]:
import torch
import torch.nn as nn

In [12]:
def double_down_conv(in_c,out_c):
    d_conv_seq=nn.Sequential(
          nn.Conv2d(in_c,out_c,kernel_size=3),
          nn.ReLU(inplace=True),
          nn.Conv2d(out_c,out_c,kernel_size=3),
          nn.ReLU(inplace=True)
        )
    return d_conv_seq

def crop_tensor(target,original):

  target_size=target.size()[2]
  original_size=original.size()[2]
  diff=original_size-target_size
  diff=diff//2
  return original[:,:,diff:original_size-diff,diff:original_size-diff]


class UNet(nn.Module):


  def __init__(self):
    super().__init__()
    # Layers 
    # 1st part
    self.mpool2d=nn.MaxPool2d(kernel_size=2, stride=2)
    self.dconv1=double_down_conv(1,64)
    self.dconv2=double_down_conv(64,128)
    self.dconv3=double_down_conv(128,256)
    self.dconv4=double_down_conv(256,512)
    self.dconv5=double_down_conv(512,1024)


    # 2nd part
    self.tconv1=nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2)
    self.tconv2=nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
    self.tconv3=nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
    self.tconv4=nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)

    self.double_u_conv1=double_down_conv(1024,512)
    self.double_u_conv2=double_down_conv(512,256)
    self.double_u_conv3=double_down_conv(256,128)
    self.double_u_conv4=double_down_conv(128,64)


    self.out=nn.Conv2d(64,2,kernel_size=1)
    

  def forward(self,image):

    # Encoder part
    x1=self.dconv1(image) #
    x2=self.mpool2d(x1)
    x3=self.dconv2(x2) #
    x4=self.mpool2d(x3)
    x5=self.dconv3(x4) #
    x6=self.mpool2d(x5)
    x8=self.dconv4(x6) #
    x9=self.mpool2d(x8)
    x10=self.dconv5(x9)

    # Decoder Part

    x11=self.tconv1(x10)
    xnew1=crop_tensor(x11,x8)
    x12=self.double_u_conv1(torch.cat([x11,xnew1],1))


    x13=self.tconv2(x12)
    xnew2=crop_tensor(x13,x5)
    x14=self.double_u_conv2(torch.cat([x13,xnew2],1))


    x15=self.tconv3(x14)
    xnew3=crop_tensor(x15,x3)
    x16=self.double_u_conv3(torch.cat([x15,xnew3],1))


    x17=self.tconv4(x16)
    xnew4=crop_tensor(x17,x1)
    x18=self.double_u_conv4(torch.cat([x17,xnew4],1))

    print("x18",x18.size())
    xfinal=self.out(x18)
    print("xfinal",xfinal.size())

    return xfinal
    


In [13]:
model=UNet()

In [14]:
image=torch.rand((1,1,572,572))
model(image)

x18 torch.Size([1, 64, 388, 388])
xfinal torch.Size([1, 2, 388, 388])


tensor([[[[ 0.0810,  0.0788,  0.0853,  ...,  0.0762,  0.0843,  0.0795],
          [ 0.0862,  0.0845,  0.0850,  ...,  0.0782,  0.0839,  0.0836],
          [ 0.0859,  0.0769,  0.0758,  ...,  0.0855,  0.0804,  0.0826],
          ...,
          [ 0.0854,  0.0831,  0.0811,  ...,  0.0861,  0.0842,  0.0798],
          [ 0.0830,  0.0819,  0.0849,  ...,  0.0812,  0.0814,  0.0838],
          [ 0.0852,  0.0912,  0.0848,  ...,  0.0832,  0.0822,  0.0800]],

         [[-0.0956, -0.0923, -0.1020,  ..., -0.0932, -0.0985, -0.0962],
          [-0.1000, -0.0979, -0.0969,  ..., -0.0941, -0.0987, -0.0977],
          [-0.0983, -0.1003, -0.0958,  ..., -0.0958, -0.0903, -0.0984],
          ...,
          [-0.1000, -0.0981, -0.0942,  ..., -0.0974, -0.0957, -0.0994],
          [-0.0928, -0.1027, -0.0986,  ..., -0.0985, -0.0995, -0.0910],
          [-0.0948, -0.0948, -0.0979,  ..., -0.0923, -0.1052, -0.0935]]]],
       grad_fn=<MkldnnConvolutionBackward>)