[**ResNeXt Blocks (Xie et al., 2017):**](https://openaccess.thecvf.com/content_cvpr_2017/papers/Xie_Aggregated_Residual_Transformations_CVPR_2017_paper.pdf) is an advancement for sparser connections, extending ResNet with grouped convolutions to enhance efficiency and performance.

<div>
<img src="./imgs/resnext.png" style="width: 800px;">
</div>

ResNeXt is an extension of ResNet that introduces grouped convolutions through a hyperparameter called cardinality, allowing the model to learn more diverse features efficiently.

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import utils

In [2]:
class ResNeXtBlock(nn.Module):
    def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, cardinality=32):
        """ ResNeXt Block with grouped convolutions. """
        super().__init__()

        # 1x1 convolution to reduce channels
        self.conv1 = nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1, stride=1)
        self.bn1 = nn.BatchNorm2d(bottleneck_channels)
        self.relu = nn.ReLU(inplace=True)

        # 3x3 grouped convolution with cardinality
        self.conv2 = nn.Conv2d(
            bottleneck_channels, bottleneck_channels, kernel_size=3, stride=stride, padding=1, groups=cardinality
        )
        self.bn2 = nn.BatchNorm2d(bottleneck_channels)

        # 1x1 convolution to restore channels
        self.conv3 = nn.Conv2d(bottleneck_channels, out_channels, kernel_size=1, stride=1)
        self.bn3 = nn.BatchNorm2d(out_channels)
        
        # Shortcut connection
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        # Main path
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))

        # Add shortcut
        out += self.shortcut(x)
        out = self.relu(out)
        
        return out

In [3]:
class ResNeXt(nn.Module):
    def __init__(self, block, layers, cardinality=32, num_classes=1000):
        super().__init__()
        self.cardinality = cardinality

        # Channel configurations for the stages
        self.channels = [256, 512, 1024, 2048]

        # Stem
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Four stages
        self.layer1 = self.make_stage(block, 64, self.channels[0], layers[0], stride=1)
        self.layer2 = self.make_stage(block, self.channels[0], self.channels[1], layers[1], stride=2)
        self.layer3 = self.make_stage(block, self.channels[1], self.channels[2], layers[2], stride=2)
        self.layer4 = self.make_stage(block, self.channels[2], self.channels[3], layers[3], stride=2)

        # Global average pooling and fully connected layer
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(self.channels[3], num_classes)

    def make_stage(self, block, in_channels, out_channels, num_blocks, stride):
        """Helper function to create a stage of ResNeXt blocks."""
        layers = []
        bottleneck_channels = out_channels // 4  # Bottleneck ratio of 4
        # First block with specified stride
        layers.append(block(in_channels, bottleneck_channels, out_channels, stride, self.cardinality))
        # Subsequent blocks with stride 1
        for _ in range(1, num_blocks):
            layers.append(block(out_channels, bottleneck_channels, out_channels, 1, self.cardinality))
        return nn.Sequential(*layers)

    def forward(self, x):
        # Stem
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)

        # Stages
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        # Pooling and classification
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [4]:
def ResNeXt50(cardinality=32, num_classes=10):
    """ResNeXt-50: 4 stages with 3, 4, 6, 3 blocks."""
    return ResNeXt(ResNeXtBlock, [3, 4, 6, 3], cardinality, num_classes)

def ResNeXt101(cardinality=32, num_classes=10):
    """ResNeXt-101: 4 stages with 3, 4, 23, 3 blocks."""
    return ResNeXt(ResNeXtBlock, [3, 4, 23, 3], cardinality, num_classes)

In [6]:
data = utils.CIFAR10DataLoader(batch_size=64, resize=(224, 224))
train_loader = data.get_train_loader()
test_loader = data.get_test_loader()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ResNeXt50(num_classes=10)
model.apply(utils.init_kaiming).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

epochs = 10
for epoch in range(epochs):
    train_loss, train_acc = utils.train_step(train_loader, model, criterion, optimizer, device)
    test_loss, test_acc = utils.eval_step(test_loader, model, criterion, device)
    print(f"Epoch {epoch + 1:>{len(str(epochs))}}/{epochs} | "
          f"Train Loss: {train_loss:.4f} | "
          f"Test Loss: {test_loss:.4f} | "
          f"Test Acc: {test_acc:.4f}")

Epoch  1/10 | Train Loss: 2.2122 | Test Loss: 2.2726 | Test Acc: 0.4027
Epoch  2/10 | Train Loss: 1.4719 | Test Loss: 1.4267 | Test Acc: 0.5245
Epoch  3/10 | Train Loss: 1.2268 | Test Loss: 1.3134 | Test Acc: 0.5940
Epoch  4/10 | Train Loss: 1.0395 | Test Loss: 1.1424 | Test Acc: 0.6478
Epoch  5/10 | Train Loss: 0.8877 | Test Loss: 1.1420 | Test Acc: 0.6802
Epoch  6/10 | Train Loss: 0.7581 | Test Loss: 1.0671 | Test Acc: 0.7154
Epoch  7/10 | Train Loss: 0.6502 | Test Loss: 0.8493 | Test Acc: 0.7369
Epoch  8/10 | Train Loss: 0.5499 | Test Loss: 1.0137 | Test Acc: 0.7229
Epoch  9/10 | Train Loss: 0.4665 | Test Loss: 0.8088 | Test Acc: 0.7280
Epoch 10/10 | Train Loss: 0.3924 | Test Loss: 0.9457 | Test Acc: 0.7550
