In [1]:
import torch
from torch import nn, optim, tensor
import torch.nn.functional as F
import torchvision.transforms.functional as TF

In [2]:
class DoubleConvolution(nn.Module):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(n_in, n_out, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Conv2d(n_out, n_out, kernel_size=3, stride=1),
            nn.ReLU()
        )
        
    def forward(self, x):
        return self.net(x)

In [3]:
class Downsample(nn.Module):
    def __init__(self, scale=2):
        super().__init__()
        self.pool = nn.MaxPool2d(scale)
        
    def forward(self, x):
        return self.pool(x)

In [4]:
class Upsample(nn.Module):
    def __init__(self, n_in, n_out, scale=2):
        super().__init__()
        self.up_pool = nn.UpsamplingNearest2d(scale_factor=scale)
        self.conv = nn.Conv2d(n_in, n_out, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        return self.relu(self.conv(self.up_pool(x)))

In [5]:
class CropAndCat(nn.Module):
    def forward(self, x, proj_tns):
        proj = TF.center_crop(proj_tns, [x.shape[2], x.shape[3]])
        out = torch.cat([x, proj], dim=1)
        return out

In [30]:
channels = [64, 128, 256, 512, 1024]
class UNet(nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()
        self.first = DoubleConvolution(input_channels, channels[0])
        self.down_double_convs = nn.ModuleList([DoubleConvolution(channels[i], channels[i+1]) for i in range(len(channels)-1)])
        self.down_samples = nn.ModuleList([Downsample() for i in range(4)])
        
        self.up_double_convs = nn.ModuleList([DoubleConvolution(channels[i], channels[i-1]) for i in range(-1, -len(channels), -1)])
        self.up_samples = nn.ModuleList([Upsample(channels[i], channels[i-1]) for i in range(-1, -len(channels), -1)])
        self.crop_and_cat = nn.ModuleList([CropAndCat() for i in range(4)])
        
        self.last = nn.Conv2d(input_channels, output_channels, kernel_size=1, stride=1)
        
    def forward(self, x):
        x = self.first(x)
        
        pass_through = []
        
        print(x.shape)
        pass_through.append(x)
        for i, mod in enumerate(self.down_double_convs):
            x = self.down_samples[i](x)
            x = mod(x)
            print(x.shape)
            pass_through.append(x)

        pass_through.pop() # get rid of the last projection since its not needed
        for i, mod in enumerate(self.up_double_convs):
            x = self.up_samples[i](x)
            
            proj = pass_through.pop()
            x = self.crop_and_cat[i](x, proj)
            x = mod(x)
        
        return x

In [31]:
a = torch.randn(2, 1, 572, 572)
unet = UNet(1, 1)

In [32]:
unet(a).shape

torch.Size([2, 64, 568, 568])
torch.Size([2, 128, 280, 280])
torch.Size([2, 256, 136, 136])
torch.Size([2, 512, 64, 64])
torch.Size([2, 1024, 28, 28])


torch.Size([2, 64, 388, 388])

In [10]:
unet

UNet(
  (first): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
  (down_double_convs): ModuleList(
    (0): DoubleConvolution(
      (net): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
        (1): ReLU()
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
        (3): ReLU()
      )
    )
    (1): DoubleConvolution(
      (net): Sequential(
        (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (1): ReLU()
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (3): ReLU()
      )
    )
    (2): DoubleConvolution(
      (net): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
        (1): ReLU()
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
        (3): ReLU()
      )
    )
    (3): DoubleConvolution(
      (net): Sequential(
        (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
        (1): ReLU()
        (2): Conv2d(512, 512, kernel_size=(3, 3), stri