In [1]:
from collections import deque
from statistics import mean
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import torch

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

train_set = MNIST('.', train=True, transform=transform, download=True)
test_set = MNIST('.', train=False, transform=transform, download=True)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=4)
test_loader = DataLoader(test_set, batch_size=len(test_set), num_workers=4)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to .\MNIST\raw\train-images-idx3-ubyte.gz


9920512it [00:08, 1147085.15it/s]                             


Extracting .\MNIST\raw\train-images-idx3-ubyte.gz to .\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to .\MNIST\raw\train-labels-idx1-ubyte.gz


32768it [00:00, 58034.02it/s]                           


Extracting .\MNIST\raw\train-labels-idx1-ubyte.gz to .\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to .\MNIST\raw\t10k-images-idx3-ubyte.gz


1654784it [00:02, 718584.01it/s]                             


Extracting .\MNIST\raw\t10k-images-idx3-ubyte.gz to .\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to .\MNIST\raw\t10k-labels-idx1-ubyte.gz


8192it [00:00, 22122.69it/s]            


Extracting .\MNIST\raw\t10k-labels-idx1-ubyte.gz to .\MNIST\raw
Processing...
Done!


In [4]:
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(128, 10)
)

In [5]:
sgd = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
adam = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

In [6]:
num_epochs = 30
print_every = 400
c = 0
best_loss = 1e9
train_losses = deque([], maxlen=print_every)
train_accs = deque([], maxlen=print_every)

In [7]:
history = {
    'test_accs': [],
    'test_losses': [],
    'train_accs': [],
    'train_losses': []
}

In [8]:
model.to(device)

Sequential(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.5, inplace=False)
  (3): Linear(in_features=256, out_features=128, bias=True)
  (4): ReLU()
  (5): Dropout(p=0.5, inplace=False)
  (6): Linear(in_features=128, out_features=10, bias=True)
)

In [9]:
for e in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.view(-1, 28 * 28).to(device)
        labels = labels.to(device)

        output = model(images)
        loss = criterion(output, labels)
        train_losses.append(loss.item())
        acc = torch.mean((torch.argmax(output, dim=1) == labels).float())
        train_accs.append(acc.item())

        if e < 15:
            adam.zero_grad()
            loss.backward()
            adam.step()
        else:
            sgd.zero_grad()
            loss.backward()
            sgd.step()

        if (c % print_every) == 0:
            model.eval()
            with torch.no_grad():
                images, labels = next(iter(test_loader))
                images = images.view(-1, 28 * 28).to(device)
                labels = labels.to(device)

                output = model(images)
                test_loss = criterion(output, labels)
                test_acc = torch.mean((torch.argmax(output, dim=1) == labels).float())

            print(
                '[%03d]' % (e+1),
                '(%03d/%03d)' % (i, len(train_loader)),
                'Trn loss: %.5f' % mean(train_losses),
                'Tst loss: %.5f' % test_loss.item(),
                'Trn acc: %.5f' % mean(train_accs),
                'Tst acc: %.5f' % test_acc.item(),
                )
            
            history['test_accs'].append(test_acc.item())
            history['test_losses'].append(test_loss.item())
            history['train_accs'].append(mean(train_accs))
            history['train_losses'].append(mean(train_losses))
            
            model.train()

        c += 1

[001] (000/938) Trn loss: 2.29181 Tst loss: 2.28203 Trn acc: 0.14062 Tst acc: 0.11380
[001] (400/938) Trn loss: 0.79403 Tst loss: 0.29588 Trn acc: 0.74223 Tst acc: 0.91660
[001] (800/938) Trn loss: 0.46224 Tst loss: 0.24396 Trn acc: 0.86129 Tst acc: 0.92450
[002] (262/938) Trn loss: 0.40439 Tst loss: 0.21838 Trn acc: 0.88070 Tst acc: 0.93290
[002] (662/938) Trn loss: 0.37395 Tst loss: 0.20320 Trn acc: 0.88969 Tst acc: 0.93790
[003] (124/938) Trn loss: 0.34651 Tst loss: 0.19861 Trn acc: 0.89676 Tst acc: 0.93850
[003] (524/938) Trn loss: 0.32980 Tst loss: 0.18528 Trn acc: 0.90363 Tst acc: 0.94400
[003] (924/938) Trn loss: 0.32681 Tst loss: 0.16970 Trn acc: 0.90363 Tst acc: 0.94840
[004] (386/938) Trn loss: 0.29573 Tst loss: 0.16426 Trn acc: 0.91332 Tst acc: 0.94970
[004] (786/938) Trn loss: 0.29387 Tst loss: 0.16176 Trn acc: 0.91266 Tst acc: 0.94890
[005] (248/938) Trn loss: 0.28285 Tst loss: 0.14430 Trn acc: 0.91703 Tst acc: 0.95530
[005] (648/938) Trn loss: 0.28035 Tst loss: 0.15156 Tr