In [1]:
import os.path

import torch
from torch import nn, optim

import numpy as np

% matplotlib inline
import matplotlib.pyplot as plt

from mnist_cnn import *
from utils import *

In [2]:
DISABLE_CUDA = False

if not DISABLE_CUDA and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [4]:
dataset = MNIST(
    os.path.join(*'data/MNIST'.split('/')),
    batch_size=32, val_size=.2
)

In [5]:
def mnist_model():
    feature_model = nn.Sequential( # 1, 28, 28
        OrderedDict([
            ('conv1', nn.Conv2d(1, 16, kernel_size=5, stride=3, padding=2, bias=False)), # 16, 10, 10
            ('conv1_bn', nn.BatchNorm2d(16)),
            ('conv1_relu', nn.ReLU()),
            ('conv2', nn.Conv2d(16, 32, kernel_size=3, stride=1, bias=False)), # 32, 8, 8
            ('conv2_bn', nn.BatchNorm2d(32)),
            ('conv2_relu', nn.ReLU()),
            ('conv3', nn.Conv2d(32, 64, kernel_size=3, stride=1, bias=False)), # 64, 6, 6
            ('conv3_bn', nn.BatchNorm2d(64)),
            ('conv3_relu', nn.ReLU())
        ])
    )

    classifier_model = nn.Sequential(
        OrderedDict([
            ('dense1', nn.Linear(64 * 6 * 6, 128, bias=False)),
            ('dense1_bn', nn.BatchNorm1d(128)),
            ('dense1_relu', nn.ReLU()),
            ('dense1_dropout', nn.Dropout()),
            ('output', nn.Linear(128, 10)),
        ])
    )

    model = nn.Sequential(
        OrderedDict([
            ('features', feature_model),
            ('flatten', Flatten()),
            ('classifier', classifier_model)
        ])
    )
    
    return model

In [6]:
N_EPOCHS = 25

model = mnist_model()
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(1, N_EPOCHS + 1):
    model.train()
    train_loss = 0
    train_accuracy = 0
    
    for X, y, label in dataset.train_loader:
        X = X.to(device); y = y.to(device)
        
        model.zero_grad()
        optimizer.zero_grad()
        
        pred = model(X)
        loss = criterion(pred, y)
        
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        train_accuracy += (torch.argmax(pred, 1) == y).sum()
        
    train_accuracy = train_accuracy.item() / dataset.train_size
        
        
    model.eval()
    val_loss = 0
    val_accuracy = 0
    
    with torch.no_grad():
        for X, y, label in dataset.val_loader:
            X = X.to(device); y = y.to(device)

            pred = model(X)
            loss = criterion(pred, y)

            val_loss += loss.item()
            val_accuracy += (torch.argmax(pred, 1) == y).sum()
        
    val_accuracy = val_accuracy.item() / dataset.val_size
    
    print('Epoch %.2d: train_loss = %.3f, train_accuracy = %.3f, val_loss = %.3f, val_accuracy = %.3f' % (
        epoch, train_loss, train_accuracy, val_loss, val_accuracy
    ))

Epoch 01: train_loss = 322.671, train_accuracy = 0.950, val_loss = 17.679, val_accuracy = 0.987
Epoch 02: train_loss = 104.895, train_accuracy = 0.980, val_loss = 14.042, val_accuracy = 0.989
Epoch 03: train_loss = 78.465, train_accuracy = 0.984, val_loss = 13.018, val_accuracy = 0.989
Epoch 04: train_loss = 62.800, train_accuracy = 0.988, val_loss = 12.603, val_accuracy = 0.991
Epoch 05: train_loss = 53.344, train_accuracy = 0.989, val_loss = 10.926, val_accuracy = 0.991
Epoch 06: train_loss = 48.640, train_accuracy = 0.990, val_loss = 12.144, val_accuracy = 0.991
Epoch 07: train_loss = 40.582, train_accuracy = 0.991, val_loss = 11.432, val_accuracy = 0.991
Epoch 08: train_loss = 36.592, train_accuracy = 0.992, val_loss = 10.457, val_accuracy = 0.992
Epoch 09: train_loss = 31.947, train_accuracy = 0.993, val_loss = 11.282, val_accuracy = 0.992
Epoch 10: train_loss = 31.277, train_accuracy = 0.994, val_loss = 10.243, val_accuracy = 0.992
Epoch 11: train_loss = 27.644, train_accuracy = 

In [7]:
model_filename = os.path.join('models', 'mnist.pt')
torch.save(model.state_dict(), model_filename)

In [8]:
model.eval()
test_accuracy = 0

with torch.no_grad():
    for X, y, label in dataset.test_loader:
        X = X.to(device); y = y.to(device)

        pred = model(X)
        test_accuracy += (torch.argmax(pred, 1) == y).sum()

test_accuracy = test_accuracy.item() / dataset.test_size

print('Test accuracy: %.3f' % test_accuracy)
print('Test error rate: %.3f' % (1 - test_accuracy))

Test accuracy: 0.993
Test error rate: 0.007
