<a href="https://colab.research.google.com/github/anirudh-g/Model_Implementations_from_scratch/blob/main/ResNets_from_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
!pip install -Uqq torchinfo
!pip install -Uqq timm

In [6]:
import torch
from torch import nn
from torchinfo import summary
import timm

In [66]:
class BasicBlock(nn.Module):

  expansion=1 # expansion is 1 as there is no expansion factor is basic block

  def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, downsample=None):

    super().__init__()

    self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1, bias=False) # 3x3 Conv Layer
    self.bn1 = nn.BatchNorm2d(num_features = out_channels)
    self.relu = nn.ReLU(inplace=True)
    self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1, bias=False)
    self.bn2 = nn.BatchNorm2d(num_features=out_channels)
    self.downsample = downsample

  def forward(self, x):

    identity = x
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)

    x = self.conv2(x)
    x = self.bn2(x)

    if self.downsample is not None:
      identity = self.downsample(x)

    x += identity

    return (self.relu(x))


In [67]:
class BottleNeckBlock(nn.Module):

  expansion = 4

  def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, downsample=None):

    super().__init__()

    base_width = 64

    width = int(out_channels * (base_width / 64.)) * 1

    self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=width, kernel_size=1, stride=stride, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(num_features=width)
    self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, kernel_size=3, stride=stride, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(num_features = width)
    self.conv3 = nn.Conv2d(in_channels=width, out_channels=width * self.expansion , kernel_size=1, stride=stride, padding=1, bias=False)
    self.bn3 = nn.BatchNorm2d(num_features = width * self.expansion)
    self.relu = nn.ReLU(inplace=True)
    self.downsample = downsample

  def forward(self, x):

    identity = x

    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.conv2(x)
    x = self.bn2(x)
    x = self.relu(x)
    x = self.conv3(x)
    x = self.bn3(x)
    if self.downsample is not None:
      identity = self.downsample(x)
    x+= identity

    return (self.relu(x))

In [68]:
class ResNet(nn.Module):

  def __init__(self, block, layers, num_classes):

    super().__init__()

    self.in_channels = 64

    # resnet stem
    self.conv1 = nn.Conv2d(in_channels=3, out_channels=self.in_channels, kernel_size=7, stride=2, padding=3, bias=False)
    self.bn1 = nn.BatchNorm2d(num_features = self.in_channels)
    self.relu = nn.ReLU(inplace=True)
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

    #res-blocks
    self.layer1 = self._make_layer(block, 64, layers[0])
    self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
    self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
    self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

    #classifier block
    self.adppool = nn.AdaptiveAvgPool2d((2,2))
    self.classifier = nn.Linear(in_features=512 * block.expansion, out_features = num_classes)


  def _make_layer(self, block, out_channels, blocks, stride=1):

    downsample = None


    if stride!=1 or self.in_channels != out_channels * block.expansion:

      downsample = nn.Sequential(
          nn.Conv2d(in_channels=self.in_channels, out_channels=out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
          nn.BatchNorm2d(num_features=out_channels * block.expansion)
    )

    layers=[]

    layers.append(block(self.in_channels, out_channels, stride, downsample))

    self.in_channels = out_channels * block.expansion

    for i in range(1, blocks):
      layers.append(block(self.in_channels, out_channels))

    return nn.Sequential(*layers)



  def forward(self, x):

    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)
    x = self.adppool(x)
    x = torch.flatten(x, 1)

    return self.classifier(x)

In [70]:
resnet34 = ResNet(block=BasicBlock, layers=[3,4,6,3], num_classes=10)

In [71]:
summary(resnet34, input_shape=(1, 3, 224, 224))

Layer (type:depth-idx)                   Param #
ResNet                                   --
├─Conv2d: 1-1                            9,408
├─BatchNorm2d: 1-2                       128
├─ReLU: 1-3                              --
├─MaxPool2d: 1-4                         --
├─Sequential: 1-5                        --
│    └─BasicBlock: 2-1                   --
│    │    └─Conv2d: 3-1                  36,864
│    │    └─BatchNorm2d: 3-2             128
│    │    └─ReLU: 3-3                    --
│    │    └─Conv2d: 3-4                  36,864
│    │    └─BatchNorm2d: 3-5             128
│    └─BasicBlock: 2-2                   --
│    │    └─Conv2d: 3-6                  36,864
│    │    └─BatchNorm2d: 3-7             128
│    │    └─ReLU: 3-8                    --
│    │    └─Conv2d: 3-9                  36,864
│    │    └─BatchNorm2d: 3-10            128
│    └─BasicBlock: 2-3                   --
│    │    └─Conv2d: 3-11                 36,864
│    │    └─BatchNorm2d: 3-12            12

In [74]:
resnet50 = ResNet(BottleNeckBlock, layers=[3,4,23,3], num_classes=10)

In [75]:
summary(resnet50, input_shape=(1,3,224,224))

Layer (type:depth-idx)                   Param #
ResNet                                   --
├─Conv2d: 1-1                            9,408
├─BatchNorm2d: 1-2                       128
├─ReLU: 1-3                              --
├─MaxPool2d: 1-4                         --
├─Sequential: 1-5                        --
│    └─BottleNeckBlock: 2-1              --
│    │    └─Conv2d: 3-1                  4,096
│    │    └─BatchNorm2d: 3-2             128
│    │    └─Conv2d: 3-3                  36,864
│    │    └─BatchNorm2d: 3-4             128
│    │    └─Conv2d: 3-5                  16,384
│    │    └─BatchNorm2d: 3-6             512
│    │    └─ReLU: 3-7                    --
│    └─BottleNeckBlock: 2-2              --
│    │    └─Conv2d: 3-8                  16,384
│    │    └─BatchNorm2d: 3-9             128
│    │    └─Conv2d: 3-10                 36,864
│    │    └─BatchNorm2d: 3-11            128
│    │    └─Conv2d: 3-12                 16,384
│    │    └─BatchNorm2d: 3-13          

References -

- https://dhruvs.space/posts/understanding-resnets/
- https://jarvislabs.ai/blogs/resnet/
- https://github.com/FrancescoSaverioZuppichini/ResNet