In [2]:
import torch
from torch import nn

In [3]:
model = nn.Sequential(
    nn.Conv2d(3, 64, 3,bias=False),
    nn.BatchNorm2d(64),
    nn.Conv2d(64, 64, 3,bias=False),
    nn.MaxPool2d(2, 2)
)

In [4]:
model[-1]

MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

In [58]:
sample = torch.zeros((2, 3, 32, 32))
sample.shape

torch.Size([2, 3, 32, 32])

In [59]:
model(sample).shape

torch.Size([2, 64, 28, 28])

In [44]:
h = (h_w[0] + (2 * pad[0]) - (dilation * (kernel_size[0] - 1)) - 1)// stride[0] + 1

In [45]:
model[0].in_channels, model[0].out_channels, model[0].kernel_size[0]

(3, 64, 3)

In [46]:
model[0].padding[0], model[0].stride[0]

(0, 1)

In [25]:
nn.MaxPool2d, nn.AvgPool2d

(torch.nn.modules.pooling.MaxPool2d, torch.nn.modules.pooling.AvgPool2d)

In [47]:
def calc(h_w, layer):
    return (h_w + (2 * layer.padding[0]) - (1 * (layer.kernel_size[0] - 1)) - 1)// layer.stride[0] + 1

In [48]:
calc(32, model[0])

30

In [49]:
h_w = 32

for layer in model:
    h_w = calc(h_w, layer)
    print(h_w)

30
28


In [52]:
for layer in model:
    print(type(layer).__name__)

Conv2d
Conv2d


In [32]:


class FashionMNIST_CNN(nn.Module):
    def __init__(self, channels=1, img_dim=28, outneurons=10, last_hidden_neurons=40, first_layer_norm=False,
        weight_init="kaiming_uniform", bias=True, dropout=0.0, batchnorm=True):

        super(FashionMNIST_CNN, self).__init__()

        self.channels = channels
        self.img_dim = img_dim
        self.in_features = channels * img_dim * img_dim
        self.num_classes = outneurons
        self.last_hidden_neurons = last_hidden_neurons
        self.dropout_p = dropout
        self.batchnorm = batchnorm
        self.first_layer_norm = first_layer_norm

        # uniform(-1/sqrt(in_features), 1/sqrt(in_features))
        weights = {
            "normal": nn.init.normal_,
            "xavier": nn.init.xavier_normal_,
            "xavier_uniform": nn.init.xavier_uniform_,
            "kaiming": nn.init.kaiming_normal_,
            "kaiming_uniform": nn.init.kaiming_uniform_,
        }

        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.scaleInputs = nn.BatchNorm2d(channels)
        self.dropout_l = nn.Dropout(self.dropout_p)


        self.conv = nn.Sequential(
            nn.Conv2d(self.channels, 128, 3, bias=bias),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            nn.AvgPool2d(2, 2),
            
            nn.Conv2d(128, 128, 3, bias=bias),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            nn.Conv2d(128, 128, 3, bias=bias),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            nn.Conv2d(128, 128, 3, bias=bias),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.Conv2d(128, 64, 2, bias=bias),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.Conv2d(64, 64, 2, bias=bias),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        in_features_fc, last_conv_out_feature = self.conv_params()

        self.linear = nn.Sequential(
            nn.Flatten(),
            nn.Dropout1d(self.dropout_p),
            nn.Linear(last_conv_out_feature * in_features_fc * in_features_fc, 300, bias=bias),
            nn.BatchNorm1d(300),
            nn.ReLU(),
            nn.Dropout1d(self.dropout_p),
            nn.Linear(300, last_hidden_neurons, bias=bias)
        )

        self.bn = nn.BatchNorm1d(last_hidden_neurons)

        self.output = nn.Linear(last_hidden_neurons, outneurons, bias=bias)


        if weight_init:
            self.__weight_init(weights[weight_init], bias)


    def forward(self, x):

        x = self._train(x)
        if self.batchnorm: x = self.bn(x)
        x = self.relu(x)
        x = self.output(x)

        return x


    def _train(self, x):

        if self.first_layer_norm: x = self.scaleInputs(x)

        x = self.conv(x)
        x = self.linear(x)

        return x

    def output_last_layer(self, x):

        x = self._train(x)
        out = x.clone().detach()
        if self.batchnorm: x = self.bn(x)
        x = self.relu(x)
        x = self.output(x)

        return out, x


    def _sum_weights(self):

        total_weights = 0
        for _, p in self.named_parameters():
            total_weights += p.sum()

        return total_weights.item()

    def _sum_abs_weights(self):

        total_weights = 0
        for _, p in self.named_parameters():
            total_weights += p.abs().sum()

        return total_weights.item()

    def _l1_regularization(self, alpha=1e-3):

        total_weights = 0
        for _, p in self.named_parameters():
            total_weights += p.abs().sum()

        return alpha * total_weights

    def _l2_regularization(self, lambd=1e-3):

        total_weights = 0
        for _, p in self.named_parameters():
            total_weights += p.pow(2).sum()

        return lambd * total_weights

    def _elastic_regularization(self, lambd=1e-3, alpha=1e-3):
        return self._l2_regularization(lambd) + self._l1_regularization(alpha)
    
    
    def __weight_init(self, fn, bias):

        for m in self.modules():

            if (
                isinstance(m, nn.Linear)
                or isinstance(m, nn.Conv2d)
                or isinstance(m, nn.Conv1d)
               ):
                fn(m.weight)

                if bias:
                    nn.init.constant_(m.bias, 0)

            if isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                
    
    def _calc1(self, h_w, layer):
        return (h_w + (2 * layer.padding[0]) - (1 * (layer.kernel_size[0] - 1)) - 1)// layer.stride[0] + 1
    
    def _calc2(self, h_w, layer):
        return (h_w + (2 * layer.padding) - (1 * (layer.kernel_size - 1)) - 1)// layer.stride + 1

    def conv_params(self):

        h_w = self.img_dim
        last_conv_out_feature = 0

        for m in self.modules():

            if isinstance(m, nn.Conv2d):
                last_conv_out_feature = m.out_channels
                h_w = self._calc1(h_w, m)
            if isinstance(m, nn.MaxPool2d) or isinstance(m, nn.AvgPool2d):
                h_w = self._calc2(h_w, m)

        return h_w, last_conv_out_feature


In [33]:
model = FashionMNIST_CNN()

model.conv_params()

(5, 64)

In [34]:
model.conv(torch.zeros((2, 1, 28, 28))).shape

torch.Size([2, 64, 5, 5])

In [35]:
model(torch.zeros((2, 1, 28, 28))).shape

torch.Size([2, 10])