<a href="https://colab.research.google.com/github/LoPA607/SOC-2024/blob/main/U_Net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
import torch
import torch.nn as nn

In [11]:
#for double convlutional layer
def double_conv(in_c, out_c):
  conv=nn.Sequential(
      nn.Conv2d(in_c, out_c, kernel_size=3),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_c, out_c, kernel_size=3),
      nn.ReLU(inplace=True),
  )
  return conv

In [21]:
def crop_img(tensor, target_tensor):
  target_size=target_tensor.size()[2]
  tensor_size=tensor.size()[2]
  delta=tensor_size-target_size
  delta=delta // 2
  return tensor[:, :, delta:tensor_size-delta, delta:tensor_size-delta]

In [25]:
class UNet(nn.Module):
  def __init__(self):
    super(UNet, self).__init__()
    self.max_pool_2x2=nn.MaxPool2d(kernel_size=2, stride=2)
    self.down_conv_1=double_conv(1,64)   #input=1channel and output=64channels
    self.down_conv_2=double_conv(64,128)
    self.down_conv_3=double_conv(128,256)
    self.down_conv_4=double_conv(256,512)
    self.down_conv_5=double_conv(512,1024)

    self.up_trans_1=nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2)
    self.up_conv_1=double_conv(1024, 512)

    self.up_trans_2=nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
    self.up_conv_2=double_conv(512, 256)

    self.up_trans_3=nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
    self.up_conv_3=double_conv(256, 128)

    self.up_trans_4=nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)
    self.up_conv_4=double_conv(128, 64)

    self.out=nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)


  def forward(self, image):
    #encoder
    x1=self.down_conv_1(image) #
    print(x1.size())
    x2=self.max_pool_2x2(x1)
    x3=self.down_conv_2(x2) #
    x4=self.max_pool_2x2(x3)
    x5=self.down_conv_3(x4) #
    x6=self.max_pool_2x2(x5)
    x7=self.down_conv_4(x6) #
    x8=self.max_pool_2x2(x7)
    x9=self.down_conv_5(x8)

    #decoder
    x = self.up_trans_1(x9)
    y = crop_img(x7,x)
    x=self.up_conv_1(torch.cat([x, y], 1))

    x = self.up_trans_2(x)
    y = crop_img(x5,x)
    x=self.up_conv_2(torch.cat([x, y], 1))

    x = self.up_trans_3(x)
    y = crop_img(x3,x)
    x=self.up_conv_3(torch.cat([x, y], 1))

    x = self.up_trans_4(x)
    y = crop_img(x1,x)
    x=self.up_conv_4(torch.cat([x, y], 1))

    x=self.out(x)
    print(x.size())
    return x





In [26]:
if __name__ == "__main__":
  image = torch.rand((1, 1, 572, 572))
  model = UNet()
  print(model(image))

torch.Size([1, 64, 568, 568])
torch.Size([1, 2, 388, 388])
tensor([[[[-1.1598e-03,  1.2682e-03,  1.3640e-03,  ...,  1.1839e-03,
            3.1728e-03,  4.0019e-03],
          [-1.1055e-03,  4.2946e-03,  3.3119e-03,  ...,  9.9182e-05,
           -1.1561e-03,  2.8131e-03],
          [ 2.0560e-03, -7.4171e-06,  1.3258e-03,  ..., -1.5969e-03,
            1.2280e-03,  2.0878e-03],
          ...,
          [ 5.0114e-05, -1.4733e-03,  3.8401e-04,  ...,  6.2406e-03,
            3.9680e-03,  3.1526e-04],
          [ 5.4194e-03,  4.8651e-03,  2.3115e-03,  ...,  2.9118e-03,
            3.0722e-03,  2.4935e-03],
          [-1.9208e-03, -2.1064e-03,  8.4147e-05,  ..., -8.1116e-05,
            7.7536e-03,  3.4603e-03]],

         [[ 8.4500e-02,  8.4862e-02,  9.0486e-02,  ...,  8.7925e-02,
            8.8599e-02,  9.2674e-02],
          [ 9.0252e-02,  8.4636e-02,  8.3281e-02,  ...,  8.7208e-02,
            8.1785e-02,  8.6667e-02],
          [ 8.1650e-02,  9.1257e-02,  8.5416e-02,  ...,  8.3140e-02,