In [None]:
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Train JAX/Flax model locally and use `jax2tf` to convert to SavedModel

In [1]:
import logging

import tensorflow as tf
import tensorflow_datasets as tfds
from absl import flags
from jax.experimental.jax2tf.examples.mnist_lib import load_mnist, FlaxMNIST
from jax.experimental.jax2tf.examples.saved_model_lib import (
    convert_and_save_model
)

In [2]:
TRAIN_BATCH_SIZE = 128
TEST_BATCH_SIZE = 16
NUM_EPOCHS = 2

PROJECT_ID = !(gcloud config get-value core/project)
PROJECT_ID = PROJECT_ID[0]

REGION = "us-central1"
os.environ['REGION'] = REGION

BUCKET_NAME = PROJECT_ID
os.environ['BUCKET_NAME'] = BUCKET_NAME
# Use a regional bucket in the above region you have rights to.
# Create if needed:
# !gsutil mb -l ${REGION} gs://${BUCKET_NAME}

MODEL_NAME = "jax_model_local"
SAVEDMODEL_DIR = f"gs://{BUCKET_NAME}/models/{MODEL_NAME}/output"

# Block TF from the GPU to let JAX use it all
tf.config.set_visible_devices([], 'GPU')

logger = logging.getLogger()

# need to initialize flags somehow to avoid errors in load_mnist
flags.FLAGS(['e'])

flax_mnist = FlaxMNIST()

In [3]:
train_ds = load_mnist(tfds.Split.TRAIN, TRAIN_BATCH_SIZE)
test_ds = load_mnist(tfds.Split.TEST, TEST_BATCH_SIZE)

In [4]:
image, _ = next(iter(train_ds))
input_signature = tf.TensorSpec.from_tensor(
    tf.expand_dims(image[0], axis=0)
)

In [5]:
logger.setLevel(logging.INFO)
predict_fn, params = flax_mnist.train(
    train_ds=train_ds,
    test_ds=test_ds,
    num_epochs=NUM_EPOCHS,
)
logger.setLevel(logging.NOTSET)

INFO:absl:Starting the local TPU driver.
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
INFO:root:mnist_flax: Epoch 0 in 5.15 sec
INFO:root:mnist_flax: Training set accuracy 88.47%
INFO:root:mnist_flax: Test set accuracy 89.08%
INFO:root:mnist_flax: Epoch 1 in 1.10 sec
INFO:root:mnist_flax: Training set accuracy 90.76%
INFO:root:mnist_flax: Test set accuracy 91.27%


In [6]:
convert_and_save_model(
    jax_fn=predict_fn,
    params=params,
    model_dir=SAVEDMODEL_DIR,
    input_signatures=[input_signature],
)





INFO:tensorflow:Assets written to: gs://dsparing-sandbox-bucket/models/jax_model_local/output/assets


INFO:tensorflow:Assets written to: gs://dsparing-sandbox-bucket/models/jax_model_local/output/assets


## Test Prediction with both predict_fn and savedmodel

In [7]:
image_to_predict, _ = next(iter(load_mnist(tfds.Split.TEST, batch_size=1)))

In [8]:
predict_fn(params, image_to_predict)

DeviceArray([[-7.755705  , -3.88061   , -4.1301394 , -5.72552   ,
              -5.2567806 , -5.15688   , -0.06752199, -8.690741  ,
              -4.3864617 , -6.6458564 ]], dtype=float32)

In [9]:
loaded_model = tf.saved_model.load(SAVEDMODEL_DIR)
loaded_model.signatures["serving_default"](image_to_predict)

{'output_0': <tf.Tensor: shape=(1, 10), dtype=float32, numpy=
 array([[-7.755705  , -3.8806124 , -4.130141  , -5.7255197 , -5.2567816 ,
         -5.1568775 , -0.06752188, -8.690744  , -4.3864636 , -6.64586   ]],
       dtype=float32)>}