In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from d2l import torch as d2l

## 加载数据

In [None]:
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)

## 定义模型

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride, use_1x1_conv):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, stride, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels)
        )
        if use_1x1_conv:
            self.conv3 = nn.Conv2d(in_channels, out_channels, 1, stride)
        else:
            self.conv3 = None

    def forward(self, X):
        Y = self.conv1(X)
        Y = self.conv2(Y)
        if self.conv3:
            X = self.conv3(X)
        Y = F.relu(Y + X)
        return Y

In [None]:
class ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),  # [64, 112, 112]
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  # [64, 56, 56]
        )
        self.block2 = nn.Sequential(
            ResidualBlock(64, 64, 1, use_1x1_conv=False),  # [64, 56, 56]
            ResidualBlock(64, 64, 1, use_1x1_conv=False),  # [64, 56, 56]
        )
        self.block3 = nn.Sequential(
            ResidualBlock(64, 128, 2, use_1x1_conv=True),  # [128, 28, 28]
            ResidualBlock(128, 128, 1, use_1x1_conv=False),  # [128, 28, 28]
        )
        self.block4 = nn.Sequential(
            ResidualBlock(128, 256, 2, use_1x1_conv=True),  # [256, 14, 14]
            ResidualBlock(256, 256, 1, use_1x1_conv=False),  # [256, 14, 14]
        )
        self.block5 = nn.Sequential(
            ResidualBlock(256, 512, 2, use_1x1_conv=True),  # [512, 7, 7]
            ResidualBlock(512, 512, 1, use_1x1_conv=False),  # [512, 7, 7]
        )
        self.block6 = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),  # [512, 1, 1]
            nn.Flatten(),  # [512]
            nn.Linear(512, 10)  # [10]
        )

    def forward(self, X):
        X = X.reshape(-1, 1, 224, 224)
        X = self.block1(X)
        X = self.block2(X)
        X = self.block3(X)
        X = self.block4(X)
        X = self.block5(X)
        y = self.block6(X)
        return y

## 训练

In [None]:
# 模型准确率评估函数
def evaluate_accuracy(model, test_iter, device):
    metrics = d2l.Accumulator(2)
    model.to(device)
    model.eval()
    for X, y in test_iter:
        X, y = X.to(device), y.to(device)
        y_hat = model(X)
        metrics.add(d2l.accuracy(y_hat, y), y.numel())
    return metrics[0] / metrics[1]

In [None]:
# 训练函数
def train(model, train_iter, test_iter, loss_fn, optimizer, num_epochs, device):
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
                            legend=['train_loss', 'train_acc', 'test_acc'])
    metrics = d2l.Accumulator(3)
    timer = d2l.Timer()
    model.to(device)
    model.train()
    for epoch in range(num_epochs):
        metrics.reset()
        timer.start()
        for X, y in train_iter:
            X, y = X.to(device), y.to(device)
            y_hat = model(X)
            loss = loss_fn(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                metrics.add(loss * y.numel(), d2l.accuracy(y_hat, y), y.numel())
        timer.stop()
        print(f'speed: {metrics[2] / timer.times[-1]:.1f} samples/sec')
        test_acc = evaluate_accuracy(model, test_iter, d2l.try_gpu())
        print(f'train_acc: {metrics[1] / metrics[2]:.3f}, test_acc: {test_acc}')
        animator.add(epoch + 1,
                     (metrics[0] / metrics[2], metrics[1] / metrics[2], test_acc))
    print(f'speed: {metrics[2] / timer.avg():.1f} samples/sec on {device}.')
    print(f'train_acc: {metrics[1] / metrics[2]:.3f}, test_acc: {test_acc}.')

In [None]:
lr = 0.001
num_epochs = 10

model = ResNet()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr)

train(model,  train_iter, test_iter, loss_fn, optimizer, num_epochs, d2l.try_gpu())

In [None]:
model.block1[0].weight.data