In [1]:
import torch 

In [2]:
import torch.nn as nn

In [3]:
#input images for each stage being passed through two convolutional layers of 64, 128, 256 and 512 channels as indicated in the image of the Unet structure
def double_conv(in_c, out_c):
    conv=nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_c, out_c, kernel_size=3),
        nn.ReLU(inplace=True)
    )
    return conv

#each output image structure after going through two conolutional layers is being copied by cropping in terms of size to resemble a copy of the downsampled image, another method this could have been done is through padding the downsampled images. However, the original UNet by Olaf Ronneberger et al. model does not use this
def crop_img(tensor, target_tensor):
    target_size=target_tensor.size()[2]
    target_size=tensor.size()[2]
    delta=tensor_size-target_size
    delta=delta//2
    return tensor[:, :, delta:tensor_size-delta, delta:tensor_size-delta]
    

In [4]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.max_pool_2x2=nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_conv_1=double_conv(1, 64)
        self.down_conv_2=double_conv(64, 128)
        self.down_conv_3=double_conv(128, 256)
        self.down_conv_4=double_conv(256, 512)
        self.down_conv_5=double_conv(512, 1024)
    #downsampling stages
        
        self.up_trans_1=nn.ConvTranspose2d(
            in_channels= 1024, 
            out_channels=512, kernel_size=2,
            stride=2)
        self.up_conv_1=double_conv(1024, 512)
        
        
        self.up_trans_2=nn.ConvTranspose2d(
            in_channels= 512, 
            out_channels=256, kernel_size=2,
            stride=2)
        self.up_conv_2=double_conv(512, 256)
        
        
        self.up_trans_3=nn.ConvTranspose2d(
            in_channels= 256, 
            out_channels=128, kernel_size=2,
            stride=2)
        self.up_conv_3=double_conv(256, 128)
        
        self.up_trans_4=nn.ConvTranspose2d(
            in_channels= 128, 
            out_channels=64, kernel_size=2,
            stride=2)
        self.up_conv_4=double_conv(128, 64)
        
        self.out=nn.Conv2d(
            in_channels=64,
            out_channels=2, kernel_size=1
        )
        
    #Upsampling stages also called by some as deconvolution or up-convolution. Here Transposed convolution is used for compiling the upsampling. Another way of upsampling compilation is using Bilinear interpolation, However, the original UNet by Olaf Ronneberger et al. does not use this method

In [5]:
def forward(self, image):
    #batch_size,  channel, height, width
    #encoder stage using simple sonvolutional neural network architecture
    x1=self.down_conv_1(image)#input image stage 1
    x2=self.max_pool_2x2(x1)
    x3=self.down_conv_2(x2)#input image stage 2
    x4=self.max_pool_2x2(x3)
    x5=self.down_conv_3(x4)#input image stage 3
    x6=self.max_pool_2x2(x5)
    x7=self.down_conv_4(x6)#input image stage 4
    x8=self.max_pool_2x2(x7)
    x9=self.down_conv_5(x8)

    #Decoder
    #Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2 convolution (“up-convolution”) that halves the number of feature channels, a concatenation with the correspondingly cropped feature map from the contracting path, and two 3x3 convolutions, each followed by a ReLU. The cropping is necessary due to the loss of border pixels in every convolution
    x=self.up_trans_1(x9)
    y=crop_img(x7, x)
    x=self.up_conv_1(torch.cat([x,y], 1))
    print(x.size())
    
    x=self.up_trans_2(x)
    y=crop_img(x5, x)
    x=self.up_conv_2(torch.cat([x,y], 1))
    
    x=self.up_trans_3(x)
    y=crop_img(x3, x)
    x=self.up_conv_3(torch.cat([x,y], 1))
    
    x=self.up_trans_4(x)
    y=crop_img(x1, x)
    x=self.up_conv_4(torch.cat([x,y], 1))
    
    x=self.out(x)
    
    print (x.size())
    return x

    
    if __name__=="__main__":
        image=torch.rand((1,1,572,572))
        model=UNet()
        print(model(image))

In [6]:
#The encoder stage can be changed and adjusted based on another pre trained model that would be preferred to be used.