## Imports

In [1]:
import torch
from torch import nn

  from .autonotebook import tqdm as notebook_tqdm


## Model

In [2]:
class block(nn.Module):
    
    def __init__(self, in_channels, intermediate_channels ,stride=1, identity_downsample=None):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, intermediate_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(intermediate_channels)

        self.conv2 = nn.Conv2d(intermediate_channels, intermediate_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(intermediate_channels)

        self.conv3 = nn.Conv2d(intermediate_channels, intermediate_channels*4, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm2d(intermediate_channels*4)

        self.relu = nn.ReLU()

        self.stride = stride
        self.idendity_downsample = identity_downsample

    def forward(self, x):
        identity = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        x = self.conv3(x)
        x = self.bn3(x)

        if self.idendity_downsample is not None:
            identity = self.idendity_downsample(identity)
        
        x += identity
        x = self.relu(x)

        return x

In [3]:
class ResNet(nn.Module):

    def __init__(self, num_stages_block, block, img_channels, num_classes):
        super().__init__()

        self.in_channels = 64

        self.conv1 = nn.Conv2d(img_channels, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU()

        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.stage1 = self.make_stage(block, num_stages_block[0], 64, stride=1)
        self.stage2 = self.make_stage(block, num_stages_block[1], 128, stride=2)
        self.stage3 = self.make_stage(block, num_stages_block[2], 256, stride=2)
        self.stage4 = self.make_stage(block, num_stages_block[3], 512, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * 4, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.max_pool(x)
        
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)

        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)

        return x


    def make_stage(self, block, num_blocks, intermediate_channels, stride):
        identity_downsample = None
        blocks = []

        if stride != 1 or self.in_channels != intermediate_channels*4:
            identity_downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, intermediate_channels*4, kernel_size=1 ,stride=stride),
                nn.BatchNorm2d(intermediate_channels*4)
            )

        blocks.append(
            block(self.in_channels, intermediate_channels ,stride=stride, identity_downsample=identity_downsample)
        )

        self.in_channels = intermediate_channels*4

        for i in range(num_blocks - 1):
            blocks.append(
                block(self.in_channels, intermediate_channels)
            )

        return nn.Sequential(*blocks)
        

## ResNet Types

In [4]:
def ResNet50(img_channel=3, num_classes=1000):
    return ResNet([3, 4, 6, 3], block, img_channel, num_classes)


def ResNet101(img_channel=3, num_classes=1000):
    return ResNet([3, 4, 23, 3], block, img_channel, num_classes)


def ResNet152(img_channel=3, num_classes=1000):
    return ResNet([3, 8, 36, 3], block, img_channel, num_classes)

## Test

In [5]:
def test():
    BATCH_SIZE = 4
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = ResNet101(img_channel=3, num_classes=1000).to(device)
    y = net(torch.randn(BATCH_SIZE, 3, 224, 224).to(device)).to(device)
    assert y.size() == torch.Size([BATCH_SIZE, 1000])
    print(y.size())

In [6]:
test()

torch.Size([4, 1000])
