# MNIST Image Classification with JAX/Flax

This notebook demonstrates how to implement a simple linear image model on [MNIST](http://yann.lecun.com/exdb/mnist/) using [JAX](https://jax.readthedocs.io/) and [Flax](https://flax.readthedocs.io/).

## Learning Objectives
1. Know how to read and display image data
2. Know how to find incorrect predictions to analyze the model
3. Visually see how computers see images

In [None]:
import os
import warnings

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
warnings.filterwarnings("ignore")


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf # For data loading
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax

print(jax.__version__)


## Exploring the data

The MNIST dataset is already included in tensorflow through the keras datasets module. Let's load it and get a sense of the data.

In [None]:
mnist = tf.keras.datasets.mnist.load_data()
(x_train, y_train), (x_test, y_test) = mnist

In [None]:
HEIGHT, WIDTH = x_train[0].shape
NCLASSES = len(np.unique(y_train))
print("Image height x width is", HEIGHT, "x", WIDTH)
print("There are", NCLASSES, "classes")


Each image is 28 x 28 pixels and represents a digit from 0 to 9. These images are black and white, so each pixel is a value from 0 (white) to 255 (black). Raw numbers can be hard to interpret sometimes, so we can plot the values to see the handwritten digit as an image.

In [None]:
IMGNO = 12
# Uncomment to see raw numerical values.
# print(x_test[IMGNO])
plt.imshow(x_test[IMGNO].reshape(HEIGHT, WIDTH))
print("The label for image number", IMGNO, "is", y_test[IMGNO])

## Define the model
Let's start with a very simple linear classifier. This was the first method to be tried on MNIST in 1998, and scored an 88% accuracy. Quite ground breaking at the time!


We can build our linear classifier using [Flax](https://flax.readthedocs.io/). We can define a module with a Dense layer and a Softmax activation.


In [None]:
class LinearModel(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=NCLASSES)(x)
        x = nn.softmax(x)
        return x

def create_train_state(rng, learning_rate, momentum):
    model = LinearModel()
    params = model.init(rng, jnp.ones([1, HEIGHT, WIDTH]))['params']
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx)


## Write Input Functions

As usual, we need to specify input functions for training and evaluating. We'll scale each pixel value so it's a decimal value between 0 and 1 as a way of normalizing the data.

**TODO 1**: Define the scale function below and build the dataset

In [None]:
BUFFER_SIZE = 5000
BATCH_SIZE = 100


def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255
    return image, label


def load_dataset(training=True):
    """Loads MNIST dataset into a tf.data.Dataset"""
    (x_train, y_train), (x_test, y_test) = mnist
    x = x_train if training else x_test
    y = y_train if training else y_test
    # One-hot encode the classes
    y = tf.keras.utils.to_categorical(y, NCLASSES)
    dataset = tf.data.Dataset.from_tensor_slices((x, y))
    dataset = dataset.map(scale).batch(BATCH_SIZE)
    if training:
        dataset = dataset.shuffle(BUFFER_SIZE).repeat()
    return dataset


In [None]:
def create_shape_test(training):
    dataset = load_dataset(training=training)
    data_iter = dataset.as_numpy_iterator()
    (images, labels) = next(data_iter)
    expected_image_shape = (BATCH_SIZE, HEIGHT, WIDTH)
    expected_label_ndim = 2
    assert images.shape == expected_image_shape
    assert labels.ndim == expected_label_ndim
    test_name = "training" if training else "eval"
    print("Test for", test_name, "passed!")


create_shape_test(True)
create_shape_test(False)


Time to train the model! The original MNIST linear classifier had an error rate of 12%. Let's use that to sanity check that our model is learning.

In [None]:
@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['image'])
        loss = -jnp.mean(jnp.sum(batch['label'] * jnp.log(logits + 1e-7), axis=-1))
        return loss, logits
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(batch['label'], -1))
    return state, loss, accuracy

@jax.jit
def eval_step(state, batch):
    logits = state.apply_fn({'params': state.params}, batch['image'])
    loss = -jnp.mean(jnp.sum(batch['label'] * jnp.log(logits + 1e-7), axis=-1))
    accuracy = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(batch['label'], -1))
    return loss, accuracy

rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
state = create_train_state(init_rng, learning_rate=0.001, momentum=0.9)

train_ds = load_dataset(training=True).as_numpy_iterator()
val_ds_tf = load_dataset(training=False)

NUM_EPOCHS = 10
STEPS_PER_EPOCH = 60000 // BATCH_SIZE

history = {'accuracy': [], 'val_accuracy': [], 'loss': [], 'val_loss': []}

for epoch in range(NUM_EPOCHS):
    # Training
    train_loss = 0
    train_acc = 0
    
    for _ in range(STEPS_PER_EPOCH):
        batch = next(train_ds)
        batch_jax = {'image': batch[0], 'label': batch[1]}
        state, loss, acc = train_step(state, batch_jax)
        train_loss += loss
        train_acc += acc
    
    train_loss /= STEPS_PER_EPOCH
    train_acc /= STEPS_PER_EPOCH
    
    # Validation
    val_loss = 0
    val_acc = 0
    count = 0
    for batch in val_ds_tf.as_numpy_iterator():
        batch_jax = {'image': batch[0], 'label': batch[1]}
        l, a = eval_step(state, batch_jax)
        val_loss += l
        val_acc += a
        count += 1
    
    val_loss /= count
    val_acc /= count
    
    history['loss'].append(train_loss)
    history['accuracy'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_accuracy'].append(val_acc)
    
    print(f"Epoch {epoch + 1}, Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.4f}")


In [None]:
BENCHMARK_ERROR = 0.12
BENCHMARK_ACCURACY = 1 - BENCHMARK_ERROR

accuracy = history["accuracy"]
val_accuracy = history["val_accuracy"]
loss = history["loss"]
val_loss = history["val_loss"]

assert accuracy[-1] > BENCHMARK_ACCURACY
assert val_accuracy[-1] > BENCHMARK_ACCURACY
print("Test to beat benchmark accuracy passed!")

assert accuracy[0] < accuracy[1]
assert accuracy[1] < accuracy[-1]
assert val_accuracy[0] < val_accuracy[1]
assert val_accuracy[1] < val_accuracy[-1]
print("Test model accuracy is improving passed!")

assert loss[0] > loss[1]
assert loss[1] > loss[-1]
assert val_loss[0] > val_loss[1]
assert val_loss[1] > val_loss[-1]
print("Test loss is decreasing passed!")


## Evaluating Predictions

Were you able to get an accuracy of over 90%? Not bad for a linear estimator! Let's make some predictions and see if we can find where the model has trouble. Change the range of values below to find incorrect predictions, and plot the corresponding images. What would you have guessed for these images?

**TODO 2**: Change the range below to find an incorrect prediction

In [None]:
image_numbers = range(0, 10, 1)  # Change me, please.

def load_prediction_dataset():
    dataset = (x_test[image_numbers], y_test[image_numbers])
    dataset = tf.data.Dataset.from_tensor_slices(dataset)
    dataset = dataset.map(scale).batch(len(image_numbers))
    return dataset

# Get batch
batch = next(load_prediction_dataset().as_numpy_iterator())
images, labels = batch

# Predict
logits = state.apply_fn({'params': state.params}, images)
predicted_results = logits

for index, prediction in enumerate(predicted_results):
    predicted_value = np.argmax(prediction)
    actual_value = y_test[image_numbers[index]]
    if actual_value != predicted_value:
        print("image number: " + str(image_numbers[index]))
        print("the prediction was " + str(predicted_value))
        print("the actual label is " + str(actual_value))
        print("")


In [None]:
bad_image_number = 8
plt.imshow(x_test[bad_image_number].reshape(HEIGHT, WIDTH));

It's understandable why the poor computer would have some trouble. Some of these images are difficult for even humans to read. In fact, we can see what the computer thinks each digit looks like.

Each of the 10 neurons in the dense layer of our model has 785 weights feeding into it. That's 1 weight for every pixel in the image + 1 for a bias term. These weights are flattened feeding into the model, but we can reshape them back into the original image dimensions to see what the computer sees.

**TODO 3**: Reshape the layer weights to be the shape of an input image and plot.

In [None]:
DIGIT = 0  # Change me to be an integer from 0 to 9.
LAYER = 1  # Layer 0 flattens image, so no weights
WEIGHT_TYPE = 0  # 0 for variable weights, 1 for biases

# Access weights from JAX state
# Params are a frozen dict.
# Structure: {'Dense_0': {'kernel': (784, 10), 'bias': (10,)}}
# The original code accessed model.layers[LAYER].get_weights().
# Dense layer is the only one with weights here.

dense_layer_weights = state.params['Dense_0']
if WEIGHT_TYPE == 0:
    digit_weights = dense_layer_weights['kernel'][:, DIGIT]
else:
    digit_weights = dense_layer_weights['bias'] # Bias is 1D

if WEIGHT_TYPE == 0:
    plt.imshow(digit_weights.reshape((HEIGHT, WIDTH)))
else:
    print("Bias for digit", DIGIT, "is", digit_weights[DIGIT])


Did you recognize the digit the computer was trying to learn? Pretty trippy, isn't it! Even with a simple "brain", the computer can form an idea of what a digit should be. The human brain, however, uses [layers and layers of calculations for image recognition](https://www.salk.edu/news-release/brain-recognizes-eye-sees/). Ready for the next challenge? <a href="https://github.com/GoogleCloudPlatform/training-data-analyst/blob/master/courses/machine_learning/images/mnist_linear.ipynb">Click here</a> to super charge our models with human-like vision.

## Bonus Exercise

Want to push your understanding further? Instead of using Keras' built in layers, try repeating the above exercise with your own [custom layers](https://www.tensorflow.org/tutorials/eager/custom_layers).

Copyright 2021 Google Inc.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.