In [None]:
from torch import nn
from common import summary, load_data_from_mnist, train


In [None]:
net = nn.Sequential(
    # use a larger window, stride and number of output channels
    nn.Conv2d(1, 96, kernel_size=11, stride=4, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2),
    # make the window smaller, set padding to 2 for consistent
    # height and width, and increase the number of output channels
    nn.Conv2d(96, 256, kernel_size=5, padding=2),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2),
    # use three successive convolutional layers and a smaller window
    nn.Conv2d(256, 384, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.Conv2d(384, 384, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.Conv2d(384, 256, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nn.Flatten(),
    # use the dropout layer to mitigate overfitting
    nn.Linear(6400, 4096),
    nn.ReLU(),
    nn.Dropout(p=0.5),
    nn.Linear(4096, 4096),
    nn.ReLU(),
    nn.Dropout(p=0.5),
    nn.Linear(4096, 10),
)

summary(net, (1, 1, 224, 224))


In [None]:
train(net, *load_data_from_mnist(resize=224), 1, 0.05, "ch6")
