In [13]:
%pip install jax jaxlib tensorflow tensorflow_datasets

import jax
import jax.numpy as jnp
from jax import jit


  pid, fd = os.forkpty()


Note: you may need to restart the kernel to use updated packages.


In [14]:
def layer(x, w, b):
    return jax.nn.relu(jnp.dot(x,w) + b)

def output_layer(x, w, b):
    return jax.nn.softmax(jnp.dot(x,w) + b)

def mlp(x, w1, b1, w2, b2, w3, b3):
    l1 = x.reshape(-1, 784)
    l2 = layer(l1, w1, b1)
    l3 = layer(l2, w2, b2)

    return output_layer(l3, w3, b3)

key = jax.random.PRNGKey(0)

w1 = jax.random.normal(key, (784, 128))
b1 = jax.random.normal(key, (128,))
w2 = jax.random.normal(key, (128, 64))
b2 = jax.random.normal(key, (64,))
w3 = jax.random.normal(key, (64, 10))
b3 = jax.random.normal(key, (10,))

params = (w1, b1, w2, b2, w3, b3)

@jit
def loss(params, x, y):
    logits = mlp(x, *params)
    labels = jax.nn.one_hot(y, 10)
    return -jnp.mean(jnp.sum(labels * jnp.log(logits + 1e-8), axis=1))

In [15]:
import tensorflow_datasets as tfds
import tensorflow as tf

def prepare_data():
    train_ds = tfds.load('mnist', split='train', as_supervised=True)
    test_ds = tfds.load('mnist', split='test', as_supervised=True)

    def normalize(images, labels):
        images = tf.cast(images, tf.float32) / 255.0
        return images, labels

    train_ds = train_ds.map(normalize).batch(32).prefetch(1)
    test_ds = test_ds.map(normalize).batch(32)

    return tfds.as_numpy(train_ds), tfds.as_numpy(test_ds)

In [16]:
EPOCHS = 30
LEARNING_RATE = 0.03

train_ds, test_ds = prepare_data()

for epoch in range(EPOCHS):
    for x, y in train_ds:
        loss_val = loss(params, x, y)
        grads = jax.grad(loss)(params, x, y)
        params = jax.tree_map(lambda p, g: p - LEARNING_RATE * g, params, grads)

    for x, y in test_ds:
        logits = mlp(x, *params)
        preds = jax.nn.one_hot(jnp.argmax(logits, axis=1), 10)

        acc = jnp.mean(jnp.argmax(preds, axis=1) == y)

    print(f"Epoch {epoch + 1} loss: {loss_val}, accuracy: {acc}")


  params = jax.tree_map(lambda p, g: p - LEARNING_RATE * g, params, grads)


Epoch 1 loss: 6.98144006729126, accuracy: 0.5
Epoch 2 loss: 6.3321099281311035, accuracy: 0.5
Epoch 3 loss: 6.907756328582764, accuracy: 0.5
Epoch 4 loss: 6.3321099281311035, accuracy: 0.5
Epoch 5 loss: 6.3321099281311035, accuracy: 0.5
Epoch 6 loss: 6.907756328582764, accuracy: 0.5
Epoch 7 loss: 6.3321099281311035, accuracy: 0.5
Epoch 8 loss: 6.3321099281311035, accuracy: 0.5
Epoch 9 loss: 5.756463527679443, accuracy: 0.5
Epoch 10 loss: 6.3321099281311035, accuracy: 0.5
Epoch 11 loss: 5.756463527679443, accuracy: 0.5
Epoch 12 loss: 5.180817127227783, accuracy: 0.625
Epoch 13 loss: 4.029524326324463, accuracy: 0.625
Epoch 14 loss: 3.4701433181762695, accuracy: 0.625
Epoch 15 loss: 4.029524326324463, accuracy: 0.625
Epoch 16 loss: 4.029524326324463, accuracy: 0.625
Epoch 17 loss: 4.605170726776123, accuracy: 0.5625
Epoch 18 loss: 4.029524326324463, accuracy: 0.625
Epoch 19 loss: 4.029514312744141, accuracy: 0.625
Epoch 20 loss: 4.605170726776123, accuracy: 0.625
Epoch 21 loss: 4.0295367