In [5]:
from torch import nn

import torch



class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(stride, stride), padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if in_channels != out_channels or stride != 1:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1), stride=(stride, stride), bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out

class SimpleCNN(nn.Module):

    def __init__(self, input_channels, hidden_channels, output_channels):
        super(SimpleCNN, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.output_channels = output_channels

        last_channel = input_channels
        for hidden_channel in hidden_channels:
            setattr(self, "resblock_{}".format(hidden_channel),
                    ResidualBlock(last_channel, hidden_channel))
            last_channel = hidden_channel

    def forward(self, x):
        # Pass input 'x' through each of the ResnetBlock layers
        for hidden_channel in self.hidden_channels:
            x = getattr(self, "resblock_{}".format(hidden_channel))(x)
  
        # You can add additional layers or operations here if needed
        
        return x


# Example usage:
input_channels = 13
hidden_channels = [64, 128, 256]
output_channels = 1

model = SimpleCNN(input_channels, hidden_channels, output_channels)
input_data = torch.randn(1, input_channels, 32, 32)  # Example input data
sum((p.numel() for p in model.parameters()))

1194752