In [41]:
import torch
import torch.nn as nn

In [56]:
### We use padding to keep the size of the input and output the same after convolution
### We use batch normalization to speed up the training process
def double_conv(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, stride=1),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c, kernel_size=3, padding=1, stride=1),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True)
    )
    return conv

In [62]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        ### 1st part encoder
        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_conv_1 = double_conv(3, 64)
        self.down_conv_2 = double_conv(64, 128)
        self.down_conv_3 = double_conv(128, 256)
        self.down_conv_4 = double_conv(256, 512)
        self.down_conv_5 = double_conv(512, 1024)
        
        ### 2nd part decoder
        self.up_trans_1 = nn.ConvTranspose2d(
            in_channels=1024, 
            out_channels=512, 
            kernel_size=2, 
            stride=2)
        
        self.up_conv_1 = double_conv(1024, 512)
        
        self.up_trans_2 = nn.ConvTranspose2d(
            in_channels=512, 
            out_channels=256, 
            kernel_size=2, 
            stride=2)
        
        self.up_conv_2 = double_conv(512, 256)
        
        self.up_trans_3 = nn.ConvTranspose2d(
            in_channels=256, 
            out_channels=128, 
            kernel_size=2, 
            stride=2)
        
        self.up_conv_3 = double_conv(256, 128)
        
        self.up_trans_4 = nn.ConvTranspose2d(
            in_channels=128, 
            out_channels=64, 
            kernel_size=2, 
            stride=2)
        
        self.up_conv_4 = double_conv(128, 64)
        
        self.out = nn.Conv2d(
            in_channels=64,
            out_channels=2,
            kernel_size=1
        )
        
    def forward(self, image):
        # encoder
        x1 = self.down_conv_1(image) # we will need x1, x3, x5, x7 in the decoder part. We don't need x9 in the decoder part because there is no maxpooling.
        print("Output after 1st convolution:", x1.size())
        x2 = self.max_pool_2x2(x1)
        x3 = self.down_conv_2(x2) #
        x4 = self.max_pool_2x2(x3)
        x5 = self.down_conv_3(x4) #
        x6 = self.max_pool_2x2(x5)
        x7 = self.down_conv_4(x6) #
        x8 = self.max_pool_2x2(x7)
        x9 = self.down_conv_5(x8)
        print("Output after 5th convolution:", x9.size())
        
        #This is decoder part
        x = self.up_trans_1(x9)
        print("Output of 1st up-convolution:", x.size())
        print("Output after 4th convolution:", x7.size())
        x = self.up_conv_1(torch.cat([x, x7], 1))
        print(x.size())
        x = self.up_trans_2(x)
        print("Output of 2nd up-convolution:", x.size())
        x = self.up_conv_2(torch.cat([x, x5], 1))
        print(x.size())
        x = self.up_trans_3(x)
        x = self.up_conv_3(torch.cat([x, x3], 1))
        x = self.up_trans_4(x)
        x = self.up_conv_4(torch.cat([x, x1], 1))
        print(x.size())
        x = self.out(x)
        return x
        
        
        
        

In [63]:
if __name__ == '__main__':
        image = torch.rand((1, 3, 512, 512))
        model = UNet()
        print(model(image))

Output after 1st convolution: torch.Size([1, 64, 512, 512])
Output after 5th convolution: torch.Size([1, 1024, 32, 32])
Output of 1st up-convolution: torch.Size([1, 512, 64, 64])
Output after 4th convolution: torch.Size([1, 512, 64, 64])
torch.Size([1, 512, 64, 64])
Output of 2nd up-convolution: torch.Size([1, 256, 128, 128])
torch.Size([1, 256, 128, 128])
torch.Size([1, 64, 512, 512])
tensor([[[[-3.1493e-01,  2.3824e-01,  1.0029e-01,  ..., -3.1960e-01,
           -1.4870e-03, -2.0888e-01],
          [-4.6253e-01, -6.8784e-01, -6.3150e-01,  ..., -3.0999e-01,
           -1.1709e+00, -5.9235e-01],
          [-4.0943e-01, -3.4427e-01, -6.8272e-01,  ..., -8.8512e-01,
           -6.6690e-01, -7.8568e-01],
          ...,
          [-2.3316e-01, -8.7829e-01, -6.7628e-01,  ..., -5.8302e-01,
           -2.8997e-01, -4.4420e-01],
          [-2.3645e-01, -3.1326e-01, -3.2855e-01,  ..., -1.9221e-01,
           -5.6292e-02, -2.9878e-01],
          [-1.6000e-01, -3.4401e-01, -3.3705e-01,  ..., -5.71