# NIN

In [None]:
import torch
from torch import nn
from d2l import torch as d2l

In [None]:
def nin_block(in_channels, out_channels, kernel_size, strides, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, stride=strides, padding=padding, kernel_size=kernel_size),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1),
        nn.ReLU()
    )

In [None]:
net = nn.Sequential(
    nin_block(1, 96, 11, 4, 0),
    nn.MaxPool2d(3, stride=2),
    nin_block(96, 256, kernel_size=5, strides=1, padding=2),
    nn.MaxPool2d(3, stride=2),
    nin_block(256, 384, kernel_size=3, strides=1, padding=1),
    nn.MaxPool2d(3, stride=2),
    nn.Dropout(p=0.5),
    nin_block(384, 10, kernel_size=3, strides=1, padding=1),
    nn.AdaptiveAvgPool2d((1,1)),
    nn.Flatten()
)

In [None]:
X = torch.rand(1,1,224,224)
for blk in net:
    X = blk(X)
    print(f"{blk.__class__.__name__} {X.shape}")

In [None]:
lr, num_epochs, batch_size = 0.03, 10, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size, resize=224)

In [None]:
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())