In [1]:
import torch
import torch.nn as nn
from model import Resnet, ResidualCNNLayer

In [2]:
image_size = 32
input_channels = 3
output_channels = [16, 32, 64]
blocks = 3
input_kernel = 3
classes = 10

model = Resnet(
    image_size=image_size, 
    in_channels=input_channels, 
    out_channels=output_channels, 
    blocks=blocks, 
    input_kernel_size=input_kernel, 
    n_classes=classes
)

In [3]:
model

Resnet(
  (first_conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (residual_blocks): ModuleList(
    (0-2): 3 x ResidualCNNLayer(
      (layers): ModuleList(
        (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (residual_projection): Identity()
    )
    (3): ResidualCNNLayer(
      (layers): ModuleList(
        (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True

In [4]:
x = torch.rand(3, input_channels, image_size, image_size)
x.shape

torch.Size([3, 3, 32, 32])

In [6]:
out = model(x)
out.shape

torch.Size([3, 10])

In [7]:
conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
conv1

Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

In [8]:
x = torch.rand(3, input_channels, image_size, image_size)
conv1(x).shape

torch.Size([3, 16, 32, 32])

In [9]:
reslayer1 = ResidualCNNLayer(16, 3, 2, downsample=False)
reslayer2 = ResidualCNNLayer(16, 3, 2, downsample=True)
reslayer3 = ResidualCNNLayer(32, 3, 2, downsample=True)

TypeError: ResidualCNNLayer.__init__() got an unexpected keyword argument 'downsample'

In [33]:
x = torch.rand(3, input_channels, image_size, image_size)
x = conv1(x)
x = reslayer1(x)
print(x.shape)
x = reslayer2(x)
print(x.shape)
x = reslayer3(x)

torch.Size([3, 16, 32, 32])
torch.Size([3, 32, 16, 16])


In [34]:
x.shape

torch.Size([3, 64, 8, 8])