# Train JAX/Flax model locally and use `jax2tf` to convert to SavedModel

In [1]:
import logging
import tensorflow as tf
import tensorflow_datasets as tfds

from absl import flags
from jax.experimental.jax2tf.examples.mnist_lib import load_mnist, FlaxMNIST
from jax.experimental.jax2tf.examples.saved_model_lib import (
    convert_and_save_model
)

In [2]:
logger = logging.getLogger()

# need to initialize flags somehow to avoid errors in load_mnist
flags.FLAGS(['e'])

train_batch_size = 128
test_batch_size = 16

BUCKET_NAME = "dsparing-sandbox-bucket"
# Use a regional bucket you have rights to.
# Create if needed:
# REGION=us-central1
# !gsutil mb -l ${REGION} gs://${BUCKET}

MODEL_NAME = "jax_model_local"
SAVEDMODEL_DIR = f"gs://{BUCKET_NAME}/models/{MODEL_NAME}/output"

flax_mnist = FlaxMNIST()

In [3]:
train_ds = load_mnist(tfds.Split.TRAIN, train_batch_size)
test_ds = load_mnist(tfds.Split.TEST, test_batch_size)

In [4]:
logger.setLevel(logging.INFO)

predict_fn, params = flax_mnist.train(
    train_ds=train_ds,
    test_ds=test_ds,
    num_epochs=10,
)

logger.setLevel(logging.NOTSET)

INFO:absl:Starting the local TPU driver.
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
INFO:root:mnist_flax: Epoch 0 in 5.34 sec
INFO:root:mnist_flax: Training set accuracy 88.38%
INFO:root:mnist_flax: Test set accuracy 88.95%
INFO:root:mnist_flax: Epoch 1 in 1.09 sec
INFO:root:mnist_flax: Training set accuracy 90.72%
INFO:root:mnist_flax: Test set accuracy 91.27%
INFO:root:mnist_flax: Epoch 2 in 1.05 sec
INFO:root:mnist_flax: Training set accuracy 92.14%
INFO:root:mnist_flax: Test set accuracy 92.68%
INFO:root:mnist_flax: Epoch 3 in 1.05 sec
INFO:root:mnist_flax: Training set accuracy 93.37%
INFO:root:mnist_flax: Test set accuracy 93.66%
INFO:root:mnist_flax: Epoch 4 in 1.07 sec
INFO:root:mnist_flax: Training set accuracy 94.10%
INFO:root:mnist_flax: Test set accuracy 94.28%
INFO:root:mnist_flax: Epoch 5 in 1.05 sec
IN

In [5]:
image, _ = next(iter(train_ds))
input_signature = tf.TensorSpec.from_tensor(tf.expand_dims(image[0], axis=0))

In [6]:
convert_and_save_model(
    jax_fn=predict_fn,
    params=params,
    model_dir=SAVEDMODEL_DIR,
    input_signatures=[input_signature],
)





INFO:tensorflow:Assets written to: gs://dsparing-sandbox-bucket/models/jax_model_local/output/assets


INFO:tensorflow:Assets written to: gs://dsparing-sandbox-bucket/models/jax_model_local/output/assets


### Make sure we can actually predict with both predict_fn and savedmodel

In [7]:
image_to_predict, _ = next(iter(load_mnist(tfds.Split.TEST, 1)))

In [8]:
predict_fn(params, image_to_predict)

DeviceArray([[-9.8502626e+00, -2.1723980e+01, -1.0447656e+01,
              -1.3074039e+01, -1.1490358e-02, -5.5715027e+00,
              -1.0479521e+01, -7.7336907e+00, -8.7529469e+00,
              -4.9744759e+00]], dtype=float32)

In [9]:
loaded_model = tf.saved_model.load(SAVEDMODEL_DIR)
loaded_model.signatures["serving_default"](image_to_predict)

{'output_0': <tf.Tensor: shape=(1, 10), dtype=float32, numpy=
 array([[-9.8502626e+00, -2.1723980e+01, -1.0447656e+01, -1.3074039e+01,
         -1.1490475e-02, -5.5715027e+00, -1.0479521e+01, -7.7336907e+00,
         -8.7529469e+00, -4.9744759e+00]], dtype=float32)>}