In [None]:
import torch
import torch.nn as nn
import d2l.torch as d2l

## 定义模型

In [None]:
class NiNBlock(nn.Module):
    def __init__(self, in_features, out_features, kernel_size, stride, padding):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_features, out_features, kernel_size, stride, padding),
            nn.ReLU(),
            nn.Conv2d(out_features, out_features, 1),
            nn.ReLU(),
            nn.Conv2d(out_features, out_features, 1),
            nn.ReLU(),
        )

    def forward(self, X):
        Y = self.block(X)
        return Y

In [None]:
class NiN(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = nn.Sequential(
            NiNBlock(1, 96, 11, 4, 1),  # [96, 54, 54]
            nn.MaxPool2d(3, 2),  # [96, 26, 26]
        )
        self.block2 = nn.Sequential(
            NiNBlock(96, 256, 5, 1, 2),  # [256, 26, 26]
            nn.MaxPool2d(3, 2),  # [256, 12, 12]
        )
        self.block3 = nn.Sequential(
            NiNBlock(256, 384, 3, 1, 1),  # [384, 12, 12]
            nn.MaxPool2d(3, 2),  # [384, 5, 5]
        )
        self.block4 = nn.Sequential(
            NiNBlock(384, 10, 3, 1, 1),  # [10, 5, 5]
            nn.AdaptiveAvgPool2d((1, 1)),  # [10, 1, 1]
            nn.Flatten()  # [10]
        )

    def forward(self, X):
        X = X.reshape(-1, 1, 224, 224)
        X = self.block1(X)
        X = self.block2(X)
        X = self.block3(X)
        y = self.block4(X)
        return y