UNET with ResNet building blocks. Resnet blocks help in identity mappings, allowing us to create deeper models and preventing the model from running into the problems of vanishing/exploding gradients. The training error in ResNet models can be decreased further than plain models i.e. the saturation point is shifted towards lower values of training error. ResNet blocks also help us in getting a smoother loss curve. ResNet combined with U-Net architecture leads to the amalgamation of the advantageous properties of both the models in the final model.

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [6]:
class DoubleConv(nn.Module):
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.skip = nn.Conv2d(in_channels, out_channels, kernel_size = 1) # Skip connection
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels)
        )
        
    def forward(self, x):
        skip_x = self.skip(x)
        conv_x = self.double_conv(x)
        added_x = skip_x + conv_x  # Element-wise addition of skip connection filters and residual filters
        return F.relu_(added_x) # Inplace functional version of relu

class UpConcatConv(nn.Module):
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=1)  # 1x1 convolution to reduce num of channels to half
        )
        self.conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x1, x2):
        x1 = self.upsample(x1)
        # No need to crop the feature maps from the corresponding contracting layer since we using padding in DoubleConv
        x = torch.cat((x2, x1), dim=1)
        return self.conv(x)

In [7]:
class UNet(nn.Module):
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.dconv1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.dconv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.dconv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.dconv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.dconv5 = DoubleConv(512, 1024)
        self.up1 = UpConcatConv(1024, 512)
        self.up2 = UpConcatConv(512, 256)
        self.up3 = UpConcatConv(256, 128)
        self.up4 = UpConcatConv(128, 64)
        self.outconv = nn.Conv2d(64, out_channels, kernel_size=1)  # Final 1x1 convolution
    
    def forward(self, x):
        x1 = self.dconv1(x)
        x2 = self.pool1(x1)
        x2 = self.dconv2(x2)
        x3 = self.pool2(x2)
        x3 = self.dconv3(x3)
        x4 = self.pool3(x3)
        x4 = self.dconv4(x4)
        x5 = self.pool4(x4)
        x5 = self.dconv5(x5)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        return self.outconv(x)

In [8]:
inp = torch.randn(10, 3, 128, 128)

In [9]:
model = UNet(3, 2)

In [10]:
out = model(inp)

In [11]:
out.shape

torch.Size([10, 2, 128, 128])