In [2]:
!pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio===0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
!pip install torchvision

Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.8.1+cu102
[?25l  Downloading https://download.pytorch.org/whl/cu102/torch-1.8.1%2Bcu102-cp37-cp37m-linux_x86_64.whl (804.1MB)
[K     |████████████████████████████████| 804.1MB 19kB/s 
[?25hCollecting torchvision==0.9.1+cu102
[?25l  Downloading https://download.pytorch.org/whl/cu102/torchvision-0.9.1%2Bcu102-cp37-cp37m-linux_x86_64.whl (17.3MB)
[K     |████████████████████████████████| 17.3MB 231kB/s 
[?25hCollecting torchaudio===0.8.1
[?25l  Downloading https://files.pythonhosted.org/packages/aa/55/01ad9244bcd595e39cea5ce30726a7fe02fd963d07daeb136bfe7e23f0a5/torchaudio-0.8.1-cp37-cp37m-manylinux1_x86_64.whl (1.9MB)
[K     |████████████████████████████████| 1.9MB 7.6MB/s 
Installing collected packages: torch, torchvision, torchaudio
  Found existing installation: torch 1.8.1+cu101
    Uninstalling torch-1.8.1+cu101:
      Successfully uninstalled torch-1.8.1+cu101
  Found existing installati

In [4]:
import torch

In [5]:
import torch.nn as nn

In [25]:
def double_conv(in_c , out_c):
  conv = nn.Sequential(
      nn.Conv2d(in_c , out_c , kernel_size = 3),
      nn.ReLU(inplace = True),
      nn.Conv2d(out_c , out_c , kernel_size = 3),
      nn.ReLU(inplace = True),
  ) 
  return conv

In [45]:
def crop_image( tensor , target_tensor):
  target_size = target_tensor.size()[2]
  tensor_size = tensor.size()[2]
  delta = tensor_size - target_size
  delta = (delta)//2

  return tensor[:, :, delta:tensor_size-delta , delta:tensor_size-delta]

In [64]:
 class UNet(nn.Module):
   def __init__(self):  
     super(UNet, self).__init__()

     self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
     self.down_conv_1 = double_conv(1, 64)
     self.down_conv_2 = double_conv(64, 128)
     self.down_conv_3 = double_conv(128, 256)
     self.down_conv_4 = double_conv(256, 512)
     self.down_conv_5 = double_conv(512, 1024)  

     self.up_trans_1 = nn.ConvTranspose2d(in_channels = 1024 ,  
                                          out_channels =512
                                          , kernel_size = 2,
                                          stride = 2)
     self.up_conv_1 = double_conv(1024, 512)

     self.up_trans_2 = nn.ConvTranspose2d(in_channels = 512 ,  
                                          out_channels =256
                                          , kernel_size = 2,
                                          stride = 2)
     self.up_conv_2 = double_conv(512,256)
           
     self.up_trans_3 = nn.ConvTranspose2d(in_channels = 256 ,  
                                          out_channels =128
                                          , kernel_size = 2,
                                          stride = 2)
     self.up_conv_3 = double_conv(256, 128)
           

     self.up_trans_4 = nn.ConvTranspose2d(in_channels = 128 ,  
                                          out_channels =64
                                          , kernel_size = 2,
                                          stride = 2)
     self.up_conv_4 = double_conv(128, 64)

     self.out = nn.Conv2d(
         in_channels = 64,
         out_channels = 2,
         kernel_size =1 
     )


           
           
    
   def forward(self, image):

      #encoder part
      x1 = self.down_conv_1(image)#
      x2 = self.max_pool_2x2(x1)
      x3 = self.down_conv_2(x2)#
      x4 = self.max_pool_2x2(x3)
      x5 = self.down_conv_3(x4)#
      x6 = self.max_pool_2x2(x5)
      x7 = self.down_conv_4(x6)#
      x8 = self.max_pool_2x2(x7)
      x9 = self.down_conv_5(x8)

      

      #decoder part
      x = self.up_trans_1(x9)
      y = crop_image(x7,x)
      x = self.up_conv_1(torch.cat([x ,y] ,1))

      x = self.up_trans_2(x)
      y = crop_image(x5,x)
      x = self.up_conv_2(torch.cat([x ,y] ,1))

      x = self.up_trans_3(x)
      y = crop_image(x3,x)
      x = self.up_conv_3(torch.cat([x ,y] ,1))

      x = self.up_trans_4(x)
      y = crop_image(x1,x)
      x = self.up_conv_4(torch.cat([x ,y] ,1))

      x = self.out(x)
      return x


      print(x.size())
      

In [65]:
if __name__ == "__main__":
  image = torch.rand((1, 1, 572, 572))
  model = UNet()
  print(model(image))

tensor([[[[-0.0211, -0.0273, -0.0232,  ..., -0.0152, -0.0232, -0.0261],
          [-0.0237, -0.0231, -0.0224,  ..., -0.0226, -0.0254, -0.0258],
          [-0.0212, -0.0228, -0.0217,  ..., -0.0206, -0.0225, -0.0248],
          ...,
          [-0.0197, -0.0242, -0.0220,  ..., -0.0255, -0.0227, -0.0247],
          [-0.0209, -0.0258, -0.0258,  ..., -0.0241, -0.0279, -0.0269],
          [-0.0263, -0.0259, -0.0237,  ..., -0.0256, -0.0264, -0.0275]],

         [[-0.0462, -0.0413, -0.0438,  ..., -0.0453, -0.0449, -0.0439],
          [-0.0421, -0.0477, -0.0414,  ..., -0.0396, -0.0425, -0.0385],
          [-0.0408, -0.0472, -0.0475,  ..., -0.0378, -0.0380, -0.0451],
          ...,
          [-0.0432, -0.0425, -0.0441,  ..., -0.0414, -0.0423, -0.0446],
          [-0.0387, -0.0433, -0.0436,  ..., -0.0467, -0.0455, -0.0482],
          [-0.0451, -0.0410, -0.0415,  ..., -0.0438, -0.0470, -0.0408]]]],
       grad_fn=<ThnnConv2DBackward>)
