In [1]:
import torch
import torch.nn as nn

In [None]:
class ConvBlock(nn.Module):
    
    def __init__(
        self,
        n_in: int,
        n_out: int,
        kernel_size: int or tuple=3,
        stride: int or tuple=1,
        padding: int or tuple=0,
        padding_mode: str="reflect"
        dilation: int or tuple=1,
        groups: int=1,
        bias: bool=False
    ) -> None:
        super().__init__()

        self.n_in = n_in
        self.n_out = n_out
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.padding_mode = padding_mode
        self.dilation = dilation
        self.groups = groups
        self.bias = bias

        self.conv = self._convolution_block()
        self.pool = self._pool()

    def _convolution_block(self) -> None:
        return nn.Sequential(
            nn.Conv2d(
                in_channels=self.n_in,
                out_channels=self.n_out,
                kernel_size=self.kernel_size,
                stride=self.stride,
                padding=self.padding,
                padding_mode=self.padding_mode,
                dilation=self.dilation,
                groups=self.groups,
                bias=self.bias
            ),
            nn.BatchNorm2d(self.n_out),
            nn.PReLu(),
            nn.Conv2d(
                in_channels=self.n_in,
                out_channels=self.n_out,
                kernel_size=self.kernel_size,
                stride=self.stride,
                padding=self.padding,
                padding_mode=self.padding_mode,
                dilation=self.dilation,
                groups=self.groups,
                bias=self.bias
            ),
            nn.BatchNorm2d(self.n_out)
        )

    def _pool(self) -> None:
        return nn.MaxPool2d(
            kernel_size=self.kernel_size,
            stride=self.stride,
            padding=self.padding,
            dilation=self.dilation
        )

    def forward(self, x):
        return x + self.conv(x)


class ResNet(nn.Module):

    def __init__(self) -> None:
        super().__init__()

        self.pre_block = nn.Sequential(
            nn.Conv2d(
                in_channels=1, 
                out_channels=64,
                kernel_size=3,
                stride=1,
                padding=1
            ),
            nn.PReLU()
        )

        self.blocks = nn.Sequential(
            ConvBlock(64, 64),
            ConvBlock(64, 64),
            ConvBlock(128, 128, stride=2),
            ConvBlock(128, 128),
            ConvBlock(256, 256, stride=2),
            ConvBlock(256, 256),
            ConvBlock(512, 512, stride=2),
            ConvBlock(512, 512),
        )

        self.final_block = nn.Sequential(
            nn.Conv2d(
                in_channels=512, 
                out_channels=64,
                kernel_size=3,
                stride=1,
                padding=1
            ),
            nn.BatchNorm2d(64)
        )