In [1]:
import torch 
import torch.nn as nn 

In [2]:
def Double_Conv(input_channels, output_channels):
    
    conv = nn.Sequential(
           nn.Conv2d(in_channels=input_channels, out_channels=output_channels, kernel_size=3),
           nn.ReLU(inplace=True),
           nn.Conv2d(in_channels=output_channels, out_channels=output_channels, kernel_size=3),
           nn.ReLU(inplace=True))
    return conv
    

In [3]:
def CropImage(orginal_tensor, target_tensor):
    
    """Assuming the the orginal dimension 
       of the image is alwyas bigger than
       that of target dimension
       
       Orginal_tensor will come from encoder 
       part, which has to be cropped
       
       """
    
    orginal_dim = orginal_tensor.shape[2]  
    target_dim  = target_tensor.shape[2]
    
    delta = (orginal_dim - target_dim) // 2
    
    return orginal_tensor[:, :, delta:orginal_dim-delta, delta:orginal_dim-delta]   

In [10]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        
        
       
        self.MaxPool = nn.MaxPool2d(kernel_size=2) 
        self.conv1 = Double_Conv(input_channels=1,    output_channels=64)
        self.conv2 = Double_Conv(input_channels=64,   output_channels=128)
        self.conv3 = Double_Conv(input_channels=128,  output_channels=256)
        self.conv4 = Double_Conv(input_channels=256,  output_channels=512)
        self.conv5 = Double_Conv(input_channels=512,  output_channels=1024)
       
        
        self.ConvTransp1 = nn.ConvTranspose2d(1024,512, 2,2) 
        self.ConvTransp2 = nn.ConvTranspose2d(512, 256, 2,2)
        self.ConvTransp3 = nn.ConvTranspose2d(256, 128, 2,2)
        self.ConvTransp4 = nn.ConvTranspose2d(128, 64,  2,2)
        
        
        self.conv6  = Double_Conv(input_channels=1024,  output_channels=512)
        self.conv7  = Double_Conv(input_channels=512,   output_channels=256)
        self.conv8  = Double_Conv(input_channels=256,   output_channels=128)
        self.conv9  = Double_Conv(input_channels=128,   output_channels=64)
        self.conv10 = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)
        
        
    def forward(self, x):
        #----------------------------------Encoder-----------------------------------------#
        x1 = self.conv1(x)
        x2 = self.MaxPool(x1)
        
        x3 = self.conv2(x2)
        x4 = self.MaxPool(x3)
        
        x5 = self.conv3(x4)
        x6 = self.MaxPool(x5)
        
        x7 = self.conv4(x6)
        x8 = self.MaxPool(x7)
        
        x9 = self.conv5(x8)
        
        
        #----------------------------------Decoder-----------------------------------------#
        
        x10 = self.ConvTransp1(x9)
        y1  = CropImage(x7, x10)
        x11 = torch.concat((x10, y1), dim=1)
        
        x12 = self.conv6(x11)
        x13 = self.ConvTransp2(x12)
        y2  = CropImage(x5, x13)
        x14 = torch.concat((x13, y2), dim=1)
        
        x15 = self.conv7(x14)
        x16 = self.ConvTransp3(x15)
        y3  = CropImage(x3, x16)
        x17 = torch.concat((x16, y3), dim=1)
        
        x18 = self.conv8(x17)
        x19 = self.ConvTransp4(x18)
        y4  = CropImage(x1, x19)
        x20 = torch.concat((x19, y4), dim=1)
        
        x21 = self.conv9(x20)
        x22 = self.conv10(x21)
                
        return print(x.shape, x22.shape)

In [11]:
x = torch.rand([1,1,572,572])
model_test = UNet()
model_test(x)

torch.Size([1, 1, 572, 572]) torch.Size([1, 2, 388, 388])
