# Predict with JAX/Flax-trained SavedModel on Vertex AI custom container with TF Serving

Vertex AI Prediction supports pre-built containers, with no additional customization such as args ("Do not specify any other subfields of containerSpec" [source](https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers#using_a_pre-built_container)), and custom containers. As we need the `--xla_cpu_compilation_enabled` arg, we can only use custom containers.

In [1]:
import json
import os

import requests
import tensorflow_datasets as tfds
from absl import flags
from google.cloud import aiplatform
from jax.experimental.jax2tf.examples.mnist_lib import load_mnist

## Create TFServing container with SavedModel baked in

NOTE: serving does not work for GPU yet.

In [2]:
PROJECT_ID = !(gcloud config get-value core/project)
PROJECT_ID = PROJECT_ID[0]
SAVEDMODEL_DIR = ("gs://dsparing-sandbox-bucket/models/"
                  "jax_model_prebuilt/output/model")
SERVING_MODELNAME = "jax_model"
MODEL_VERSION = 1

USE_GPU = False
if USE_GPU:
    IMAGE_TAG = "latest-gpu"
    TFSERVING_TAG = "latest-gpu"
else:
    IMAGE_TAG = "latest-cpu"
    TFSERVING_TAG = "latest"

IMAGE_NAME = "jax_tfserving_image"
IMAGE_URI = f"gcr.io/{PROJECT_ID}/{IMAGE_NAME}:{IMAGE_TAG}"
print(IMAGE_URI)

os.environ["SERVING_MODELNAME"] = SERVING_MODELNAME
os.environ["SAVEDMODEL_DIR"] = SAVEDMODEL_DIR
os.environ["MODEL_VERSION"] = str(MODEL_VERSION)
os.environ["IMAGE_URI"] = IMAGE_URI
os.environ["IMAGE_TAG"] = IMAGE_TAG
os.environ["TFSERVING_TAG"] = TFSERVING_TAG


gcr.io/dsparing-sandbox/jax_tfserving_image:latest-cpu


Check that `SAVEDMODEL_DIR` actually contains a SavedModel

In [3]:
%%bash
gsutil ls $SAVEDMODEL_DIR

gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/
gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/saved_model.pb
gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/assets/
gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/variables/


In [4]:
%%bash
mkdir -p $SERVING_MODELNAME/$MODEL_VERSION # tf serving wants model versions in numbered directories.
gsutil -q cp -r $SAVEDMODEL_DIR/* $SERVING_MODELNAME/$MODEL_VERSION/

following [TF Serving tutorial](https://www.tensorflow.org/tfx/serving/docker#creating_your_own_serving_image) below:

In [5]:
%%bash
docker run -d --name serving_base tensorflow/serving:$TFSERVING_TAG

4d92b24d393e3559bdf3fcb669145ed049430837892f6562100cbcd5c5bef563


In [6]:
%%bash
docker cp $SERVING_MODELNAME serving_base:/models/$SERVING_MODELNAME
docker commit --change "ENV MODEL_NAME $SERVING_MODELNAME" serving_base $IMAGE_URI

sha256:65a3cabbd278f35a19f0be65447ee808fbc23eeb996f62aa7e00818009a61245


In [7]:
%%bash
docker rm -f serving_base

serving_base


## Try serving locally

Get image to predict

In [8]:
# need to initialize flags somehow to avoid errors in load_mnist
flags.FLAGS(['e'])

image_to_predict, _ = next(iter(load_mnist(tfds.Split.TEST, batch_size=1)))
instances = image_to_predict.numpy().tolist()
image_json = json.dumps(instances)

INFO:absl:Load dataset info from /home/jupyter/tensorflow_datasets/mnist/3.0.1
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.splits from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.module_name from disk and from code do not match. Keeping the one from code.
INFO:absl:Reusing dataset mnist (/home/jupyter/tensorflow_datasets/mnist/3.0.1)
INFO:absl:Constructing tf.data.Dataset mnist for split test, from /home/jupyter/tensorflow_datasets/mnist/3.0.1


Start up container

In [9]:
%%bash
docker run -d -p 8501:8501 -e MODEL_NAME=$SERVING_MODELNAME --name serving_jax $IMAGE_URI --xla_cpu_compilation_enabled=true

d580fb0664eecbd54ecc4e23282b51940f3df052634d03d80bf1ca16484274a5


In [10]:
%%bash
sleep 2
docker logs serving_jax

2021-06-26 20:23:36.608829: I tensorflow_serving/model_servers/server.cc:89] Building single TensorFlow model file config:  model_name: jax_model model_base_path: /models/jax_model
2021-06-26 20:23:36.609255: I tensorflow_serving/model_servers/server_core.cc:465] Adding/updating models.
2021-06-26 20:23:36.609286: I tensorflow_serving/model_servers/server_core.cc:591]  (Re-)adding model: jax_model
2021-06-26 20:23:36.709703: I tensorflow_serving/core/basic_manager.cc:740] Successfully reserved resources to load servable {name: jax_model version: 1}
2021-06-26 20:23:36.709743: I tensorflow_serving/core/loader_harness.cc:66] Approving load for servable version {name: jax_model version: 1}
2021-06-26 20:23:36.709757: I tensorflow_serving/core/loader_harness.cc:74] Loading servable version {name: jax_model version: 1}
2021-06-26 20:23:36.709812: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:38] Reading SavedModel from: /models/jax_model/1
2021-06-26 20:23:36.711456: I exter

Send prediction

In [11]:
data = json.dumps({"instances": image_to_predict.numpy().tolist()})
json_response = requests.post(
    f'http://localhost:8501/v1/models/{SERVING_MODELNAME}:predict',
    data=data,
)
predictions = json.loads(json_response.text)['predictions']
print(predictions)

[[-12.7403393, -3.94832897, -6.22885, -0.0298643112, -13.4114065, -7.25624228, -17.2837029, -8.350564, -5.00295305, -7.60113287]]


In [12]:
%%bash
docker rm -f serving_jax

serving_jax


## Push image to Container Registry

In [13]:
%%bash
docker push $IMAGE_URI

The push refers to repository [gcr.io/dsparing-sandbox/jax_tfserving_image]
33caea58d11f: Preparing
bb4423850a27: Preparing
b60ba33781cd: Preparing
547f89523b17: Preparing
bd91f28d5f3c: Preparing
8cafc6d2db45: Preparing
a5d4bacb0351: Preparing
5153e1acaabc: Preparing
8cafc6d2db45: Waiting
a5d4bacb0351: Waiting
5153e1acaabc: Waiting
bd91f28d5f3c: Layer already exists
b60ba33781cd: Layer already exists
bb4423850a27: Layer already exists
547f89523b17: Layer already exists
5153e1acaabc: Layer already exists
8cafc6d2db45: Layer already exists
a5d4bacb0351: Layer already exists
33caea58d11f: Pushed
latest-cpu: digest: sha256:53cf8d3dbf24827136bba0b077edbf46ba0591d28c5f690b7c60bc43d051466b size: 1991


## Upload prediction container to Vertex AI

In [14]:
REGION = "us-central1"

MACHINE_TYPE = "n1-standard-2"

if USE_GPU:
    MODEL_DISPLAYNAME = f"{SERVING_MODELNAME}-gpu"
    ACCELERATOR_TYPE = "NVIDIA_TESLA_T4"
    ACCELERATOR_COUNT = 1
else:
    MODEL_DISPLAYNAME = f"{SERVING_MODELNAME}-cpu"
    ACCELERATOR_TYPE = None
    ACCELERATOR_COUNT = None

In [15]:
aiplatform.init(project=PROJECT_ID, location=REGION)

In [16]:
model = aiplatform.Model.upload(
    display_name=MODEL_DISPLAYNAME,
    serving_container_image_uri=IMAGE_URI,
    serving_container_predict_route=f"/v1/models/{SERVING_MODELNAME}:predict",
    serving_container_health_route=f"/v1/models/{SERVING_MODELNAME}",
    serving_container_args=['--xla_cpu_compilation_enabled=true'],
    serving_container_ports=[8501],
)

print(model.display_name)
print(model.resource_name)

INFO:google.cloud.aiplatform.models:Creating Model
INFO:google.cloud.aiplatform.models:Create Model backing LRO: projects/654544512569/locations/us-central1/models/5315522993785405440/operations/8645506383568961536
INFO:google.cloud.aiplatform.models:Model created. Resource name: projects/654544512569/locations/us-central1/models/5315522993785405440
INFO:google.cloud.aiplatform.models:To use this Model in another session:
INFO:google.cloud.aiplatform.models:model = aiplatform.Model('projects/654544512569/locations/us-central1/models/5315522993785405440')
jax_model-cpu
projects/654544512569/locations/us-central1/models/5315522993785405440


In [17]:
endpoint = model.deploy(
    machine_type=MACHINE_TYPE,
    accelerator_type=ACCELERATOR_TYPE,
    accelerator_count=ACCELERATOR_COUNT,
)

INFO:google.cloud.aiplatform.models:Creating Endpoint
INFO:google.cloud.aiplatform.models:Create Endpoint backing LRO: projects/654544512569/locations/us-central1/endpoints/2382175504460414976/operations/7073750113616658432
INFO:google.cloud.aiplatform.models:Endpoint created. Resource name: projects/654544512569/locations/us-central1/endpoints/2382175504460414976
INFO:google.cloud.aiplatform.models:To use this Endpoint in another session:
INFO:google.cloud.aiplatform.models:endpoint = aiplatform.Endpoint('projects/654544512569/locations/us-central1/endpoints/2382175504460414976')
INFO:google.cloud.aiplatform.models:Deploying model to Endpoint : projects/654544512569/locations/us-central1/endpoints/2382175504460414976
INFO:google.cloud.aiplatform.models:Deploy Endpoint model backing LRO: projects/654544512569/locations/us-central1/endpoints/2382175504460414976/operations/5712537126243926016
INFO:google.cloud.aiplatform.models:Endpoint model deployed. Resource name: projects/65454451256

## Get Online Predictions from Vertex AI

In [19]:
prediction = endpoint.predict(instances)
print(prediction.predictions)

[[-18.9266491, -21.4380112, -11.6358833, -1.71661377e-05, -25.4557629, -12.2982159, -26.5184898, -27.9774284, -12.2741051, -19.7286148]]
