In [1]:
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Upload JAX/Flax-trained SavedModel to Vertex AI with TF Serving custom prediction container

Vertex AI Prediction supports pre-built containers, with no additional customization such as args ("Do not specify any other subfields of containerSpec" [source](https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers#using_a_pre-built_container)), and custom containers. As we need the `--xla_cpu_compilation_enabled` arg, we can only use custom containers. However, we can simply use tensorflow/serving from Docker Hub as our custom container.

In [2]:
import json
import os

import requests
import tensorflow_datasets as tfds
from absl import flags
from google.cloud import aiplatform
from jax.experimental.jax2tf.examples import mnist_lib

Below, `BASE_OUTPUT_DIR/model/MODEL_NAME/MODEL_VERSION` should point to a model created in [training-prebuilt.ipynb](training-prebuilt.ipynb). Change it to any path containing the correct SavedModel structure.

In [3]:
PROJECT_ID = !(gcloud config get-value project)
PROJECT_ID = PROJECT_ID[0]

REGION = "us-central1"

BUCKET_NAME = PROJECT_ID
# Use a regional bucket in the above region you have rights to.
# Create if needed:
# !gsutil mb -l $REGION gs://$BUCKET_NAME

BASE_OUTPUT_DIR = f"gs://{BUCKET_NAME}"
MODEL_NAME = "jax_model_prebuilt"
MODEL_VERSION = 1

SERVING_BATCH_SIZE = 3

Check that `BASE_OUTPUT_DIR/model/MODEL_NAME/MODEL_VERSION` actually contains a SavedModel

In [4]:
!gsutil ls $BASE_OUTPUT_DIR/model/$MODEL_NAME/$MODEL_VERSION

gs://dsparing-sandbox/model/jax_model_prebuilt/1/
gs://dsparing-sandbox/model/jax_model_prebuilt/1/saved_model.pb
gs://dsparing-sandbox/model/jax_model_prebuilt/1/assets/
gs://dsparing-sandbox/model/jax_model_prebuilt/1/variables/


## Create TFServing container

We will simply use the default TensorFlow Serving container image. We still need to build this container, because a Container Registry or Artifact Registry container is expected, so in effect we copy this from Docker Hub.

NOTE: serving does not work for GPU yet, the below is a CPU example.

In [5]:
SERVING_FOLDER = "serving"
SERVING_IMAGE_NAME = "tensorflow-serving"
SERVING_IMAGE_URI = f"gcr.io/{PROJECT_ID}/{SERVING_IMAGE_NAME}"

USE_GPU_SERVING = False
if USE_GPU_SERVING:
    TFSERVING_TAG = "latest-gpu"
else:
    TFSERVING_TAG = "latest"

In [6]:
os.environ["SERVING_FOLDER"] = SERVING_FOLDER
os.environ["TFSERVING_TAG"] = TFSERVING_TAG

In [7]:
%%bash
mkdir -p $SERVING_FOLDER
cat > $SERVING_FOLDER/Dockerfile << EOF
FROM tensorflow/serving:$TFSERVING_TAG
EOF

We only build the container if it is not already available.

In [8]:
!cd $SERVING_FOLDER && docker pull $SERVING_IMAGE_URI || gcloud builds submit --tag $SERVING_IMAGE_URI

Using default tag: latest
latest: Pulling from dsparing-sandbox/tensorflow-serving
Digest: sha256:6651f4839e1124dbde75ee531825112af0a6b8ef082c88ab14ca53eb69a2e4bb
Status: Image is up to date for gcr.io/dsparing-sandbox/tensorflow-serving:latest
gcr.io/dsparing-sandbox/tensorflow-serving:latest


## Try serving locally

Get image to predict

In [9]:
# need to initialize flags somehow to avoid errors in load_mnist
flags.FLAGS([""])

images_to_predict, _ = next(
    iter(mnist_lib.load_mnist(tfds.Split.TEST, batch_size=SERVING_BATCH_SIZE))
)
instances = images_to_predict.numpy().tolist()
image_json = json.dumps(instances)

INFO:absl:Load dataset info from /home/jupyter/tensorflow_datasets/mnist/3.0.1
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.splits from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.module_name from disk and from code do not match. Keeping the one from code.
INFO:absl:Reusing dataset mnist (/home/jupyter/tensorflow_datasets/mnist/3.0.1)
INFO:absl:Constructing tf.data.Dataset mnist for split test, from /home/jupyter/tensorflow_datasets/mnist/3.0.1


Start up container (for GPU it should use `--runtime=nvidia` but for now GPU prediction does not work.)

In [10]:
!docker pull $SERVING_IMAGE_URI
!docker run -d -p 8501:8501 \
    --name serving_jax \
    $SERVING_IMAGE_URI \
    --xla_cpu_compilation_enabled=true \
    --model_name=$MODEL_NAME \
    --model_base_path=$BASE_OUTPUT_DIR/model/$MODEL_NAME

Using default tag: latest
latest: Pulling from dsparing-sandbox/tensorflow-serving
Digest: sha256:6651f4839e1124dbde75ee531825112af0a6b8ef082c88ab14ca53eb69a2e4bb
Status: Image is up to date for gcr.io/dsparing-sandbox/tensorflow-serving:latest
gcr.io/dsparing-sandbox/tensorflow-serving:latest
27abdc67652dbc1e146d4a919c65e1f38583d65c51ebbd2c44ca5f9a050415e6


In [11]:
!sleep 20 && docker logs serving_jax

2021-07-01 15:05:17.078661: I tensorflow_serving/model_servers/server.cc:89] Building single TensorFlow model file config:  model_name: jax_model_prebuilt model_base_path: gs://dsparing-sandbox/model/jax_model_prebuilt
2021-07-01 15:05:17.078916: I tensorflow_serving/model_servers/server_core.cc:465] Adding/updating models.
2021-07-01 15:05:17.078941: I tensorflow_serving/model_servers/server_core.cc:591]  (Re-)adding model: jax_model_prebuilt
2021-07-01 15:05:18.515660: I tensorflow_serving/core/basic_manager.cc:740] Successfully reserved resources to load servable {name: jax_model_prebuilt version: 1}
2021-07-01 15:05:18.515699: I tensorflow_serving/core/loader_harness.cc:66] Approving load for servable version {name: jax_model_prebuilt version: 1}
2021-07-01 15:05:18.515741: I tensorflow_serving/core/loader_harness.cc:74] Loading servable version {name: jax_model_prebuilt version: 1}
2021-07-01 15:05:18.728470: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:38] Readin

With the `sleep`, we give it some time for TF Serving to load the model from Cloud Storage and be ready to accept requests. Verify in the log below that it is indeed ready. (Probably printing something like `"Entering the event loop ..."`)

Send prediction

In [12]:
data = json.dumps({"instances": instances})
json_response = requests.post(
    f"http://localhost:8501/v1/models/{MODEL_NAME}:predict",
    data=data,
)
predictions = json.loads(json_response.text)["predictions"]
print(predictions)

[[-9.59468746, -6.71643162, -9.19317722, -6.08347225, -9.86418915, -7.86642075, -13.7691307, -0.586410522, -3.49557447, -0.893427849], [-8.17506, -8.27988815, -5.86172485, -2.22976971, -11.3345842, -2.34807348, -9.04989243, -7.64320087, -0.234498978, -6.24677086], [-13.4298983, -25.2465477, -9.84352875, -18.5516396, -10.8125267, -11.114028, -0.000108718872, -25.0840569, -10.8588352, -16.1272583]]


In [13]:
!docker rm -f serving_jax

serving_jax


## Upload model to Vertex AI using custom container

In [14]:
model = aiplatform.Model.upload(
    display_name=MODEL_NAME,
    serving_container_image_uri=SERVING_IMAGE_URI,
    artifact_uri=os.path.join(BASE_OUTPUT_DIR, "model", MODEL_NAME),
    serving_container_predict_route=f"/v1/models/{MODEL_NAME}:predict",
    serving_container_health_route=f"/v1/models/{MODEL_NAME}",
    serving_container_args=[
        "--xla_cpu_compilation_enabled=true",
        f"--model_name={MODEL_NAME}",
        f"--model_base_path=$(AIP_STORAGE_URI)",
    ],
    serving_container_ports=[8501],
)

print(model.display_name, model.resource_name)

INFO:google.cloud.aiplatform.models:Creating Model
INFO:google.cloud.aiplatform.models:Create Model backing LRO: projects/654544512569/locations/us-central1/models/2423649083060125696/operations/6377785641513517056
INFO:google.cloud.aiplatform.models:Model created. Resource name: projects/654544512569/locations/us-central1/models/2423649083060125696
INFO:google.cloud.aiplatform.models:To use this Model in another session:
INFO:google.cloud.aiplatform.models:model = aiplatform.Model('projects/654544512569/locations/us-central1/models/2423649083060125696')
jax_model_prebuilt projects/654544512569/locations/us-central1/models/2423649083060125696


You need to note the model resource name (as a unique identifier for your model) for prediction later.