## ResNext on CIFAR 10

In [None]:
import torch.nn as nn
import torch
import torch.nn.functional as F


# Building Block for ResNeXt (type (a))
class AggregatedBlock(nn.Module):
    cardinality = 8

    def __init__(self, in_channels, out_channels, stride=1):
        super(AggregatedBlock, self).__init__()

        assert out_channels % self.cardinality == 0
        mid_channels = out_channels // self.cardinality

        self.layer = []
        for _ in range(self.cardinality):
            self.layer.append(nn.Sequential(
                nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(mid_channels),
                nn.ReLU(),
                nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, bias=False),
                nn.BatchNorm2d(mid_channels),
                nn.ReLU(),
                nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(out_channels)
            ))
        self.layer = nn.ModuleList(self.layer)

        if in_channels != out_channels or stride != 1:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.downsample = None

    def forward(self, x):
        out = sum([b(x) for b in self.layer])

        residual = x
        if self.downsample is not None:
            residual = self.downsample(residual)

        return F.relu(out + residual)


# Building Block for ResNeXt (type (b))
class InceptionBlock(nn.Module):
    cardinality = 8

    def __init__(self, in_channels, out_channels, stride=1):
        super(InceptionBlock, self).__init__()

        assert out_channels % self.cardinality == 0
        mid_channels = out_channels // self.cardinality

        self.layer = []
        for i in range(self.cardinality):
            self.layer.append(nn.Sequential(
                nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(mid_channels),
                nn.ReLU(),
                nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, bias=False),
                nn.BatchNorm2d(mid_channels),
                nn.ReLU()
            ))
        self.layer = nn.ModuleList(self.layer)

        self.tail = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )

        if in_channels != out_channels or stride != 1:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.downsample = None

    def forward(self, x):
        out = [b(x) for b in self.layer]
        out = torch.cat(out, dim=1)
        out = self.tail(out)

        residual = x
        if self.downsample is not None:
            residual = self.downsample(residual)

        return F.relu(out + residual)


# Building Block for ResNeXt (type (c))
class GroupConvBlock(nn.Module):
    cardinality = 8

    def __init__(self, in_channels, out_channels, stride=1):
        super(GroupConvBlock, self).__init__()

        assert out_channels % self.cardinality == 0

        self.layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, groups=self.cardinality,
                      padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )

        if in_channels != out_channels or stride != 1:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.downsample = None

    def forward(self, x):
        out = self.layer(x)

        residual = x
        if self.downsample is not None:
            residual = self.downsample(residual)

        return F.relu(out + residual)


# This is an illustration of how groupconv can be implemented
class GroupConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1, bias=False):
        super(GroupConvLayer, self).__init__()

        assert in_channels % groups == 0 and out_channels % groups == 0

        self.groups = groups
        self.layer = []
        self.width = in_channels // groups
        for i in range(groups):
            self.layer.append(nn.Conv2d(
                in_channels=in_channels // groups,
                out_channels=out_channels // groups,
                kernel_size=kernel_size,
                stride=stride,
                bias=bias
            ))
        self.layer = nn.ModuleList(self.layer)

    def forward(self, x):
        w = self.width
        out = [layer(x[:, i * w: (i + 1) * w]) for i, layer in enumerate(self.layers)]

        return torch.cat(out, dim=1)


class ResNeXt(nn.Module):
    cardinality = 32

    def __init__(self, topology_type, num_blocks, num_classes=100):
        assert topology_type in ['a', 'b', 'c']

        super(ResNeXt, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16)
        )

        self.topology_type = topology_type

        if topology_type == 'a':
            block_type = AggregatedBlock
        elif topology_type == 'b':
            block_type = InceptionBlock
        elif topology_type == 'c':
            block_type = GroupConvBlock

        self.layer2 = [block_type(16, 16) for _ in range(num_blocks[0])]
        self.layer2 = nn.ModuleList(self.layer2)

        self.layer3 = [block_type(16, 32, 2)]
        self.layer3 += [block_type(32, 32) for _ in range(num_blocks[1] - 1)]
        self.layer3 = nn.ModuleList(self.layer3)

        self.layer4 = [block_type(32, 64, 2)]
        self.layer4 += [block_type(64, 64) for _ in range(num_blocks[2] - 1)]
        self.layer4 = nn.ModuleList(self.layer4)

        self.avgpool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64, num_classes)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.layer1(x)
        out = self.relu(out)

        for layer in self.layer2:
            out = layer(out)

        for layer in self.layer3:
            out = layer(out)

        for layer in self.layer4:
            out = layer(out)

        out = self.avgpool(out)
        out = torch.squeeze(out)
        out = self.fc(out)

        return out