In [1]:
import torch
from torch import nn

In [2]:
def unet_conv(ic, oc, ks=3, stride=1, act=nn.SiLU, norm=None, bias=True):
    layers = nn.Sequential()
    if norm: layers.append(norm(ic))
    if act : layers.append(act())
    layers.append(nn.Conv2d(ic, oc, stride=stride, kernel_size=ks, padding=ks//2, bias=bias))
    return layers

In [3]:
class ResBlock(nn.Module):
    def __init__(self, ic, oc=None, ks=3, act=nn.SiLU, norm=nn.BatchNorm2d):
        super().__init__()
        self.conv = nn.Sequential(
            unet_conv(ic, oc, ks, act=act, norm=norm),
            unet_conv(oc, oc, ks, act=act, norm=norm)
        )
        self.idconv = nn.Identity() if ic==oc else nn.Conv2d(ic, oc, kernel_size=1)
    
    def forward(self, x):
        return self.conv(x) + self.idconv(x)

In [4]:
class UNET(nn.Module):
    def __init__(self, in_channels, out_channels, nf=[64, 128, 256, 512]):
        super().__init__()
        self.encoder = nn.ModuleList()
        for f in nf:
            self.encoder.append(ResBlock(in_channels, f))
            in_channels = f
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottle_neck = ResBlock(nf[-1], nf[-1]*2)

        self.decoder = nn.ModuleList()
        rnf = nf[::-1]
        for f in rnf:
            self.decoder.append(nn.ConvTranspose2d(f*2, f, kernel_size=2, stride=2))
            self.decoder.append(ResBlock(f*2, f))
        
        self.final_conv = nn.Conv2d(nf[0], out_channels, kernel_size=1)

    def forward(self, x):
        skips = []
        for enc in self.encoder:
            x = enc(x)
            skips.append(x)
            x = self.pool(x)
              
        x = self.bottle_neck(x)
        skips = skips[::-1]
        
        for i in range(0, len(self.decoder), 2):
            x = self.decoder[i](x)
            x = torch.cat((skips[i//2], x), dim=1)
            x = self.decoder[i+1](x)

        return self.final_conv(x)

In [5]:
model = UNET(1, 1)
x = torch.randn((3, 1, 160, 160))
pred = model(x)
print(x.shape)
print(pred.shape)

torch.Size([3, 1, 160, 160])
torch.Size([3, 1, 160, 160])


In [6]:
model

UNET(
  (encoder): ModuleList(
    (0): ResBlock(
      (conv): Sequential(
        (0): Sequential(
          (0): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): SiLU()
          (2): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (1): Sequential(
          (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): SiLU()
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (idconv): Conv2d(1, 64, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): ResBlock(
      (conv): Sequential(
        (0): Sequential(
          (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): SiLU()
          (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (1): Sequential(
          (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_r