In [3]:

import torch
import torch.nn as nn

In [110]:
class Residual_Block(nn.Module):
    def __init__(self, in_channels, intermediate_channels, stride, factor):
        super(Residual_Block, self).__init__()
        self.factor = factor
        self.identity_downsampling_layer = None
        self.final_out_channels = self.factor * intermediate_channels
        self.conv1 = nn.Conv2d(
            in_channels, 
            intermediate_channels, 
            kernel_size=1, 
            stride=1, 
            padding=0, 
            bias=False)
        self.batchNorm1 = nn.BatchNorm2d(intermediate_channels)
        
        self.conv2 = nn.Conv2d(
            intermediate_channels,
            intermediate_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias = False
        )
        self.batchNorm2 = nn.BatchNorm2d(intermediate_channels)
        
        self.conv3 = nn.Conv2d(
            intermediate_channels,
            out_channels=self.final_out_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False
        )
        self.batchNorm3 = nn.BatchNorm2d(self.final_out_channels)
        self.relu = nn.ReLU()
        
        if in_channels != self.final_out_channels:
            print('down')
            self.identity_downsampling_layer = nn.Sequential(
                nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=self.final_out_channels,
                    kernel_size=1,
                    stride=stride,
                    bias=False
                ),
                nn.BatchNorm2d(self.final_out_channels)
            )
    def forward(self, x):
        identity = x.clone() # copying the tensor to the identity var
        # print('\nInside Res Block: ')
        # print(identity.shape)
        
        x = self.conv1(x)
        x = self.batchNorm1(x)
        x = self.relu(x)
        print('After 1 convolutions: ', x.shape)
        
        x = self.conv2(x)
        x = self.batchNorm2(x)
        x = self.relu(x)
        print('After 2 convolutions: ', x.shape)
        
        x = self.conv3(x)
        x = self.batchNorm3(x)
        x = self.relu(x)
        print('After 3 convolutions: ', x.shape)
        
        if self.identity_downsampling_layer:
            print('Here!')
            identity = self.identity_downsampling_layer(identity)
            print('After Downsampling: ', identity.shape)
        
        print(x.shape, identity.shape)
        assert x.shape == identity.shape
        x = x + identity
        x = self.relu(x)
        print('At end of Res Block: ', x.shape)
        return x

In [111]:
class Resnet(nn.Module):
    def __init__(self, Residual_Block, num_layers, intermediate_channels, num_classes, image_channels):
        super(Resnet, self).__init__()
        self.res_block_inchannel = 64
        self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.batchNorm1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer1 = self._make_layer(
            Residual_Block, num_layers[0], intermediate_channels[0], stride=1,
        )
        
        self.layer2 = self._make_layer(
            Residual_Block, num_layers[1], intermediate_channels[1], stride=2,
        )
        
        self.layer3 = self._make_layer(
            Residual_Block, num_layers[2], intermediate_channels[2], stride=2,
        )
        
        self.layer4 = self._make_layer(
            Residual_Block, num_layers[3], intermediate_channels[3], stride=2,
        )
        
        self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(intermediate_channels[-1] * 4, num_classes)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.batchNorm1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        print(x.shape) # shape: (56,56,64)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        print(x.shape)
        x = self.avg_pool(x)
        print('After average pooling: ', x.shape)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        return x

    def _make_layer(self, Residual_Block, num_layers, intermediate_channels, stride):
        layers = []
        
        layers.append(
            Residual_Block(self.res_block_inchannel, intermediate_channels, stride=stride, factor=4)
        )
        
        self.res_block_inchannel = intermediate_channels * 4
        # print(next_channel)
        
        for layer in range(num_layers - 1):
            layers.append(
                Residual_Block(self.res_block_inchannel, intermediate_channels, stride=1, factor=4)
            )
        
        return nn.Sequential(*layers)

In [114]:
def test():
    net = Resnet(Residual_Block,[3,4,6,3], [64, 128, 256, 512], 1000, 3)
    print(net)
    y = net(torch.randn(4, 3, 224, 224)).to("cuda")
    print(y.size())
    print(y)

test()

down
down
down
down
Resnet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (batchNorm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Residual_Block(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (batchNorm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (batchNorm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (batchNorm3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (identity_downsampling_layer): Sequential(
        (0): Co