In [1]:
import torch
from torch import nn

In [33]:
image = torch.randn((1,1,572,572))


In [23]:
def double_convs(in_channels,out_channels,kernel_size=3):
  return nn.Sequential(
      nn.Conv2d(in_channels,out_channels,kernel_size),
      nn.ReLU(inplace=True), # inplace=True This means that the output tensor is written to the same memory location as the input tensor -> more memory efficient
      nn.Conv2d(out_channels,out_channels,kernel_size),
      nn.ReLU(inplace=True),
  )

In [24]:
model_instance = double_convs(1,64)
pass_image = model_instance(image)
pass_image.shape

torch.Size([1, 64, 568, 568])

In [53]:
def crop_image(tensor , target_tensor):
  target_size = target_tensor.size()[2] # image.shape[-2]
  tensor_size = tensor.size()[2]
  delta = tensor_size - target_size # (572 - 262) = 10
  delta = delta //2  # answer is (5) cause we want to crop from left and right
  # delta : tensor_size - delta   --> 5 : 572 - 5 --> (5 : 567)
  return tensor [ : , : , delta : tensor_size - delta , delta : tensor_size - delta ]

In [84]:
class UNet(nn.Module):
  def __init__(self):
    super(UNet,self).__init__()

    # Building Blocks

    self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2,stride=2)
    self.down_conv_1 = double_convs(1,64)
    self.down_conv_2 = double_convs(64,128)
    self.down_conv_3 = double_convs(128,256)
    self.down_conv_4 = double_convs(256,512)
    self.down_conv_5 = double_convs(512,1024)

    self.up_trans_1 = nn.ConvTranspose2d(in_channels=1024,
                                         out_channels=512,
                                         kernel_size=2,
                                         stride=2)

    self.up_conv_1 = double_convs(1024,512)

    self.up_trans_2 = nn.ConvTranspose2d(in_channels=512,
                                         out_channels=256,
                                         kernel_size=2,
                                         stride=2)

    self.up_conv_2 = double_convs(512,256)

    self.up_trans_3 = nn.ConvTranspose2d(in_channels=256,
                                         out_channels=128,
                                         kernel_size=2,
                                         stride=2)

    self.up_conv_3 = double_convs(256,128)

    self.up_trans_4 = nn.ConvTranspose2d(in_channels=128,
                                         out_channels=64,
                                         kernel_size=2,
                                         stride=2)

    self.up_conv_4 = double_convs(128,64)

    self.out = nn.Conv2d(64,2,kernel_size=1)

  def forward(self, x:torch.Tensor) -> torch.Tensor:
    # Encoder part .. first half
    X1 = self.down_conv_1(x) # ---> x here is like the image we're passing into the model
    X2 = self.max_pool_2x2(X1) # maxpool then double convs
    X3 = self.down_conv_2(X2) # ---> double convs (64,128)
    X4 = self.max_pool_2x2(X3) # maxpool then double convs
    X5 = self.down_conv_3(X4)  # ---> double convs (128,256)
    X6 = self.max_pool_2x2(X5) # maxpool then double convs
    X7 = self.down_conv_4(X6) # ---> double convs (256,512)
    X8 = self.max_pool_2x2(X7) # maxpool then double convs
    X9 = self.down_conv_5(X8) # double convs (512,1024)

    print(f"The shape of the first Encoder part{X9.shape}") # check if it matches the paper

    # Decoder part .. first half

    X10 = self.up_trans_1(X9)
    # print(X10.shape)
    y = crop_image(X7,X10) # tensor swapping is bad here
    print(y.shape)
    X11 = self.up_conv_1(torch.concat((X10,y),1))
    print(X11.shape)

    X12 = self.up_trans_2(X11)
    y = crop_image(X5,X12) # tensor swapping is bad here
    X13 = self.up_conv_2(torch.concat((X12,y),1))
    print(X13.shape)

    X14 = self.up_trans_3(X13)
    y = crop_image(X3,X14) # tensor swapping is bad here
    X15 = self.up_conv_3(torch.concat((X14,y),1))
    print(X15.shape)

    X16 = self.up_trans_4(X15)
    y = crop_image(X1,X16) # tensor swapping is bad here
    X17 = self.up_conv_4(torch.concat((X16,y),1))
    print(X17.shape)

    X18 = self.out(X17)
    print(X18.shape)
    return X18


In [85]:
model_1 = UNet()
model_1(image)

The shape of the first Encoder parttorch.Size([1, 1024, 28, 28])
torch.Size([1, 512, 56, 56])
torch.Size([1, 512, 52, 52])
torch.Size([1, 256, 100, 100])
torch.Size([1, 128, 196, 196])
torch.Size([1, 64, 388, 388])
torch.Size([1, 2, 388, 388])


tensor([[[[-0.0569, -0.0674, -0.0583,  ..., -0.0594, -0.0571, -0.0603],
          [-0.0565, -0.0660, -0.0585,  ..., -0.0562, -0.0554, -0.0630],
          [-0.0544, -0.0633, -0.0549,  ..., -0.0714, -0.0559, -0.0638],
          ...,
          [-0.0544, -0.0631, -0.0604,  ..., -0.0684, -0.0556, -0.0586],
          [-0.0641, -0.0624, -0.0514,  ..., -0.0622, -0.0572, -0.0575],
          [-0.0550, -0.0554, -0.0590,  ..., -0.0520, -0.0639, -0.0597]],

         [[-0.0008,  0.0084, -0.0034,  ..., -0.0079, -0.0073, -0.0103],
          [ 0.0006, -0.0114, -0.0009,  ..., -0.0126, -0.0146, -0.0061],
          [-0.0042,  0.0017, -0.0089,  ...,  0.0070, -0.0154, -0.0119],
          ...,
          [-0.0045, -0.0038, -0.0161,  ..., -0.0076, -0.0054, -0.0118],
          [-0.0038, -0.0103, -0.0012,  ..., -0.0124, -0.0190, -0.0150],
          [-0.0069, -0.0110, -0.0040,  ..., -0.0132, -0.0059, -0.0089]]]],
       grad_fn=<ConvolutionBackward0>)