In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, SubsetRandomSampler

from torchvision.datasets import MNIST
from torchvision.transforms import Compose, Normalize, ToTensor

In [3]:
def get_mnist_data_loaders(path, batch_size, valid_batch_size):

    # MNIST specific transforms
    mnist_xforms = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])

    # Training data loader
    train_dataset = MNIST(root=path, train=True, download=True, transform=mnist_xforms)

    tbs = len(train_dataset) if batch_size == 0 else batch_size
    train_loader = DataLoader(train_dataset, batch_size=tbs, shuffle=True)

    # Validation data loader
    valid_dataset = MNIST(root=path, train=False, download=True, transform=mnist_xforms)

    vbs = len(valid_dataset) if valid_batch_size == 0 else valid_batch_size
    valid_loader = DataLoader(valid_dataset, batch_size=vbs, shuffle=True)

    return train_loader, valid_loader

In [64]:
class NeuralNetwork(nn.Module):
    def __init__(self, layer_sizes):
        super(NeuralNetwork, self).__init__()

        first_layer = nn.Flatten()
        middle_layers = [
            nn.Sequential(nn.Linear(nlminus1, nl), nn.ReLU())
            for nl, nlminus1 in zip(layer_sizes[1:-1], layer_sizes)
        ]
        last_layer = nn.Linear(layer_sizes[-2], layer_sizes[-1])

        all_layers = [first_layer] + middle_layers + [last_layer]

        self.layers = nn.Sequential(*all_layers)

    def forward(self, X):
        return self.layers(X)

In [65]:
def train_one_epoch(dataloader, model, loss_fn, optimizer):

    model.train()

    size = len(dataloader.dataset)

    for batch, (X, Y) in enumerate(dataloader):

        X, Y = X.to(device), Y.to(device)

        output = model(X)

        loss = loss_fn(output, Y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [66]:
def compute_validation_accuracy(dataloader, model, loss_fn):

    model.eval()

    size = len(dataloader.dataset)

    num_batches = len(dataloader)

    valid_loss, correct = 0, 0

    with torch.no_grad():

        for X, Y in dataloader:
            X, Y = X.to(device), Y.to(device)
            pred = model(X)
            valid_loss += loss_fn(pred, Y).item()
            correct += (pred.argmax(1) == Y).type(torch.float).sum().item()

        valid_loss /= num_batches
        correct /= size

        print(
            f"Validation Metrics:\n\tAccuracy: {(100*correct):>0.1f}%\n\tAvg loss: {valid_loss:>8f}"
        )

In [67]:
# Configuration parameters
data_path = "../data"
seed = 0

torch.manual_seed(seed)

# Hyperparameters
batch_size = 64
valid_batch_size = 0
learning_rate = 1e-3
num_epochs = 5

# Training device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using '{device}' device.")

Using 'cuda' device.


In [68]:
# Get data loaders
train_loader, valid_loader = get_mnist_data_loaders(
    data_path, batch_size, valid_batch_size
)
batch_X, batch_Y = next(iter(train_loader))

In [69]:
# Neural network model
nx = batch_X.shape[1:].numel()
ny = int(torch.unique(batch_Y).shape[0])
layer_sizes = (nx, 512, 50, ny)

model = NeuralNetwork(layer_sizes).to(device)
print(model)

NeuralNetwork(
  (layers): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Sequential(
      (0): Linear(in_features=784, out_features=512, bias=True)
      (1): ReLU()
    )
    (2): Sequential(
      (0): Linear(in_features=512, out_features=50, bias=True)
      (1): ReLU()
    )
    (3): Linear(in_features=50, out_features=10, bias=True)
  )
)


In [70]:
# Training utilities
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [71]:
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_one_epoch(train_loader, model, loss_fn, optimizer)
    compute_validation_accuracy(valid_loader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.314172  [    0/60000]
loss: 2.286030  [ 6400/60000]
loss: 2.252837  [12800/60000]
loss: 2.236573  [19200/60000]
loss: 2.161557  [25600/60000]
loss: 2.119941  [32000/60000]
loss: 2.101425  [38400/60000]
loss: 2.105149  [44800/60000]
loss: 2.042984  [51200/60000]
loss: 1.967875  [57600/60000]
Validation Metrics:
	Accuracy: 60.7%
	Avg loss: 1.923344
Epoch 2
-------------------------------
loss: 1.859424  [    0/60000]
loss: 1.821604  [ 6400/60000]
loss: 1.785439  [12800/60000]
loss: 1.656407  [19200/60000]
loss: 1.643545  [25600/60000]
loss: 1.617386  [32000/60000]
loss: 1.445982  [38400/60000]
loss: 1.412978  [44800/60000]
loss: 1.513695  [51200/60000]
loss: 1.152074  [57600/60000]
Validation Metrics:
	Accuracy: 75.7%
	Avg loss: 1.223937
Epoch 3
-------------------------------
loss: 1.302889  [    0/60000]
loss: 1.163493  [ 6400/60000]
loss: 1.218116  [12800/60000]
loss: 1.058661  [19200/60000]
loss: 1.004611  [25600/60000]
loss: 0.890238  

In [74]:
for param in model.parameters():
    print(param.shape)

torch.Size([512, 784])
torch.Size([512])
torch.Size([50, 512])
torch.Size([50])
torch.Size([10, 50])
torch.Size([10])
