In [1]:
import torch
import torchvision.transforms.functional
from torch import nn

In [2]:
class DoubleConvolution(nn.Module):
    def __init__(self, inChannels: int, outChannels: int):
        super.__init__()

        self.first = nn.Conv2d(inChannels, outChannels, kernel_size=3, padding=1)
        self.act1 = nn.ReLU()

        self.second = nn.Conv2d(outChannels, outChannels, kernel_size=3, padding=1)
        self.act2 = nn.ReLU()
    
    def forward(self, x: torch.Tensor):
        x = self.first(x)
        x = self.act1(x)
        x = self.second(x)
        return self.act2(x) 


In [3]:
class DownSample(nn.Module):
    def __init__(self):
        super.__init__()

        self.pool = nn.MaxPool2d(2)

    def forward(self, x: torch.Tensor):
        return self.pool(x)

In [5]:
class UpSample(nn.Module):
    def __init__(self, inChannels: int, outChannels: int):
        super.__init__()

        self.up = nn.ConvTranspose2d(inChannels, outChannels, kernel_size=2, stride=2)
    
    def forward(self, x: torch.Tensor):
        return self.up(x)

In [7]:
class CropAndConcat(nn.Module):
    def forward(self, x: torch.Tensor, contracting_x: torch.Tensor):
        contracting_x = torchvision.transforms.functional.center_crop(contracting_x, [x.shape[2], x.shape[3]])
        x = torch.cat([x, contracting_x], dim=1)
        return x

In [None]:
class UNet(nn.Module):
    def __init__(self, inChannels: int, outChannels: int):
        super.__init__()

        self.downConv = nn.ModuleList([DoubleConvolution(i, 0) for i, o in 
                                       [(inChannels, 64), (64, 128), (128, 256), (256, 512)]])
        
        self.downSample = nn.ModuleList([DownSample() for _ in range(4)])

        self.middleConv = DoubleConvolution(512, 1024)

        self.upSample = nn.ModuleList([UpSample(i, o) for i, o in 
                                       [(1024, 512), (512, 256), (256, 128), (128, 64)]])
        
        self.upConv = nn.ModuleList([DoubleConvolution(i, o) for i, o in
                                     [(1024, 512), (512, 256), (256, 128), (128, 64)]])
        
        self.concat = nn.ModuleList([CropAndConcat() for _ in range(4)])

        self.finalConv = nn.Conv2d(64, outChannels, kernel_size=1)
    
    def forward(self, x: torch.Tensor):
        passThrough = []

        for i in range(len(self.downConv)):
            x = self.downConv[i](x)
            passThrough.append(x)
            x = self.downSample[i](x)

        x = self.middleConv(x)

        for i in range(len(self.upSample)):
            x = self.upSample[i](x)
            x = self.concat[i](x, passThrough[-i])
            x = self.upConv[i](x)
        
        return self.finalConv(x)