In [86]:
import torch

import torch.nn as nn

from collections import OrderedDict
from torchinfo import summary
from functools import reduce
from itertools import accumulate

class ResNext2d(nn.Module):

    def __init__(
        self,
        in_channels=3,
        out_channels=1,
        # A tuple of (features, repeats, stride)
        # blocks=[(128, 1, 1), (256, 1, 2), (512, 1, 2), (1024, 1, 2)],
        blocks=[(128, 3, 1), (256, 4, 2), (512, 6, 2), (1024, 3, 2)],
    ):
        super(ResNext2d, self).__init__()

        # 1. Initialize block configurations
        self.blocks = blocks
        
        # 2. Get backbone's input and output sizes
        block_0_output, last_block_features = 64, blocks[-1][0]

        # 3. Initialize the first convolution block
        self.block_0 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=block_0_output,
            kernel_size=7,
            padding=3,
            stride=2,
            bias=False,
        )

        # 4. Initialize the pooling layer after the first convolution block
        self.pool_0 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # 5. Iterate over and create blocks
        block_input = block_0_output
        for block_id, (features, repeats, stride) in enumerate(blocks, start=1):
            name = f'block_{block_id}'
            setattr(
                self, name,
                ResNext2d._repeated_block(
                    block_input,
                    features,
                    repeats,
                    stride,
                    name=name
                )
            )
            block_input = 2*features

        # 6. Global average pooling to merge the spatial dimensions
        self.global_pool = nn.AdaptiveAvgPool2d(1)

        # 7. Output layer
        self.output = nn.Linear(
            in_features=last_block_features*2,
            out_features=out_channels,
        )
    
    def forward(self, x):
        # 1. Apply first block
        x = self.block_0(x)
        x = self.pool_0(x)

        # 2. Apply rest of the blocks
        for block_id, _ in enumerate(self.blocks, start=1):
            x = getattr(self, f'block_{block_id}')(x)

        # 3. Apply the global average pooling layer
        x = self.global_pool(x)

        # 4. Output
        output = self.output(x.flatten(1))

        return output

    @staticmethod
    def _repeated_block(in_channels, features, repeats, stride, name, start=1, bias=False):
        # 1. Create all but last repeated blocks
        blocks = list(map(
            lambda repeat_id: (
                f'{name}_{repeat_id}',
                ResNext2d._block(
                    in_channels if repeat_id == 1 else 2*features,
                    features,
                    stride=stride if repeat_id == 1 else 1,
                    name=f'{name}_{repeat_id}'
                )
            ),
            range(1, repeats + 1)
        ))

        return nn.Sequential(OrderedDict(blocks))


    @staticmethod
    def _block(in_channels, features, stride, name, bias=False):
        return nn.Sequential(
            OrderedDict([
                (
                    name + "_conv1",
                    nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=features,
                        kernel_size=1,
                        bias=bias,
                    ),
                ),
                (
                    name + "_conv2",
                    nn.Conv2d(
                        in_channels=features,
                        out_channels=features,
                        kernel_size=3,
                        padding=1,
                        stride=stride,
                        groups=32,
                        bias=bias,
                    ),
                ),
                (
                    name + "_conv3",
                    nn.Conv2d(
                        in_channels=features,
                        out_channels=features*2,
                        kernel_size=1,
                        bias=bias,
                    ),
                ),
                # (name + "norm2", nn.BatchNorm2d(num_features=features)),
                # (name + "relu2", nn.ReLU(inplace=True)),
            ])
        )

model = ResNext2d(in_channels=1)

summary(
    model=model, 
    input_size=(2, 1, 224, 224),
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"],
)
# print(model)

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
ResNext2d (ResNext2d)                    [2, 1, 224, 224]     [2, 1]               --                   True
├─Conv2d (block_0)                       [2, 1, 224, 224]     [2, 64, 112, 112]    3,136                True
├─MaxPool2d (pool_0)                     [2, 64, 112, 112]    [2, 64, 56, 56]      --                   --
├─Sequential (block_1)                   [2, 64, 56, 56]      [2, 256, 56, 56]     --                   True
│    └─Sequential (block_1_1)            [2, 64, 56, 56]      [2, 256, 56, 56]     --                   True
│    │    └─Conv2d (block_1_1_conv1)     [2, 64, 56, 56]      [2, 128, 56, 56]     8,192                True
│    │    └─Conv2d (block_1_1_conv2)     [2, 128, 56, 56]     [2, 128, 56, 56]     4,608                True
│    │    └─Conv2d (block_1_1_conv3)     [2, 128, 56, 56]     [2, 256, 56, 56]     32,768               True
│    └─Sequentia

In [96]:
import torch

import torch.nn as nn

from collections import OrderedDict
from torchinfo import summary
from functools import reduce
from itertools import accumulate

class ResNext(nn.Module):

    def __init__(
        self,
        in_channels=3,
        out_channels=1,
        # A tuple of (features, repeats, stride)
        # blocks=[(128, 1, 1), (256, 1, 2), (512, 1, 2), (1024, 1, 2)],
        blocks=[(128, 3, 1), (256, 4, 2), (512, 6, 2), (1024, 3, 2)],
        conv_fn=nn.Conv2d,
        mp_fn=nn.MaxPool2d,
    ):
        super(ResNext, self).__init__()

        # 1. Initialize block configurations
        self.blocks = blocks
        self.conv_fn = conv_fn
        self.mp_fn = mp_fn
        
        # 2. Get backbone's input and output sizes
        block_0_output, last_block_features = 64, blocks[-1][0]

        # 3. Initialize the first convolution block
        self.block_0 = self.conv_fn(
            in_channels=in_channels,
            out_channels=block_0_output,
            kernel_size=7,
            padding=3,
            stride=2,
            bias=False,
        )

        # 4. Initialize the pooling layer after the first convolution block
        self.pool_0 = self.mp_fn(kernel_size=3, stride=2, padding=1)

        # 5. Iterate over and create blocks
        block_input = block_0_output
        for block_id, (features, repeats, stride) in enumerate(blocks, start=1):
            name = f'block_{block_id}'
            setattr(
                self, name,
                self._repeated_block(
                    block_input,
                    features,
                    repeats,
                    stride,
                    name=name,
                )
            )
            block_input = 2*features

        # 6. Global average pooling to merge the spatial dimensions
        self.global_pool = nn.AdaptiveAvgPool2d(1)

        # 7. Output layer
        self.output = nn.Linear(
            in_features=last_block_features*2,
            out_features=out_channels,
        )
    
    def forward(self, x):
        # 1. Apply first block
        x = self.block_0(x)
        x = self.pool_0(x)

        # 2. Apply rest of the blocks
        for block_id, _ in enumerate(self.blocks, start=1):
            x = getattr(self, f'block_{block_id}')(x)

        # 3. Apply the global average pooling layer
        x = self.global_pool(x)

        # 4. Output
        output = self.output(x.flatten(1))

        return output

    def _repeated_block(self, in_channels, features, repeats, stride, name, start=1, bias=False):
        # 1. Create all but last repeated blocks
        blocks = list(map(
            lambda repeat_id: (
                f'{name}_{repeat_id}',
                self._block(
                    in_channels if repeat_id == 1 else 2*features,
                    features,
                    stride=stride if repeat_id == 1 else 1,
                    name=f'{name}_{repeat_id}'
                )
            ),
            range(1, repeats + 1)
        ))

        return nn.Sequential(OrderedDict(blocks))


    def _block(self, in_channels, features, stride, name, bias=False):
        return nn.Sequential(
            OrderedDict([
                (
                    name + "_conv1",
                    self.conv_fn(
                        in_channels=in_channels,
                        out_channels=features,
                        kernel_size=1,
                        bias=bias,
                    ),
                ),
                (
                    name + "_conv2",
                    self.conv_fn(
                        in_channels=features,
                        out_channels=features,
                        kernel_size=3,
                        padding=1,
                        stride=stride,
                        groups=32,
                        bias=bias,
                    ),
                ),
                (
                    name + "_conv3",
                    self.conv_fn(
                        in_channels=features,
                        out_channels=features*2,
                        kernel_size=1,
                        bias=bias,
                    ),
                ),
                # (name + "norm2", nn.BatchNorm2d(num_features=features)),
                # (name + "relu2", nn.ReLU(inplace=True)),
            ])
        )

model = ResNext(in_channels=1, conv_fn=nn.Conv3d, mp_fn=nn.MaxPool3d)

summary(
    model=model, 
    input_size=(2, 1, 5, 224, 224),
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"],
)

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
ResNext (ResNext)                        [2, 1, 5, 224, 224]  [2, 1]               --                   True
├─Conv3d (block_0)                       [2, 1, 5, 224, 224]  [2, 64, 3, 112, 112] 21,952               True
├─MaxPool3d (pool_0)                     [2, 64, 3, 112, 112] [2, 64, 2, 56, 56]   --                   --
├─Sequential (block_1)                   [2, 64, 2, 56, 56]   [2, 256, 2, 56, 56]  --                   True
│    └─Sequential (block_1_1)            [2, 64, 2, 56, 56]   [2, 256, 2, 56, 56]  --                   True
│    │    └─Conv3d (block_1_1_conv1)     [2, 64, 2, 56, 56]   [2, 128, 2, 56, 56]  8,192                True
│    │    └─Conv3d (block_1_1_conv2)     [2, 128, 2, 56, 56]  [2, 128, 2, 56, 56]  13,824               True
│    │    └─Conv3d (block_1_1_conv3)     [2, 128, 2, 56, 56]  [2, 256, 2, 56, 56]  32,768               True
│    └─Sequentia