This notebook can be used in Google Colab

It implements a CIFAR10 training using JAX Deep Learning framework

In [None]:
# This notebook is inspired from JAX mnist example
# https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html

In [None]:
from pathlib import Path
from typing import Tuple

import tensorflow as tf
import tensorflow_datasets as tfds

# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type="GPU")


def get_data_from_tfds(
    name: str,
    data_dir: Path,
) -> Tuple[tf.Tensor | tf.data.Dataset, tf.Tensor | tf.data.Dataset]:
    """Fetch full datasets for evaluation.

    Args:
        name: name of the dataset for tfds.load() method
        data_dir: path to save the data
    Returns:
        tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
    """
    cifar10_data, _ = tfds.load(
        name=name,
        batch_size=-1,
        data_dir=data_dir,
        with_info=True,
    )
    cifar10_data = tfds.as_numpy(cifar10_data)
    train_data, test_data = cifar10_data["train"], cifar10_data["test"]
    return train_data, test_data


In [None]:
import jax.numpy as jnp
from jax.example_libraries import stax, optimizers

train_data, test_data = get_data_from_tfds(name="cifar10", data_dir=Path('/tmp/tfds'))

X_train, Y_train = train_data['image'], train_data['label']
X_test, Y_test = test_data['image'], test_data['label']

X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
                                   jnp.array(X_test, dtype=jnp.float32),\
                                   jnp.array(Y_train, dtype=jnp.float32),\
                                   jnp.array(Y_test, dtype=jnp.float32)
classes =  jnp.unique(Y_train)
conv_init, conv_apply = stax.serial(
    stax.Conv(32, (3,3), padding="SAME"),
    stax.Relu,
    stax.MaxPool(window_shape=(2, 2), strides=(2, 2)),
    stax.Conv(16, (3,3), padding="SAME"),
    stax.Relu,
    stax.MaxPool(window_shape=(2, 2), strides=(2, 2)),
    stax.Flatten,
    stax.Dense(64),
    stax.Relu,
    stax.Dense(len(classes)),
    stax.Softmax
)

In [None]:
# Init weights and verify the shapes
import jax
rng = jax.random.PRNGKey(123)

weights = conv_init(rng, (18,32,32,3))[1]

for w in weights:
    if w:
        w, b = w
        print("Weights : {}, Biases : {}".format(w.shape, b.shape))

Weights : (3, 3, 3, 32), Biases : (1, 1, 1, 32)
Weights : (3, 3, 32, 16), Biases : (1, 1, 1, 16)
Weights : (1024, 64), Biases : (64,)
Weights : (64, 10), Biases : (10,)


In [None]:
# Make a prediction and verify output shape
preds = conv_apply(weights, X_train[:5])
preds.shape

(5, 10)

In [39]:
from typing import Callable

def CrossEntropyLoss(
    weights: list,
    input_data: jax.Array,
    targets: jax.Array,
) -> jax.Array:
    """Implement of cross entropy loss.

    Args:
        weights: list from _, _, opt_get_weights = optimizers.adam(lr), opt_get_weights(opt_state)
        input_data: data to predict
        targets: groundtruth targets in one hot encoding

    Returns:
        loss value
    """
    preds = conv_apply(weights, input_data)
    log_preds = jnp.log(preds + tf.keras.backend.epsilon())
    return -jnp.mean(targets * log_preds)

In [42]:
from jax import value_and_grad
from tqdm import tqdm

def TrainModelInBatches(
    X: jax.Array,
    Y: jax.Array,
    epochs: int,
    opt_state: jax.example_libraries.optimizers.OptimizerState,
    opt_update: Callable,
    opt_get_weights: Callable,
    batch_size: int,
) -> jax.example_libraries.optimizers.OptimizerState:
    """Train Jax model in batches.

    Args:
        X: training input
        Y: groundtruth in one hot encoding
        epochs: number of epochs
        opt_state: from opt_init(weights)
        opt_update: from _, opt_update, _ = optimizers.adam(lr)
        opt_get_weights: from _, _, opt_get_weights = optimizers.adam(lr)
        batch_size: batch size for training

    Returns:
        updated opt_state
    """

    for i in range(epochs):
        batches = jnp.arange((X.shape[0] // batch_size) + 1)
        progress_bar = tqdm(batches, position=0, leave=True)

        losses = []
        for batch in batches:
            if batch != batches[-1]:
                start, end = int(batch * batch_size), int(batch * batch_size + batch_size)
            else:
                start, end = int(batch * batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end]

            loss, gradients = value_and_grad(CrossEntropyLoss)(
                opt_get_weights(opt_state),
                X_batch,
                Y_batch,
            )

            ## Update Weights
            opt_state = opt_update(i, gradients, opt_state)

            losses.append(loss)

            progress_bar.set_description(f"Epoch {i+1}/{epochs}")
            progress_bar.set_postfix(train_loss=jnp.round(jnp.array(losses).mean(), decimals=3))
            progress_bar.update()

    return opt_state

In [43]:
seed = jax.random.PRNGKey(42)
learning_rate = jnp.array(1/1e4)
epochs = 2

batch_size=256

opt_init, opt_update, opt_get_weights = optimizers.adam(learning_rate)
opt_state = opt_init(weights)
one_hot_targets = jax.nn.one_hot(Y_train, num_classes=len(classes))


final_opt_state = TrainModelInBatches(
    X=X_train,
    Y=one_hot_targets,
    epochs=epochs,
    opt_state=opt_state,
    opt_update=opt_update,
    opt_get_weights=opt_get_weights,
    batch_size=batch_size,
)

  0%|          | 0/196 [11:01<?, ?it/s]
Epoch 1/2:   2%|▏         | 3/196 [05:14<5:36:48, 104.71s/it, train_loss=1.3820001]
Epoch 1/2: 100%|██████████| 196/196 [03:12<00:00,  1.02it/s, train_loss=0.95400006]
Epoch 2/2: 100%|██████████| 196/196 [02:56<00:00,  1.11it/s, train_loss=0.26000002]
