In [1]:
import jax.numpy as jnp
from jax import random, jit, vmap, grad

In [2]:
def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))


def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [
        random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)
    ]


layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.key(0))

In [3]:
from jax.scipy.special import logsumexp


def relu(x):
    return jnp.maximum(0, x)


def predict(params, image):
    activations = image
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

In [4]:
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds)

[-2.2827816 -2.298299  -2.292807  -2.3304513 -2.3079185 -2.3170114
 -2.3065283 -2.296568  -2.300571  -2.2937381]


In [5]:
random_flattened_images = random.normal(
    random.key(1),
    (
        10,
        28 * 28,
    ),
)
try:
    preds = predict(params, random_flattened_images)
except Exception as e:
    print(type(e))
    print(e)

<class 'TypeError'>
dot_general requires contracting dimensions to have the same shape, got (784,) and (10,).


In [6]:
batched_predict = vmap(predict, in_axes=(None, 0))
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

(10, 10)


# Utility & Loss functions

In [7]:
def one_hot(x, k, dtype=jnp.float32):
    """Create a one-hot encoding of x of size k."""
    return jnp.array(x[:, None] == jnp.arange(k), dtype)


def accuracy(params, images, targets):
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
    return jnp.mean(predicted_class == target_class)


def loss(params, images, targets):
    preds = batched_predict(params, images)
    return -jnp.mean(preds * targets)


@jit
def update(params, x, y):
    grads = grad(loss)(params, x, y)
    return [
        (w - step_size * dw, b - step_size * db)
        for (w, b), (dw, db) in zip(params, grads)
    ]

# Data Loading with TensorFlow Datasets

In [8]:
import tensorflow as tf

# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type="GPU")

import tensorflow_datasets as tfds

data_dir = "/tmp/tfds"

# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(
    name="mnist", batch_size=-1, data_dir=data_dir, with_info=True
)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data["train"], mnist_data["test"]
num_labels = info.features["label"].num_classes
h, w, c = info.features["image"].shape
num_pixels = h * w * c

# Full train set
train_images, train_labels = train_data["image"], train_data["label"]
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)

# Full test set
test_images, test_labels = test_data["image"], test_data["label"]
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)

2024-06-10 12:06:14.243932: W external/local_tsl/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata.google.internal".


[1mDownloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /tmp/tfds/mnist/3.0.1...[0m


Dl Completed...:   0%|          | 0/5 [00:00<?, ? file/s]

[1mDataset mnist downloaded and prepared to /tmp/tfds/mnist/3.0.1. Subsequent calls will reuse this data.[0m


In [9]:
import numpy as np

x = np.array([1, 2, 3])
x[:, None] == np.array([1, 2, 3, 4])

array([[ True, False, False, False],
       [False,  True, False, False],
       [False, False,  True, False]])

In [10]:
print("Train:", train_images.shape, train_labels.shape)
print("Test:", test_images.shape, test_labels.shape)

Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10)


# Training Loop

In [11]:
import time


def get_train_batches():
    # as_supervised=True gives us the (image, label) as a tuple instead of a dict
    ds = tfds.load(name="mnist", split="train", as_supervised=True, data_dir=data_dir)
    # You can build up an arbitrary tf.data input pipeline
    ds = ds.batch(batch_size).prefetch(1)
    # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
    return tfds.as_numpy(ds)


for epoch in range(num_epochs):
    start_time = time.time()
    for x, y in get_train_batches():
        x = jnp.reshape(x, (len(x), num_pixels))
        y = one_hot(y, num_labels)
        params = update(params, x, y)
    epoch_time = time.time() - start_time

    train_acc = accuracy(params, train_images, train_labels)
    test_acc = accuracy(params, test_images, test_labels)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))

2024-06-10 12:06:31.623071: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 0 in 9.42 sec
Training set accuracy 0.9252833724021912
Test set accuracy 0.9266999959945679


2024-06-10 12:06:34.366034: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 1 in 0.86 sec
Training set accuracy 0.9428499937057495
Test set accuracy 0.9411999583244324


2024-06-10 12:06:35.280952: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 2 in 0.89 sec
Training set accuracy 0.953166663646698
Test set accuracy 0.9513999819755554


2024-06-10 12:06:36.181634: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 3 in 0.87 sec
Training set accuracy 0.959933340549469
Test set accuracy 0.9557999968528748


2024-06-10 12:06:37.060244: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 4 in 0.85 sec
Training set accuracy 0.9650999903678894
Test set accuracy 0.960599958896637


2024-06-10 12:06:37.956795: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 5 in 0.87 sec
Training set accuracy 0.9691833257675171
Test set accuracy 0.9629999995231628


2024-06-10 12:06:38.844868: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 6 in 0.86 sec
Training set accuracy 0.9725666642189026
Test set accuracy 0.9651999473571777


2024-06-10 12:06:39.726152: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 7 in 0.85 sec
Training set accuracy 0.9754666686058044
Test set accuracy 0.9666999578475952


2024-06-10 12:06:40.646448: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 8 in 0.89 sec
Training set accuracy 0.9782000184059143
Test set accuracy 0.9680999517440796
Epoch 9 in 0.86 sec
Training set accuracy 0.9803500175476074
Test set accuracy 0.9693999886512756


2024-06-10 12:06:41.536175: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
