<a href="https://colab.research.google.com/github/HyeonGeunY/papers/blob/main/resnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import os
from typing import List

In [6]:
class ShortcutBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, stride=stride)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x: torch.Tensor):
        return self.bn(self.conv(x))

In [8]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.act1 = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Identity()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = ShortcutBlock(in_channels=in_channels, out_channels=out_channels, stride=stride)
        
        self.act2 = nn.ReLU()

    def forward(self, x: torch.Tensor):
        shortcut = self.shortcut(x)
        x = self.bn2(self.conv2(self.act1(self.bn1(self.conv(x)))))
        return self.act2(x + shortcut)


In [10]:
class BottleNeckBlock(nn.Module):
    def __init__(self, in_channels, bottle_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, bottle_channels, kernel_size=1, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(bottle_channels)
        self.act1 = nn.ReLU()

        self.conv2 = nn.Conv2d(bottle_channels, bottle_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(bottle_channels)
        self.act2 = nn.ReLU()

        self.conv3 = nn.Conv2d(bottle_channels, out_channels, kernel_size=1, stride=1)
        self.bn3 = nn.BatchNorm2d(bottle_channels)
        self.act3 = nn.ReLU()

        self.shortcut = nn.Identity()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = ShortcutBlock(in_channels, out_channels, stride=stride)

    def forward(self, x:torch.Tensor):
        shortcut = self.shortcut(x)
        x = self.act1(self.bn1(self.conv1(x)))
        x = self.act2(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        return self.act3(x + shortcut)

In [14]:
class ResNet50(nn.Module):
    def __init__(self, num_blocks: List[int], strides : List[int], first_kernel_size: int = 7, input_channels: int = 3, n_classes=10):
        """
        MNIST 크기 (28 * 28)에 맞춰주기 위해 stride 크기 조정
        """
        super().__init__()
        self.firstconv = nn.Conv2d(input_channels, 64, kernel_size=first_kernel_size, stride=strides[0], padding=first_kernel_size//2)
        self.bn1 = nn.BatchNorm2d(64)
        self.act1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=3)

        self.layer1 = self._repeat_layers(64, 64, 256, stride=strides[1], n_count=num_blocks[0])
        self.layer2 = self._repeat_layers(256, 128, 512, stride=strides[2], n_count=num_blocks[1])
        self.layer3 = self._repeat_layers(512, 256, 1024, stride=strides[3], n_count=num_blocks[2])
        self.layer4 = self._repeat_layers(1024, 512, 2048, stride=strides[4], n_count=num_blocks[3])

        self.avg_pool = nn.AvgPool2d(kernel_size=4)
        self.fc = self.Linear(2048, n_classes)
    
    def _repeat_layers(self, in_channels, bottle_channels, out_channels, stride, n_count):
        layers = []
        strides = [stride] + [1] * (n_count - 1)
        for s in strides:
            layers.append(BottleNeckBlock(in_channels, bottle_channels, out_channels, s))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.act1(self.bn1(self.firstconv(x)))
        x = self.maxpool1(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1) # 2차원으로 만들어주기
        x = self.fc(x)
        return x

def ResNet50_for_mnist_cifar():
    return ResNet50(num_blocks=[3, 4, 6, 3], strides=[1, 1, 2, 2, 2],input_channels=3)