# Model File

This file's purpose is to develop the RestNet CNN based on the following architecture: https://medium.com/@sharma.tanish096/detailed-explanation-of-residual-network-resnet50-cnn-model-106e0ab9fa9e

In [67]:
import torch 
from torch import nn

torch.__version__

'2.2.2'

In [68]:
import torchvision

torchvision.__version__

'0.17.2'

In [69]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [70]:
from importnb import Notebook
with Notebook():
    import dataloader

In [71]:
# Importing all relevant variables
train_data = dataloader.train_data
BATCH_SIZE = dataloader.BATCH_SIZE


In [72]:
img, label = train_data[0][0], train_data[0][1]
img.shape

torch.Size([3, 224, 224])

For understanding the hyperparameters set beneath, reference the following website:

https://poloclub.github.io/cnn-explainer/

This model uses a TinyVGG CNN, not a RestNet50 CNN, but the hyperparameter explanation is very well done. 

In [73]:
import torch
from torch import nn

class ResidualBlock(nn.Module):
    """
    Residual block with the bottleneck architecture. Key integration of the RestNet50 architecture aimed at tackling the vanishing gradient problem.
    Essentially, this block provides the model with a helper path that skips some layers from the input to the output, allowing the residual to be learned more easily. 
    This situation arises when the gradients become increasingly smaller, causing the earlier layers during back propagation to receive exponentially smaller gradients, preventing the model from learning.
    
    This block contains three mini-layers: 1x1, 3x3, and 1x1 convolutions. We compress the data, then extract spatial features, and then compress again to its original state. 
    """

    def __init__(self, in_channels, mid_channels, out_channels, stride=1):
        super().__init__()

        self.conv_block_1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=stride, bias=False)
        self.bn_block_1 = nn.BatchNorm2d(mid_channels)
        self.conv_block_2 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn_block_2 = nn.BatchNorm2d(mid_channels)
        self.conv_block_3 = nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=1, bias=False)
        self.bn_block_3 = nn.BatchNorm2d(out_channels)

        # Shortcut connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels: # True if there are mismatched dimensions amongst the input and output channels
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), # 1x1 convolution
                nn.BatchNorm2d(out_channels) # We normalize to match the input to the size of the output
            )

    def forward(self, x):
        shortcut = self.shortcut(x) # Apply the shortcut (identity or adjusted input)
        x = nn.ReLU()(self.bn_block_1(self.conv_block_1(x))) # 1st layer: 1x1 convolution + batch norm + ReLU
        x = nn.ReLU()(self.bn_block_2(self.conv_block_2(x))) # 2nd layer: 3x3 convolution + batch norm + ReLU
        x = self.bn_block_3(self.conv_block_3(x)) # 3rd layer: 1x1 convolution + batch norm
        x += shortcut # Add shortcut (residual connection)
        return nn.ReLU()(x) # Apply ReLU to the final output

In [74]:
from torch import nn
import torch

class RestNet(nn.Module):
    def __init__(self, input_shape, output_shape):
        super().__init__()

        self.conv_block_1 = nn.Sequential(
            nn.Conv2d(input_shape, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        # This corresponds to the 4 residual stages after the initial convolution and pooling layers. Early stages focus on basic patterns while later stages focus on abstract representations
        self.conv_block_2 = self._make_stage(64, 64, 256, num_blocks=3, stride=1) # Extracts low-level features like edges and simple textures without spatial reduction.
        self.conv_block_3 = self._make_stage(256, 128, 512, num_blocks=4, stride=2) # Captures more complex features and reduces the spatial resolution
        self.conv_block_4 = self._make_stage(512, 256, 1024, num_blocks=6, stride=2) # Processes high-level features like object parts or shapes and further reduces spatial resolution
        self.conv_block_5 = self._make_stage(1024, 512, 2048, num_blocks=3, stride=2) # Extracts the most abstract and high-level features, preparing for the classification head

        # Classifier. It converts the high-level feature maps into class predictions
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(2048, output_shape)
        )

    # Creates a residual stage by stacking residual blocks. 
    # The blocks in the stage work hierarchically to extract increasingly complex features while potentially reducing the spatial dimensions of the feature maps.
    def _make_stage(self, in_channels, mid_channels, out_channels, num_blocks, stride):

        strides = [stride] + [1] * (num_blocks - 1)
        layers = [] # to store the residual blocks

        for stride in strides:

            layers.append(ResidualBlock(in_channels, mid_channels, out_channels, stride))
            in_channels = out_channels
            
        return nn.Sequential(*layers)

    def forward(self, x):
        # [Batch Size, Color Channels, Height, Width]
        # print(f"The original shape is {x.shape}")
        x = self.conv_block_1(x)
        # print(f"The shape after conv_block_1 is {x.shape}")
        x = self.conv_block_2(x)
        # print(f"The shape after conv_block_2 is {x.shape}")
        x = self.conv_block_3(x)
        # print(f"The shape after conv_block_3 is {x.shape}")
        x = self.conv_block_4(x)
        # print(f"The shape after conv_block_4 is {x.shape}")
        x = self.conv_block_5(x)
        # print(f"The shape after conv_block_5 is {x.shape}")
        x = self.classifier(x)
        # print(f"The final shape is {x.shape}")
        return x

In [75]:
model_0 = RestNet(input_shape=3, output_shape=1000).to(device)
model_0

RestNet(
  (conv_block_1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (conv_block_2): Sequential(
    (0): ResidualBlock(
      (conv_block_1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_block_1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_block_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_block_2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_block_3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_block_3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential(
        (0): Conv2d(64, 256,

In [76]:
model_0.state_dict()

OrderedDict([('conv_block_1.0.weight',
              tensor([[[[ 0.0158, -0.0677,  0.0734,  ...,  0.0285,  0.0069, -0.0811],
                        [ 0.0667, -0.0516, -0.0312,  ...,  0.0067, -0.0015,  0.0808],
                        [ 0.0436, -0.0539, -0.0522,  ..., -0.0075, -0.0044,  0.0127],
                        ...,
                        [ 0.0492, -0.0535,  0.0123,  ...,  0.0118, -0.0011, -0.0553],
                        [-0.0386, -0.0417, -0.0310,  ...,  0.0383,  0.0132, -0.0735],
                        [ 0.0022,  0.0684,  0.0397,  ...,  0.0342,  0.0218, -0.0561]],
              
                       [[ 0.0283,  0.0718, -0.0378,  ...,  0.0595, -0.0140,  0.0109],
                        [ 0.0537,  0.0643, -0.0800,  ..., -0.0317,  0.0602, -0.0436],
                        [ 0.0250,  0.0188,  0.0433,  ...,  0.0251,  0.0118, -0.0218],
                        ...,
                        [-0.0216, -0.0413,  0.0517,  ..., -0.0382, -0.0738, -0.0036],
                        [-0

In [77]:
img.unsqueeze(0).shape # [Batch Size, Color Channels, Width, Height]

torch.Size([1, 3, 224, 224])

In [78]:
model_0(img.unsqueeze(0))

tensor([[-6.6449e-01, -3.4246e-01, -7.2592e-01,  5.1968e-01, -4.5169e-01,
          1.0855e-01,  3.0631e-01,  2.8014e-01,  6.8969e-01, -6.8110e-02,
         -1.1489e-01,  1.8599e+00, -4.4219e-01,  1.4373e+00, -1.7206e-01,
         -3.7997e-01, -6.3005e-03,  8.0300e-01,  3.4506e-01, -1.5771e-01,
         -9.0610e-01, -9.4695e-01,  5.1342e-01, -8.9621e-01, -8.5747e-02,
         -2.4313e-01,  1.0420e+00,  2.2185e-01,  3.6755e-01,  6.6655e-01,
         -8.1951e-01, -2.9355e-01, -8.1922e-01,  2.3883e-01, -7.7371e-01,
          2.0954e-01,  1.0212e-01,  4.2349e-01,  8.4449e-02, -8.7156e-01,
         -7.0995e-02, -1.4831e-01, -5.1019e-03,  3.6441e-01, -8.5877e-01,
         -3.2055e-01,  1.1758e+00, -1.9160e-01,  4.8010e-03, -1.4954e-01,
         -1.0939e-01, -9.5061e-01,  1.3404e-02, -3.0256e-01, -2.1446e-01,
         -1.9064e-02, -5.0037e-01,  9.2502e-01,  4.9291e-01, -4.1279e-01,
         -3.2937e-01,  1.8872e-03,  2.3324e-01, -5.0908e-02,  1.5436e+00,
         -3.4537e-01,  3.1201e-01,  4.

In [79]:
try:
  import torchinfo
except:
  !pip install torchinfo
  import torchinfo

from torchinfo import summary
summary(model_0, input_size=[1, 3, 224, 224])

Layer (type:depth-idx)                   Output Shape              Param #
RestNet                                  [1, 1000]                 --
├─Sequential: 1-1                        [1, 64, 56, 56]           --
│    └─Conv2d: 2-1                       [1, 64, 112, 112]         9,408
│    └─BatchNorm2d: 2-2                  [1, 64, 112, 112]         128
│    └─ReLU: 2-3                         [1, 64, 112, 112]         --
│    └─MaxPool2d: 2-4                    [1, 64, 56, 56]           --
├─Sequential: 1-2                        [1, 256, 56, 56]          --
│    └─ResidualBlock: 2-5                [1, 256, 56, 56]          --
│    │    └─Sequential: 3-1              [1, 256, 56, 56]          16,896
│    │    └─Conv2d: 3-2                  [1, 64, 56, 56]           4,096
│    │    └─BatchNorm2d: 3-3             [1, 64, 56, 56]           128
│    │    └─Conv2d: 3-4                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 56, 56]           12