In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):

        super(Block, self).__init__()
        
        self.stride1 = stride
        self.inChannels = in_channels
        self.outChannels = out_channels
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, 1, stride =stride, padding =0, bias= False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride = 1, padding =1, bias= False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        
        self.conv3 = nn.Conv2d(out_channels, out_channels * 4, 1, stride = 1, padding =0, bias= False)
        self.bn3 = nn.BatchNorm2d(out_channels * 4)
        self.relu2 = nn.ReLU()
        
        
        
        
        self.convSkip = nn.Conv2d(in_channels, out_channels * 4, 1, stride =stride, padding = 0, bias= False)
        self.bn4 = nn.BatchNorm2d(out_channels * 4)
        

    def forward(self, x):

        identity = x.clone()
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.bn3(x)
        
        if self.stride1 != 1 or self.inChannels != self.outChannels :
            identity = self.convSkip(identity)
            identity = self.bn4(identity)
        
        x += identity
        x = self.relu(x)
        return x

In [4]:
class ResNet(nn.Module):
    def __init__(self, Block, Architecture_number =50, classes=10):

        super(ResNet, self).__init__()
        
        if Architecture_number == 50 :
            identity_layers_stage2 = 3
            identity_layers_stage3 = 4
            identity_layers_stage4 = 6
            identity_layers_stage5 = 3
        
        elif Architecture_number == 101 :
            identity_layers_stage2 = 3
            identity_layers_stage3 = 4
            identity_layers_stage4 = 23
            identity_layers_stage5 = 3
        
        elif Architecture_number == 152 :
            identity_layers_stage2 = 3
            identity_layers_stage3 = 8
            identity_layers_stage4 = 36
            identity_layers_stage5 = 3
        
        else :
            raise Exception("Architecture_number must be one of these numbers: 50 , 101 , 152 ")
        
                
        
        
        # Stage 1
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # Stage 2
        self.stage2 = self.stage(Block, channels = [64 ,64], num_blocks = identity_layers_stage2, stride = 2)
        
        # Stage 3
        self.stage3 = self.stage(Block, channels = [256 ,128], num_blocks = identity_layers_stage3, stride = 2)
        
        # Stage 4
        self.stage4 = self.stage(Block, channels = [512 ,256], num_blocks = identity_layers_stage4, stride = 2)
        
        # Stage 5
        self.stage5 = self.stage(Block, channels = [1024 ,512], num_blocks = identity_layers_stage5, stride = 2)
        
        # final stage
        self.avgpool = nn.AvgPool2d(kernel_size=4, stride=1)
        self.fc = nn.Linear(2048, classes)


    def stage(self, Block, channels, num_blocks, stride):
        
        first_block = Block(channels[0], channels[1], stride)
        other_blocks = [Block(channels[1] * 4, channels[1]) for _ in range(1, num_blocks)]
        
        return nn.Sequential(first_block, *other_blocks)
        
    
    
    
    
    def forward(self, x):


        x = self.conv1(x)

        x = self.bn1(x)

        x = self.relu(x)

        x = self.maxpool(x)


        x = self.stage2(x)

        x = self.stage3(x)
        
        x = self.stage4(x)
        
        x = self.stage5(x)


        x = self.avgpool(x)


        x = x.view(-1, self.fc.in_features)

        x = self.fc(x)


        return x

In [5]:
# define number of layers of ResNet [50 , 101,  152]
Architecture_number = 50
model = ResNet(Block, Architecture_number = Architecture_number, classes=10)

In [6]:
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (stage2): Sequential(
    (0): Block(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU()
      (convSkip): Conv2d(64, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
     