In [1]:
# Codeblock 1
import torch
import torch.nn as nn
from torchinfo import summary

In [2]:
# Codeblock 2
CARDINALITY  = 32              #(1)
NUM_CHANNELS = [3, 64, 256, 512, 1024, 2048]  #(2)
NUM_BLOCKS   = [3, 4, 6, 3]    #(3)
NUM_CLASSES  = 1000            #(4)

In [7]:
# Codeblock 3a
class Block(nn.Module):
    def __init__(self, 
                 in_channels,            #(1)
                 add_channel=False,      #(2)
                 channel_multiplier=2,   #(3)
                 downsample=False):      #(4)
        super().__init__()
        

        self.add_channel = add_channel
        self.channel_multiplier = channel_multiplier
        self.downsample = downsample
        
        
        if self.add_channel:             #(5)
            out_channels = in_channels*self.channel_multiplier  #(6)
        else:
            out_channels = in_channels   #(7) 
        
        mid_channels = out_channels//2   #(8).
        
        
        if self.downsample:      #(9)
            stride = 2           #(10)
        else:
            stride = 1
            
            
# Codeblock 3b
        if self.add_channel or self.downsample:    #(1)
            self.projection = nn.Conv2d(in_channels=in_channels,    #(2) 
                                        out_channels=out_channels, 
                                        kernel_size=1, 
                                        stride=stride, 
                                        padding=0, 
                                        bias=False)
            nn.init.kaiming_normal_(self.projection.weight, nonlinearity='relu')
            self.bn_proj = nn.BatchNorm2d(num_features=out_channels)
        

        self.conv0 = nn.Conv2d(in_channels=in_channels,       #(3)
                               out_channels=mid_channels,     #(4)
                               kernel_size=1, 
                               stride=1, 
                               padding=0, 
                               bias=False)
        nn.init.kaiming_normal_(self.conv0.weight, nonlinearity='relu')
        self.bn0 = nn.BatchNorm2d(num_features=mid_channels)
        

        self.conv1 = nn.Conv2d(in_channels=mid_channels,      #(5)
                               out_channels=mid_channels, 
                               kernel_size=3, 
                               stride=stride,                 #(6)
                               padding=1, 
                               bias=False, 
                               groups=CARDINALITY)            #(7)
        nn.init.kaiming_normal_(self.conv1.weight, nonlinearity='relu')
        self.bn1 = nn.BatchNorm2d(num_features=mid_channels)
        

        self.conv2 = nn.Conv2d(in_channels=mid_channels,      #(8)
                               out_channels=out_channels,     #(9)
                               kernel_size=1, 
                               stride=1, 
                               padding=0, 
                               bias=False)
        nn.init.kaiming_normal_(self.conv2.weight, nonlinearity='relu')
        self.bn2 = nn.BatchNorm2d(num_features=out_channels)
        
        self.relu = nn.ReLU()
        
# Codeblock 3c
    def forward(self, x):
        #print(f'original\t\t: {x.size()}')
        
        if self.add_channel or self.downsample:              #(1)
            residual = self.bn_proj(self.projection(x))      #(2)
            #print(f'after projection\t: {residual.size()}')
        else:
            residual = x                                     #(3)
            #print(f'no projection\t\t: {residual.size()}')
        
        x = self.conv0(x)    #(4)
        x = self.bn0(x)
        x = self.relu(x)
        #print(f'after conv0-bn0-relu\t: {x.size()}')

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        #print(f'after conv1-bn1-relu\t: {x.size()}')
        
        x = self.conv2(x)    #(5)
        x = self.bn2(x)
        #print(f'after conv2-bn2\t\t: {x.size()}')
        
        x = x + residual
        x = self.relu(x)     #(6)
        #print(f'after summation\t\t: {x.size()}')
        
        return x

In [4]:
# Codeblock 4
block = Block(in_channels=256, add_channel=True, downsample=True)
x = torch.randn(1, 256, 56, 56)

out = block(x)

original		: torch.Size([1, 256, 56, 56])
after projection	: torch.Size([1, 512, 28, 28])
after conv0-bn0-relu	: torch.Size([1, 256, 56, 56])
after conv1-bn1-relu	: torch.Size([1, 256, 28, 28])
after conv2-bn2		: torch.Size([1, 512, 28, 28])
after summation		: torch.Size([1, 512, 28, 28])


In [5]:
# Codeblock 5
block = Block(in_channels=512, add_channel=False, downsample=False)
x = torch.randn(1, 512, 28, 28)

out = block(x)

original		: torch.Size([1, 512, 28, 28])
no projection		: torch.Size([1, 512, 28, 28])
after conv0-bn0-relu	: torch.Size([1, 256, 28, 28])
after conv1-bn1-relu	: torch.Size([1, 256, 28, 28])
after conv2-bn2		: torch.Size([1, 512, 28, 28])
after summation		: torch.Size([1, 512, 28, 28])


In [6]:
# Codeblock 6
block = Block(in_channels=64, add_channel=True, channel_multiplier=4, downsample=False)
x = torch.randn(1, 64, 56, 56)

out = block(x)

original		: torch.Size([1, 64, 56, 56])
after projection	: torch.Size([1, 256, 56, 56])
after conv0-bn0-relu	: torch.Size([1, 128, 56, 56])
after conv1-bn1-relu	: torch.Size([1, 128, 56, 56])
after conv2-bn2		: torch.Size([1, 256, 56, 56])
after summation		: torch.Size([1, 256, 56, 56])


In [10]:
# Codeblock 7a
class ResNeXt(nn.Module):
    def __init__(self):
        super().__init__()
        

        # conv1 stage  #(1)
        self.resnext_conv1 = nn.Conv2d(in_channels=NUM_CHANNELS[0],
                                       out_channels=NUM_CHANNELS[1],
                                       kernel_size=7,    #(2) 
                                       stride=2,         #(3)
                                       padding=3, 
                                       bias=False)
        nn.init.kaiming_normal_(self.resnext_conv1.weight, 
                                nonlinearity='relu')
        self.resnext_bn1 = nn.BatchNorm2d(num_features=NUM_CHANNELS[1])
        self.relu = nn.ReLU()
        self.resnext_maxpool1 = nn.MaxPool2d(kernel_size=3,    #(4)
                                             stride=2, 
                                             padding=1)
        

        # conv2 stage  #(5)
        self.resnext_conv2 = nn.ModuleList([
            Block(in_channels=NUM_CHANNELS[1],
                  add_channel=True,       #(6)
                  channel_multiplier=4,
                  downsample=False)       #(7)
        ])
        for _ in range(NUM_BLOCKS[0]-1):  #(8)
            self.resnext_conv2.append(Block(in_channels=NUM_CHANNELS[2]))
            

        # conv3 stage  #(9)
        self.resnext_conv3 = nn.ModuleList([Block(in_channels=NUM_CHANNELS[2],  #(10)
                                                  add_channel=True, 
                                                  downsample=True)])
        for _ in range(NUM_BLOCKS[1]-1):    #(11)
            self.resnext_conv3.append(Block(in_channels=NUM_CHANNELS[3]))
            
            
        # conv4 stage  #(12)
        self.resnext_conv4 = nn.ModuleList([Block(in_channels=NUM_CHANNELS[3],  #(13)
                                                  add_channel=True, 
                                                  downsample=True)])
        
        for _ in range(NUM_BLOCKS[2]-1):    #(14)
            self.resnext_conv4.append(Block(in_channels=NUM_CHANNELS[4]))
            
            
        # conv5 stage  #(15)
        self.resnext_conv5 = nn.ModuleList([Block(in_channels=NUM_CHANNELS[4],  #(16)
                                                  add_channel=True, 
                                                  downsample=True)])
        
        for _ in range(NUM_BLOCKS[3]-1):    #(17)
            self.resnext_conv5.append(Block(in_channels=NUM_CHANNELS[5]))
 
       
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))  #(18)

        self.fc = nn.Linear(in_features=NUM_CHANNELS[5],        #(19)
                            out_features=NUM_CLASSES)
        
# Codeblock 7b
    def forward(self, x):
        #print(f'original\t\t: {x.size()}')
        
        x = self.relu(self.resnext_bn1(self.resnext_conv1(x)))
        #print(f'after resnext_conv1\t: {x.size()}')
        
        x = self.resnext_maxpool1(x)
        #print(f'after resnext_maxpool1\t: {x.size()}')
        
        for i, block in enumerate(self.resnext_conv2):
            x = block(x)
            #print(f'after resnext_conv2 #{i}\t: {x.size()}')
            
        for i, block in enumerate(self.resnext_conv3):
            x = block(x)
            #print(f'after resnext_conv3 #{i}\t: {x.size()}')
            
        for i, block in enumerate(self.resnext_conv4):
            x = block(x)
            #print(f'after resnext_conv4 #{i}\t: {x.size()}')
            
        for i, block in enumerate(self.resnext_conv5):
            x = block(x)
            #print(f'after resnext_conv5 #{i}\t: {x.size()}')
        
        x = self.avgpool(x)
        #print(f'after avgpool\t\t: {x.size()}')
        
        x = torch.flatten(x, start_dim=1)
        #print(f'after flatten\t\t: {x.size()}')
        
        x = self.fc(x)
        #print(f'after fc\t\t: {x.size()}')
        
        return x

In [9]:
# Codeblock 8
resnext = ResNeXt()
x = torch.randn(1, 3, 224, 224)

out = resnext(x)

original		: torch.Size([1, 3, 224, 224])
after resnext_conv1	: torch.Size([1, 64, 112, 112])
after resnext_maxpool1	: torch.Size([1, 64, 56, 56])
after resnext_conv2 #0	: torch.Size([1, 256, 56, 56])
after resnext_conv2 #1	: torch.Size([1, 256, 56, 56])
after resnext_conv2 #2	: torch.Size([1, 256, 56, 56])
after resnext_conv3 #0	: torch.Size([1, 512, 28, 28])
after resnext_conv3 #1	: torch.Size([1, 512, 28, 28])
after resnext_conv3 #2	: torch.Size([1, 512, 28, 28])
after resnext_conv3 #3	: torch.Size([1, 512, 28, 28])
after resnext_conv4 #0	: torch.Size([1, 1024, 14, 14])
after resnext_conv4 #1	: torch.Size([1, 1024, 14, 14])
after resnext_conv4 #2	: torch.Size([1, 1024, 14, 14])
after resnext_conv4 #3	: torch.Size([1, 1024, 14, 14])
after resnext_conv4 #4	: torch.Size([1, 1024, 14, 14])
after resnext_conv4 #5	: torch.Size([1, 1024, 14, 14])
after resnext_conv5 #0	: torch.Size([1, 2048, 7, 7])
after resnext_conv5 #1	: torch.Size([1, 2048, 7, 7])
after resnext_conv5 #2	: torch.Size([1, 

In [11]:
# Codeblock 9
resnext = ResNeXt()
summary(resnext, input_size=(1, 3, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
ResNeXt                                  [1, 1000]                 --
├─Conv2d: 1-1                            [1, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         128
├─ReLU: 1-3                              [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 56, 56]           --
├─ModuleList: 1-5                        --                        --
│    └─Block: 2-1                        [1, 256, 56, 56]          --
│    │    └─Conv2d: 3-1                  [1, 256, 56, 56]          16,384
│    │    └─BatchNorm2d: 3-2             [1, 256, 56, 56]          512
│    │    └─Conv2d: 3-3                  [1, 128, 56, 56]          8,192
│    │    └─BatchNorm2d: 3-4             [1, 128, 56, 56]          256
│    │    └─ReLU: 3-5                    [1, 128, 56, 56]          --
│    │    └─Conv2d: 3-6                  [1, 128, 56, 56]          4,608