In [None]:
import torch
import torch.nn as nn
from torchsummary import summary

""" BasicBlock for ResNet18 """
class BasicBlock(nn.Module):
    """
    output = (channels, H, W) -> conv2d (3x3) -> (channels, H, W) -> conv2d (3x3) -> (channels, H, W) + (channels, H, W)
    """
    expansion = 1
    def __init__(self, filters, strides = 1, downsample = None):
        super(BasicBlock, self).__init__()
        
        in_channels, out_channels = filters
        
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, (3, 3), stride = strides, padding = (1, 1), bias = False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, (3, 3), padding = (1, 1), bias = False),
            nn.BatchNorm2d(out_channels)
        )
        self.relu = nn.ReLU(inplace = True)
        self.downsample = downsample
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)
        
        out = self.relu(identity + out)
        return out

""" BottleneckBlock for ResNet50 & 152 """
class BottleneckBlock(nn.Module):
    """
    output = (channels * 4, H, W) -> conv2d (1x1) -> (channels, H, W) -> conv2d (3x3) -> (channels, H, W)
             -> conv2d (1x1) -> (channels * 4, H, W) + (channels * 4, H, W)
    """
    expansion = 4
    def __init__(self, filters, strides = 1, downsample = None):
        super(BottleneckBlock, self).__init__()
        
        in_channels, mid_channels = filters
        out_channels = mid_channels * self.expansion
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, (1, 1), stride = strides, padding = 0, bias = False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace = True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(mid_channels, mid_channels, (3, 3), stride = 1, padding = (1, 1), bias = False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace = True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(mid_channels, out_channels, (1, 1), stride = 1, padding = 0, bias = False),
            nn.BatchNorm2d(out_channels)
        )
        self.relu = nn.ReLU(inplace = True)
        self.downsample = downsample

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)
        
        out = self.relu(identity + out)
        return out

""" ResNet Model """
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes = 5):
        super(ResNet, self).__init__()
        
        self.in_channels = 64
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels = 3, out_channels = self.in_channels, kernel_size = (7, 7), stride = (2, 2), 
                      padding = (3, 3), bias = False),
            nn.BatchNorm2d(self.in_channels),
            nn.ReLU(inplace = True)
        )
        self.conv2 = nn.Sequential(
            nn.MaxPool2d(kernel_size = (3, 3), stride = (2, 2), padding = 1),
            self._make_layer(block, 64, layers[0])
        )
        self.conv3 = self._make_layer(block, 128, layers[1], strides = 2)
        self.conv4 = self._make_layer(block, 256, layers[2], strides = 2)
        self.conv5 = self._make_layer(block, 512, layers[3], strides = 2)
        self.avgpool = nn.AdaptiveAvgPool2d(output_size = (1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, strides = 1):
        downsample = None
        
        if strides != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels * block.expansion, (1, 1), stride = strides, bias = False),
                nn.BatchNorm2d(out_channels * block.expansion),
            )

        build_layer = []
        build_layer.append(block(filters = [self.in_channels, out_channels], strides = strides, downsample = downsample))
        
        self.in_channels = out_channels * block.expansion
        for _ in range(1, num_blocks):
            build_layer.append(block(filters = [self.in_channels, out_channels]))

        return nn.Sequential(*build_layer)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.avgpool(x)
        
        x = x.view(x.shape[0], -1)
        out = self.fc(x)
        return out
    
""" Get(return) ResNet18 model """
def ResNet18():
    return ResNet(block = BasicBlock, layers = [2, 2, 2, 2])

""" Get(return) ResNet50 model """
def ResNet50():
    return ResNet(block = BottleneckBlock, layers = [3, 4, 6, 3])

""" Get(return) ResNet152 model """
def ResNet152():
    return ResNet(block = BottleneckBlock, layers = [3, 8, 36, 3])

if __name__ == '__main__':

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    summary(ResNet50().to(device), (3, 512, 512))
#     summary(ResNet18().to(device), (3, 224, 224))
#     summary(ResNet152().to(device), (3, 224, 224))