In [2]:
import torch
from torch import nn
from torch import optim
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

In [3]:
# Define model
model = nn.Sequential(
    nn.Linear(28 * 28, 64),
    nn.ReLU(),
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.Linear(64, 10)
)

In [4]:
# Define optimizer
params = model.parameters()
optimizer = optim.Adam(params, lr=0.01)

In [5]:
# Define loss
criterion = nn.CrossEntropyLoss()

In [6]:
# Train-Validation split
train_data = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
train, val = random_split(train_data, [55000, 5000])
train_loader = DataLoader(train, batch_size=32)
val_loader = DataLoader(val, batch_size=32)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [9]:
train_data

Dataset MNIST
    Number of datapoints: 60000
    Root location: data
    Split: Train
    StandardTransform
Transform: ToTensor()

In [10]:
# Training loop
epochs = 5

for epoch in range(1, epochs + 1):
    print('Epoch:', epoch, end=' ')
    model.train()
    losses = []
    num_correct = 0
    num_samples = 0
    for batch in train_loader:
        x, y = batch
        # x: b * 1 * 28 * 28
        b = x.size(0)
        x = x.view(b, -1)

        # 1. Forward
        preds = model(x)

        # 2. Compute the objective function
        loss = criterion(preds, y)

        # 3. Cleaning the gradients
        optimizer.zero_grad()

        # 4. Accumulate the partial derivatives of loss with respect to parameters
        loss.backward()

        # 5. Step in the opposite direcrion of the gradient
        optimizer.step()

        losses.append(loss.item())
        _, predictions = preds.max(1)
        num_correct += (predictions == y).sum()
        num_samples += b

    print(f'Train loss: {torch.tensor(losses).mean():.2f}', end=' ')
    print(f'Train acc: {num_correct / num_samples * 100:.2f}%', end=' ')

    model.eval()
    losses = []
    num_correct = 0
    num_samples = 0
    for batch in val_loader:
        x, y = batch
        b = x.size(0)
        x = x.view(b, -1)

        with torch.no_grad():
            preds = model(x)

        loss = criterion(preds, y)

        losses.append(loss.item())
        _, predictions = preds.max(1)
        num_correct += (predictions == y).sum()
        num_samples += b

    print(f'Validation loss: {torch.tensor(losses).mean():.2f}', end=' ')
    print(f'Validation acc: {num_correct / num_samples * 100:.2f}%')

Epoch: 1 tensor([[-1.0908e+01, -5.2467e+00,  3.5425e-01, -4.6639e+00,  1.6124e+01,
         -5.2294e+00, -5.3333e+00, -9.7227e-01, -2.6683e+00, -1.0710e+00],
        [-2.8809e+00,  1.6954e+00,  1.1926e+01, -2.2654e+00, -3.9202e+00,
         -1.4521e+01, -3.6704e+00,  9.7389e-01, -5.7954e+00, -1.4228e+01],
        [-2.3685e+00,  4.1080e-01, -3.4701e+00, -1.7183e+00, -5.8935e+00,
         -5.0799e+00, -3.5619e+00, -2.8604e+00,  7.3207e+00, -3.6513e-01],
        [ 2.3654e+00, -1.5694e+01, -6.1429e+00,  6.7116e-02, -2.3574e+01,
         -3.9887e-01, -4.9516e+00, -1.6807e+01,  1.4823e+01, -9.1699e+00],
        [-6.7101e+00,  6.6780e+00,  1.8104e+01, -4.0345e+00, -4.2122e+00,
         -2.7313e+01, -7.8385e+00,  4.2480e+00, -1.1022e+01, -2.3492e+01],
        [ 1.7862e+01, -4.6204e+01, -1.1500e+01, -4.9595e+00, -4.3792e+00,
         -6.6435e+00, -6.1859e-01, -4.1079e+00, -1.6227e+01,  1.6638e+00],
        [-1.6309e+00,  2.6866e+00,  4.4385e-01, -2.4863e+00, -7.9679e-01,
         -1.7295e+00,  