In [2]:
import torch
from typing import List

In [8]:
class AllCNN(torch.nn.Module):
    def __init__(self, base_dim: int, exponents: List[int], activation_name: str, num_class: int):
        super().__init__()
        self.first = torch.nn.Sequential(torch.nn.Conv2d(3, base_dim, 3, padding=1, bias=False),
                                         torch.nn.BatchNorm2d(base_dim),
                                         torch.nn.ReLU())
        self.blocks = torch.nn.ModuleList([
            torch.nn.Sequential(torch.nn.Conv2d(base_dim*2**i, base_dim*2**i, 3, padding=1, bias=False),
                                torch.nn.BatchNorm2d(base_dim*2**i),
                                torch.nn.ReLU(),
                                torch.nn.Conv2d(base_dim*2**i, base_dim*2**(i+1), 3, padding=1, bias=False, stride=2),
                                torch.nn.BatchNorm2d(base_dim*2**(i+1)),
                                torch.nn.ReLU())
                        for i in exponents])
        last_exp = exponents[-1] + 1
        self.final_layer = torch.nn.Sequential(
            torch.nn.Conv2d(base_dim*2**last_exp, base_dim*2**last_exp, 3, padding=1, bias=False),
            torch.nn.BatchNorm2d(base_dim*2**last_exp),
            torch.nn.ReLU(),
            torch.nn.Conv2d(base_dim*2**last_exp, base_dim*2**last_exp, 1, padding=0, bias=False),
            torch.nn.BatchNorm2d(base_dim*2**last_exp),
            torch.nn.ReLU(),
            torch.nn.Conv2d(base_dim*2**last_exp, num_class, 1, padding=0, bias=False))
        self.avg = torch.nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        x = self.first(x)
        for block in self.blocks:
            x = block(x)
        x = self.final_layer(x)
        x = self.avg(x)
        return x

In [17]:
model = AllCNN(base_dim=32, exponents=[], activation_name='relu', num_class=10)

IndexError: list index out of range

In [16]:
model(torch.randn(1, 3, 32, 32)).squeeze([-1, -2]).shape

torch.Size([1, 10])

In [18]:
model

AllCNN(
  (first): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (blocks): ModuleList(
    (0): Sequential(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
  )
  (final_layer): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e