# 4.NiN
## 4.1 加载包

In [1]:
import time
import torch
from torch import nn, optim

import sys
sys.path.append("..") 
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(torch.__version__)
print(device)

1.7.0
cuda


## 4.2 定义 NiN块

In [2]:
def nin_block(in_channels, out_channels, kernel_size, stride, padding):
    blk = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
                        nn.ReLU(),
                        nn.Conv2d(out_channels, out_channels, kernel_size=1),
                        nn.ReLU(),
                        nn.Conv2d(out_channels, out_channels, kernel_size=1),
                        nn.ReLU())
    return blk

## 4.3 定义 NiN网络

In [3]:
net = nn.Sequential(
    nin_block(1, 96, kernel_size=11, stride=4, padding=0),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nin_block(96, 256, kernel_size=5, stride=1, padding=2),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nin_block(256, 384, kernel_size=3, stride=1, padding=1),
    nn.MaxPool2d(kernel_size=3, stride=2), 
    nn.Dropout(0.5),
    # 标签类别数是10
    nin_block(384, 10, kernel_size=3, stride=1, padding=1),
    # 全局平均池化层可通过将窗口形状设置成输入的高和宽实现
    nn.AvgPool2d(kernel_size=5),
    # 将四维的输出转成二维的输出，其形状为(批量大小, 10)
    d2l.FlattenLayer())

### Ex:观察样本变化

In [1]:
# X = torch.rand(1, 1, 224, 224)

# for name, blk in net.named_children(): 
#     X = blk(X)
#     print(name, 'output shape: ', X.shape)

## 4.4 获取数据和训练模型
> 调用数据读取函数d2l.load_data_fashion_mnist、准确率计算函数d2l.evaluate_accuracy及训练函数d2l.train_ch5,如需修改可到d2l的util包中修改

In [5]:
batch_size = 128
# 如出现“out of memory”的报错信息，可减小batch_size或resize
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)

lr, num_epochs = 0.002, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

training on  cuda
epoch 1, loss 1.5668, train acc 0.423, test acc 0.637, time 52.1 sec
epoch 2, loss 0.5598, train acc 0.802, test acc 0.834, time 51.7 sec
epoch 3, loss 0.4320, train acc 0.839, test acc 0.834, time 51.6 sec
epoch 4, loss 0.3900, train acc 0.854, test acc 0.855, time 51.7 sec
epoch 5, loss 0.3639, train acc 0.865, test acc 0.870, time 52.1 sec
