In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
# Define basic block ---- used for resnet18, resnet34
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride, downsample):
        super(BasicBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.downsample = downsample
        
    def forward(self, x):
        output = self.conv1(x)
        output = self.bn(output)
        output = self.relu(output)
        
        output = self.conv2(output)
        output = self.bn(output)
    
        if self.downsample is True:
            x = nn.Conv2d(self.in_channels, self.out_channels, 1, self.stride)
            x = self.bn(x)
            
        output = output + x
        
        output = self.relu(output)
        
        return output

In [5]:
# Define Bottleneck Block (used for resnet50 and higher)
class BottleneckBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels, stride, downsample):
        super(BottleneckBlock, self).__init__()
        self.in_channels = in_channels
        self.mid_channels = mid_channels
        self.out_channels = out_channels
        self.stride = stride
        self.downsample = downsample
        
        self.conv1 = nn.Conv2d(in_channels, mid_channels, 1, stride=self.stride)
        self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(mid_channels, out_channels, 1, stride=1)
        
        self.bn1 = nn.BatchNorm2d(mid_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        output = self.conv1(x)
        output = self.bn1(output)
        output = self.relu(output)
        
        output = self.conv2(output)
        output = self.bn1(output)
        output = self.relu(output)
        
        output = self.conv3(output)
        output = self.bn2(output)
        
        if self.downsample is True:
            x = nn.Conv2d(self.in_channels, self.out_channels, 1, stride=self.stride)(x)
            x = self.bn2(x)
            
        output = output + x
        
        output = self.relu(output)
        
        return output

In [None]:
class ResNet50(nn.Module):
    
    def __init__(self):
        super(ResNet50, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.pool1 = nn.MaxPool2d(3, stride=2)
        
        self.layer1 = BottleneckBlock(64, 64, 256, 1, True)
        self.layer2 = BottleneckBlock(256, 64, 256, 1, False)
        self.layer3 = BottleneckBlock(256, 64, 256, 1, False)
        self.layer4 = BottleneckBlock(256, 128, 512, 2, True)
        self.layer5 = BottleneckBlock(512, 128, 512, 1, False)
        self.layer6 = BottleneckBlock(512, 128, 512, 1, False)
        self.layer7 = BottleneckBlock(512, 128, 512, 1, False)
        self.layer8 = BottleneckBlock(512, 256, 1024, 2, True)
        self.layer9 = BottleneckBlock(1024, 256, 1024, 1, False)
        self.layer10 = BottleneckBlock(1024, 256, 1024, 1, False)
        self.layer11 = BottleneckBlock(1024, 256, 1024, 1, False)
        self.layer12 = BottleneckBlock(1024, 256, 1024, 1, False)
        self.layer13 = BottleneckBlock(1024, 256, 1024, 1, False)
        self.layer14 = BottleneckBlock(1024, 512, 2048, 2, True)
        self.layer15 = BottleneckBlock(2048, 512, 2048, 1, False)
        self.layer16 = BottleneckBlock(2048, 512, 2048, 1, False)
        
        self.pool2 = nn.AvgPool2d(7, stride=1)
        
        self.fc1 = nn.Linear(2048, 1000)
        
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        
        # first layer
        output = self.conv1(x)
        output = self.bn1(output)
        output = self.relu(output)
        output = self.pool1(output)
        
        # next layers
        output = self.layer1(output)
        output = self.layer2(output)
        output = self.layer3(output)
        output = self.layer4(output)
        output = self.layer5(output)
        output = self.layer6(output)
        output = self.layer7(output)
        output = self.layer8(output)
        output = self.layer9(output)
        output = self.layer10(output)
        output = self.layer11(output)
        output = self.layer12(output)
        output = self.layer13(output)
        output = self.layer14(output)
        output = self.layer15(output)
        output = self.layer16(output)
        output = self.pool2(output)
        output = self.fc1(output)
        output = self.softmax(output)
        
        return output
        