runs on TF2.5 [Vertex Notebook](https://cloud.google.com/vertex-ai/docs/general/notebooks)

In [None]:
%%bash
pip3 install --user --upgrade jax jaxlib==0.1.67+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip3 install --user --upgrade flax

specific installation with `-f` flag comes from [official repo install instructions](https://github.com/google/jax#installation)

# Train JAX/Flax model locally and use `jax2tf` to convert to SavedModel

In [2]:
import logging
import tensorflow as tf
import tensorflow_datasets as tfds

from absl import flags
from jax.experimental.jax2tf.examples.mnist_lib import load_mnist, FlaxMNIST
from jax.experimental.jax2tf.examples.saved_model_lib import convert_and_save_model

In [3]:
flags.FLAGS(['e']) # need to initialize flags somehow to avoid errors in load_mnist

train_batch_size = 128
test_batch_size = 16

BUCKET_NAME = "dsparing-sandbox-bucket"
# Use a regional bucket you have rights to.
# Create if needed:
# REGION=us-central1
# !gsutil mb -l ${REGION} gs://${BUCKET}

MODEL_NAME = "jax_model"
MODEL_DIR = f"gs://{BUCKET_NAME}/models/{MODEL_NAME}"

flax_mnist = FlaxMNIST()

In [4]:
train_ds = load_mnist(tfds.Split.TRAIN, train_batch_size)
test_ds = load_mnist(tfds.Split.TEST, test_batch_size)

In [5]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

predict_fn, params = flax_mnist.train(
    train_ds=train_ds,
    test_ds=test_ds,
    num_epochs=2
)

INFO:absl:Starting the local TPU driver.
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
INFO:root:mnist_flax: Epoch 0 in 6.14 sec
INFO:root:mnist_flax: Training set accuracy 88.43%
INFO:root:mnist_flax: Test set accuracy 89.03%
INFO:root:mnist_flax: Epoch 1 in 1.11 sec
INFO:root:mnist_flax: Training set accuracy 91.01%
INFO:root:mnist_flax: Test set accuracy 91.37%


In [6]:
image, _ = next(iter(train_ds))
input_signature = tf.TensorSpec.from_tensor(tf.expand_dims(image[0], axis=0))

In [7]:
convert_and_save_model(
    jax_fn=predict_fn,
    params=params,
    model_dir=MODEL_DIR,
    input_signatures=[input_signature],
)





INFO:tensorflow:Assets written to: gs://dsparing-sandbox-bucket/models/jax_model/assets


INFO:tensorflow:Assets written to: gs://dsparing-sandbox-bucket/models/jax_model/assets


### Make sure we can actually predict with both predict_fn and savedmodel

In [8]:
image_to_predict, _ = next(iter(load_mnist(tfds.Split.TEST, 1)))

In [9]:
predict_fn(params, image_to_predict)

DeviceArray([[ -9.073951  ,  -2.959507  ,  -0.18073097,  -2.9901378 ,
              -12.367456  ,  -5.4054875 ,  -6.7608185 ,  -9.742961  ,
               -2.8584464 , -10.4303465 ]], dtype=float32)

In [10]:
loaded_model = tf.saved_model.load(MODEL_DIR)
loaded_model.signatures["serving_default"](image_to_predict)

{'output_0': <tf.Tensor: shape=(1, 10), dtype=float32, numpy=
 array([[ -9.073951  ,  -2.959507  ,  -0.18073097,  -2.9901378 ,
         -12.367456  ,  -5.4054875 ,  -6.7608185 ,  -9.742961  ,
          -2.8584464 , -10.4303465 ]], dtype=float32)>}