<a href="https://colab.research.google.com/github/addo561/learning-pytorch/blob/main/Paper-implementations/Resnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install labml-nn

Collecting labml==0.4.168 (from labml-nn)
  Using cached labml-0.4.168-py3-none-any.whl.metadata (7.5 kB)
Using cached labml-0.4.168-py3-none-any.whl (130 kB)
Installing collected packages: labml
  Attempting uninstall: labml
    Found existing installation: labml 0.5.3
    Uninstalling labml-0.5.3:
      Successfully uninstalled labml-0.5.3
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
labml-app 0.5.14 requires labml>=0.5.2, but you have labml 0.4.168 which is incompatible.[0m[31m
[0mSuccessfully installed labml-0.4.168


In [None]:
import torch
import torch.nn  as nn
from labml_helpers.module import Module

#Linear projections for shortcut connection

In [None]:
#for f(x) + Ws(Linear projection) * x if  feature map and x is not of same shape
class shortcut_projection(Module):
  def __init__(self, in_channels: int,out_channels: int,stride: int):
    super().__init__()
    self.conv  = nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=stride)
    self.bn  =  nn.BatchNorm2d(out_channels)

  def forward(self,x: torch.Tensor):
    return  self.bn(self.conv(x))


#Residual Block

In [None]:
class ResidualBlock(Module):
  def __init__(self,in_channels: int,out_channels: int,stride: int):
    super().__init__()
    #1st conv  block
    self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=stride,padding=1)
    self.bn1 = nn.BatchNorm2d(out_channels)
    self.act1 = nn.ReLU()

    #2nd conv block with stride  1
    self.conv2 = nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
    self.bn2 =  nn.BatchNorm2d(out_channels)

    #Shortcut connection should be a projection if the stride length is not 1 or if the number of channels change
    if stride!= 1 or in_channels != out_channels:
      self.shortcut =  shortcut_projection(in_channels,out_channels,stride)
    else:
      self.shortcut = nn.Identity()

    self.act2 = nn.ReLU()

  def  forward(self,x: torch.Tensor):
    shortcut  = self.shortcut(x)
    x = self.act1(self.bn1(self.conv1(x)))
    x = self.bn2(self.conv2)

    return self.act2(x +  shortcut)



In [None]:
class BottleneckResidualBlock(Module):
  def __init__(self, in_channels:  int,bottleneck_channels: int,out_channels: int,stride: int):
    super().__init__()
    #first conv block
    self.conv1 = nn.Conv2d(in_channels,bottleneck_channels,kernel_size=1,stride=1)
    self.bn1 =  nn.BatchNorm2d(bottleneck_channels)
    self.act1 = nn.ReLU()

    #2nd conv block
    self.conv2 = nn.Conv2d(bottleneck_channels,bottleneck_channels,kernel_size=3,stride=stride,padding=1)
    self.bn2 =  nn.BatchNorm2d(bottleneck_channels)
    self.act2 = nn.ReLU()

    #3rd
    self.conv3 = nn.Conv2d(bottleneck_channels,out_channels,kernel_size=1,stride=1)
    self.bn3 = nn.BatchNorm2d(out_channels)


    if stride!=1 or in_channels!=out_channels:
      self.shortcut = shortcut_projection(in_channels,out_channels,stride=stride)
    else:
      self.shortcut = nn.Identity()
    self.act3 =  nn.ReLU()

  def forward(self,x: torch.Tensor):
    shortcut = self.shortcut(x)
    x = self.act1(self.bn1(self.conv1(x)))
    x = self.act2(self.bn2(self.conv2(x)))
    x = self.bn3(self.conv3(x))
    return  self.act3(x + shortcut)





#Resnet model

In [None]:
from typing import Optional,List
#resnet is just stack of bottlenecks or residual  blocks
class  ResnetBase(Module):
  def __init__(self,n_blocks: List[int],n_channels: List[int],
               bottlenecks:Optional[List[int]]=None,img_channels: int=3,first_kernel_size =  7):
   super().__init__()
   assert len(n_blocks)==len(n_channels)
   assert bottlenecks is None or len(bottlenecks) == len(n_channels)

   self.conv = nn.Conv2d(img_channels,n_channels[0],kernel_size=first_kernel_size,
                         stride=2,padding=first_kernel_size//2)
   self.bn = nn.BatchNorm2d(n_channels[0])

   blocks = []
   prev_channels = n_channels[0]

   for i,channels in enumerate(n_channels):
    stride = 2 if  len(blocks)==0  else 1
    if  bottlenecks is  None:
      blocks.append(ResidualBlock(prev_channels,channels,stride=stride))
    else:
      blocks.append(BottleneckResidualBlock(prev_channels,bottlenecks[i],channels,stride=stride))

    prev_channels = channels

    for _  in range(n_blocks[i]-1):
      if bottlenecks  is  None:
        blocks.append(ResidualBlock(channels,channels,stride=1))
      else:
        blocks.append(BottleneckResidualBlock(prev_channels,bottlenecks[i],channels,stride=1))
   self.blocks = nn.Sequential(*blocks)

  def forward(self,x:torch.Tensor):
    x = self.bn(self.conv(x))
    x = self.blocks(x)
    x = x.view(x.shape[0], x.shape[1], -1)
    return x.mean(dim=-1)




In [None]:
from labml import experiment
from labml.configs import option
from labml_nn.experiments.cifar10 import CIFAR10Configs

In [None]:
class Configs(CIFAR10Configs):
  n_blocks: List[int]  = [3,3,3]
  n_channels: List[int] =  [16,32,64]
  bottlenecks:  Optional[List[int]] = None
  first_kernel_size: int = 3


In [None]:
@option(Configs.model)
def  _resnet(c:Configs):
  base = ResnetBase(c.n_blocks,c.n_channels,c.bottlenecks,img_channels=3,first_kernel_size=c.first_kernel_size)
  classification =  nn.Linear(c.n_channels[-1],10)
  model  = nn.Sequential(base,classification)
  return  model.to(c.device)

In [None]:
def main():
  experiment.create(name='resnet',comment='cifar10')
  conf =  Configs()
  experiment.configs(conf,{
      'bottlenecks':[8,16,16],
      'n_blocks':[6,6,6],

      'optimizer.optimizer':'Adam',
      'optimizer.learning_rate':2.5e-4,

      'epochs':500,
      'train_batch_size':256,

      'train_dataset':'cifar10_train_augmented',
      'valid_dataset':'cifar10_valid_no_augment',
  })
  experiment.add_pytorch_models({'model': conf.model})
  with experiment.start():
    conf.run()
if __name__ =='__main__':
  main()