This notebook can be used in Google Colab

It implements a CIFAR10 training using JAX Deep Learning framework

In [1]:
# This notebook is inspired from JAX mnist example
# https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html

In [3]:
from pathlib import Path
from typing import Tuple

import tensorflow as tf
import tensorflow_datasets as tfds

# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type="GPU")


def get_data_from_tfds(
    name: str,
    data_dir: Path,
) -> Tuple[tf.Tensor | tf.data.Dataset, tf.Tensor | tf.data.Dataset]:
    """Fetch full datasets for evaluation.

    Args:
        name: name of the dataset for tfds.load() method
        data_dir: path to save the data
    Returns:
        tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
    """
    cifar10_data, _ = tfds.load(
        name=name,
        batch_size=-1,
        data_dir=data_dir,
        with_info=True,
    )
    cifar10_data = tfds.as_numpy(cifar10_data)
    train_data, test_data = cifar10_data["train"], cifar10_data["test"]
    return train_data, test_data


In [11]:
import jax.numpy as jnp
from jax.example_libraries import stax, optimizers

train_data, test_data = get_data_from_tfds(name="cifar10", data_dir=Path('/tmp/tfds'))

X_train, Y_train = train_data['image'], train_data['label']
X_test, Y_test = test_data['image'], test_data['label']

X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
                                   jnp.array(X_test, dtype=jnp.float32),\
                                   jnp.array(Y_train, dtype=jnp.float32),\
                                   jnp.array(Y_test, dtype=jnp.float32)
classes =  jnp.unique(Y_train)
conv_init, conv_apply = stax.serial(
    stax.Conv(32, (3,3), padding="SAME"),
    stax.Relu,
    stax.MaxPool(window_shape=(2, 2), strides=(2, 2)),
    stax.Conv(16, (3,3), padding="SAME"),
    stax.Relu,
    stax.MaxPool(window_shape=(2, 2), strides=(2, 2)),
    stax.Flatten,
    stax.Dense(64),
    stax.Relu,
    stax.Dense(len(classes)),
    stax.Softmax
)

In [16]:
# Init weights and verify the shapes
import jax
rng = jax.random.PRNGKey(123)

weights = conv_init(rng, (18,32,32,3))[1]

for w in weights:
    if w:
        w, b = w
        print("Weights : {}, Biases : {}".format(w.shape, b.shape))

Weights : (3, 3, 3, 32), Biases : (1, 1, 1, 32)
Weights : (3, 3, 32, 16), Biases : (1, 1, 1, 16)
Weights : (1024, 64), Biases : (64,)
Weights : (64, 10), Biases : (10,)


In [18]:
# Make a prediction and verify output shape
preds = conv_apply(weights, X_train[:5])
preds.shape

(5, 10)

In [25]:
from typing import Callable

def CrossEntropyLoss(
    conv_apply: Callable,
    weights: list,
    input_data: jax.Array,
    targets: jax.Array,
) -> jax.Array:
    """Implement of cross entropy loss.

    Args:
        conv_apply: callable from _, conv_apply = stax.serial(...)
        weights: list from _, _, opt_get_weights = optimizers.adam(lr), opt_get_weights(opt_state)
        input_data: data to predict
        targets: groundtruth targets

    Returns:
        loss value
    """
    preds = conv_apply(weights, input_data)
    log_preds = jnp.log(preds + tf.keras.backend.epsilon())
    return -jnp.mean(targets * log_preds)