In [None]:
import torch
from torch import nn
from torch import functional as F

In [None]:
def conv_block(in_f, out_f, *args, **kwargs):
  return nn.Sequential(nn.Conv2d(in_f, out_f, *args, **kwargs),
                       nn.BatchNorm2d(out_f),
                       nn.ReLU()
                       )
  
def dec_block(in_f, out_f):
  return nn.Sequential(nn.Linear(in_f, out_f),
                       nn.Sigmoid()
                       )
  
class MyClassifier(nn.Module):
  def __init__(self, in_c, enc_sizes, dec_sizes, n_classes):
    super().__init__()
    self.enc_sizes = [in_c, *enc_sizes]
    print(self.enc_sizes)
    self.dec_sizes = [32*28*28, *dec_sizes]
    print(self.dec_sizes)

    for in_f, out_f in zip(self.enc_sizes, self.enc_sizes[1:]):
      print(in_f, out_f)

    conv_blocks = [conv_block(in_f, out_f, kernel_size=3, padding=1) for in_f, out_f in zip(self.enc_sizes, self.enc_sizes[1:])]

    self.encoder = nn.Sequential(*conv_blocks)

    dec_blocks = [dec_block(in_f, out_f) 
                       for in_f, out_f in zip(self.dec_sizes, self.dec_sizes[1:])]
        
    self.decoder = nn.Sequential(*dec_blocks)

    self.last = nn.Linear(self.dec_sizes[-1], n_classes)

  def forward(self, x):
    x = self.encoder(x)
    x = x.view(x.size(0), -1)
    x = self.decoder(x)
    x = last(x)
    return x


In [None]:
model = MyClassifier(1, [32,64], [1024,512], 10)
print(model)

[1, 32, 64]
[25088, 1024, 512]
1 32
32 64
MyClassifier(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
  )
  (decoder): Sequential(
    (0): Sequential(
      (0): Linear(in_features=25088, out_features=1024, bias=True)
      (1): Sigmoid()
    )
    (1): Sequential(
      (0): Linear(in_features=1024, out_features=512, bias=True)
      (1): Sigmoid()
    )
  )
  (last): Linear(in_features=512, out_features=10, bias=True)
)


In [None]:
l1 = [1 , 32, 64]
l2 = [32, 64]

print(zip(l1,l2))

<zip object at 0x7f6883630aa0>


In [None]:
for i in zip(l1, l2):
  print(i)

(1, 32)
(32, 64)
