# Demo Notebook


In [1]:
import os
import re
import jax
import absl.flags as flags

In [2]:
# This needs to be run first before any JAX code, to force JAX to use CPU in our demo for training.
num_cpu_devices = 4
xla_flags = os.getenv('XLA_FLAGS', '')
xla_flags = re.sub(
    r'--xla_force_host_platform_device_count=\S+', '', xla_flags
).split()
os.environ['XLA_FLAGS'] = ' '.join(
    [f'--xla_force_host_platform_device_count={num_cpu_devices}'] + xla_flags
)
jax.config.update('jax_platforms', 'cpu')
flags.FLAGS.jax_allow_unused_tpus = True
jax.devices()

## Create and Train Mnist Model

In [3]:
from flax import linen as nn
from flax.training import train_state
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
import tensorflow_datasets as tfds

In [4]:
# Load the MNIST train and test datasets into memory.
def get_datasets():

  ds_builder = tfds.builder('mnist')
  ds_builder.download_and_prepare()
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  train_ds['image'] = jnp.float32(train_ds['image']) / 255.0
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.0
  return train_ds, test_ds

In [5]:
# Define the model.

class Mnist(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x

# Define train step.

@jax.jit
def apply_model(state, images, labels):
  """Computes gradients, loss and accuracy for a single batch."""

  def loss_fn(params):
    logits = state.apply_fn({'params': params}, images)
    one_hot = jax.nn.one_hot(labels, 10)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
    return loss, logits

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  return grads, loss, accuracy


@jax.jit
def update_model(state, grads):
  return state.apply_gradients(grads=grads)

# Create train state.

def create_train_state(rng, config):
  """Creates initial `TrainState`."""
  mnist = Mnist()
  params = mnist.init(rng, jnp.ones([1, 28, 28, 1]))['params']
  tx = optax.sgd(config.learning_rate, config.momentum)
  return train_state.TrainState.create(apply_fn=mnist.apply, params=params, tx=tx)

# Define train loop.

def train_epoch(state, train_ds, batch_size, rng):
  """Train for a single epoch."""
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, len(train_ds['image']))
  perms = perms[: steps_per_epoch * batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))

  epoch_loss = []
  epoch_accuracy = []

  for perm in perms:
    batch_images = train_ds['image'][perm, ...]
    batch_labels = train_ds['label'][perm, ...]
    grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
    state = update_model(state, grads)
    epoch_loss.append(loss)
    epoch_accuracy.append(accuracy)
  train_loss = np.mean(epoch_loss)
  train_accuracy = np.mean(epoch_accuracy)
  return state, train_loss, train_accuracy

# Define model training and evaluation loop.
def train_and_evaluate(
    config: ml_collections.ConfigDict
) -> train_state.TrainState:
  """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.

  Returns:
    The train state (which includes the `.params`).
  """
  train_ds, test_ds = get_datasets()
  rng = jax.random.key(0)

  rng, init_rng = jax.random.split(rng)
  state = create_train_state(init_rng, config)

  for epoch in range(1, config.num_epochs + 1):
    rng, input_rng = jax.random.split(rng)
    state, train_loss, train_accuracy = train_epoch(
        state, train_ds, config.batch_size, input_rng
    )
    # Evaluate model on test set.
    _, test_loss, test_accuracy = apply_model(
        state, test_ds['image'], test_ds['label']
    )

    print(
        'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f,'
        ' test_accuracy: %.2f'
        % (
            epoch,
            train_loss,
            train_accuracy * 100,
            test_loss,
            test_accuracy * 100,
        )
    )
  # Return the train state (including the params/weights)
  return state

In [6]:
# Create the configuration of hyperparameters, feel free to tune them.
def get_config():
  config = ml_collections.ConfigDict()
  config.learning_rate = 0.1
  config.momentum = 0.9
  config.batch_size = 128
  config.num_epochs = 1
  return config



In [None]:
# Run model training and evaluation.
state = train_and_evaluate(get_config())

# Export the Model (new!)


In [7]:
import frozendict
from jax.experimental import mesh_utils

# Create mesh to enable sharding in 2 dimensions
ici_mesh = frozendict.frozendict({'data': 2, 'model': 2})
devices = mesh_utils.create_device_mesh(tuple(ici_mesh.values()))
mesh = jax.sharding.Mesh(devices, tuple(ici_mesh.keys()))
mesh

## Shard the Model

In [8]:
# Create the parameter sharding spec.
import jax.sharding as jsharding

def create_params_sharding_spec(mesh, p):
  specs = (None,) * (len(p.shape) - 1) + ('model',) if len(p.shape) > 0 else ()
  return jax.sharding.NamedSharding(
      mesh, jsharding.PartitionSpec(*specs))

params_sharding_spec = jax.tree_util.tree_map(
    lambda p: create_params_sharding_spec(mesh, p), {'params':state.params})


In [9]:
params_sharding_spec

In [10]:
# Note here that the input tensor is sharded in two dimensions, which is not supported by DTensor.
inputs_sharding_spec = jax.sharding.NamedSharding(
    mesh, jsharding.PartitionSpec('data', 'model', None, None))


In [11]:
inputs_sharding_spec

In [12]:
# Create a new instance of the model.
model = Mnist()

# Shard the model function.
model_apply_fn = jax.jit(
  model.apply,
  in_shardings=(
      params_sharding_spec,
      inputs_sharding_spec,
  ),
  out_shardings=jax.sharding.NamedSharding(
      mesh,
      jsharding.PartitionSpec(None),
  ),
)

In [13]:
!rm -r /tmp/mnist

## Write the checkpoint to disk.

In [14]:
# Write the checkpoint.
import orbax.checkpoint as ocp
ckpter = ocp.Checkpointer(ocp.StandardCheckpointHandler())
ckpter.save("/tmp/mnist/ckpt", {'params': state.params})


In [15]:
!ls /tmp/mnist/ckpt

## Start using Orbax Export!

In [16]:
from orbax.export import constants
from orbax.export import jax_module
from orbax.export import export_manager
from orbax.export import serving_config as osc

In [17]:
# Define the spec of the parameters.
params_arg_spec = jax.tree_util.tree_map(
    lambda p: jax.ShapeDtypeStruct(p.shape, p.dtype), {'params':state.params})


In [18]:
# Create a JAX Module.
model_function_name = 'mnist_forward_fn'

orbax_module = jax_module.JaxModule(
    params=params_arg_spec,
    apply_fn={model_function_name: model_apply_fn},
    export_version=constants.ExportModelType.ORBAX_MODEL, # Note this is a new version option.
    jax2obm_kwargs={
        constants.CHECKPOINT_PATH: os.fspath("ckpt"),
        },
    )

## Define pre- and post-processing functions.

In [19]:
# Define TF pre- and post-processing functions for serving.
import tensorflow as tf

# The data preprocessing function for resizing images
def process_image(x: tf.Tensor) -> tf.Tensor:
  # x is a uint8 tensor of shape (b, length, width, 3).
  gray_image = tf.image.rgb_to_grayscale(x)
  resized_image = tf.image.resize(gray_image, [28, 28]) / 255.0
  return resized_image



In [20]:
# Define the spec of input arg (e.g., for the TF Preprocessing function).
input_args_spec = [tf.TensorSpec((100, None, None, 3), tf.uint8)] # Set the batch size to 100.


In [21]:
# The post-processing function for selecting the most probable class.
def select_digit(x: tf.Tensor) -> tf.Tensor:
  return tf.math.argmax(x, axis=1)

## Export the model.


In [22]:
# Define the Orbax Export serving config.
serving_config =osc.ServingConfig(
    signature_key=model_function_name,
    input_signature=input_args_spec,
    tf_preprocessor=process_image,
    tf_postprocessor=select_digit)

In [23]:
# Create the Orbax Export Manager.
em = export_manager.ExportManager(
    module = orbax_module,
    serving_configs = [serving_config])

In [24]:
# Write model to disk.
em.save(os.fspath("/tmp/mnist"))

In [25]:
!ls /tmp/mnist

## Load the Model using Orbax Model Runner
Note: the runner has a Python API but is in C++ under-the-hood.

In [26]:
# Load the model using Orbax Model Runner (new!)
from .learning.infra.mira.experimental.orbax_model.python import orbax_model_runner

In [27]:
# Runner is in C++ with a Python API.
runner = orbax_model_runner.ModelRunner(model_path="/tmp/mnist")
runner

In [28]:
jax.devices()

In [29]:
print(f"Stable HLO will run on {runner.ifrt_platform_name()} platform with {runner.ifrt_device_count()} devices.")

## Run inference(Uncomment all the code below for interactive drawing and model prediction)

In [30]:
# # Create prediction data.

# from IPython.display import HTML, display
# from google.colab.output import eval_js
# from base64 import b64decode


# canvas_html = """
# <canvas width=%d height=%d></canvas>
# <button>Finish</button>
# <script>
# var canvas = document.querySelector('canvas')
# var ctx = canvas.getContext('2d')
# ctx.lineWidth = %d
# ctx.strokeStyle = 'rgba(255,255,0,1)';
# ctx.fillStyle = 'rgba(0,0,0,0.9)';
# ctx.fillRect(0, 0, ctx.canvas.width, ctx.canvas.height)
# var button = document.querySelector('button')
# var mouse = {x: 0, y: 0}
# canvas.addEventListener('mousemove', function(e) {
#   mouse.x = e.pageX - this.offsetLeft
#   mouse.y = e.pageY - this.offsetTop
# })
# canvas.onmousedown = ()=>{
#   ctx.beginPath()
#   ctx.moveTo(mouse.x, mouse.y)
#   canvas.addEventListener('mousemove', onPaint)
# }
# canvas.onmouseup = ()=>{
#   canvas.removeEventListener('mousemove', onPaint)
# }
# var onPaint = ()=>{
#   ctx.lineTo(mouse.x, mouse.y)
#   ctx.stroke()
# }
# var data = new Promise(resolve=>{
#   button.onclick = ()=>{
#     resolve(canvas.toDataURL('image/png'))
#   }
# })
# </script>
# """

# def draw(w=300, h=300, line_width=20):
#   display(HTML(canvas_html % (w, h, line_width)))
#   data = eval_js("data")
#   binary = b64decode(data.split(',')[1])
#   return binary

# image_data = draw()


In [31]:
# image_data[:100]

In [32]:
# original_image = tf.image.decode_image(image_data, 3)
# original_image

In [33]:
# # Run the model
# def create_model_inputs(image, batch_size):
#   inputs = [np.asarray(image) for _ in range(batch_size)]
#   batched_inputs = np.stack(inputs, axis=0)
#   return batched_inputs

# predicted_labels = runner.run(create_model_inputs(original_image, batch_size=100))

In [34]:
# # Plot result!

# from matplotlib import pyplot as plt
# import numpy as np
# from PIL import Image

# def plot_image(image, title=""):
#   """Plots images from image tensors.

#   Args:
#     image: 3D image tensor. [height, width, channels].
#     title: Title to display in the plot.
#   """
#   image = np.asarray(image)
#   image = tf.clip_by_value(image, 0, 255)
#   image = Image.fromarray(tf.cast(image, tf.uint8).numpy())
#   plt.imshow(image)
#   plt.axis("off")
#   plt.title(title)

# plot_image(original_image, title=f"predicted_label={predicted_labels[0]}")