<a href="https://colab.research.google.com/github/IANGECHUKI176/deeplearning/blob/main/pytorch/convnets/resnext.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

original [paper](https://arxiv.org/pdf/1611.05431.pdf)

ResNeXt is a simple, highly modularized network architecture for image classification. The
network is constructed by repeating a building block that aggregates a set of transformations
with the same topology. The simple design results in a homogeneous, multi-branch architecture
that has only a few hyper-parameters to set. This strategy exposes a new dimension, which is
referred as “cardinality” (the size of the set of transformations), as an essential factor in
addition to the dimensions of depth and width.

We can think of cardinality as the set of separate conv block representing same complexity as
when those blocks are combined together to make a single block.

Blog:
#### Citation ####

PyTorch Code: https://github.com/Mayurji/Image-Classification-PyTorch/blob/main/ResNeXt.py

@article{Xie2016,
  title={Aggregated Residual Transformations for Deep Neural Networks},
  author={Saining Xie and Ross Girshick and Piotr Dollár and Zhuowen Tu and Kaiming He},
  journal={arXiv preprint arXiv:1611.05431},
  year={2016}
}


In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchsummary import summary

In [None]:
class Block(nn.Module):
    '''Grouped convolution block.'''
    expansion = 2
    def __init__(self,in_planes,cardinality = 32,bottleneck_width = 4,stride = 1):
        super(Block,self).__init__()
        group_width = cardinality * bottleneck_width
        self.conv1 = nn.Conv2d(in_planes,group_width,kernel_size = 1,bias = False)
        self.bn1 = nn.BatchNorm2d(group_width)
        """group=cardinality, it divides the out_channel by 32(cardinality) i.e. thus, divides channel 128 into 4"""
        self.conv2 = nn.Conv2d(group_width,group_width,kernel_size = 3,stride = stride,padding = 1,groups = cardinality,bias = False)
        self.bn2 = nn.BatchNorm2d(group_width)
        self.conv3 = nn.Conv2d(group_width,self.expansion * group_width,kernel_size = 1,bias = False)
        self.bn3 = nn.BatchNorm2d(self.expansion * group_width)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * group_width:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes,self.expansion * group_width,kernel_size = 1,stride = stride,bias = False),
                nn.BatchNorm2d(self.expansion * group_width)
            )

    def forward(self,x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        short_cut = self.shortcut(x)
        out = out + short_cut
        out = F.relu(out)
        return out

In [None]:
# blk = Block(3)
# summary(blk,(3,224,224))

In [None]:
class ResNeXt(nn.Module):
    def __init__(self,num_blocks,cardinality,bottleneck_width,num_classes = 10):
        super(ResNeXt,self).__init__()
        self.cardinality = cardinality
        self.bottleneck_width = bottleneck_width

        self.in_planes = 64
        self.conv1 = nn.Conv2d(3,64,kernel_size = 1,bias = False)
        self.bn1 = nn.BatchNorm2d(64)

        self.layer1 = self._make_layers(num_blocks[0],1)
        self.layer2 = self._make_layers(num_blocks[1],2)
        self.layer3 = self._make_layers(num_blocks[2],2)

        #self.layer4  = self._make_layers(num_blocks[3],2)
        print('bottleneck width',self.bottleneck_width)
        self.linear = nn.Linear(cardinality*self.bottleneck_width,num_classes)
    def _make_layers(self,num_blocks,stride):
        strides = [stride] + [1]*(num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(Block(self.in_planes,self.cardinality,self.bottleneck_width,stride))
            self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width
        # Increase bottleneck_width by 2 after each stage
        self.bottleneck_width *= 2
        return nn.Sequential(*layers)

    def forward(self,x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        #out = self.layer4(out)
        out = F.adaptive_avg_pool2d(out,1)
        out = out.view(out.size(0),-1)
        out = self.linear(out)
        return out

In [None]:
# blk1 = ResNeXt(num_blocks= [3, 4, 6, 3], cardinality=32, bottleneck_width=4)
# summary(blk1,(3,224,224))

In [None]:
def resnext50():
    """ return a resnext50(c32x4d) network
    """
    return ResNeXt(num_blocks= [3, 4, 6, 3], cardinality=32, bottleneck_width=4)

def resnext101():
    """ return a resnext101(c32x4d) network
    """
    return ResNeXt(num_blocks=  [3, 4, 23, 3], cardinality=32, bottleneck_width=4)

def resnext152():
    """ return a resnext101(c32x4d) network
    """
    return ResNeXt(num_blocks= [3, 4, 36, 3], cardinality=32, bottleneck_width=4)

# alternative implementation on this : [github](https://github.com/YeonwooSung/PyTorch_CNN_Architectures/blob/master/models/resnext.py)

In [None]:
CARDINALITY = 32 #How many groups a feature map was splitted into
DEPTH = 4
BASEWIDTH = 64
#"""The grouped convolutional layer in Fig. 3(c) performs 32 groups
#of convolutions whose input and output channels are 4-dimensional.
#The grouped convolutional layer concatenates them as the outputs
#of the layer."""
class ResNextBottleNeckC(nn.Module):
    def __init__(self,in_channels,out_channels,stride):
        super(ResNextBottleNeckC,self).__init__()

        intermediate_channels = CARDINALITY * ((DEPTH * out_channels) // BASEWIDTH)
        #"""We note that the input/output width of the template is fixed as
        #256-d (Fig. 3), We note that the input/output width of the template
        #is fixed as 256-d (Fig. 3), and all widths are dou- bled each time
        #when the feature map is subsampled (see Table 1)."""
        self.split_transforms = nn.Sequential(
            nn.Conv2d(in_channels,intermediate_channels,kernel_size = 1,groups = CARDINALITY,bias = False),
            nn.BatchNorm2d(intermediate_channels),
            nn.ReLU(inplace = True),
            nn.Conv2d(intermediate_channels,intermediate_channels,kernel_size = 3,stride = stride,groups = CARDINALITY,padding = 1,bias = False),
            nn.BatchNorm2d(intermediate_channels),
            nn.ReLU(inplace = True),
            nn.Conv2d(intermediate_channels,out_channels*4,kernel_size = 1,bias = False),
            nn.BatchNorm2d(out_channels*4)
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels*4:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels,out_channels * 4,kernel_size = 1,stride = stride,bias = False),
                nn.BatchNorm2d(out_channels * 4)
            )

    def forward(self,x):
        out = self.split_transforms(x)
        shortcut = self.shortcut(x)
        out = F.relu(out + shortcut)
        return out

In [None]:
class ResNext(nn.Module):
    def __init__(self,block,num_blocks,n_classes = 10):
        super(ResNext,self).__init__()
        self.in_channels = 64

        self.conv1 = nn.Sequential(nn.Conv2d(3,64,kernel_size = 3,stride = 1,padding = 1,bias = False),
                                   nn.BatchNorm2d(64),
                                   nn.ReLU(inplace = True))
        self.conv2 = self._make_layers(block,num_blocks[0],64,1)
        self.conv3 = self._make_layers(block,num_blocks[1],128,2)
        self.conv4 = self._make_layers(block,num_blocks[2],256,2)
        self.conv5 = self._make_layers(block,num_blocks[3],512,2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
        self.linear = nn.Linear(512 * 4,n_classes)

    def _make_layers(self,block,num_blocks,out_channels,stride):
        strides = [stride] + [1]*(num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels,out_channels,stride))
            self.in_channels = out_channels * 4
        return nn.Sequential(*layers)
    def forward(self,x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.conv5(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0),-1)
        out = self.linear(out)

In [None]:
def resnext50():
    """ return a resnext50(c32x4d) network
    """
    return ResNext(ResNextBottleNeckC, [3, 4, 6, 3])

def resnext101():
    """ return a resnext101(c32x4d) network
    """
    return ResNext(ResNextBottleNeckC, [3, 4, 23, 3])

def resnext152():
    """ return a resnext101(c32x4d) network
    """
    return ResNext(ResNextBottleNeckC, [3, 4, 36, 3])