In [34]:
import torch
from torch import nn
from torch.nn import functional as F


class Residual(nn.Module):
    def __init__(self, num_channels,  use_conv1=True, strides=1):
        super(Residual, self).__init__()
        self.use_conv1 = use_conv1
        self.conv1 = nn.LazyConv2d(num_channels, kernel_size=3, padding=1, stride=strides)
        self.bn1 = nn.LazyBatchNorm2d()
        self.relu1 = nn.ReLU()
        self.conv2 = nn.LazyConv2d(num_channels, kernel_size=3, padding=1)
        self.bn2 = nn.LazyBatchNorm2d()
        self.relu2 = nn.ReLU()
        if self.use_conv1:
            self.conv3 = nn.LazyConv2d(num_channels, kernel_size=1, stride=strides)
        else:
            self.conv3 = None

    def forward(self, X):
        identy = X
        Y = self.conv1(X)
        Y = self.bn1(Y)
        Y = self.relu1(Y)
        Y = self.conv2(Y)
        Y = self.bn2(Y)
        Y = self.relu2(Y)
        if self.use_conv1:
            identy = self.conv3(X)
        Y += identy
        return F.relu(Y)


class ResNet(nn.Module):
    def __init__(self, num_residuals, num_classes):
        super(ResNet, self).__init__()
        self.fist_blk = nn.Sequential(
            nn.LazyConv2d(64, kernel_size=7, stride=2, padding=3),
            nn.LazyBatchNorm2d(),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        self.residual_blk = self.residual_creation(out_channels= (64,64,128,128,256,256,512,512))
        self.output_blk = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.LazyLinear(num_classes)
        )

    def forward(self, X):
        X = self.fist_blk(X)
        X = self.residual_blk(X)
        X = self.output_blk(X)
        return X

    @staticmethod
    def residual_creation(out_channels):
        """
        architecture =
        """
        blk = []
        for i, num_channel in enumerate(out_channels):
            if i != 0 and i % 2 == 0:
                blk.append(Residual(num_channel, use_conv1=True, strides=2))
            else:
                blk.append(Residual(num_channel, use_conv1=False))
        return nn.Sequential(*blk)


class RestNet18(ResNet):
    def __init__(self):
        super(RestNet18, self).__init__((64,64,128,128,256,256,512,512), 10)

In [35]:
model = RestNet18()

In [36]:
X = torch.rand(1, 1, 96, 96)
model(X)

tensor([[-0.7926,  0.7759, -1.5305, -0.0628, -0.1914,  0.9099, -0.6234,  0.3404,
         -0.6080, -0.1081]], grad_fn=<AddmmBackward0>)

In [37]:
model

RestNet18(
  (fist_blk): Sequential(
    (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (residual_blk): Sequential(
    (0): Residual(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU()
    )
    (1): Residual(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU()
      (conv2): Conv2d(64, 64, kern