In [1]:
import torch
import torch.nn as nn
import torchmetrics


In [2]:
from functools import partial
import torch.nn.functional as F

class ResidualUnit(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1):
        super().__init__()
        DefaultConv2d = partial(
            nn.Conv2d, stride = 1, padding = 1, kernel_size = 3, bias=False
        )
        self.main_layers = nn.Sequential(
            DefaultConv2d(in_channels, out_channels, stride=stride),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            DefaultConv2d(out_channels, out_channels),
            nn.BatchNorm2d(out_channels)
            )
        if stride > 1:
            self.skip_connection = nn.Sequential(
                DefaultConv2d(in_channels, out_channels, kernel_size=1, 
                              stride = stride, padding = 0),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.skip_connection = nn.Identity()
    
    def forward(self, inputs):
        return F.relu(self.main_layers(inputs) + self.skip_connection(inputs))

In [3]:
class ResNet34(nn.Module):
    def __init__(self):
        super().__init__()
        layers = [
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7,
                      stride = 2, padding = 3, bias = False),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride = 2, padding = 1)
        ]

        prev_filters = 64
        for filters in [64] * 3 + [128] * 4 + [256] * 6 + [512] * 3:
            stride = 1 if filters == prev_filters else 2
            layers.append(ResidualUnit(prev_filters, filters, stride))
            prev_filters = filters
        layers += [
            nn.AdaptiveAvgPool2d(output_size=1),
            nn.Flatten(),
            nn.LazyLinear(10),
        ]
        self.resnet = nn.Sequential(*layers)
    
    def forward(self, inputs):
        return self.resnet(inputs)