Before running, install required packages:

In [2]:
! pip install numpy torch torchvision pytorch-ignite



---

In [3]:
import numpy as np
import torch
from torch import optim, nn
from torch.utils.data import DataLoader, TensorDataset
from torchvision import models, datasets, transforms
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss

# Setup

Dataset MNIST will be loaded further down.

In [4]:
# Set up hyperparameters.
lr = 0.001
batch_size = 128
num_epochs = 3

In [5]:
# Set up logging.
print_every = 1  # batches

In [6]:
# Set up device.
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

# Dataset & Preprocessing

In [7]:
def load_data(train):
    # Download and transform dataset.
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # grayscale to RGB
    ])
    dataset = datasets.MNIST("./data", train=train, download=True, transform=transform)

    # Wrap in data loader.
    if use_cuda:
        kwargs = {"pin_memory": True, "num_workers": 1}
    else:
        kwargs = {}
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=train, **kwargs)
    return loader

In [8]:
train_loader = load_data(train=True)
val_loader = None
test_loader = load_data(train=False)

100%|██████████| 9.91M/9.91M [00:00<00:00, 21.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 645kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.86MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.37MB/s]


# Model

In [9]:
# Set up model, loss, optimizer.
model = models.alexnet(pretrained=False)
model = model.to(device)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)



# Training

In [10]:
# Set up pytorch-ignite trainer and evaluator.
trainer = create_supervised_trainer(
    model,
    optimizer,
    loss_func,
    device=device,
)
metrics = {
    "accuracy": Accuracy(),
    "loss": Loss(loss_func),
}
evaluator = create_supervised_evaluator(
    model, metrics=metrics, device=device
)

In [11]:
@trainer.on(Events.ITERATION_COMPLETED(every=print_every))
def log_batch(trainer):
    batch = (trainer.state.iteration - 1) % trainer.state.epoch_length + 1
    print(
        f"Epoch {trainer.state.epoch} / {num_epochs}, "
        f"batch {batch} / {trainer.state.epoch_length}: "
        f"loss: {trainer.state.output:.3f}"
    )

In [12]:
@trainer.on(Events.EPOCH_COMPLETED)
def log_epoch(trainer):
    print(f"Epoch {trainer.state.epoch} / {num_epochs} average results: ")

    def log_results(name, metrics, epoch):
        print(
            f"{name + ':':6} loss: {metrics['loss']:.3f}, "
            f"accuracy: {metrics['accuracy']:.3f}"
        )

    # Train data.
    evaluator.run(train_loader)
    log_results("train", evaluator.state.metrics, trainer.state.epoch)

    # Val data.
    if val_loader:
        evaluator.run(val_loader)
        log_results("val", evaluator.state.metrics, trainer.state.epoch)

    # Test data.
    if test_loader:
        evaluator.run(test_loader)
        log_results("test", evaluator.state.metrics, trainer.state.epoch)

    print()
    print("-" * 80)
    print()

In [None]:
# Start training.
trainer.run(train_loader, max_epochs=num_epochs)

Epoch 1 / 3, batch 1 / 469: loss: 6.906
Epoch 1 / 3, batch 2 / 469: loss: 3.103
Epoch 1 / 3, batch 3 / 469: loss: 144.111
Epoch 1 / 3, batch 4 / 469: loss: 10.242
Epoch 1 / 3, batch 5 / 469: loss: 4.851
Epoch 1 / 3, batch 6 / 469: loss: 6.040
Epoch 1 / 3, batch 7 / 469: loss: 5.752
Epoch 1 / 3, batch 8 / 469: loss: 4.952
Epoch 1 / 3, batch 9 / 469: loss: 4.026
Epoch 1 / 3, batch 10 / 469: loss: 3.494
Epoch 1 / 3, batch 11 / 469: loss: 2.796
Epoch 1 / 3, batch 12 / 469: loss: 3.784
Epoch 1 / 3, batch 13 / 469: loss: 3.929
Epoch 1 / 3, batch 14 / 469: loss: 4.265
Epoch 1 / 3, batch 15 / 469: loss: 3.481
Epoch 1 / 3, batch 16 / 469: loss: 3.072
Epoch 1 / 3, batch 17 / 469: loss: 3.290
Epoch 1 / 3, batch 18 / 469: loss: 3.420
Epoch 1 / 3, batch 19 / 469: loss: 2.909
Epoch 1 / 3, batch 20 / 469: loss: 2.879
Epoch 1 / 3, batch 21 / 469: loss: 2.937
Epoch 1 / 3, batch 22 / 469: loss: 2.914
Epoch 1 / 3, batch 23 / 469: loss: 2.823
Epoch 1 / 3, batch 24 / 469: loss: 2.618
Epoch 1 / 3, batch 25 