<a href="https://colab.research.google.com/github/DimpleB0501/Assignment_cifar10_resnet/blob/main/backup/ResNet50architecture_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

In [None]:
class block(nn.Module):
  def __init__(self, in_channels, out_channels, identity_downsample=None, stride=1): # identity downsample (conv layer) is used in case we have changed the input size or change number of channels
    super(block, self).__init__()
    self.expansion = 4 # number of channels after a block is 4 times what it was when it entered
    
    self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = 1, padding = 0)
    self.bn1 = nn.BatchNorm2d(out_channels)

    self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = stride, padding = 1)
    self.bn2 = nn.BatchNorm2d(out_channels)

    self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size = 1, stride = 1, padding = 0)
    self.bn3 = nn.BatchNorm2d(out_channels*self.expansion)

    self.relu = nn.ReLU()
    self.identity_downsample = identity_downsample
    self.stride = stride

  def forward(self,x):
    identity = x.clone()

    x = self.conv1(x)
    x = self.bn1(x)

    x = self.conv2(x)
    x = self.bn2(x)

    x = self.conv3(x)
    x = self.bn3(x)

    if self.identity_downsample is not None:
      identity  = self.identity_downsample(identity)

    x += identity # y = x + f(x), reducing the error function f(x) so that input becomes equal to output
    x = self.relu(x)

    return x

In [None]:
class ResNet(nn.Module): 
  def __init__(self, block, layers, image_channels, num_classes): # block is residual block , layers tells us how many times we want to use the block, Resnet50 [3, 4, 6, 3], image channels (3 in case of RGB data), number of classes 10 for cifar10
    super(ResNet, self).__init__()
    self.in_channels = 64
    self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)
    self.bn1 = nn.BatchNorm2d(64)
    self.relu = nn.ReLU()

    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

    # ResNet layers
    self.layer1 = self._make_layer(block, layers[0], out_channels=64, stride = 1)
    self.layer2 = self._make_layer(block, layers[1], out_channels=128, stride = 2)
    self.layer3 = self._make_layer(block, layers[2], out_channels=256, stride = 2)
    #self.layer4 = self._make_layer(block, layers[3], out_channels=512, stride = 2)

    self.avgpool = nn.AdaptiveAvgPool2d((1,1))
    #print ("here")
    self.fc = nn.Linear(256*4, num_classes) # number of channels times expansion(4)

  def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    #x = self.layer4(x)

    x = self.avgpool(x)
    x = x.reshape(x.shape[0], -1)
    x = self.fc(x)
    return x

  def _make_layer(self, block, num_residual_blocks, out_channels, stride): # number of times the blocks are used.
    identity_downsample = None
    layers = []

    # identity_downsample is changed either when we change the input size or input channels is not equal to out channels
    if stride != 1 or self.in_channels != out_channels * 4:
      identity_downsample = nn.Sequential(
                nn.Conv2d(
                    self.in_channels,
                    out_channels * 4,
                    kernel_size=1,
                    stride=stride,
                ),
                nn.BatchNorm2d(out_channels * 4),
            )
      
    layers.append(block(self.in_channels, out_channels, identity_downsample, stride))  
    self.in_channels = out_channels*4

    for i in range(num_residual_blocks - 1):
      layers.append(block(self.in_channels, out_channels))

    return nn.Sequential(*layers)

In [None]:
def ResNet50(img_channels, num_classes=10):
  return ResNet(block, [3,3,3], img_channels, num_classes)

In [None]:
from functools import reduce
def pytorch_count_params(model):
  "count number trainable parameters in a pytorch model"
  total_params = sum(reduce( lambda a, b: a*b, x.size()) for x in model.parameters())
  return total_params

In [None]:
from torchsummary import summary

def test():
  net = ResNet50(img_channels=3, num_classes=10).to("cuda")
  print("Total number of trainable parameters in ResNet50: ", pytorch_count_params(net))

In [None]:
test()

Total number of trainable parameters in ResNet50:  4931850
