### Setting up the environment


In [None]:
!pip install keras_core

In [None]:
import os

os.environ["KERAS_BACKEND"] = "jax"

import jax
import numpy as np
import tensorflow as tf
import keras_core as keras

from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P


### Model Architecture

Define the model structure using Keras.

Define any model here, dosent just have to be a normal CNN.


In [None]:
def get_model():
    # Make a simple convnet with batch normalization and dropout.
    inputs = keras.Input(shape=(28, 28, 1))
    x = keras.layers.Rescaling(1.0 / 255.0)(inputs)
    x = keras.layers.Conv2D(filters=12, kernel_size=3, padding="same", use_bias=False)(
        x
    )
    x = keras.layers.BatchNormalization(scale=False, center=True)(x)
    x = keras.layers.ReLU()(x)
    x = keras.layers.Conv2D(
        filters=24,
        kernel_size=6,
        use_bias=False,
        strides=2,
    )(x)
    x = keras.layers.BatchNormalization(scale=False, center=True)(x)
    x = keras.layers.ReLU()(x)
    x = keras.layers.Conv2D(
        filters=32,
        kernel_size=6,
        padding="same",
        strides=2,
        name="large_k",
    )(x)
    x = keras.layers.BatchNormalization(scale=False, center=True)(x)
    x = keras.layers.ReLU()(x)
    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dense(256, activation="relu")(x)
    x = keras.layers.Dropout(0.5)(x)
    outputs = keras.layers.Dense(10)(x)
    model = keras.Model(inputs, outputs)
    return model

### Data Preprocessing

Loading and processing the MNIST dataset.


In [None]:
def get_datasets():
    # Load the data and split it between train and test sets
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

    # Scale images to the [0, 1] range
    x_train = x_train.astype("float32")
    x_test = x_test.astype("float32")
    # Make sure images have shape (28, 28, 1)
    x_train = np.expand_dims(x_train, -1)
    x_test = np.expand_dims(x_test, -1)
    print("x_train shape:", x_train.shape)
    print(x_train.shape[0], "train samples")
    print(x_test.shape[0], "test samples")

    # Create TF Datasets
    train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    eval_data = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    return train_data, eval_data

### Training Configuration


In [None]:
# Config
num_epochs = 20
batch_size = 8096 * 2 # Cause why not

train_data, eval_data = get_datasets()
train_data = train_data.batch(batch_size, drop_remainder=True)

model = get_model()
optimizer = keras.optimizers.Adam(1e-3)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

(one_batch, one_batch_labels) = next(iter(train_data))
model.build(one_batch)
optimizer.build(model.trainable_variables)


### Loss & Gradient Computation

JAX computation is purely stateless
In JAX, everything must be a stateless function -- so our loss computation function must be stateless as well. That means that all Keras variables (e.g. weight tensors) must be passed as function inputs, and any variable that has been updated during the forward pass must be returned as function output. The function have no side effect.

During the forward pass, the non-trainable variables of a Keras model might get updated. These variables could be, for instance, RNG seed state variables or BatchNormalization statistics. We're going to need to return those. So we need something like this:


In [None]:
# This is the loss function that will be differentiated.
# Keras provides a pure functional forward pass: model.stateless_call

def compute_loss(trainable_variables, non_trainable_variables, x, y):
    
    y_pred, updated_non_trainable_variables = model.stateless_call(trainable_variables, non_trainable_variables, x)
    
    loss_value = loss(y, y_pred)
    
    return loss_value, updated_non_trainable_variables


Once you have such a function, you can get the gradient function by specifying hax_aux in value_and_grad: it tells JAX that the loss computation function returns more outputs than just the loss. Note that the loss should always be the first output.

In [None]:
# Function to compute gradients
compute_gradients = jax.value_and_grad(compute_loss, has_aux=True)

### Training step definition


By default, JAX operations run eagerly, just like in TensorFlow eager mode and PyTorch eager mode. And just like TensorFlow eager mode and PyTorch eager mode, it's pretty slow -- eager mode is better used as a debugging environment, not as a way to do any actual work. So let's make our `train_step` fast by compiling it.

When you have a stateless JAX function, you can compile it to XLA via the `@jax.jit` decorator. It will get traced during its first execution, and in subsequent executions you will be executing the traced graph (this is just like `@tf.function(jit_compile=True)`. Let's try it:

In [None]:
# Training step, Keras provides a pure functional optimizer.stateless_apply
@jax.jit
def train_step(train_state, x, y):
    trainable_variables, non_trainable_variables, optimizer_variables = train_state
    (loss_value, non_trainable_variables), grads = compute_gradients(
        trainable_variables, non_trainable_variables, x, y
    )

    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )

    return loss_value, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
    )



### Replicate model and optimizer


In [None]:
# Replicate the model and optimizer variable on all devices
def get_replicated_train_state(devices):
    # All variables will be replicated on all devices
    var_mesh = Mesh(devices, axis_names=("_"))
    # In NamedSharding, axes not mentioned are replicated (all axes here)
    var_replication = NamedSharding(var_mesh, P())

    # Apply the distribution settings to the model variables
    trainable_variables = jax.device_put(model.trainable_variables, var_replication)
    non_trainable_variables = jax.device_put(
        model.non_trainable_variables, var_replication
    )
    optimizer_variables = jax.device_put(optimizer.variables, var_replication)

    # Combine all state in a tuple
    return (trainable_variables, non_trainable_variables, optimizer_variables)


In [None]:
num_devices = len(jax.local_devices())
print(f"Running on {num_devices} devices: {jax.local_devices()}")
devices = mesh_utils.create_device_mesh((num_devices,))

In [None]:
# Data will be split along the batch axis
data_mesh = Mesh(devices, axis_names=("batch",))  # naming axes of the mesh
data_sharding = NamedSharding(
    data_mesh,
    P(
        "batch",
    ),
)  # naming axes of the sharded partition

# Display data sharding
x, y = next(iter(train_data))
sharded_x = jax.device_put(x.numpy(), data_sharding)
print("Data sharding")
jax.debug.visualize_array_sharding(jax.numpy.reshape(sharded_x, [-1, 28 * 28]))


### Training Loop


The following are the part of train state

``` python
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
state = trainable_variables, non_trainable_variables, optimizer_variables
```

In [None]:
train_state = get_replicated_train_state(devices)

In [None]:
import time

print(f"Batch Size: {batch_size}")

# Warm-up epoch
start_time = time.time()
data_iter = iter(train_data)
for data in data_iter:
    x, y = data
    sharded_x = jax.device_put(x.numpy(), data_sharding)
    loss_value, train_state = train_step(train_state, sharded_x, y.numpy())
end_time = time.time()
warmup_time = end_time - start_time
print("Warmup time:", warmup_time)

# Training loop
total_time = 0.0
for epoch in range(num_epochs):
    data_iter = iter(train_data)
    start_time = time.time()
    for data in data_iter:
        x, y = data
        sharded_x = jax.device_put(x.numpy(), data_sharding)
        loss_value, train_state = train_step(train_state, sharded_x, y.numpy())
    end_time = time.time()
    epoch_time = end_time - start_time
    total_time += epoch_time
    print("Epoch", epoch, "loss:", loss_value)

average_time_per_epoch = total_time / num_epochs
total_time_excl_warmup = total_time

print("Average time per epoch:", average_time_per_epoch)
print("Total time excluding warm-up:", total_time_excl_warmup)


### Post-Processing

A key thing to notice here is that the loop is entirely stateless -- the variables attached to the model (`model.weights`) are never getting updated during the loop. Their new values are only stored in the state tuple. That means that at some point, before saving the model, you should be attaching the new variable values back to the model.

Just call `variable.assign(new_value)` on each model variable you want to update:


In [None]:
# Post-processing model state update to write them back into the model
trainable_variables, non_trainable_variables, optimizer_variables = train_state
for variable, value in zip(model.trainable_variables, trainable_variables):
    variable.assign(value)
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
    variable.assign(value)

### Saving the hardwork

In [None]:
model.save('model.keras')

In [None]:
!ls