In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF


In [None]:
class DoubleConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(DoubleConv, self).__init__()
    self.conv = nn.Sequential(
        # Defining the 1st convolutional layer
        nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), # 3x3 (kernel size=3) convolutional layer with stride=1 and padding=1
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        # Defining the 2nd convolutional layer
        nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), # 3x3 convolutional layer with stride=1 and padding=1
        nn.BatchNorm2d(out_channels), # Batch Normalization
        nn.ReLU(inplace=True), # ReLU activation function
    )

  def forward(self, x):
    return self.conv(x)

class UNET(nn.Module):
  def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],): # In UNET, default value of in_channel=3 (RGB scale) and of out_channel=1 (grayscale)
    super (UNET, self).__init__()
    self.ups = nn.ModuleList() # defining the decoder path list
    self.downs = nn.ModuleList() # defining the encoder path list
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # Max pooling layer with kernel size=2 and stride=2 while down-sampling

    # Going down in UNET; in a contracting path(encoder/downgoing), no. of channels increase and spacial dimensions (resolution) decrease
    for feature in features:
      self.downs.append(DoubleConv(in_channels, feature)) # Add a down-sampling block (double convolution followed by max pooling) to the downs ModuleList
      in_channels = feature # Update in_channels to feature for the next iteration

    # Going up in UNET
    for feature in reversed(features):
      self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2,)) # Transpose convolutional layer with kernel size=2 and stride=2 (transposing doubles the spatial dimensions of the subs. layers)
      self.ups.append(DoubleConv(feature*2, feature)) # Adding a DoubleConv layer to ups list

    # Defining a bottleneck block (double convolution) to connect the encoder and decoder paths
    self.bottleneck = DoubleConv(features[-1], features[-1]*2) # Add a DoubleConv layer for the bottleneck
    self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) # 1x1 convolutional layer for the final output

  def forward(self, x):
    skip_connections = []

    # Passing input through the downsampling blocks and it stores the output feature maps (skip connections)
    for down in self.downs:
      x = down(x) # input tensor goes through normalization and activation and gets stored back here (in encoding path: reduced spacial dimensions and increased channel depth)
      skip_connections.append(x)
      x = self.pool(x)

    # Passing the output feature map from the last downsampling block through the bottleneck block
    x = self.bottleneck(x)
    # Reversing the list of skip connections to align with the up-sampling blocks
    skip_connections = skip_connections[::-1]

    # Passing the bottleneck output through the upsampling blocks and concatenate with the corresponding skip connection
    for i in range(0, len(self.ups), 2):
      x = self.ups[i](x)
      skip_connection = skip_connections[i//2]
      if x.shape != skip_connection.shape:
        x = TF.resize(x, size = skip_connection.shape[2:]) # Resizing the tensor if shapes don't match
      concat_skip = torch.cat((skip_connection, x), dim=1) # Concatenating the upsampled feature map and the skip connection and pass through the double convolution block
      x = self.ups[i+1](concat_skip)

    return self.final_conv(x)

# Testing if the prediction's shape matches the ouput's shape
def test():
  x = torch.randn((3, 1, 160, 160))
  model = UNET(in_channels = 1, out_channels = 1)
  prediction = model(x)
  print(prediction.shape)
  print(x.shape)
  assert prediction.shape == x.shape

if __name__ == "__main__":
  test()



torch.Size([3, 1, 160, 160])
torch.Size([3, 1, 160, 160])
