In [1]:
import torch
from torch import nn

In [2]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

In [3]:
class UNET(nn.Module):
    def __init__(self, in_channels, out_channels, nf=[64, 128, 256, 512]):
        super().__init__()
        self.encoder = nn.ModuleList()
        for f in nf:
            self.encoder.append(DoubleConv(in_channels, f))
            in_channels = f
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottle_neck = DoubleConv(nf[-1], nf[-1]*2)

        self.decoder = nn.ModuleList()
        rnf = nf[::-1]
        for f in rnf:
            self.decoder.append(nn.ConvTranspose2d(f*2, f, kernel_size=2, stride=2))
            self.decoder.append(DoubleConv(f*2, f))
        
        self.final_conv = nn.Conv2d(nf[0], out_channels, kernel_size=1)

    def forward(self, x):
        skips = []
        for enc in self.encoder:
            x = enc(x)
            skips.append(x)
            x = self.pool(x)
              
        x = self.bottle_neck(x)
        skips = skips[::-1]
        
        for i in range(0, len(self.decoder), 2):
            x = self.decoder[i](x)
            x = torch.cat((skips[i//2], x), dim=1)
            x = self.decoder[i+1](x)

        return self.final_conv(x)

In [4]:
model = UNET(1, 1)
x = torch.randn((3, 1, 160, 160))
pred = model(x)
print(x.shape)
print(pred.shape)

torch.Size([3, 1, 160, 160])
torch.Size([3, 1, 160, 160])
