In [4]:
import torch
from torch import nn


def conv_block(in_f,out_f,*args,**kwargs):
    return nn.Sequential(
        nn.Conv2d(in_f,out_f,*args,**kwargs),
        nn.BatchNorm2d(out_f),
        nn.ReLU()
    )


In [5]:
class MyCNNClassifier(nn.Module):
    def __init__(self,in_c,n_classes) -> None:
        super().__init__()
        self.encoder = nn.Sequential(
            conv_block(in_c,32,kernel_size=3,stride=1,padding=1),
            conv_block(32,64,kernel_size=3,stride=1,padding=1),
        )
        self.decoder = nn.Sequential(
            nn.Linear(64*28*28,1024),
            nn.Sigmoid(),
            nn.Linear(1024,n_classes)
        )
    def forward(self,X: torch.Tensor)->torch.Tensor:
        output= self.encoder(X)
        output= output.view(output.shape[0],-1)
        output= self.decoder(output)
        return output



In [12]:
model= MyCNNClassifier(1,10)
model

MyCNNClassifier(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
  )
  (decoder): Sequential(
    (0): Linear(in_features=50176, out_features=1024, bias=True)
    (1): Sigmoid()
    (2): Linear(in_features=1024, out_features=10, bias=True)
  )
)

## Summary of our model

In [14]:
from torchinfo import summary
summary(model,input_size=(1,1,28,28))

Layer (type:depth-idx)                   Output Shape              Param #
MyCNNClassifier                          [1, 10]                   --
├─Sequential: 1-1                        [1, 64, 28, 28]           --
│    └─Sequential: 2-1                   [1, 32, 28, 28]           --
│    │    └─Conv2d: 3-1                  [1, 32, 28, 28]           320
│    │    └─BatchNorm2d: 3-2             [1, 32, 28, 28]           64
│    │    └─ReLU: 3-3                    [1, 32, 28, 28]           --
│    └─Sequential: 2-2                   [1, 64, 28, 28]           --
│    │    └─Conv2d: 3-4                  [1, 64, 28, 28]           18,496
│    │    └─BatchNorm2d: 3-5             [1, 64, 28, 28]           128
│    │    └─ReLU: 3-6                    [1, 64, 28, 28]           --
├─Sequential: 1-2                        [1, 10]                   --
│    └─Linear: 2-3                       [1, 1024]                 51,381,248
│    └─Sigmoid: 2-4                      [1, 1024]                 --
│