# Nested U-net

In [1]:
import torch
from torch import nn

In [2]:
class VGGBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        return out
    
pool = nn.MaxPool2d(2, 2)
up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

In [3]:
input_channels = 3
input = torch.rand(1, input_channels, 32, 32)

nb_filter = [3, 6, 12]

x0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])(input)
x1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])(pool(x0_0))
x0_1 = VGGBlock(nb_filter[0] + nb_filter[1], 
                nb_filter[0], nb_filter[0])(torch.cat([x0_0, up(x1_0)], 1))

x2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])(pool(x1_0))
x1_1 = VGGBlock(nb_filter[1] + nb_filter[2], nb_filter[1],
                nb_filter[1])(torch.cat([x1_0, up(x2_0)], 1))
x0_2 = VGGBlock(nb_filter[0] * 2 + 
                nb_filter[1], nb_filter[0], nb_filter[0])(torch.cat([x0_0, x0_1, up(x1_1)], 1))

# deep_supervision
num_classes = 2
output1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)(x0_1)
output2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)(x0_2)

