# ResNet

In [None]:
import torch
import torch.nn as nn

In [None]:
class IdentityBlock(nn.Module):
  def __init__(self, in_channels, mid_kernel_size, filters):
    super(IdentityBlock, self).__init__()
    filter1, filter2, filter3 = filters

    self.conv1 = nn.Conv2d(in_channels, filter1, kernel_size=1, padding=0)
    self.bn1 = nn.BatchNorm2d(filter1)
    self.relu1 = nn.ReLU(inplace=True)

    self.conv2 = nn.Conv2d(filter1, filter2, kernel_size=mid_kernel_size, padding=mid_kernel_size//2)
    self.bn2 = nn.BatchNorm2d(filter2)
    self.relu2 = nn.ReLU(inplace=True)

    self.conv3 = nn.Conv2d(filter2, filter3, kernel_size=1, padding=0)
    self.bn3 = nn.BatchNorm2d(filter3)
    self.relu_out = nn.ReLU(inplace=True)

  def forward(self, x):

    identity = x

    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu1(out)

    out = self.conv2(out)
    out = self.bn2(out)
    out = self.relu2(out)

    out = self.conv3(x)
    out = self.bn3(out)

    out += identity

    out = self.relu_out(out)
    return out

In [None]:
import torch.nn.functional as F

class ConvolutionalBlock(nn.Module):
  def __init__(self, in_channels, mid_kernel_size, filters):
    super(ConvolutionalBlock, self).__init__()
    filter1, filter2, filter3 = filters

    self.conv1 = nn.Conv2d(in_channels, filter1, kernel_size=1, stride=2)
    self.bn1 = nn.BatchNorm2d(filter1)

    self.conv2 = nn.Conv2d(filter1, filter2, kernel_size=mid_kernel_size, padding=mid_kernel_size//2)
    self.bn2 = nn.BatchNorm2d(filter2)

    self.conv3 = nn.Conv2d(filter2, filter3, kernel_size=1, stride=2)
    self.bn3 = nn.BatchNorm2d(filter3)

    self.shortcut_conv = nn.Conv2d(in_channels, filter3, kernel_size=1, stride=2)
    self.shortcut_bn = nn.BatchNorm2d(filter3)

  def forward(self, x):
    shortcut = self.shortcut_bn(self.shortcut_conv(x))

    out = F.relu(self.bn1(self.conv1(x)))
    out = F.relu(self.bn2(self.conv2(out)))
    out = self.bn3(self.conv3(out))

    out += shortcut
    out = F.relu(out)
    return out

In [None]:
class ResNet(nn.Module):
  def __init__(self):
    super(ResNet, self).__init__()

    self.stage0 = nn.Sequential(
        nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),  # (224,224,3) -> (112,112,64)
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  # (56,56,64)
    )

    self.stage1 = nn.Sequential(
        ConvolutionalBlock(64, 3, (64, 64, 256), strides=1),
        IdentityBlock(256, 3, (64, 64, 256)),
        IdentityBlock(256, 3, (64, 64, 256))
    )

    self.stage2 = nn.Sequential(
        ConvolutionalBlock(256, 3, (128, 128, 512), strides=2),
        IdentityBlock(512, 3, (128, 128, 512)),
        IdentityBlock(512, 3, (128, 128, 512)),
        IdentityBlock(512, 3, (128, 128, 512))
    )

    self.stage3 = nn.Sequential(
        ConvolutionalBlock(512, 3, (256, 256, 1024), strides=2),
        IdentityBlock(1024, 3, (256, 256, 1024)),
        IdentityBlock(1024, 3, (256, 256, 1024)),
        IdentityBlock(1024, 3, (256, 256, 1024)),
        IdentityBlock(1024, 3, (256, 256, 1024)),
        IdentityBlock(1024, 3, (256, 256, 1024))
    )

    self.stage4 = nn.Sequential(
        ConvolutionalBlock(1024, 3, (512, 512, 2048), strides=2),
        IdentityBlock(2048, 3, (512, 512, 2048)),
        IdentityBlock(2048, 3, (512, 512, 2048))
    )

    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    self.dropout1 = nn.Dropout(0.5)
    self.fc1 = nn.Linear(2048, 500)
    self.dropout2 = nn.Dropout(0.5)
    self.fc2 = nn.Linear(500, 1000)

  def forward(self, x):
    x = self.stage0(x)
    x = self.stage1(x)
    x = self.stage2(x)
    x = self.stage3(x)
    x = self.stage4(x)
    x = self.avgpool(x)
    x = torch.flatten(x, 1)
    x = self.dropout1(x)
    x = F.relu(self.fc1(x))
    x = self.dropout2(x)
    x = self.fc2(x)
    x = F.softmax(x, dim=1)
    return x

### 사전학습 모델 로드

In [None]:
import torchvision.models as models

In [None]:
vgg_model = models.vgg16(pretrained=True)

resnet_model = models.resnet50(pretrained=True)

inception_model = models.inception_v3(pretrained=True, aux_logits=True)

mobilenet_model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT)