In [1]:
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Predict with JAX/Flax-trained SavedModel on Vertex AI custom container with TF Serving

Vertex AI Prediction supports pre-built containers, with no additional customization such as args ("Do not specify any other subfields of containerSpec" [source](https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers#using_a_pre-built_container)), and custom containers. As we need the `--xla_cpu_compilation_enabled` arg, we can only use custom containers.

In [2]:
import json
import os

import requests
import tensorflow_datasets as tfds
from absl import flags
from google.cloud import aiplatform
from jax.experimental.jax2tf.examples.mnist_lib import load_mnist

## Create TFServing container with SavedModel baked in

NOTE: serving does not work for GPU yet.

Below, `SAVEDMODELDIR` should point to a model created in [training-prebuilt.ipynb](training-prebuilt.ipynb). Change it to any directory containing a SavedModel.

In [5]:
PROJECT_ID = !(gcloud config get-value core/project)
PROJECT_ID = PROJECT_ID[0]

REGION = "us-central1"
os.environ['REGION'] = REGION

BUCKET_NAME = "dsparing-sandbox-bucket"
os.environ['BUCKET_NAME'] = BUCKET_NAME
# Use a regional bucket in the above region you have rights to.
# Create if needed:
# !gsutil mb -l ${REGION} gs://${BUCKET_NAME}

MODEL_NAME = "jax_model_prebuilt"

SAVEDMODEL_DIR = (f"gs://{BUCKET_NAME}/models/"
                  f"{MODEL_NAME}/output/model")

SERVING_MODELNAME = "jax_model"
MODEL_VERSION = 1

USE_GPU = False
if USE_GPU:
    IMAGE_TAG = "latest-gpu"
    TFSERVING_TAG = "latest-gpu"
else:
    IMAGE_TAG = "latest-cpu"
    TFSERVING_TAG = "latest"

IMAGE_NAME = "jax_tfserving_image"
IMAGE_URI = f"gcr.io/{PROJECT_ID}/{IMAGE_NAME}:{IMAGE_TAG}"
print(IMAGE_URI)

os.environ["SERVING_MODELNAME"] = SERVING_MODELNAME
os.environ["SAVEDMODEL_DIR"] = SAVEDMODEL_DIR
os.environ["MODEL_VERSION"] = str(MODEL_VERSION)
os.environ["IMAGE_URI"] = IMAGE_URI
os.environ["IMAGE_TAG"] = IMAGE_TAG
os.environ["TFSERVING_TAG"] = TFSERVING_TAG

gcr.io/dsparing-sandbox/jax_tfserving_image:latest-cpu


Check that `SAVEDMODEL_DIR` actually contains a SavedModel

In [21]:
%%bash
gsutil ls $SAVEDMODEL_DIR

gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/
gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/saved_model.pb
gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/1/
gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/assets/
gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/variables/


In [22]:
%%bash
gsutil cp -r $SAVEDMODEL_DIR/* $SAVEDMODEL_DIR/1/

Copying gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/saved_model.pb...
Copying gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/1/saved_model.pb...
/ [2 files][ 86.3 KiB/ 86.3 KiB]                                                
==> NOTE: You are performing a sequence of gsutil operations that may
run significantly faster if you instead use gsutil -m cp ... Please
see the -m section under "gsutil help options" for further information
about when gsutil -m can be advantageous.

Copying gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/variables/variables.data-00000-of-00001...
Copying gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/variables/variables.index...
- [4 files][  3.2 MiB/  3.2 MiB]                                                
Operation completed over 4 objects/3.2 MiB.                                      


In [23]:
%%bash
gsutil ls $SAVEDMODEL_DIR/*

gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/
gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/saved_model.pb

gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/1/:
gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/1/saved_model.pb
gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/1/1/
gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/1/variables/

gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/assets/:
gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/assets/

gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/variables/:
gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/variables/
gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/variables/variables.data-00000-of-00001
gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/variables/variables.index


In [7]:
%%bash
mkdir -p $SERVING_MODELNAME/$MODEL_VERSION # tf serving wants model versions in numbered directories.
gsutil -q cp -r $SAVEDMODEL_DIR/* $SERVING_MODELNAME/$MODEL_VERSION/

Below we follow the [TF Serving tutorial](https://www.tensorflow.org/tfx/serving/docker#creating_your_own_serving_image) for creating your own serving image with the model baked in. We spin up a TF Serving container, copy the model inside the container, and commit this change together with the `MODEL_NAME` environment variable. After the commit, the base TF Serving container can be removed.

In [8]:
%%bash
docker run -d --name serving_base tensorflow/serving:$TFSERVING_TAG

d962ca3b1475655da79a303881a3666165ef29a38e74c99635b872322b158060


In [9]:
%%bash
docker cp $SERVING_MODELNAME serving_base:/models/$SERVING_MODELNAME
docker commit --change "ENV MODEL_NAME $SERVING_MODELNAME" serving_base $IMAGE_URI

sha256:0149676bf3240ab13fad895cd1b5b5874209db084c9412a77b78273be6b07f7a


In [10]:
%%bash
docker rm -f serving_base

serving_base


## Try serving locally

Get image to predict

In [11]:
# need to initialize flags somehow to avoid errors in load_mnist
flags.FLAGS(['e'])

image_to_predict, _ = next(iter(load_mnist(tfds.Split.TEST, batch_size=1)))
instances = image_to_predict.numpy().tolist()
image_json = json.dumps(instances)

INFO:absl:Load dataset info from /home/jupyter/tensorflow_datasets/mnist/3.0.1
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.splits from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.module_name from disk and from code do not match. Keeping the one from code.
INFO:absl:Reusing dataset mnist (/home/jupyter/tensorflow_datasets/mnist/3.0.1)
INFO:absl:Constructing tf.data.Dataset mnist for split test, from /home/jupyter/tensorflow_datasets/mnist/3.0.1


Start up container

In [26]:
%%bash
docker run -d -p 8501:8501 -e MODEL_NAME=$SERVING_MODELNAME --name serving_jax \
    tensorflow/serving:$TFSERVING_TAG --xla_cpu_compilation_enabled=true \
        --model_base_path="gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/"

938c6cadafb0566de144f7b8880e1112201d6994c763b85bf9f26f1911a847b3


In [28]:
%%bash
sleep 2
docker logs serving_jax

2021-06-27 01:59:21.819077: I tensorflow_serving/model_servers/server.cc:89] Building single TensorFlow model file config:  model_name: jax_model model_base_path: gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/
2021-06-27 01:59:21.819343: I tensorflow_serving/model_servers/server_core.cc:465] Adding/updating models.
2021-06-27 01:59:21.819378: I tensorflow_serving/model_servers/server_core.cc:591]  (Re-)adding model: jax_model
2021-06-27 01:59:23.395493: I tensorflow_serving/core/basic_manager.cc:740] Successfully reserved resources to load servable {name: jax_model version: 1}
2021-06-27 01:59:23.395525: I tensorflow_serving/core/loader_harness.cc:66] Approving load for servable version {name: jax_model version: 1}
2021-06-27 01:59:23.395557: I tensorflow_serving/core/loader_harness.cc:74] Loading servable version {name: jax_model version: 1}
2021-06-27 01:59:23.534478: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:38] Reading SavedModel from: gs:/

Send prediction

In [29]:
data = json.dumps({"instances": image_to_predict.numpy().tolist()})
json_response = requests.post(
    f'http://localhost:8501/v1/models/{SERVING_MODELNAME}:predict',
    data=data,
)
predictions = json.loads(json_response.text)['predictions']
print(predictions)

[[-14.4188423, -0.00523996353, -8.44263649, -7.61883926, -10.9973564, -9.40729, -8.70840836, -10.2604904, -5.51396084, -8.56981277]]


In [25]:
%%bash
docker rm -f serving_jax

serving_jax


## Push image to Container Registry

In [None]:
%%bash
docker push $IMAGE_URI

## Upload prediction container to Vertex AI

In [None]:
REGION = "us-central1"

MACHINE_TYPE = "n1-standard-2"

if USE_GPU:
    MODEL_DISPLAYNAME = f"{SERVING_MODELNAME}-gpu"
    ACCELERATOR_TYPE = "NVIDIA_TESLA_T4"
    ACCELERATOR_COUNT = 1
else:
    MODEL_DISPLAYNAME = f"{SERVING_MODELNAME}-cpu"
    ACCELERATOR_TYPE = None
    ACCELERATOR_COUNT = None

In [None]:
model = aiplatform.Model.upload(
    display_name=MODEL_DISPLAYNAME,
    serving_container_image_uri=IMAGE_URI,
    serving_container_predict_route=f"/v1/models/{SERVING_MODELNAME}:predict",
    serving_container_health_route=f"/v1/models/{SERVING_MODELNAME}",
    serving_container_args=['--xla_cpu_compilation_enabled=true'],
    serving_container_ports=[8501],
)

print(model.display_name)
print(model.resource_name)

In [None]:
endpoint = model.deploy(
    machine_type=MACHINE_TYPE,
    accelerator_type=ACCELERATOR_TYPE,
    accelerator_count=ACCELERATOR_COUNT,
)

## Get Online Predictions from Vertex AI

In [None]:
prediction = endpoint.predict(instances)
print(prediction.predictions)