How to build a U-Net, functional-style. The idea is to store encoder outputs in tuple:
- Encoder layer: transform last element of tuple, duplicate it and downsample one copy
- Bridge layer: transform last element of tuple
- Decoder layer: upsample last element of tuple, add it up with the pre-last, transform result

In [2]:
from torch import nn
import torch

from pytorch_toolz.functools import Reduce, Parallel
from pytorch_toolz.itertools import Slice, Chain
from pytorch_toolz.operator import ItemGetter, Apply

In [34]:
def conv_block(in_ch, out_ch):
    return nn.Sequential(
        nn.BatchNorm2d(in_ch),
        nn.ReLU(),
        nn.Conv2d(in_ch, out_ch, kernel_size=3, padding='same')
    )


def downsampling(in_ch, out_ch):
    return nn.Conv2d(in_ch, out_ch, kernel_size=2, stride=2)


def upsampling(in_ch, out_ch):
    return nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)


def res_block(in_ch, out_ch):
    return nn.Sequential(
        Parallel(
            nn.Sequential(
                conv_block(in_ch, in_ch // 4),
                conv_block(in_ch // 4, out_ch)
            ),
            nn.Identity()
        ),
        Reduce(torch.add)
    )


def encoder_layer(in_ch, out_ch):
    return nn.Sequential(
        Parallel(
            Slice(-1),
            nn.Sequential(
                ItemGetter(-1),
                *[res_block(in_ch, in_ch)] * 3,
                Parallel(
                    nn.Identity(),
                    downsampling(in_ch, out_ch)
                )
            ),
        ),
        Chain()
    )


def bridge_block(ch):
    return nn.Sequential(
        Parallel(
            Slice(-1),
            nn.Sequential(
                ItemGetter(-1),
                res_block(ch, ch),
                Parallel(nn.Identity())
            ),
        ),
        Chain()
    )


def decoder_layer(in_ch, out_ch):
    return nn.Sequential(
        Parallel(
            Slice(-2),
            nn.Sequential(
                Parallel(
                    nn.Sequential(
                        ItemGetter(-1),
                        upsampling(in_ch, out_ch)
                    ),
                    ItemGetter(-2)
                ),
                Reduce(torch.add),
                *[res_block(out_ch, out_ch)] * 3,
                Parallel(nn.Identity())
            ),
        ),
        Chain(),
    )

unet = nn.Sequential(
    Parallel(nn.Identity()),
    encoder_layer(8, 16),
    encoder_layer(16, 32),
    bridge_block(32),
    bridge_block(32),
    decoder_layer(32, 16),
    decoder_layer(16, 8),
    ItemGetter(0)
)

Basic test, to check that shapes match.

In [35]:
unet(torch.randn((1, 8, 16, 16))).shape

torch.Size([1, 8, 16, 16])