In [1]:
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Upload JAX/Flax-trained SavedModel to Vertex AI with TF Serving custom prediction container

Vertex AI Prediction supports pre-built containers, with no additional customization such as args ("Do not specify any other subfields of containerSpec" [source](https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers#using_a_pre-built_container)), and custom containers. As we need the `--xla_cpu_compilation_enabled` arg, we can only use custom containers. However, we can simply use tensorflow/serving from Docker Hub as our custom container.

In [2]:
import json
import os

import requests
import tensorflow_datasets as tfds
from absl import flags
from google.cloud import aiplatform
from jax.experimental.jax2tf.examples import mnist_lib

Below, `MODEL_BASE_PATH/model/MODEL_NAME/MODEL_VERSION` should point to a model created in [training-prebuilt.ipynb](training-prebuilt.ipynb). Change it to any directory containing a SavedModel.

In [3]:
PROJECT_ID = !(gcloud config get-value core/project)
PROJECT_ID = PROJECT_ID[0]

REGION = "us-central1"
os.environ['REGION'] = REGION

BUCKET_NAME = PROJECT_ID
os.environ['BUCKET_NAME'] = BUCKET_NAME
# Use a regional bucket in the above region you have rights to.
# Create if needed:
# !gsutil mb -l ${REGION} gs://${BUCKET_NAME}

MODEL_BASE_PATH = f"gs://{BUCKET_NAME}/model"
MODEL_NAME = "jax_model_prebuilt"
MODEL_VERSION = 1

USE_GPU_SERVING = False
if USE_GPU_SERVING:
    TFSERVING_TAG = "latest-gpu"
else:
    TFSERVING_TAG = "latest"

IMAGE_TAG = 'latest'
SERVING_IMAGE_NAME = 'jax_vertex_serving'
SERVING_IMAGE_URI = f"gcr.io/{PROJECT_ID}/{SERVING_IMAGE_NAME}:{IMAGE_TAG}"
os.environ['SERVING_IMAGE_URI'] = SERVING_IMAGE_URI

os.environ["MODEL_BASE_PATH"] = MODEL_BASE_PATH
os.environ["MODEL_NAME"] = MODEL_NAME
os.environ["MODEL_VERSION"] = str(MODEL_VERSION)
os.environ["TFSERVING_TAG"] = TFSERVING_TAG

Check that `MODEL_BASE_PATH/model/MODEL_NAME/MODEL_VERSION` actually contains a SavedModel

In [4]:
!gsutil ls $MODEL_BASE_PATH/model/$MODEL_NAME/$MODEL_VERSION

gs://dsparing-sandbox/savedmodels/jax_model_prebuilt/model/model/1/
gs://dsparing-sandbox/savedmodels/jax_model_prebuilt/model/model/1/saved_model.pb
gs://dsparing-sandbox/savedmodels/jax_model_prebuilt/model/model/1/assets/
gs://dsparing-sandbox/savedmodels/jax_model_prebuilt/model/model/1/variables/


## Create TFServing container

NOTE: serving does not work for GPU yet, the below is a CPU example.

In [5]:
SERVING_FOLDER = 'serving'
os.environ['SERVING_FOLDER'] = SERVING_FOLDER

In [6]:
%%bash
mkdir -p $SERVING_FOLDER
cat > $SERVING_FOLDER/Dockerfile << EOF
FROM tensorflow/serving:$TFSERVING_TAG

EOF

In [7]:
!cd $SERVING_FOLDER && \
    gcloud builds submit --tag $SERVING_IMAGE_URI

Creating temporary tarball archive of 1 file(s) totalling 32 bytes before compression.
Uploading tarball of [.] to [gs://dsparing-sandbox_cloudbuild/source/1624944806.096853-0e413b57b65a40618358358be0abdea7.tgz]
Created [https://cloudbuild.googleapis.com/v1/projects/dsparing-sandbox/locations/global/builds/f7a90bce-6d5e-4424-a53a-77a69d623958].
Logs are available at [https://console.cloud.google.com/cloud-build/builds/f7a90bce-6d5e-4424-a53a-77a69d623958?project=654544512569].
----------------------------- REMOTE BUILD OUTPUT ------------------------------
starting build "f7a90bce-6d5e-4424-a53a-77a69d623958"

FETCHSOURCE
Fetching storage object: gs://dsparing-sandbox_cloudbuild/source/1624944806.096853-0e413b57b65a40618358358be0abdea7.tgz#1624944806297629
Copying gs://dsparing-sandbox_cloudbuild/source/1624944806.096853-0e413b57b65a40618358358be0abdea7.tgz#1624944806297629...
/ [1 files][  199.0 B/  199.0 B]                                                
Operation completed over 1 ob

## Try serving locally

Get image to predict

In [8]:
# need to initialize flags somehow to avoid errors in load_mnist
flags.FLAGS([''])

image_to_predict, _ = next(iter(
    mnist_lib.load_mnist(tfds.Split.TEST, batch_size=1)
))
instances = image_to_predict.numpy().tolist()
image_json = json.dumps(instances)

INFO:absl:Load dataset info from /home/jupyter/tensorflow_datasets/mnist/3.0.1
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.splits from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.module_name from disk and from code do not match. Keeping the one from code.
INFO:absl:Reusing dataset mnist (/home/jupyter/tensorflow_datasets/mnist/3.0.1)
INFO:absl:Constructing tf.data.Dataset mnist for split test, from /home/jupyter/tensorflow_datasets/mnist/3.0.1


Start up container (for GPU it should use `--runtime=nvidia` but for now GPU prediction does not work.)

In [9]:
%%bash
docker pull $SERVING_IMAGE_URI
docker run -d -p 8501:8501 \
    --name serving_jax \
    --env MODEL_NAME=$MODEL_NAME \
    --env MODEL_BASE_PATH=$$MODEL_BASE_PATH \
    $SERVING_IMAGE_URI \
    --xla_cpu_compilation_enabled=true
#    --model_name=$MODEL_NAME \
#    --model_base_path=$MODEL_BASE_PATH/$MODEL_NAME

latest: Pulling from dsparing-sandbox/jax_vertex_serving
Digest: sha256:6651f4839e1124dbde75ee531825112af0a6b8ef082c88ab14ca53eb69a2e4bb
Status: Image is up to date for gcr.io/dsparing-sandbox/jax_vertex_serving:latest
gcr.io/dsparing-sandbox/jax_vertex_serving:latest
c25aa20a7c112c9df780d8ca9cfcf9cf0875e058f7225c438fc6233a4bf2b2ba


In [10]:
!sleep 20 && docker logs serving_jax

2021-06-29 05:33:52.199259: I tensorflow_serving/model_servers/server.cc:89] Building single TensorFlow model file config:  model_name: model model_base_path: gs://dsparing-sandbox/savedmodels/jax_model_prebuilt/model/model
2021-06-29 05:33:52.199824: I tensorflow_serving/model_servers/server_core.cc:465] Adding/updating models.
2021-06-29 05:33:52.199850: I tensorflow_serving/model_servers/server_core.cc:591]  (Re-)adding model: model
2021-06-29 05:33:53.773127: I tensorflow_serving/core/basic_manager.cc:740] Successfully reserved resources to load servable {name: model version: 1}
2021-06-29 05:33:53.773172: I tensorflow_serving/core/loader_harness.cc:66] Approving load for servable version {name: model version: 1}
2021-06-29 05:33:53.773205: I tensorflow_serving/core/loader_harness.cc:74] Loading servable version {name: model version: 1}
2021-06-29 05:33:53.917876: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:38] Reading SavedModel from: gs://dsparing-sandbox/savedm

With the `sleep`, we give it some time for TF Serving to load the model from Cloud Storage and be ready to accept requests. Verify in the log below that it is indeed ready. (Probably printing something like `"Entering the event loop ..."`)

Send prediction

In [11]:
data = json.dumps({"instances": image_to_predict.numpy().tolist()})
json_response = requests.post(
    f"http://localhost:8501/v1/models/{MODEL_NAME}:predict",
    data=data,
)
predictions = json.loads(json_response.text)['predictions']
print(predictions)

[[-10.0790615, -0.0246343613, -5.78542423, -6.52261972, -9.82075691, -8.41873264, -7.13855028, -8.35187149, -3.99987411, -8.9720192]]


In [12]:
!docker rm -f serving_jax

serving_jax


## Upload model to Vertex AI using custom container

In [22]:
model = aiplatform.Model.upload(
    display_name=MODEL_NAME,
    serving_container_image_uri=SERVING_IMAGE_URI,
    artifact_uri=f"{MODEL_BASE_PATH}/{MODEL_NAME}/model",
    serving_container_predict_route=f"/v1/models/{MODEL_NAME}:predict",
    serving_container_health_routef="/v1/models/{MODEL_NAME}",
    serving_container_args=[
        '--xla_cpu_compilation_enabled=true',
#        f'--model_name={MODEL_NAME}',
#        f'--model_base_path=$(AIP_STORAGE_URI)/{MODEL_NAME}',
    ],
    model_serving_container_environment_variables = {
        "MODEL_NAME": MODEL_NAME,
        "MODEL_BASE_PATH": AIP_STORAGE_URI,
    }
    serving_container_ports=[8501],
)

print(model.display_name)
print(model.resource_name)

INFO:google.cloud.aiplatform.models:Creating Model
INFO:google.cloud.aiplatform.models:Create Model backing LRO: projects/654544512569/locations/us-central1/models/996570951137099776/operations/4201544266889035776
INFO:google.cloud.aiplatform.models:Model created. Resource name: projects/654544512569/locations/us-central1/models/996570951137099776
INFO:google.cloud.aiplatform.models:To use this Model in another session:
INFO:google.cloud.aiplatform.models:model = aiplatform.Model('projects/654544512569/locations/us-central1/models/996570951137099776')
jax_model_prebuilt
projects/654544512569/locations/us-central1/models/996570951137099776


You need to note the model resource name (as a unique identifier for your model) for prediction later.