In [18]:
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Predict with JAX/Flax-trained SavedModel on Vertex AI custom container with TF Serving

Vertex AI Prediction supports pre-built containers, with no additional customization such as args ("Do not specify any other subfields of containerSpec" [source](https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers#using_a_pre-built_container)), and custom containers. As we need the `--xla_cpu_compilation_enabled` arg, we can only use custom containers.

In [1]:
import json
import os

import requests
import tensorflow_datasets as tfds
from absl import flags
from google.cloud import aiplatform
from jax.experimental.jax2tf.examples.mnist_lib import load_mnist

## Create TFServing container with SavedModel baked in

NOTE: serving does not work for GPU yet.

Below, `SAVEDMODELDIR` should point to a model created in [training-prebuilt.ipynb](training-prebuilt.ipynb). Change it to any directory containing a SavedModel.

In [2]:
PROJECT_ID = !(gcloud config get-value core/project)
PROJECT_ID = PROJECT_ID[0]

REGION = "us-central1"
os.environ['REGION'] = REGION

BUCKET_NAME = PROJECT_ID
os.environ['BUCKET_NAME'] = BUCKET_NAME
# Use a regional bucket in the above region you have rights to.
# Create if needed:
# !gsutil mb -l ${REGION} gs://${BUCKET_NAME}

MODEL_NAME = "jax_model_prebuilt"

SAVEDMODEL_DIR = ("gs://{BUCKET_NAME}/models/"
                  f"{MODEL_NAME}/output/model")

SERVING_MODELNAME = "jax_model"
MODEL_VERSION = 1

USE_GPU = False
if USE_GPU:
    IMAGE_TAG = "latest-gpu"
    TFSERVING_TAG = "latest-gpu"
else:
    IMAGE_TAG = "latest-cpu"
    TFSERVING_TAG = "latest"

IMAGE_NAME = "jax_tfserving_image"
IMAGE_URI = f"gcr.io/{PROJECT_ID}/{IMAGE_NAME}:{IMAGE_TAG}"
print(IMAGE_URI)

os.environ["SERVING_MODELNAME"] = SERVING_MODELNAME
os.environ["SAVEDMODEL_DIR"] = SAVEDMODEL_DIR
os.environ["MODEL_VERSION"] = str(MODEL_VERSION)
os.environ["IMAGE_URI"] = IMAGE_URI
os.environ["IMAGE_TAG"] = IMAGE_TAG
os.environ["TFSERVING_TAG"] = TFSERVING_TAG

gcr.io/dsparing-sandbox/jax_tfserving_image:latest-cpu


Check that `SAVEDMODEL_DIR` actually contains a SavedModel

In [3]:
%%bash
gsutil ls $SAVEDMODEL_DIR

gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/
gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/saved_model.pb
gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/assets/
gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/variables/


In [4]:
%%bash
mkdir -p $SERVING_MODELNAME/$MODEL_VERSION # tf serving wants model versions in numbered directories.
gsutil -q cp -r $SAVEDMODEL_DIR/* $SERVING_MODELNAME/$MODEL_VERSION/

Below we follow the [TF Serving tutorial](https://www.tensorflow.org/tfx/serving/docker#creating_your_own_serving_image) for creating your own serving image with the model baked in. We spin up a TF Serving container, copy the model inside the container, and commit this change together with the `MODEL_NAME` environment variable. After the commit, the base TF Serving container can be removed.

In [5]:
%%bash
docker run -d --name serving_base tensorflow/serving:$TFSERVING_TAG

37b012feee8e6a6b16923d27dcabe4491d8c97210ff52c097d38d36fe8fbb1eb


Unable to find image 'tensorflow/serving:latest' locally
latest: Pulling from tensorflow/serving
01bf7da0a88c: Pulling fs layer
f3b4a5f15c7a: Pulling fs layer
57ffbe87baa1: Pulling fs layer
e72e6208e893: Pulling fs layer
6ea3f464ef73: Pulling fs layer
01e9bf86544b: Pulling fs layer
68f6bba3dc50: Pulling fs layer
6ea3f464ef73: Waiting
01e9bf86544b: Waiting
68f6bba3dc50: Waiting
e72e6208e893: Waiting
57ffbe87baa1: Verifying Checksum
57ffbe87baa1: Download complete
f3b4a5f15c7a: Verifying Checksum
f3b4a5f15c7a: Download complete
01bf7da0a88c: Verifying Checksum
01bf7da0a88c: Download complete
e72e6208e893: Download complete
01e9bf86544b: Download complete
68f6bba3dc50: Verifying Checksum
68f6bba3dc50: Download complete
01bf7da0a88c: Pull complete
6ea3f464ef73: Verifying Checksum
6ea3f464ef73: Download complete
f3b4a5f15c7a: Pull complete
57ffbe87baa1: Pull complete
e72e6208e893: Pull complete
6ea3f464ef73: Pull complete
01e9bf86544b: Pull complete
68f6bba3dc50: Pull complete
Digest: sha25

In [6]:
%%bash
docker cp $SERVING_MODELNAME serving_base:/models/$SERVING_MODELNAME
docker commit --change "ENV MODEL_NAME $SERVING_MODELNAME" serving_base $IMAGE_URI

sha256:8d3310d6efe401cb35b065f2d44329b85b529ba050065fa5f0b1d1f00ab659ac


In [7]:
%%bash
docker rm -f serving_base

serving_base


## Try serving locally

Get image to predict

In [8]:
# need to initialize flags somehow to avoid errors in load_mnist
flags.FLAGS(['e'])

image_to_predict, _ = next(iter(load_mnist(tfds.Split.TEST, batch_size=1)))
instances = image_to_predict.numpy().tolist()
image_json = json.dumps(instances)

INFO:absl:Load dataset info from /home/jupyter/tensorflow_datasets/mnist/3.0.1
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.splits from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.module_name from disk and from code do not match. Keeping the one from code.
INFO:absl:Reusing dataset mnist (/home/jupyter/tensorflow_datasets/mnist/3.0.1)
INFO:absl:Constructing tf.data.Dataset mnist for split test, from /home/jupyter/tensorflow_datasets/mnist/3.0.1


Start up container

In [9]:
%%bash
docker run -d -p 8501:8501 -e MODEL_NAME=$SERVING_MODELNAME --name serving_jax $IMAGE_URI --xla_cpu_compilation_enabled=true

86618ca15a6ddf2df432024af8ff5a2c85f3cfde782c4e00ed7e87a0f0026b9c


In [10]:
%%bash
sleep 2
docker logs serving_jax

2021-06-26 23:00:36.960212: I tensorflow_serving/model_servers/server.cc:89] Building single TensorFlow model file config:  model_name: jax_model model_base_path: /models/jax_model
2021-06-26 23:00:36.960474: I tensorflow_serving/model_servers/server_core.cc:465] Adding/updating models.
2021-06-26 23:00:36.960496: I tensorflow_serving/model_servers/server_core.cc:591]  (Re-)adding model: jax_model
2021-06-26 23:00:37.060946: I tensorflow_serving/core/basic_manager.cc:740] Successfully reserved resources to load servable {name: jax_model version: 1}
2021-06-26 23:00:37.060992: I tensorflow_serving/core/loader_harness.cc:66] Approving load for servable version {name: jax_model version: 1}
2021-06-26 23:00:37.061009: I tensorflow_serving/core/loader_harness.cc:74] Loading servable version {name: jax_model version: 1}
2021-06-26 23:00:37.061067: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:38] Reading SavedModel from: /models/jax_model/1
2021-06-26 23:00:37.062541: I exter

Send prediction

In [11]:
data = json.dumps({"instances": image_to_predict.numpy().tolist()})
json_response = requests.post(
    f'http://localhost:8501/v1/models/{SERVING_MODELNAME}:predict',
    data=data,
)
predictions = json.loads(json_response.text)['predictions']
print(predictions)

[[-5.36204624, -4.91872883, -4.52080679, -0.384578705, -12.0657101, -1.39947963, -5.91683674, -9.53136826, -3.0643661, -8.50711155]]


In [12]:
%%bash
docker rm -f serving_jax

serving_jax


## Push image to Container Registry

In [13]:
%%bash
docker push $IMAGE_URI

The push refers to repository [gcr.io/dsparing-sandbox/jax_tfserving_image]
5a3a97ca4978: Preparing
bb4423850a27: Preparing
b60ba33781cd: Preparing
547f89523b17: Preparing
bd91f28d5f3c: Preparing
8cafc6d2db45: Preparing
a5d4bacb0351: Preparing
5153e1acaabc: Preparing
a5d4bacb0351: Waiting
5153e1acaabc: Waiting
8cafc6d2db45: Waiting
bd91f28d5f3c: Layer already exists
547f89523b17: Layer already exists
b60ba33781cd: Layer already exists
bb4423850a27: Layer already exists
8cafc6d2db45: Layer already exists
a5d4bacb0351: Layer already exists
5153e1acaabc: Layer already exists
5a3a97ca4978: Pushed
latest-cpu: digest: sha256:5469c3038606bbcb54d7f2bb384b33bb03e3c3d33da32fd5c1a81864fd6222d9 size: 1991


## Upload prediction container to Vertex AI

In [14]:
REGION = "us-central1"

MACHINE_TYPE = "n1-standard-2"

if USE_GPU:
    MODEL_DISPLAYNAME = f"{SERVING_MODELNAME}-gpu"
    ACCELERATOR_TYPE = "NVIDIA_TESLA_T4"
    ACCELERATOR_COUNT = 1
else:
    MODEL_DISPLAYNAME = f"{SERVING_MODELNAME}-cpu"
    ACCELERATOR_TYPE = None
    ACCELERATOR_COUNT = None

In [15]:
model = aiplatform.Model.upload(
    display_name=MODEL_DISPLAYNAME,
    serving_container_image_uri=IMAGE_URI,
    serving_container_predict_route=f"/v1/models/{SERVING_MODELNAME}:predict",
    serving_container_health_route=f"/v1/models/{SERVING_MODELNAME}",
    serving_container_args=['--xla_cpu_compilation_enabled=true'],
    serving_container_ports=[8501],
)

print(model.display_name)
print(model.resource_name)

INFO:google.cloud.aiplatform.models:Creating Model
INFO:google.cloud.aiplatform.models:Create Model backing LRO: projects/654544512569/locations/us-central1/models/501174992126345216/operations/4723293320594325504
INFO:google.cloud.aiplatform.models:Model created. Resource name: projects/654544512569/locations/us-central1/models/501174992126345216
INFO:google.cloud.aiplatform.models:To use this Model in another session:
INFO:google.cloud.aiplatform.models:model = aiplatform.Model('projects/654544512569/locations/us-central1/models/501174992126345216')
jax_model-cpu
projects/654544512569/locations/us-central1/models/501174992126345216


In [16]:
endpoint = model.deploy(
    machine_type=MACHINE_TYPE,
    accelerator_type=ACCELERATOR_TYPE,
    accelerator_count=ACCELERATOR_COUNT,
)

INFO:google.cloud.aiplatform.models:Creating Endpoint
INFO:google.cloud.aiplatform.models:Create Endpoint backing LRO: projects/654544512569/locations/us-central1/endpoints/366814671212118016/operations/573226263972413440
INFO:google.cloud.aiplatform.models:Endpoint created. Resource name: projects/654544512569/locations/us-central1/endpoints/366814671212118016
INFO:google.cloud.aiplatform.models:To use this Endpoint in another session:
INFO:google.cloud.aiplatform.models:endpoint = aiplatform.Endpoint('projects/654544512569/locations/us-central1/endpoints/366814671212118016')
INFO:google.cloud.aiplatform.models:Deploying model to Endpoint : projects/654544512569/locations/us-central1/endpoints/366814671212118016
INFO:google.cloud.aiplatform.models:Deploy Endpoint model backing LRO: projects/654544512569/locations/us-central1/endpoints/366814671212118016/operations/3536594818782199808
INFO:google.cloud.aiplatform.models:Endpoint model deployed. Resource name: projects/654544512569/loca

## Get Online Predictions from Vertex AI

In [17]:
prediction = endpoint.predict(instances)
print(prediction.predictions)

[[-5.36204624, -4.91872883, -4.52080679, -0.384578705, -12.0657101, -1.39947963, -5.91683674, -9.53136826, -3.0643661, -8.50711155]]
