In [5]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF


In [28]:
class DoubleConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(DoubleConv, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), #kernelsize, stride, padding (input,height,width will be the same in the output)
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), #kernelsize, stride, padding (input,height,width will be the same in the output)
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        )

  def forward(self, x) :
    return self.conv(x)

class UNET (nn.Module):
  def __init__(self, in_channels = 3, out_channels = 1, features = [64, 128, 256, 512],): #for binary segmentation we use 1 in out_channels. If we want to segmentate more, we change this numbe
    super(UNET, self).__init__()
    self.ups = nn.ModuleList() # To store Conv Layers
    self.downs = nn.ModuleList()
    self.pool = nn.MaxPool2d(kernel_size=2 , stride = 2)
    # important to choose input divisible by 16 to be able to concatenate them

    # Down Part UNET
    for feature in features :
      self.downs.append(DoubleConv(in_channels, feature))
      in_channels = feature

    # Up part UNET
    for feature in reversed(features) :
      self.ups.append(
          nn.ConvTranspose2d(
              feature*2, feature, kernel_size= 2, stride = 2  #feature*2 is the connection path. We concatenate Decoder and encoder, kernel_size and stride will double hight and width
          )
      )
      self.ups.append(DoubleConv(feature*2, feature)) #because it has 3 conv layers each stage. It "ups" and then 2 convs

    # Bottleneck
    self.bottleneck = DoubleConv(features[-1], features[-1]*2) # EXPLAIN

    # Final conv to change # channels
    self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size = 1)

  def forward(self, x):
    connecting_paths = [] #explain x2
    for down in self.downs :
      x = down(x)
      connecting_paths.append(x)
      x = self.pool(x) # Decreasess the resolution (downsamples images) as it goes downwards

    # Bottleneck
    x = self.bottleneck(x)
    connecting_paths = connecting_paths[::-1] # reverse list to upsample

    for idx in range(0, len(self.ups), 2) : # up and double conv (single step)
      x = self.ups[idx](x)
      connecting_path = connecting_paths[idx//2] # explain x3

      if x.shape != connecting_path.shape : # in the paper they cropped it
        x = TF.resize(x, size = connecting_path.shape[2:]) # we resize it.

      cat_path = torch.cat((connecting_path, x), dim = 1)
      x = self.ups[idx+1](cat_path)
      print(cat_path.shape)

    return self.final_conv(x)

def test ():
  x = torch.randn((3, 1, 160 , 160))  # batch size, inputchannel,
  model = UNET(in_channels = 1, out_channels = 1)
  preds = model(x)
  print(preds.shape, x.shape)
  assert preds.shape == x.shape

if __name__ == "__main__" :
  test()




torch.Size([3, 1024, 20, 20])
torch.Size([3, 512, 40, 40])
torch.Size([3, 256, 80, 80])
torch.Size([3, 128, 160, 160])
torch.Size([3, 1, 160, 160]) torch.Size([3, 1, 160, 160])
