In [6]:
import torch
from torch import nn
from torch import optim
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

In [7]:
# Define model
model = nn.Sequential(
    nn.Linear(28 * 28, 64),
    nn.ReLU(),
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.Linear(64, 10)
)

In [8]:
# Define optimizer
params = model.parameters()
optimizer = optim.Adam(params, lr=0.01)

In [9]:
# Define loss
criterion = nn.CrossEntropyLoss()

In [12]:
# Train-Validation split
train_data = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
train, val = random_split(train_data, [55000, 5000])
train_loader = DataLoader(train, batch_size=32)
val_loader = DataLoader(val, batch_size=32)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [13]:
# Training loop
epochs = 5

for epoch in range(1, epochs + 1):
    print('Epoch:', epoch, end=' ')
    model.train()
    losses = []
    num_correct = 0
    num_samples = 0
    for batch in train_loader:
        x, y = batch
        # x: b * 1 * 28 * 28
        b = x.size(0)
        x = x.view(b, -1)

        # 1. Forward
        preds = model(x)

        # 2. Compute the objective function
        loss = criterion(preds, y)

        # 3. Cleaning the gradients
        optimizer.zero_grad()

        # 4. Accumulate the partial derivatives of loss with respect to parameters
        loss.backward()

        # 5. Step in the opposite direcrion of the gradient
        optimizer.step()

        losses.append(loss.item())
        _, predictions = preds.max(1)
        num_correct += (predictions == y).sum()
        num_samples += b

    print(f'Train loss: {torch.tensor(losses).mean():.2f}', end=' ')
    print(f'Train acc: {num_correct / num_samples * 100:.2f}%', end=' ')

    model.eval()
    losses = []
    num_correct = 0
    num_samples = 0
    for batch in val_loader:
        x, y = batch
        b = x.size(0)
        x = x.view(b, -1)

        with torch.no_grad():
            preds = model(x)

        loss = criterion(preds, y)

        losses.append(loss.item())
        _, predictions = preds.max(1)
        num_correct += (predictions == y).sum()
        num_samples += b

    print(f'Validation loss: {torch.tensor(losses).mean():.2f}', end=' ')
    print(f'Validation acc: {num_correct / num_samples * 100:.2f}%')

Epoch: 1 Train loss: 0.28 Train acc: 91.73% Validation loss: 0.20 Validation acc: 94.38%
Epoch: 2 Train loss: 0.19 Train acc: 94.97% Validation loss: 0.22 Validation acc: 94.56%
Epoch: 3 Train loss: 0.17 Train acc: 95.33% Validation loss: 0.19 Validation acc: 95.26%
Epoch: 4 Train loss: 0.16 Train acc: 95.81% Validation loss: 0.19 Validation acc: 95.26%
Epoch: 5 Train loss: 0.14 Train acc: 96.24% Validation loss: 0.17 Validation acc: 95.86%
