In [2]:
from numpy import floor
from torch import nn

class ConvolutionalModel(nn.Module):

    def __init__(self, in_channels, in_width, conv1_channels, pool1_width, conv2_channels, pool2_width, fc3_width, fc4_width, class_count):
        super(ConvolutionalModel, self).__init__()
                                                                                                                # in_channels x in_width x in_width
        self.conv1 = nn.Conv2d(in_channels, conv1_channels, kernel_size=5, stride=1, padding=2, bias=True)      # conv1_channels x in_width x in_width
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(pool1_width, stride=2)                                                        # conv1_channels x w2 x w2

        w2 = floor((in_width - pool1_width) / 2 + 1)

        self.conv2 = nn.Conv2d(conv1_channels, conv2_channels, kernel_size=5, stride=1, padding=2, bias=True)   # conv2_channels x w2 x w2
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(pool2_width, stride=2)                                                        # conv2_channels x w3 x w3
        
        w3 = floor((w2 - pool2_width) / 2 + 1)

        self.flatten3 = nn.Flatten()                                                                            # (conv2_channels x w3 x w3)
        self.fc3 = nn.Linear((int)(conv2_channels * w3 * w3), fc3_width)                                        # fc3width
        self.relu3 = nn.ReLU()

        self.fc4 = nn.Linear(fc3_width, fc4_width)                                                              # fc4width
        self.relu4 = nn.ReLU()                                                       

        self.fc_logits = nn.Linear(fc4_width, class_count)                                                      # class_count

        self.reset_parameters()

    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear) and m is not self.fc_logits:
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                nn.init.constant_(m.bias, 0)
        self.fc_logits.reset_parameters()

    def forward(self, x):
        h = self.conv1(x)
        h = self.relu1(h)
        h = self.pool1(h)

        h = self.conv2(h)
        h = self.relu2(h)
        h = self.pool2(h)
        
        h = self.flatten3(h)
        h = self.fc3(h)
        h = self.relu3(h)

        h = self.fc4(h)
        h = self.relu4(h)

        logits = self.fc_logits(h)
        return logits

In [3]:
import torch

data_channels = 3
data_width = 32
CifarCnn = ConvolutionalModel(data_channels, data_width, 16, 3, 32, 3, 256, 128, 10)

inp = torch.randn(1, 3, 32, 32)
out = CifarCnn(inp)
print(out)

tensor([[ 2.3714, -2.5116, -0.8904, -1.0978, -1.0536, -0.9353,  0.3170, -2.0879,
          0.4633,  0.7313]], grad_fn=<AddmmBackward0>)
