In [33]:
%pip install jax jaxlib tensorflow tensorflow_datasets

import jax
import jax.numpy as jnp
from jax import jit


  pid, fd = os.forkpty()


Note: you may need to restart the kernel to use updated packages.


In [34]:
def layer(x, w, b):
    return jax.nn.relu(jnp.dot(x,w) + b)

def output_layer(x, w, b):
    return jax.nn.softmax(jnp.dot(x,w) + b)

def mlp(x, params):
    x = x.reshape(-1, 784)
    for layer_params in params[:-1]:
        x = layer(x, layer_params["weight"], layer_params["bias"])

    return output_layer(x, params[-1]["weight"], params[-1]["bias"])

key = jax.random.PRNGKey(0)

random = lambda x: jax.random.normal(key, x)

params = [{"weight": random((784, 128)), "bias": random((128,))}, 
          {"weight": random((128, 64)), "bias": random((64,))}, 
          {"weight": random((64, 10)), "bias": random((10,))}]

@jit
def cross_entropy_loss(params, x, y):
    logits = mlp(x, params)
    labels = jax.nn.one_hot(y, 10)
    return -jnp.mean(jnp.sum(labels * jnp.log(logits + 1e-8), axis=1))

In [35]:
import tensorflow_datasets as tfds
import tensorflow as tf

def prepare_data():
    train_ds = tfds.load('mnist', split='train', as_supervised=True)
    test_ds = tfds.load('mnist', split='test', as_supervised=True)

    def normalize(images, labels):
        images = tf.cast(images, tf.float32) / 255.0
        return images, labels

    train_ds = train_ds.map(normalize).batch(32).prefetch(1)
    test_ds = test_ds.map(normalize).batch(32)

    return tfds.as_numpy(train_ds), tfds.as_numpy(test_ds)

In [36]:
EPOCHS = 30
LEARNING_RATE = 0.01

train_ds, test_ds = prepare_data()

for epoch in range(EPOCHS):
    for x, y in train_ds:
        loss_val = cross_entropy_loss(params, x, y)
        grads = jax.grad(cross_entropy_loss)(params, x, y)
        params = jax.tree_util.tree_map(lambda p, g: p - LEARNING_RATE * g, params, grads)

    for x, y in test_ds:
        logits = mlp(x, params)
        preds = jax.nn.one_hot(jnp.argmax(logits, axis=1), 10)

        acc = jnp.mean(jnp.argmax(preds, axis=1) == y)

    print(f"Epoch {epoch + 1} loss: {loss_val}, accuracy: {acc}")


Epoch 1 loss: 5.756463527679443, accuracy: 0.625
Epoch 2 loss: 5.180817127227783, accuracy: 0.625
Epoch 3 loss: 3.484708786010742, accuracy: 0.8125
Epoch 4 loss: 3.4538283348083496, accuracy: 0.8125
Epoch 5 loss: 2.8782315254211426, accuracy: 0.8125
Epoch 6 loss: 2.878230094909668, accuracy: 0.8125
Epoch 7 loss: 2.8782315254211426, accuracy: 0.8125
Epoch 8 loss: 2.3025851249694824, accuracy: 0.8125
Epoch 9 loss: 2.8782315254211426, accuracy: 0.8125
Epoch 10 loss: 2.8782315254211426, accuracy: 0.8125
Epoch 11 loss: 2.546276092529297, accuracy: 0.8125
Epoch 12 loss: 2.8782315254211426, accuracy: 0.875
Epoch 13 loss: 2.8782315254211426, accuracy: 0.8125
Epoch 14 loss: 3.4538779258728027, accuracy: 0.875
Epoch 15 loss: 4.029524326324463, accuracy: 0.875
Epoch 16 loss: 3.453864336013794, accuracy: 0.875
Epoch 17 loss: 3.4538779258728027, accuracy: 0.875
Epoch 18 loss: 2.3025851249694824, accuracy: 0.875
Epoch 19 loss: 2.8782315254211426, accuracy: 0.8125
Epoch 20 loss: 2.8782315254211426, a