In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchkeras import summary

# 2-Conv

In [6]:
class ConV2(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='reflect'),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='reflect'),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.conv2(x)

# DownSample

In [7]:
class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.downsample = nn.Sequential(
            # Use conv with stride 2  to downsample
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1, padding_mode='reflect' ),
            ConV2(in_channels, out_channels)
        )

    def forward(self,x):
        return self.downsample(x)

# UpSample

In [None]:
class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        # ==========================Attention==================================
        # after concat, the number of channel was changed
        self.conv = ConV2(in_channels + in_channels//2, out_channels)

    def forward(self, x1, x2):
        x1 = self.upsample(x1)
        '''
                Edge Padding —————— 71 / 2  = 35
        BCHW
                Left Edge : add  diff_W//2  pixel
                Right Edge : add  diff_W - diff_W//2 pixel
                Top Edge : add  diff_H//2  pixel
                Right Edge : add  diff_H - diff_H//2 pixel
            
        IF x2.size < x1.size , diff < 0 , still useful
        '''
        diff_H = x2.size()[2] - x1.size()[2]
        diff_W = x2.size()[3] - x1.size()[3]
        
        x1 = F.pad(x1,pad=[diff_W//2, diff_W-diff_W//2, diff_H, diff_H-diff_H//2]) 
        x = torch.cat([x1, x2], dim=1)
        return self.conv(x)

# OutConv

In [9]:
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.out = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.out(x)

# U-Net

In [10]:
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.con = ConV2(in_channels, 64)
        self.down1 = DownSample(64, 128)
        self.down2 = DownSample(128, 256)
        self.down3 = DownSample(256, 512)
        self.midout = OutConv(512, 512)
        self.up1 = UpSample(512, 256)
        self.up2 = UpSample(256, 128)
        self.up3 = UpSample(128, 64)
        self.out = OutConv(64, out_channels)

    def forward(self, x0):
        x1 = self.con(x0)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.midout(x4)
        x = self.up1(x5, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        x = self.out(x)
        return x

In [11]:
unet = UNet(3,3)
x = torch.rand(size=(2,3,256,256))
print(unet(x).shape)

torch.Size([2, 3, 256, 256])


In [12]:
unet = UNet(3,3)
summary(unet, input_data=torch.rand(size=(2,3,572,572)))

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
Conv2d-1                          [-1, 64, 572, 572]                1,792
BatchNorm2d-2                     [-1, 64, 572, 572]                  128
ReLU-3                            [-1, 64, 572, 572]                    0
Conv2d-4                          [-1, 64, 572, 572]               36,928
BatchNorm2d-5                     [-1, 64, 572, 572]                  128
ReLU-6                            [-1, 64, 572, 572]                    0
Conv2d-7                          [-1, 64, 286, 286]               36,928
Conv2d-8                         [-1, 128, 286, 286]               73,856
BatchNorm2d-9                    [-1, 128, 286, 286]                  256
ReLU-10                          [-1, 128, 286, 286]                    0
Conv2d-11                        [-1, 128, 286, 286]              147,584
BatchNorm2d-12                   [-1,

