In [88]:
# Source: https://github.com/mateuszbuda/brain-segmentation-pytorch/blob/master/unet.py

import torch

import torch.nn as nn

from collections import OrderedDict
from torchinfo import summary
from functools import reduce
from itertools import accumulate
from torchview import draw_graph

class UNet2d(nn.Module):

    def __init__(self, in_channels=3, out_channels=1, blocks=[32, 64, 128, 256]):
        super(UNet, self).__init__()

        self.blocks = blocks

        # Encoder Modules
        in_features = in_channels
        for enc_id, block in enumerate(blocks, start=1):
            setattr(
                self, f'encoder{enc_id}',
                UNet._block(in_features, block, name=f'enc{enc_id}')
            )
            in_features = block

        # Pooling Layers
        for pool_id, _ in enumerate(blocks, start=1):
            setattr(
                self, f'pool{pool_id}',
                nn.MaxPool2d(kernel_size=2, stride=2)
            )

        # Bottleneck Module
        self.bottleneck = UNet._block(blocks[-1], blocks[-1]*2, name="bottleneck")

        # Upconv Layers
        for up_id, block in enumerate(blocks, start=1):
            setattr(
                self, f'upconv{up_id}',
                nn.ConvTranspose2d(
                    block*2, block, kernel_size=2, stride=2,
                )
            )

        # Decoder Modules
        for dec_id, block in enumerate(blocks, start=1):
            setattr(
                self, f'decoder{dec_id}',
                UNet._block(block*2, block, name=f'dec{dec_id}')
            )

        # Output Layer
        self.output = nn.Conv2d(
            in_channels=blocks[0], out_channels=out_channels, kernel_size=1,
        )
    
    def forward(self, x):
        # 1. Encoder Leg
        block_encodings = [None]
        for block_id, _ in enumerate(self.blocks, start=1):
            encoder = getattr(self, f'encoder{block_id}')
            pool = getattr(self, f'pool{block_id}')

            x = encoder(x)
            block_encodings.append(x)
            x = pool(x)

        # 2. Apply Bottleneck
        x = self.bottleneck(x)

        # 3. Decoder Leg
        for block_id in range(len(self.blocks), 0, -1):
            upconv = getattr(self, f'upconv{block_id}')
            decoder = getattr(self, f'decoder{block_id}')
            block_encoding = block_encodings[block_id]

            x = upconv(x)
            x = torch.cat((block_encoding, x), dim=1)
            x = decoder(x)

        # 4. Output
        output = self.output(x)

        return output

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

model = UNet2d(in_channels=1)

summary(
    model=model, 
    input_size=(16, 1, 224, 224),
    # input_size=(4, 1, 512, 512),
    col_names=["output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"],
)

# graph = draw_graph(
#     model, 
#     input_size=(1, 1, 224, 224), 
#     # expand_nested=True
# )

# # View Model Architecture
# graph.visual_graph

Layer (type (var_name))                  Output Shape         Param #              Trainable
UNet (UNet)                              [16, 1, 224, 224]    --                   True
├─Sequential (encoder1)                  [16, 32, 224, 224]   --                   True
│    └─Conv2d (enc1conv1)                [16, 32, 224, 224]   288                  True
│    └─BatchNorm2d (enc1norm1)           [16, 32, 224, 224]   64                   True
│    └─ReLU (enc1relu1)                  [16, 32, 224, 224]   --                   --
│    └─Conv2d (enc1conv2)                [16, 32, 224, 224]   9,216                True
│    └─BatchNorm2d (enc1norm2)           [16, 32, 224, 224]   64                   True
│    └─ReLU (enc1relu2)                  [16, 32, 224, 224]   --                   --
├─MaxPool2d (pool1)                      [16, 32, 112, 112]   --                   --
├─Sequential (encoder2)                  [16, 64, 112, 112]   --                   True
│    └─Conv2d (enc2conv1)        