In [1]:
import pathlib

import torch

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import datasets, transforms

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
ROOT = pathlib.Path.home() / ".data/mnist"

# https://github.com/pytorch/examples/blob/master/mnist/main.py
# https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457/2
tsfms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_ds = datasets.MNIST(str(ROOT), download=True, train=True, transform=tsfms)
test_ds = datasets.MNIST(str(ROOT), download=True, train=False, transform=tsfms)

# Hyperparameters

In [11]:
batch_size = 300
num_epochs = 100

In [5]:
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
test_dl = DataLoader(test_ds, batch_size=batch_size)
criterion = nn.CrossEntropyLoss()

In [6]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [7]:
model = Net()
model = model.to(device)

In [8]:
optimizer = torch.optim.Adam(model.parameters())

In [9]:
import logging
import sys

root = logging.getLogger()
root.setLevel(logging.DEBUG)

handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
root.addHandler(handler)

logger = logging.getLogger("training")

In [13]:
model.train()
for epoch in range(num_epochs):
    logger.info(f"Starting epoch {epoch}")
    for i, (x, y) in enumerate(train_dl):
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        unnormalized_logits = model(x)
        loss = criterion(unnormalized_logits, y)
        loss.backward()
        optimizer.step()
    logger.info(f"Finished epoch {epoch}")

model.eval()
correct = 0
total = 0
for x, y in test_dl:
    x = x.to(device)
    y = y.to(device)
    unnormalized_logits = model(x)
    max_value, predicted = torch.max(unnormalized_logits, dim=1)
    total += y.size(0)
    correct += (predicted == y).sum()

correct = int(correct.data)
print(f"{correct} / {total} {correct / total}")

2019-02-09 08:38:01,003 - training - INFO - Starting epoch 0
2019-02-09 08:38:05,385 - training - INFO - Finished epoch 0
2019-02-09 08:38:05,386 - training - INFO - Starting epoch 1
2019-02-09 08:38:09,750 - training - INFO - Finished epoch 1
2019-02-09 08:38:09,751 - training - INFO - Starting epoch 2
2019-02-09 08:38:14,076 - training - INFO - Finished epoch 2
2019-02-09 08:38:14,078 - training - INFO - Starting epoch 3
2019-02-09 08:38:18,406 - training - INFO - Finished epoch 3
2019-02-09 08:38:18,407 - training - INFO - Starting epoch 4
2019-02-09 08:38:22,761 - training - INFO - Finished epoch 4
2019-02-09 08:38:22,762 - training - INFO - Starting epoch 5
2019-02-09 08:38:27,158 - training - INFO - Finished epoch 5
2019-02-09 08:38:27,159 - training - INFO - Starting epoch 6
2019-02-09 08:38:31,494 - training - INFO - Finished epoch 6
2019-02-09 08:38:31,495 - training - INFO - Starting epoch 7
2019-02-09 08:38:36,033 - training - INFO - Finished epoch 7
2019-02-09 08:38:36,034 