In [1]:
import os.path

import torch
from torch import nn, optim

import numpy as np

% matplotlib inline
import matplotlib.pyplot as plt

from cifar10_cnn import *
from utils import *

In [2]:
DISABLE_CUDA = False

if not DISABLE_CUDA and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [3]:
dataset = CIFAR10(
    os.path.join(*'data/cifar-10-batches-py'.split('/')),
    batch_size=32, val_size=.2
)

In [4]:
def cifar10_model():
    feature_model = nn.Sequential( # 3, 32, 32
        OrderedDict([
            ('conv1', nn.Conv2d(3, 64, kernel_size=3, stride=1, bias=False)), # 64, 30, 30
            ('conv1_bn', nn.BatchNorm2d(64)),
            ('conv1_relu', nn.ReLU()),
            ('conv2', nn.Conv2d(64, 64, kernel_size=3, stride=1, bias=False)), # 64, 28, 28
            ('conv2_bn', nn.BatchNorm2d(64)),
            ('conv2_relu', nn.ReLU()),
            ('maxpool_1', nn.MaxPool2d(2)), # 64, 14, 14
            ('conv3', nn.Conv2d(64, 128, kernel_size=3, stride=1, bias=False)), # 128, 12, 12
            ('conv3_bn', nn.BatchNorm2d(128)),
            ('conv3_relu', nn.ReLU()),
            ('conv4', nn.Conv2d(128, 128, kernel_size=3, stride=1, bias=False)), # 128, 10, 10
            ('conv4_bn', nn.BatchNorm2d(128)),
            ('conv4_relu', nn.ReLU()),
            ('maxpool_2', nn.MaxPool2d(2)), # 128, 5, 5
            ('conv5', nn.Conv2d(128, 256, kernel_size=3, stride=1, bias=False)), # 256, 3, 3
            ('conv5_bn', nn.BatchNorm2d(256)),
            ('conv5_relu', nn.ReLU()),
        ])
    )

    classifier_model = nn.Sequential(
        OrderedDict([
            ('dense1', nn.Linear(256 * 3 * 3, 256, bias=False)),
            ('dense1_bn', nn.BatchNorm1d(256)),
            ('dense1_relu', nn.ReLU()),
            ('dense1_dropout', nn.Dropout()),
            ('dense2', nn.Linear(256, 128, bias=False)),
            ('dense2_bn', nn.BatchNorm1d(128)),
            ('dense2_relu', nn.ReLU()),
            ('dense2_dropout', nn.Dropout()),
            ('output', nn.Linear(128, 10)),
        ])
    )

    model = nn.Sequential(
        OrderedDict([
            ('features', feature_model),
            ('flatten', Flatten()),
            ('classifier', classifier_model)
        ])
    )
    
    return model

In [5]:
N_EPOCHS = 100

model = cifar10_model()
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(1, N_EPOCHS + 1):
    model.train()
    train_loss = 0
    train_accuracy = 0
    
    for X, y, label in dataset.train_loader:
        X = X.to(device); y = y.to(device)
        
        model.zero_grad()
        optimizer.zero_grad()
        
        pred = model(X)
        loss = criterion(pred, y)
        
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        train_accuracy += (torch.argmax(pred, 1) == y).sum()
        
    train_accuracy = train_accuracy.item() / dataset.train_size
        
        
    model.eval()
    val_loss = 0
    val_accuracy = 0
    
    with torch.no_grad():
        for X, y, label in dataset.val_loader:
            X = X.to(device); y = y.to(device)

            pred = model(X)
            loss = criterion(pred, y)

            val_loss += loss.item()
            val_accuracy += (torch.argmax(pred, 1) == y).sum()
        
    val_accuracy = val_accuracy.item() / dataset.val_size
    
    print('Epoch %.2d: train_loss = %.3f, train_accuracy = %.3f, val_loss = %.3f, val_accuracy = %.3f' % (
        epoch, train_loss, train_accuracy, val_loss, val_accuracy
    ))

Epoch 01: train_loss = 1798.179, train_accuracy = 0.486, val_loss = 346.392, val_accuracy = 0.609
Epoch 02: train_loss = 1272.278, train_accuracy = 0.651, val_loss = 257.747, val_accuracy = 0.711
Epoch 03: train_loss = 1060.212, train_accuracy = 0.714, val_loss = 223.918, val_accuracy = 0.748
Epoch 04: train_loss = 906.066, train_accuracy = 0.759, val_loss = 207.283, val_accuracy = 0.774
Epoch 05: train_loss = 789.055, train_accuracy = 0.789, val_loss = 190.934, val_accuracy = 0.787
Epoch 06: train_loss = 692.603, train_accuracy = 0.816, val_loss = 188.157, val_accuracy = 0.801
Epoch 07: train_loss = 605.175, train_accuracy = 0.840, val_loss = 175.501, val_accuracy = 0.811
Epoch 08: train_loss = 521.653, train_accuracy = 0.863, val_loss = 181.060, val_accuracy = 0.810
Epoch 09: train_loss = 446.894, train_accuracy = 0.882, val_loss = 162.419, val_accuracy = 0.830
Epoch 10: train_loss = 392.699, train_accuracy = 0.896, val_loss = 192.046, val_accuracy = 0.808
Epoch 11: train_loss = 345.

Epoch 86: train_loss = 38.346, train_accuracy = 0.990, val_loss = 300.958, val_accuracy = 0.827
Epoch 87: train_loss = 37.926, train_accuracy = 0.990, val_loss = 303.059, val_accuracy = 0.821
Epoch 88: train_loss = 43.105, train_accuracy = 0.989, val_loss = 313.564, val_accuracy = 0.817
Epoch 89: train_loss = 32.538, train_accuracy = 0.991, val_loss = 331.268, val_accuracy = 0.816
Epoch 90: train_loss = 37.427, train_accuracy = 0.991, val_loss = 332.148, val_accuracy = 0.815
Epoch 91: train_loss = 41.260, train_accuracy = 0.989, val_loss = 319.325, val_accuracy = 0.817
Epoch 92: train_loss = 36.647, train_accuracy = 0.991, val_loss = 323.648, val_accuracy = 0.816
Epoch 93: train_loss = 30.290, train_accuracy = 0.992, val_loss = 334.919, val_accuracy = 0.821
Epoch 94: train_loss = 39.571, train_accuracy = 0.990, val_loss = 333.779, val_accuracy = 0.818
Epoch 95: train_loss = 36.334, train_accuracy = 0.991, val_loss = 317.946, val_accuracy = 0.819
Epoch 96: train_loss = 35.873, train_acc

In [6]:
model_filename = os.path.join('models', 'cifar10.pt')
torch.save(model.state_dict(), model_filename)

In [7]:
model.eval()
test_accuracy = 0

with torch.no_grad():
    for X, y, label in dataset.test_loader:
        X = X.to(device); y = y.to(device)

        pred = model(X)
        test_accuracy += (torch.argmax(pred, 1) == y).sum()

test_accuracy = test_accuracy.item() / dataset.test_size

print('Test accuracy: %.3f' % test_accuracy)
print('Test error rate: %.3f' % (1 - test_accuracy))

Test accuracy: 0.823
Test error rate: 0.177
