In [10]:
import torch
import torch.nn as nn

In [11]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

In [12]:
class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels)
        self.pool = nn.MaxPool2d(2)

    def forward(self, x):
        down = self.conv(x)
        p =  self.pool(down)
        return down, p

In [13]:
class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x1, x2], 1)
        return self.conv(x)

In [14]:
class UNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.downsample1 = DownSample(in_channels, 64)
        self.downsample2 = DownSample(64, 128)
        self.downsample3 = DownSample(128, 256)
        self.downsample4 = DownSample(256, 512)

        self.bottleneck = DoubleConv(512, 1024)

        self.upsample1 = UpSample(1024, 512)
        self.upsample2 = UpSample(512, 256)
        self.upsample3 = UpSample(256, 128)
        self.upsample4 = UpSample(128, 64)

        self.out = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        down1, p1 = self.downsample1(x)
        down2, p2 = self.downsample2(p1)
        down3, p3 = self.downsample3(p2)
        down4, p4 = self.downsample4(p3)

        b = self.bottleneck(p4)

        up1 = self.upsample1(b, down4)
        up2 = self.upsample2(up1, down3)
        up3 = self.upsample3(up2, down2)
        up4 = self.upsample4(up3, down1)

        out = self.out(up4)
        return out

In [15]:
input_image = torch.rand((1, 3, 512, 512))
model = UNet(3, 10)
model(input_image).shape

torch.Size([1, 10, 512, 512])