In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class down_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(down_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
    def forward(self, x):
        conv = self.conv(x)
        x = self.pool(conv)
        return conv, x

class up_conv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )
        self.conv = nn.Sequential(
            nn.Conv2d(2*in_channels, out_channels, kernel_size=3, ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x, y):
        x = self.up(x)
        x = torch.cat([x, y], dim=1)
        x = self.conv(x)
        return x



    


In [None]:

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=2, features=[64, 128, 256, 512, 1024]):
        super(UNet, self).__init__()
        self.inc = down_block(in_channels, features[0])
        self.down1 = down_block(features[0], features[1])
        self.down2 = down_block(features[1], features[2])
        self.down3 = down_block(features[2], features[3])
        self.bottom = nn.Sequential(
            nn.Conv2d(features[3], features[4], kernel_size=3, ),
            nn.BatchNorm2d(features[4]),
            nn.ReLU(inplace=True),
            nn.Conv2d(features[4], features[4], kernel_size=3, ),
            nn.BatchNorm2d(features[4]),
            nn.ReLU(inplace=True)
        )
        self.up1 = up_conv(features[4], features[3])
        self.up2 = up_conv(features[3], features[2])
        self.up3 = up_conv(features[2], features[1])
        self.up4 = up_conv(features[1], features[0])
        self.outc = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        conv1, x1 = self.inc(x)
        conv2, x2 = self.down1(x1)
        conv3, x3 = self.down2(x2)
        conv4, x4 = self.down3(x3)
        conv5 = self.bottom(x4)
        x = self.up1(conv5, conv4)
        x = self.up2(x, conv3)
        x = self.up3(x, conv2)
        x = self.up4(x, conv1)
        x = self.outc(x)
        return x

In [None]:
from torchsummary import summary

model = UNet(in_channels=3, out_channels=1, features=[64, 128, 256, 512])
summary(model, (3, 572, 572))