In [18]:
from d2l import torch as d2l
import torch
import torch.nn as nn
import torch.nn.functional as F

In [22]:
class Inception(nn.Module):
    def __init__(self, num_inputs, c1, c2, c3, c4):
        super().__init__()
        self.num_inputs = num_inputs
        self.c1 = c1
        self.c2 = c2
        self.c3 = c3
        self.c4 = c4
        
        self.p1 = nn.Conv2d(self.num_inputs, self.c1, kernel_size=1)
        self.p2 = nn.Sequential(
            nn.Conv2d(self.num_inputs, self.c2[0], kernel_size=1), nn.ReLU(),
            nn.Conv2d(self.c2[0], self.c2[1], kernel_size=3, padding=1), nn.ReLU()
        )
        self.p3 = nn.Sequential(
            nn.Conv2d(self.num_inputs, self.c3[0], kernel_size=1), nn.ReLU(),
            nn.Conv2d(self.c3[0], self.c3[1], kernel_size=5, padding=2), nn.ReLU()
        )
        self.p4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, padding=1),
            nn.Conv2d(self.num_inputs, self.c4, kernel_size=1), nn.ReLU(),
        )
        
    def forward(self, X):
        X = X.view((-1, 1, 224, 224))
        
        op1 = F.relu(self.p1(X))
        op2 = F.relu(self.p2(X))
        op3 = F.relu(self.p3(X))
        op4 = F.relu(self.p4(X))
        
        return torch.cat([op1, op2, op3, op4], dim=1)

In [26]:
X = torch.rand((1, 1, 224, 224))

for layer in Inception(1, 1, (1, 1), (1, 1), 1):
    X = layer(X)
    print(f'{layer.__class__.__name__}: {X.shape}')

TypeError: 'Inception' object is not iterable

In [35]:
Inception(1, 1, (1, 1), (1, 1), 1)

Inception(
  (p1): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
  (p2): Sequential(
    (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
  )
  (p3): Sequential(
    (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(1, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): ReLU()
  )
  (p4): Sequential(
    (0): MaxPool2d(kernel_size=3, stride=3, padding=1, dilation=1, ceil_mode=False)
    (1): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
    (2): ReLU()
  )
)