In [335]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

In [336]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, identity_ds=None):
        super(ResidualBlock, self).__init__()
        self.expansion = 4
        self.convBlock = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1), 
            nn.BatchNorm2d(out_channels), 
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1), 
            nn.BatchNorm2d(out_channels), 
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1), 
            nn.BatchNorm2d(out_channels*self.expansion), 
        )
        self.relu = nn.ReLU()
        self.identity_ds = identity_ds
        self.stride = stride
        
    def forward(self, x):
        identity = x.clone()
        x = self.convBlock(x)
        if self.identity_ds != None:
            identity = self.identity_ds(identity)
        x += identity
        return self.relu(x) 

In [337]:
class ResNet(nn.Module):
    def __init__(self, layers, input_channels, num_classes):
        super(ResNet, self).__init__()
        self.expansion = 4
        self.input_channels = input_channels
        self.in_channels = 64
        self.layer0 = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        self.layer1 = self.residual_layer(64, layers[0], 1)
        self.layer2 = self.residual_layer(128, layers[1], 2)
        self.layer3 = self.residual_layer(256, layers[2], 2)
        self.layer4 = self.residual_layer(512, layers[3], 2)


        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(2048, num_classes)

    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.flatten(start_dim=1)
        x = self.linear(x)

        return x

    def residual_layer(self, out_channels, num_blocks, stride):
        layers = []
        identity = None
        if self.in_channels != out_channels * self.expansion or stride != 1:
            identity = nn.Sequential (
                nn.Conv2d(self.in_channels,out_channels * self.expansion, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels * self.expansion),
            )
        layers += [ResidualBlock(self.in_channels, out_channels, stride, identity)]
        self.in_channels = out_channels * self.expansion
        layers += [ResidualBlock(self.in_channels, out_channels)]*(num_blocks - 1)
        return nn.Sequential(*layers)

In [338]:
model = ResNet([3, 4, 6, 3], 3, 10).to(DEVICE)
x = torch.randn(1, 3, 224, 224).to(DEVICE)
output = model(x)
print(output.shape)

torch.Size([1, 10])
