# Predict with JAX/Flax-trained SavedModel on Vertex AI custom container with TF Serving

Vertex AI Prediction supports pre-built containers, with no additional customization such as args ("Do not specify any other subfields of containerSpec" [source](https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers#using_a_pre-built_container)), and custom containers. As we need the `--xla_cpu_compilation_enabled` arg, we can only use custom containers.

### Create TFServing container with SavedModel baked in

NOTE: serving does not work for GPU yet.

In [1]:
import os

PROJECT_ID = !(gcloud config get-value core/project)
PROJECT_ID = PROJECT_ID[0]
SAVEDMODEL_DIR = "gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model"
SERVING_MODELNAME = "jax_model"
MODEL_VERSION = 1

USE_GPU = False
if USE_GPU:
    IMAGE_TAG='latest-gpu'
    TFSERVING_TAG = "latest-gpu"
else:
    IMAGE_TAG='latest-cpu'
    TFSERVING_TAG = "latest"    

IMAGE_NAME='jax_tfserving_image'
IMAGE_URI = f'gcr.io/{PROJECT_ID}/{IMAGE_NAME}:{IMAGE_TAG}'
print(IMAGE_URI)

os.environ["SERVING_MODELNAME"] = SERVING_MODELNAME
os.environ["SAVEDMODEL_DIR"] = SAVEDMODEL_DIR
os.environ["MODEL_VERSION"] = str(MODEL_VERSION)
os.environ["IMAGE_URI"] = IMAGE_URI
os.environ["IMAGE_TAG"] = IMAGE_TAG
os.environ["TFSERVING_TAG"] = TFSERVING_TAG


gcr.io/dsparing-sandbox/jax_tfserving_image:latest-cpu


Check that `SAVEDMODEL_DIR` actually contains a SavedModel

In [2]:
%%bash
gsutil ls $SAVEDMODEL_DIR

gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/
gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/saved_model.pb
gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/assets/
gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/variables/


In [3]:
%%bash
mkdir -p $SERVING_MODELNAME/$MODEL_VERSION # tf serving wants model versions in numbered directories.
gsutil -q cp -r $SAVEDMODEL_DIR/* $SERVING_MODELNAME/$MODEL_VERSION/

following [TF Serving tutorial](https://www.tensorflow.org/tfx/serving/docker#creating_your_own_serving_image) below:

In [4]:
%%bash
docker run -d --name serving_base tensorflow/serving:$TFSERVING_TAG

0c7bdc7de8fe1aa98b9de311be141164cd623db4fbd5eb17208d80e55e9adbe3


In [5]:
%%bash
docker cp $SERVING_MODELNAME serving_base:/models/$SERVING_MODELNAME
docker commit --change "ENV MODEL_NAME $SERVING_MODELNAME" serving_base $IMAGE_URI

sha256:879ed17c03e9116d7925befb890bec81d7e0ca57cf17fd531aafa0e2d859d4a2


In [6]:
%%bash
docker rm -f serving_base

serving_base


### Optional: Try serving locally

Get image to predict 

In [7]:
import json
import requests
import tensorflow_datasets as tfds

from absl import flags
from jax.experimental.jax2tf.examples.mnist_lib import load_mnist

In [8]:
flags.FLAGS(['e']) # need to initialize flags somehow to avoid errors in load_mnist

image_to_predict, _ = next(iter(load_mnist(tfds.Split.TEST, 1)))
image_json = json.dumps(image_to_predict.numpy().tolist())

Start up container

In [9]:
%%bash
docker run -d -p 8501:8501 -e MODEL_NAME=$SERVING_MODELNAME --name serving_jax $IMAGE_URI --xla_cpu_compilation_enabled=true

62e139c00ab0ab36d54dedf068cc0b180d86697957ccb40466d726a0ab6aaac1


In [10]:
%%bash
sleep 2
docker logs serving_jax

2021-06-26 18:36:03.355282: I tensorflow_serving/model_servers/server.cc:89] Building single TensorFlow model file config:  model_name: jax_model model_base_path: /models/jax_model
2021-06-26 18:36:03.355546: I tensorflow_serving/model_servers/server_core.cc:465] Adding/updating models.
2021-06-26 18:36:03.355576: I tensorflow_serving/model_servers/server_core.cc:591]  (Re-)adding model: jax_model
2021-06-26 18:36:03.455981: I tensorflow_serving/core/basic_manager.cc:740] Successfully reserved resources to load servable {name: jax_model version: 1}
2021-06-26 18:36:03.456017: I tensorflow_serving/core/loader_harness.cc:66] Approving load for servable version {name: jax_model version: 1}
2021-06-26 18:36:03.456030: I tensorflow_serving/core/loader_harness.cc:74] Loading servable version {name: jax_model version: 1}
2021-06-26 18:36:03.456078: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:38] Reading SavedModel from: /models/jax_model/1
2021-06-26 18:36:03.457659: I exter

Send prediction

In [11]:
data = json.dumps({"instances": image_to_predict.numpy().tolist()})
json_response = requests.post(f'http://localhost:8501/v1/models/{SERVING_MODELNAME}:predict', data=data)
predictions = json.loads(json_response.text)['predictions']
print(predictions)

[[-14.7594662, -0.0016374588, -10.8501253, -7.16693, -10.4939137, -11.1298847, -12.9479427, -10.0422325, -7.21408, -10.8261595]]


In [12]:
%%bash
docker rm -f serving_jax

serving_jax


### Push image to registry

In [13]:
%%bash
docker push $IMAGE_URI

The push refers to repository [gcr.io/dsparing-sandbox/jax_tfserving_image]
74af849d2512: Preparing
bb4423850a27: Preparing
b60ba33781cd: Preparing
547f89523b17: Preparing
bd91f28d5f3c: Preparing
8cafc6d2db45: Preparing
a5d4bacb0351: Preparing
5153e1acaabc: Preparing
8cafc6d2db45: Waiting
5153e1acaabc: Waiting
a5d4bacb0351: Waiting
bb4423850a27: Layer already exists
547f89523b17: Layer already exists
bd91f28d5f3c: Layer already exists
b60ba33781cd: Layer already exists
8cafc6d2db45: Layer already exists
a5d4bacb0351: Layer already exists
5153e1acaabc: Layer already exists
74af849d2512: Pushed
latest-cpu: digest: sha256:b7e799de9dd2f27f1bae320334dc87f5f3a5c095fdd1f0b2e838631e310f8776 size: 1991


### Upload prediction container to Vertex AI

In [14]:
from google.cloud import aiplatform

REGION = "us-central1"

MACHINE_TYPE="n1-standard-2"

if USE_GPU:
    MODEL_DISPLAYNAME = f"{SERVING_MODELNAME}-gpu"
    ACCELERATOR_TYPE = "NVIDIA_TESLA_T4"
    ACCELERATOR_COUNT = 1
else:
    MODEL_DISPLAYNAME = f"{SERVING_MODELNAME}-cpu"
    ACCELERATOR_TYPE = None
    ACCELERATOR_COUNT = None

In [15]:
aiplatform.init(project=PROJECT_ID, location=REGION)

In [16]:
model = aiplatform.Model.upload(
    display_name=MODEL_DISPLAYNAME,
    serving_container_image_uri=IMAGE_URI,
    serving_container_predict_route=f"/v1/models/{SERVING_MODELNAME}:predict",
    serving_container_health_route=f"/v1/models/{SERVING_MODELNAME}",
    serving_container_args=['--xla_cpu_compilation_enabled=true'],
    serving_container_ports=[8501],
)

print(model.display_name)
print(model.resource_name)

jax_model-cpu
projects/654544512569/locations/us-central1/models/4069151796910620672


In [17]:
endpoint = model.deploy(
    machine_type=MACHINE_TYPE,
    accelerator_type=ACCELERATOR_TYPE,
    accelerator_count=ACCELERATOR_COUNT,
)

### Get predictions

In [18]:
import json
import requests
import tensorflow_datasets as tfds

from absl import flags
from jax.experimental.jax2tf.examples.mnist_lib import load_mnist

In [19]:
flags.FLAGS(['e']) # need to initialize flags somehow to avoid errors in load_mnist

image_to_predict, _ = next(iter(load_mnist(tfds.Split.TEST, 1)))
print(image_to_predict.shape)
image_json = json.dumps(image_to_predict.numpy().tolist())

(1, 28, 28, 1)


In [20]:
instances = image_to_predict.numpy().tolist()
prediction = endpoint.predict(instances)
print(prediction.predictions)

[[-16.5650826, -28.286602, -15.1550636, -10.3459358, -9.69814873, -13.5413094, -23.0040817, -8.21421432, -10.2531109, -0.000401496887]]
