In [19]:
import torch.nn as nn

from collections import OrderedDict

class ResidualAdd(nn.Module):
    def __init__(self, block, in_channels, out_channels, stride, layer_fns, name):
        super().__init__()

        # 1. Initialize attributes
        self.block = block
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.layer_fns = layer_fns

        # 2. Initialize a downsample layer
        if self.in_channels != self.out_channels:
            self.downsample = nn.Sequential(OrderedDict([
                (
                    name + '_residual_conv',
                    self.layer_fns['Conv'](
                        in_channels=in_channels,
                        out_channels=out_channels,
                        kernel_size=1,
                        stride=stride,
                    )
                ),
                (
                    name + '_residual_bn',
                    self.layer_fns['BatchNorm'](out_channels)
                )
            ]))
        
        # 3. Initialize the ReLU layer
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        block_x = self.block(x)

        # 1. Downsample if block_x and x are not the same size
        if hasattr(self, 'downsample'):
            x = self.downsample(x)
        
        # 2. Add residual connection
        x = x + block_x

        # 3. Apply ReLU
        x = self.relu(x)

        return x

In [20]:
import torch

import torch.nn as nn

from collections import OrderedDict
from torchinfo import summary
from torchvision import ops as vision_ops

class ResNext(nn.Module):

    def __init__(
        self,
        in_channels=3,
        out_channels=1,
        # A tuple of (features, repeats, stride)
        # blocks=[(128, 1, 1), (256, 1, 2), (512, 1, 2), (1024, 1, 2)],
        blocks=[(128, 3, 1), (256, 4, 2), (512, 6, 2), (1024, 3, 2)],
        layer_fns=dict(
            ConvNormActivation=vision_ops.Conv2dNormActivation,
            Conv=nn.Conv2d,
            MaxPool=nn.MaxPool2d,
            AvgPool=nn.AdaptiveAvgPool2d,
            BatchNorm=nn.BatchNorm2d,
            ReLU=nn.ReLU,
        )
    ):
        super(ResNext, self).__init__()

        # 1. Initialize block configurations
        self.blocks = blocks
        self.layer_fns = layer_fns
        
        # 2. Get backbone's input and output sizes
        block_0_output, last_block_features = 64, blocks[-1][0]

        # 3. Initialize the first convolution block
        self.block_0 = self.layer_fns['ConvNormActivation'](
            in_channels=in_channels,
            out_channels=block_0_output,
            kernel_size=7,
            padding=3,
            stride=2,
            bias=False,
        )

        # 4. Initialize the pooling layer after the first convolution block
        self.pool_0 = self.layer_fns['MaxPool'](kernel_size=3, stride=2, padding=1)

        # 5. Iterate over and create blocks
        block_input = block_0_output
        for block_id, (features, repeats, stride) in enumerate(blocks, start=1):
            name = f'block_{block_id}'
            setattr(
                self, name,
                self._repeated_block(
                    block_input,
                    features,
                    repeats,
                    stride,
                    name=name,
                )
            )
            block_input = 2*features

        # 6. Global average pooling to merge the spatial dimensions
        self.global_pool = self.layer_fns['AvgPool'](1)

        # 7. Output layer
        self.output = nn.Linear(
            in_features=last_block_features*2,
            out_features=out_channels,
        )
    
    def forward(self, x):
        # 1. Apply first block
        x = self.block_0(x)
        x = self.pool_0(x)

        # 2. Apply rest of the blocks
        for block_id, _ in enumerate(self.blocks, start=1):
            x = getattr(self, f'block_{block_id}')(x)

        # 3. Apply the global average pooling layer
        x = self.global_pool(x)

        # 4. Output
        output = self.output(x.flatten(1))

        return output

    def _repeated_block(self, in_channels, features, repeats, stride, name, start=1, bias=False):
        # 1. Create all but last repeated blocks
        blocks = list(map(
            lambda repeat_id: (
                f'{name}_{repeat_id}',
                self._block(
                    in_channels if repeat_id == 1 else 2*features,
                    features,
                    stride=stride if repeat_id == 1 else 1,
                    name=f'{name}_{repeat_id}'
                )
            ),
            range(1, repeats + 1)
        ))

        return nn.Sequential(OrderedDict(blocks))


    def _block(self, in_channels, features, stride, name, bias=False):
        return ResidualAdd(
            nn.Sequential(OrderedDict([
                (
                    name + "_conv1",
                    self.layer_fns['ConvNormActivation'](
                        in_channels=in_channels,
                        out_channels=features,
                        kernel_size=1,
                        bias=bias,
                    ),
                ),
                (
                    name + "_conv2",
                    self.layer_fns['ConvNormActivation'](
                        in_channels=features,
                        out_channels=features,
                        kernel_size=3,
                        padding=1,
                        stride=stride,
                        groups=32,
                        bias=bias,
                    ),
                ),
                # We use Conv + BatchNorm here. Activation is applied in the ResidualAdd module.
                (
                    name + "_conv3",
                    self.layer_fns['Conv'](
                        in_channels=features,
                        out_channels=features*2,
                        kernel_size=1,
                        bias=bias,
                    ),
                ),
                (
                    name + "_bn",
                    self.layer_fns['BatchNorm'](num_features=features*2),
                ),
            ])),
            in_channels=in_channels,
            out_channels=features*2,
            stride=stride,
            layer_fns=self.layer_fns,
            name=name,
        )

model = ResNext(
    in_channels=1,
    layer_fns=dict(
        ConvNormActivation=vision_ops.Conv3dNormActivation,
        Conv=nn.Conv3d,
        MaxPool=nn.MaxPool3d,
        AvgPool=nn.AdaptiveAvgPool3d,
        BatchNorm=nn.BatchNorm3d,
        ReLU=nn.ReLU,
    )
)

summary(
    model=model, 
    input_size=(2, 1, 5, 224, 224),
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"],
)

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
ResNext (ResNext)                                            [2, 1, 5, 224, 224]  [2, 1]               --                   True
├─Conv3dNormActivation (block_0)                             [2, 1, 5, 224, 224]  [2, 64, 3, 112, 112] --                   True
│    └─Conv3d (0)                                            [2, 1, 5, 224, 224]  [2, 64, 3, 112, 112] 21,952               True
│    └─BatchNorm3d (1)                                       [2, 64, 3, 112, 112] [2, 64, 3, 112, 112] 128                  True
│    └─ReLU (2)                                              [2, 64, 3, 112, 112] [2, 64, 3, 112, 112] --                   --
├─MaxPool3d (pool_0)                                         [2, 64, 3, 112, 112] [2, 64, 2, 56, 56]   --                   --
├─Sequential (block_1)                                       [2, 64, 2, 56, 56]   [2, 256, 2, 56