In [8]:
import torch
import torch.nn as nn

In [9]:
# Double convolution block with padding
def double_conv(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
        nn.ReLU(inplace=True)
    )
    return conv

In [10]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_conv_1 = double_conv(1, 64)
        self.down_conv_2 = double_conv(64, 128)
        self.down_conv_3 = double_conv(128, 256)
        self.down_conv_4 = double_conv(256, 512)
        self.down_conv_5 = double_conv(512, 1024)

        self.up_trans_1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.up_conv_1 = double_conv(1024, 512)

        self.up_trans_2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.up_conv_2 = double_conv(512, 256)

        self.up_trans_3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up_conv_3 = double_conv(256, 128)
        
        self.up_trans_4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up_conv_4 = double_conv(128, 64)

        self.out = nn.Conv2d(64, 2, kernel_size=1)

    def forward(self, image):
        # Encoder
        x1 = self.down_conv_1(image)
        x2 = self.max_pool_2x2(x1)
        x3 = self.down_conv_2(x2)
        x4 = self.max_pool_2x2(x3)
        x5 = self.down_conv_3(x4)
        x6 = self.max_pool_2x2(x5)
        x7 = self.down_conv_4(x6)
        x8 = self.max_pool_2x2(x7)
        x9 = self.down_conv_5(x8)

        # Decoder
        #Before each concatenation in the decoder, we check if the sizes of the tensors match. If they don't, we use nn.functional.pad to pad the smaller tensor to match the dimensions of the larger tensor.This ensures that the tensors have the same dimensions before concatenation without cropping.
        x = self.up_trans_1(x9)
        if x.size() != x7.size():
            diffY = x7.size()[2] - x.size()[2]
            diffX = x7.size()[3] - x.size()[3]
            x = nn.functional.pad(x, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = self.up_conv_1(torch.cat([x, x7], 1))

        x = self.up_trans_2(x)
        if x.size() != x5.size():
            diffY = x5.size()[2] - x.size()[2]
            diffX = x5.size()[3] - x.size()[3]
            x = nn.functional.pad(x, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = self.up_conv_2(torch.cat([x, x5], 1))

        x = self.up_trans_3(x)
        if x.size() != x3.size():
            diffY = x3.size()[2] - x.size()[2]
            diffX = x3.size()[3] - x.size()[3]
            x = nn.functional.pad(x, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = self.up_conv_3(torch.cat([x, x3], 1))

        x = self.up_trans_4(x)
        if x.size() != x1.size():
            diffY = x1.size()[2] - x.size()[2]
            diffX = x1.size()[3] - x.size()[3]
            x = nn.functional.pad(x, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = self.up_conv_4(torch.cat([x, x1], 1))

        x = self.out(x)
        return x

In [11]:
image = torch.rand(1, 1, 696, 520)  # Batch size, channel, height, width
model = UNet()
output = model(image)
print(output.size())  # Should be [1, 2, 696, 520]

torch.Size([1, 2, 696, 520])
