In [2]:
import torch
from torch import nn
from torchinfo import summary

In [18]:
class ResNextBlock(nn.Module):
    expansion = 2

    def __init__(self, in_channel, inner_channel, cardinality, stride=1, projection=None):
        super().__init__()

        self.residual = nn.Sequential(
            nn.Conv2d(in_channel, inner_channel, 1, bias=False),
            nn.BatchNorm2d(inner_channel),
            nn.ReLU(),

            nn.Conv2d(inner_channel, inner_channel, 3, stride=stride, padding=1, groups=cardinality, bias=False),
            nn.BatchNorm2d(inner_channel),
            nn.ReLU(),

            nn.Conv2d(inner_channel, inner_channel*self.expansion, 1, bias=False),
            nn.BatchNorm2d(inner_channel * self.expansion)
        )

        self.projection = projection
        self.relu = nn.ReLU()

    
    def forward(self, x):
        residual = self.residual(x)

        if self.projection is not None:
            shortcut = self.projection(x)
        else:
            shortcut = x
        
        output = self.relu(residual + shortcut)
        return output

In [24]:
class ResNext(nn.Module):
    def __init__(self, cardinality, block, num_block_list, num_class=1000):
        super().__init__()

        self.in_channel = 64
        self.cardinality = cardinality

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.conv2 = self.make_layer(block, num_block_list[0], inner_channel=128, stride=1)
        self.conv3 = self.make_layer(block, num_block_list[1], inner_channel=256, stride=2)
        self.conv4 = self.make_layer(block, num_block_list[2], inner_channel=512, stride=2)
        self.conv5 = self.make_layer(block, num_block_list[3], inner_channel=1024, stride=2)

        self.GlobalAvgPool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(1024*block.expansion, num_class)
        

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)

        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)

        x = self.GlobalAvgPool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        return x

    def make_layer(self, block, num_block, inner_channel, stride):
        if stride != 1 or self.in_channel != inner_channel*block.expansion:
            projection = nn.Sequential(
                nn.Conv2d(self.in_channel, inner_channel*block.expansion, 1, stride=stride),
                nn.BatchNorm2d(inner_channel*block.expansion)
            )
        else:
            projection = None
        
        layers = []
        layers += [block(self.in_channel, inner_channel, self.cardinality, stride=stride, projection=projection)]

        self.in_channel = inner_channel * block.expansion

        for _ in range(0, num_block-1):
            layers += [block(self.in_channel, inner_channel, self.cardinality)]

        return nn.Sequential(*layers)
        

model = ResNext(cardinality=32, block=ResNextBlock, num_block_list=[3, 4, 6, 3], num_class=1000)
summary(model, (2, 3, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
ResNext                                  [2, 1000]                 --
├─Sequential: 1-1                        [2, 64, 112, 112]         --
│    └─Conv2d: 2-1                       [2, 64, 112, 112]         9,408
│    └─BatchNorm2d: 2-2                  [2, 64, 112, 112]         128
│    └─ReLU: 2-3                         [2, 64, 112, 112]         --
├─MaxPool2d: 1-2                         [2, 64, 56, 56]           --
├─Sequential: 1-3                        [2, 256, 56, 56]          --
│    └─ResNextBlock: 2-4                 [2, 256, 56, 56]          --
│    │    └─Sequential: 3-1              [2, 256, 56, 56]          46,592
│    │    └─Sequential: 3-2              [2, 256, 56, 56]          17,152
│    │    └─ReLU: 3-3                    [2, 256, 56, 56]          --
│    └─ResNextBlock: 2-5                 [2, 256, 56, 56]          --
│    │    └─Sequential: 3-4              [2, 256, 56, 56]          71,168