# Train JAX/Flax model locally and use `jax2tf` to convert to SavedModel

runs on TF2.5 [Vertex Notebook](https://cloud.google.com/vertex-ai/docs/general/notebooks)

specific installation with `-f` flag comes from [official repo install instructions](https://github.com/google/jax#installation)

In [1]:
import logging
import tensorflow as tf
import tensorflow_datasets as tfds

from absl import flags
from jax.experimental.jax2tf.examples.mnist_lib import load_mnist, FlaxMNIST
from jax.experimental.jax2tf.examples.saved_model_lib import (
    convert_and_save_model
)

In [2]:
logger = logging.getLogger()

# need to initialize flags somehow to avoid errors in load_mnist
flags.FLAGS(['e'])

train_batch_size = 128
test_batch_size = 16

BUCKET_NAME = "dsparing-sandbox-bucket"
# Use a regional bucket you have rights to.
# Create if needed:
# REGION=us-central1
# !gsutil mb -l ${REGION} gs://${BUCKET}

MODEL_NAME = "jax_model"
MODEL_DIR = f"gs://{BUCKET_NAME}/models/{MODEL_NAME}"

flax_mnist = FlaxMNIST()

In [3]:
train_ds = load_mnist(tfds.Split.TRAIN, train_batch_size)
test_ds = load_mnist(tfds.Split.TEST, test_batch_size)

In [4]:
logger.setLevel(logging.INFO)

predict_fn, params = flax_mnist.train(
    train_ds=train_ds,
    test_ds=test_ds,
    num_epochs=2
)

logger.setLevel(logging.NOTSET)

INFO:absl:Starting the local TPU driver.
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
INFO:root:mnist_flax: Epoch 0 in 5.09 sec
INFO:root:mnist_flax: Training set accuracy 88.48%
INFO:root:mnist_flax: Test set accuracy 89.25%
INFO:root:mnist_flax: Epoch 1 in 1.08 sec
INFO:root:mnist_flax: Training set accuracy 91.01%
INFO:root:mnist_flax: Test set accuracy 91.54%


In [5]:
image, _ = next(iter(train_ds))
input_signature = tf.TensorSpec.from_tensor(tf.expand_dims(image[0], axis=0))

In [6]:
convert_and_save_model(
    jax_fn=predict_fn,
    params=params,
    model_dir=MODEL_DIR,
    input_signatures=[input_signature],
)





INFO:tensorflow:Assets written to: gs://dsparing-sandbox-bucket/models/jax_model/assets


INFO:tensorflow:Assets written to: gs://dsparing-sandbox-bucket/models/jax_model/assets


### Make sure we can actually predict with both predict_fn and savedmodel

In [7]:
image_to_predict, _ = next(iter(load_mnist(tfds.Split.TEST, 1)))

In [8]:
predict_fn(params, image_to_predict)

DeviceArray([[ -6.6423545 , -12.732087  ,  -5.5349283 ,  -7.608638  ,
               -4.5714254 ,  -2.2557116 ,  -2.1221883 , -12.284796  ,
               -0.27703872,  -6.638275  ]], dtype=float32)

In [9]:
loaded_model = tf.saved_model.load(MODEL_DIR)
loaded_model.signatures["serving_default"](image_to_predict)

{'output_0': <tf.Tensor: shape=(1, 10), dtype=float32, numpy=
 array([[ -6.6423545 , -12.732087  ,  -5.5349283 ,  -7.608638  ,
          -4.5714254 ,  -2.2557116 ,  -2.1221883 , -12.284796  ,
          -0.27703872,  -6.638275  ]], dtype=float32)>}