In [1]:
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Train JAX/Flax model locally and use `jax2tf` to convert to SavedModel

In [2]:
import logging
import os

import tensorflow as tf
import tensorflow_datasets as tfds
from absl import flags
from jax.experimental.jax2tf.examples import mnist_lib
from jax.experimental.jax2tf.examples import saved_model_lib

In [3]:
TRAIN_BATCH_SIZE = 128
TEST_BATCH_SIZE = 16
NUM_EPOCHS = 2

PROJECT_ID = !(gcloud config get-value project)
PROJECT_ID = PROJECT_ID[0]

REGION = "us-central1"

BUCKET_NAME = PROJECT_ID
# Use a regional bucket in the above region you have rights to.
# Create if needed:
# !gsutil mb -l $REGION gs://$BUCKET_NAME

BASE_OUTPUT_DIR = f"gs://{BUCKET_NAME}"
MODEL_NAME = "jax_model_local"
MODEL_VERSION = 1

# Block TF from the GPU to let JAX use it all
tf.config.set_visible_devices([], 'GPU')

logger = logging.getLogger()

# need to initialize flags somehow to avoid errors in load_mnist
flags.FLAGS([''])

flax_mnist = mnist_lib.FlaxMNIST()

In [4]:
train_ds = mnist_lib.load_mnist(tfds.Split.TRAIN, TRAIN_BATCH_SIZE)
test_ds = mnist_lib.load_mnist(tfds.Split.TEST, TEST_BATCH_SIZE)

In [5]:
image, _ = next(iter(train_ds))
input_signature = tf.TensorSpec.from_tensor(tf.expand_dims(image[0], axis=0))

In [6]:
logger_level = logger.level
logger.setLevel(logging.INFO)
predict_fn, params = flax_mnist.train(
    train_ds=train_ds,
    test_ds=test_ds,
    num_epochs=NUM_EPOCHS,
)
logger.setLevel(logger_level)

INFO:absl:Starting the local TPU driver.
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
INFO:root:mnist_flax: Epoch 0 in 5.23 sec
INFO:root:mnist_flax: Training set accuracy 88.51%
INFO:root:mnist_flax: Test set accuracy 89.00%
INFO:root:mnist_flax: Epoch 1 in 1.09 sec
INFO:root:mnist_flax: Training set accuracy 90.91%
INFO:root:mnist_flax: Test set accuracy 91.36%


In [7]:
saved_model_lib.convert_and_save_model(
    jax_fn=predict_fn,
    params=params,
    model_dir=os.path.join(
        BASE_OUTPUT_DIR, "model", MODEL_NAME, MODEL_VERSION
    ),
    input_signatures=[input_signature],
)





INFO:tensorflow:Assets written to: gs://dsparing-sandbox/model/jax_model_local/1/assets


INFO:tensorflow:Assets written to: gs://dsparing-sandbox/model/jax_model_local/1/assets


## Test Prediction with both predict_fn and savedmodel

In [8]:
image_to_predict, _ = next(
    iter(mnist_lib.load_mnist(tfds.Split.TEST, batch_size=1))
)

In [9]:
predict_fn(params, image_to_predict)

DeviceArray([[-6.58013596e-05, -3.16075134e+01, -1.48247795e+01,
              -1.14919910e+01, -1.93614635e+01, -9.88178444e+00,
              -1.52083292e+01, -1.50044022e+01, -1.30740175e+01,
              -1.35713797e+01]], dtype=float32)

In [10]:
!gsutil ls -l $BASE_OUTPUT_DIR/model/$MODEL_NAME/$MODEL_VERSION

         0  2021-06-29T20:39:31Z  gs://dsparing-sandbox/model/jax_model_local/1/
     53991  2021-06-29T22:59:15Z  gs://dsparing-sandbox/model/jax_model_local/1/saved_model.pb
                                 gs://dsparing-sandbox/model/jax_model_local/1/assets/
                                 gs://dsparing-sandbox/model/jax_model_local/1/variables/
TOTAL: 2 objects, 53991 bytes (52.73 KiB)


In [11]:
!saved_model_cli show \
    --dir $BASE_OUTPUT_DIR/model/$MODEL_NAME/$MODEL_VERSION \
    --all


MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is: 

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['inputs'] tensor_info:
        dtype: DT_FLOAT
        shape: (1, 28, 28, 1)
        name: serving_default_inputs:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['output_0'] tensor_info:
        dtype: DT_FLOAT
        shape: (1, 10)
        name: StatefulPartitionedCall:0
  Method name is: tensorflow/serving/predict

Defined Functions:
  Function Name: '__call__'
    Option #1
      Callable with:
        Argument #1
          inputs: TensorSpec(sh

In [12]:
loaded_model = tf.saved_model.load(f"{ARTIFACT_URI}/{MODEL_VERSION}")
loaded_model.signatures["serving_default"](image_to_predict)

{'output_0': <tf.Tensor: shape=(1, 10), dtype=float32, numpy=
 array([[-6.5562963e-05, -3.1607523e+01, -1.4824780e+01, -1.1491998e+01,
         -1.9361465e+01, -9.8817873e+00, -1.5208329e+01, -1.5004404e+01,
         -1.3074018e+01, -1.3571379e+01]], dtype=float32)>}