# MNIST Classification in JAX

This notebook implements a solution to the MNIST classification problem using JAX to model and train a LeNet-300-100 feed-forward neural network for character recognition.

The code is based on the [JAX](), [Haiku](), and [Optax]() example MNIST classification programs.

In [2]:
import time  # For measuring time

import numpy.random as npr  # Random number generation

import jax  # Main JAX module
import jax.numpy as jnp  # JAX's version of NumPy
import haiku as hk  # Haiku's module for neural networks
import optax  # Optimisers for gradient descent

import datasets  # Module from JAX authors for downloading datasets


## Hyperparameters

A standard set of hyperparameters are explicitly defined below along with a brief explanation of their purpose.
A pseudo-random number generator is also initialized with a fixed seed to ensure reproducibility of results.

In [3]:
learning_rate = 1e-2  # Learning rate for SGD.
batch_size = 256  # Batch size for SGD (reduce if fails to fit on GPU).
input_size = 28 * 28  # Size of the input vector to the model.
num_epochs = 10  # Number of training epochs.
validation_split = 0.2  # Fraction of the training data to use for validation.

# Random number generator sequence.
key_seq = hk.PRNGSequence(1729)


No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


## Dataset

Using the dataset module copied from the JAX example programs, the training and test data for MNIST can be loaded as follows.
In addition, we split the training data into a training and validation set according to the aforementioned hyperparameters.

In [4]:
# Download and parse the raw MNIST dataset.
train_images, train_labels, test_images, test_labels = datasets.mnist()

# Calculate the number of samples in final train and validation sets
train_samples = train_images.shape[0]
val_samples = int(train_samples * validation_split)
train_samples -= val_samples

# Split training data into validation and final training data
train_images, val_images = train_images[:train_samples], train_images[train_samples:]
train_labels, val_labels = train_labels[:train_samples], train_labels[train_samples:]

# Determine the number of batches in the training set
num_complete_batches, leftover = divmod(train_samples, batch_size)
num_batches = num_complete_batches + bool(leftover)


# Set up a data stream for the training set
def data_stream():
    rng = npr.RandomState(0)
    while True:
        perm = rng.permutation(train_samples)
        for i in range(num_batches):
            batch_idx = perm[i * batch_size : (i + 1) * batch_size]
            yield train_images[batch_idx], train_labels[batch_idx]


batches = data_stream()


downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to ./data
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to ./data
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to ./data
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to ./data


## Model Architecture

In this notebook, the LeNet-300-100 architecture is used to train a neural network to classify handwritten digits.
This is a feed-forward neural network with 2 hidden layers, each with 300 and 100 neurons respectively.
The ReLU activation function is used for the hidden layers, and the log softmax activation function is used for the output layer.

In [5]:
def forward_pass(x):
    model = hk.Sequential(
        [
            hk.Linear(300),  # 300 hidden units
            jax.nn.relu,
            hk.Linear(100),  # 100 hidden units
            jax.nn.relu,
            hk.Linear(10),  # 10 output units
            jax.nn.log_softmax,
        ]
    )
    return model(x)


# Trasform into pure functions and remove rng (as it is unnecessary)
network = hk.without_apply_rng(hk.transform(forward_pass))


## Loss Functions

As is standard for classification tasks, we will use a softmax output layer to predict probabilities for each of our target classes, and we will use the cross entropy loss to measure the network's performance. 
A function for computing the accuracy of the model is also defined here, which we will use in the training loop.

In [6]:
def cross_entropy_loss(params, x, y):
    y_pred = network.apply(params, x)
    return -jnp.mean(jnp.sum(y_pred * y, axis=-1))


def accuracy(params, x, y):
    predictions = network.apply(params, x)
    return jnp.mean(jnp.argmax(predictions, axis=-1) == jnp.argmax(y, axis=-1))



## Optimiser

To optimise the parameters of the network we will use the ADAM optimiser from the `optax` library.

In [7]:
# Initialise adam optimiser returning the gradient transformation functions
opt_init, opt_update = optax.adam(learning_rate)

# Initialise the model's parameters and the optimiser's state
params = network.init(next(key_seq), jnp.zeros([1, input_size]))
opt_state = opt_init(params)


## Optimisation

Finally, the network parameters are optimised to accurately predict the labels of the training data. The network is then tested on the test data to see how well it performs on data it has not seen before. 

In [8]:

print("Starting training...")
for epoch in range(num_epochs):
    start_time = time.time()  # Record epoch start time
    
    # Iterate over the training batches and optimise the parameters
    for _ in range(num_batches):
        x, y = next(batches)

        # Calculate the loss and gradients
        loss, grad = jax.value_and_grad(cross_entropy_loss)(params, x, y)
        
        # Update the parameters and optimiser state
        updates, opt_state = opt_update(grad, opt_state, params)
        params = optax.apply_updates(params, updates)
        
    epoch_time = time.time() - start_time  # Record epoch end time

    # calculate training and test accuracy
    train_acc = accuracy(params, train_images, train_labels) * 100
    val_acc = accuracy(params, val_images, val_labels) * 100

    print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
    print(f"Training set accuracy {train_acc:0.2f}%")
    print(f"Validation set accuracy {val_acc:0.2f}%")

print("\nStarting testing...")
test_acc = accuracy(params, test_images, test_labels) * 100
print(f"Test set accuracy {test_acc:0.2f}%")

Starting training...
Epoch 0 in 3.69 sec
Training set accuracy 96.86%
Validation set accuracy 96.14%
Epoch 1 in 2.60 sec
Training set accuracy 97.68%
Validation set accuracy 96.66%
Epoch 2 in 2.53 sec
Training set accuracy 97.51%
Validation set accuracy 96.22%
Epoch 3 in 2.51 sec
Training set accuracy 98.56%
Validation set accuracy 97.13%
Epoch 4 in 2.49 sec
Training set accuracy 98.73%
Validation set accuracy 97.21%
Epoch 5 in 2.53 sec
Training set accuracy 98.99%
Validation set accuracy 97.14%
Epoch 6 in 2.51 sec
Training set accuracy 98.45%
Validation set accuracy 96.48%
Epoch 7 in 2.52 sec
Training set accuracy 98.81%
Validation set accuracy 97.02%
Epoch 8 in 2.54 sec
Training set accuracy 98.91%
Validation set accuracy 97.18%
Epoch 9 in 2.58 sec
Training set accuracy 98.99%
Validation set accuracy 97.14%

Starting testing...
Test set accuracy 97.43%
