In [8]:
import jax.numpy as jnp
from jax import grad, jit, random
import numpy as np
import tensorflow as tf


In [9]:
# Data Preparation:

# Example data
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train, X_test = X_train / 255.0, X_test / 255.0
X_train, X_test = X_train.reshape(-1, 784), X_test.reshape(-1, 784)
y_train, y_test = jnp.array(y_train), jnp.array(y_test)

# One-hot encode the labels
def one_hot(x, k):
    return jnp.array(x[:, None] == jnp.arange(k), dtype=jnp.float32)

y_train = one_hot(y_train, 10)
y_test = one_hot(y_test, 10)



In [10]:
# Model Building:

def relu(x):
    return jnp.maximum(0, x)

def predict(params, X):
    W1, b1, W2, b2, W3, b3 = params
    X = relu(jnp.dot(X, W1) + b1)
    X = relu(jnp.dot(X, W2) + b2)
    return jnp.dot(X, W3) + b3

def loss(params, X, y):
    preds = predict(params, X)
    return -jnp.mean(jnp.sum(y * preds, axis=1))


In [11]:
# Model Initialization:

key = random.PRNGKey(0)
input_size, hidden1, hidden2, output_size = 784, 128, 64, 10

W1 = random.normal(key, (input_size, hidden1))
b1 = jnp.zeros(hidden1)
W2 = random.normal(key, (hidden1, hidden2))
b2 = jnp.zeros(hidden2)
W3 = random.normal(key, (hidden2, output_size))
b3 = jnp.zeros(output_size)

params = [W1, b1, W2, b2, W3, b3]


In [12]:
# Model Compilation and Training

learning_rate = 0.001

@jit
def update(params, X, y):
    grads = grad(loss)(params, X, y)
    return [(param - learning_rate * grad) for param, grad in zip(params, grads)]

for epoch in range(5):
    params = update(params, X_train, y_train)
    l = loss(params, X_train, y_train)
    print(f'Epoch {epoch+1}, Loss: {l}')


Epoch 1, Loss: -21.011432647705078
Epoch 2, Loss: -50.31199264526367
Epoch 3, Loss: -79.86878204345703
Epoch 4, Loss: -109.75804901123047
Epoch 5, Loss: -140.04869079589844


In [13]:
# Model Evaluation:

def accuracy(params, X, y):
    predictions = predict(params, X)
    return jnp.mean(jnp.argmax(predictions, axis=1) == jnp.argmax(y, axis=1))

acc = accuracy(params, X_test, y_test)
print(f'Test accuracy: {acc}')


Test accuracy: 0.13539999723434448


In [14]:
# Model Prediction:

predictions = predict(params, X_test)
print(jnp.argmax(predictions, axis=1))


[9 0 7 ... 2 7 6]


## Improving the Model

In [15]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

## Hyperparameters
Let’s get a few bookkeeping items out of the way.

In [16]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.key(0))

### Auto-batching predictions
Let us first define our prediction function. Note that we’re defining this for a single image example. We’re going to use JAX’s `vmap` function to automatically handle mini-batches, with no performance penalty.

In [17]:
from jax.scipy.special import logsumexp

def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)
  
  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)

Let’s check that our prediction function only works on single images.

In [18]:
# This works on single examples
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)

(10,)


In [19]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
try:
  preds = predict(params, random_flattened_images)
except TypeError:
  print('Invalid shapes!')

Invalid shapes!


In [20]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

(10, 10)


At this point, we have all the ingredients we need to define our neural network and train it. We’ve built an auto-batched version of `predict`, which we should be able to use in a loss function. We should be able to use `grad` to take the derivative of the loss with respect to the neural network parameters. Last, we should be able to use `jit` to speed up everything.

## Utility and loss functions

In [21]:
def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)
  
def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

## Data Loading with tensorflow/datasets

JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don’t include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let’s just use them instead of reinventing anything. We’ll use the `tensorflow/datasets` data loader

```bash
pip install tensorflow
pip install tensorflow-datasets
```

In [23]:
import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')

import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

# Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)

# Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)

Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /tmp/tfds\mnist\3.0.1...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/2 [00:00<?, ? splits/s]

Generating train examples...: 0 examples [00:00, ? examples/s]

Shuffling \tmp\tfds\mnist\incomplete.SW3VQQ_3.0.1\mnist-train.tfrecord*...:   0%|          | 0/60000 [00:00<?,…

Generating test examples...: 0 examples [00:00, ? examples/s]

Shuffling \tmp\tfds\mnist\incomplete.SW3VQQ_3.0.1\mnist-test.tfrecord*...:   0%|          | 0/10000 [00:00<?, …

Dataset mnist downloaded and prepared to /tmp/tfds\mnist\3.0.1. Subsequent calls will reuse this data.


In [24]:
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)

Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10)


## Training Loop

In [25]:
import time

def get_train_batches():
  # as_supervised=True gives us the (image, label) as a tuple instead of a dict
  ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
  # You can build up an arbitrary tf.data input pipeline
  ds = ds.batch(batch_size).prefetch(1)
  # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
  return tfds.as_numpy(ds)

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in get_train_batches():
    x = jnp.reshape(x, (len(x), num_pixels))
    y = one_hot(y, num_labels)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 3.05 sec
Training set accuracy 0.9253333210945129
Test set accuracy 0.9268999695777893
Epoch 1 in 2.98 sec
Training set accuracy 0.9427500367164612
Test set accuracy 0.941100001335144
Epoch 2 in 3.07 sec
Training set accuracy 0.9531833529472351
Test set accuracy 0.9511999487876892
Epoch 3 in 3.10 sec
Training set accuracy 0.9600666761398315
Test set accuracy 0.9553999900817871
Epoch 4 in 2.92 sec
Training set accuracy 0.965149998664856
Test set accuracy 0.960099995136261
Epoch 5 in 2.79 sec
Training set accuracy 0.9690666794776917
Test set accuracy 0.9628999829292297
Epoch 6 in 2.87 sec
Training set accuracy 0.9725666642189026
Test set accuracy 0.9652999639511108
Epoch 7 in 2.82 sec
Training set accuracy 0.9754999876022339
Test set accuracy 0.9666999578475952
Epoch 8 in 2.84 sec
Training set accuracy 0.9781333208084106
Test set accuracy 0.9680999517440796
Epoch 9 in 2.96 sec
Training set accuracy 0.9802833199501038
Test set accuracy 0.9692999720573425


We’ve now used most of the JAX API: `grad` for derivatives, `jit` for speedups and `vmap` for auto-vectorization. We used NumPy to specify all of our computation, and borrowed the great data loaders from `tensorflow/datasets`, and ran the whole thing on the GPU.

### [Return to Main Page](../README.md)