# Predict with JAX/Flax-trained SavedModel on Vertex AI custom container with TF Serving

runs on TF2.5 [Vertex Notebook](https://cloud.google.com/vertex-ai/docs/general/notebooks)

Vertex AI Prediction supports pre-built containers, with no additional customization such as args ("Do not specify any other subfields of containerSpec" [source](https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers#using_a_pre-built_container)), and custom containers. As we need the `--xla_cpu_compilation_enabled` arg, we can only use custom containers.

### Create TFServing container with SavedModel baked in

NOTE: serving does not work for GPU yet.

In [2]:
import os

PROJECT_ID = !(gcloud config get-value core/project)
PROJECT_ID = PROJECT_ID[0]
SAVEDMODEL_DIR = "gs://dsparing-sandbox-bucket/models/jax_model"
SERVING_MODELNAME = "jax_model"
MODEL_VERSION = 1

USE_GPU = False
if USE_GPU:
    IMAGE_TAG='latest-gpu'
    TFSERVING_TAG = "latest-gpu"
else:
    IMAGE_TAG='latest-cpu'
    TFSERVING_TAG = "latest"    

IMAGE_NAME='jax_tfserving_image'
IMAGE_URI = f'gcr.io/{PROJECT_ID}/{IMAGE_NAME}:{IMAGE_TAG}'
print(IMAGE_URI)

os.environ["SERVING_MODELNAME"] = SERVING_MODELNAME
os.environ["SAVEDMODEL_DIR"] = SAVEDMODEL_DIR
os.environ["MODEL_VERSION"] = str(MODEL_VERSION)
os.environ["IMAGE_URI"] = IMAGE_URI
os.environ["IMAGE_TAG"] = IMAGE_TAG
os.environ["TFSERVING_TAG"] = TFSERVING_TAG


gcr.io/dsparing-sandbox/jax_tfserving_image:latest-cpu


Check that `SAVEDMODEL_DIR` actually contains a SavedModel

In [3]:
%%bash
gsutil ls $SAVEDMODEL_DIR

gs://dsparing-sandbox-bucket/models/jax_model/
gs://dsparing-sandbox-bucket/models/jax_model/saved_model.pb
gs://dsparing-sandbox-bucket/models/jax_model/assets/
gs://dsparing-sandbox-bucket/models/jax_model/variables/


In [4]:
%%bash
mkdir -p $SERVING_MODELNAME/$MODEL_VERSION # tf serving wants model versions in numbered directories.
gsutil -q cp -r $SAVEDMODEL_DIR/* $SERVING_MODELNAME/$MODEL_VERSION/

following [TF Serving tutorial](https://www.tensorflow.org/tfx/serving/docker#creating_your_own_serving_image) below:

In [5]:
%%bash
docker run -d --name serving_base tensorflow/serving:$TFSERVING_TAG

71ec89fa52eb26147e52cd6781d2d4d35dd72c698fffd7792d6d62a38ec9827d


In [6]:
%%bash
docker cp $SERVING_MODELNAME serving_base:/models/$SERVING_MODELNAME
docker commit --change "ENV MODEL_NAME $SERVING_MODELNAME" serving_base $IMAGE_URI

sha256:a6c04272045bf4441eaf54c1aa420c1977cfa731e1282b1f3b39ff2d4ca169db


In [7]:
%%bash
docker rm -f serving_base

serving_base


### Optional: Try serving locally

Get image to predict 

Start up container

Send prediction

### Push image to registry

In [8]:
%%bash
docker push $IMAGE_URI

The push refers to repository [gcr.io/dsparing-sandbox/jax_tfserving_image]
c607ec907455: Preparing
bb4423850a27: Preparing
b60ba33781cd: Preparing
547f89523b17: Preparing
bd91f28d5f3c: Preparing
8cafc6d2db45: Preparing
a5d4bacb0351: Preparing
5153e1acaabc: Preparing
8cafc6d2db45: Waiting
a5d4bacb0351: Waiting
5153e1acaabc: Waiting
bd91f28d5f3c: Layer already exists
b60ba33781cd: Layer already exists
bb4423850a27: Layer already exists
547f89523b17: Layer already exists
8cafc6d2db45: Layer already exists
a5d4bacb0351: Layer already exists
5153e1acaabc: Layer already exists
c607ec907455: Pushed
latest-cpu: digest: sha256:20d7802c6ea76f2660ca7456921d0e71867f46727aa42794e77d2c59c704e464 size: 1991


### Upload prediction container to Vertex AI

In [9]:
from google.cloud import aiplatform

REGION = "us-central1"

MACHINE_TYPE="n1-standard-2"

if USE_GPU:
    MODEL_DISPLAYNAME = f"{SERVING_MODELNAME}-gpu"
    ACCELERATOR_TYPE = "NVIDIA_TESLA_T4"
    ACCELERATOR_COUNT = 1
else:
    MODEL_DISPLAYNAME = f"{SERVING_MODELNAME}-cpu"
    ACCELERATOR_TYPE = None
    ACCELERATOR_COUNT = None

In [10]:
aiplatform.init(project=PROJECT_ID, location=REGION)

In [11]:
model = aiplatform.Model.upload(
    display_name=MODEL_DISPLAYNAME,
    serving_container_image_uri=IMAGE_URI,
    serving_container_predict_route=f"/v1/models/{SERVING_MODELNAME}:predict",
    serving_container_health_route=f"/v1/models/{SERVING_MODELNAME}",
    serving_container_args=['--xla_cpu_compilation_enabled=true', '--runtime=nvidia'],
    serving_container_ports=[8501],
)

print(model.display_name)
print(model.resource_name)

INFO:google.cloud.aiplatform.models:Creating Model
INFO:google.cloud.aiplatform.models:Create Model backing LRO: projects/654544512569/locations/us-central1/models/8700118851242688512/operations/3500014066926092288
INFO:google.cloud.aiplatform.models:Model created. Resource name: projects/654544512569/locations/us-central1/models/8700118851242688512
INFO:google.cloud.aiplatform.models:To use this Model in another session:
INFO:google.cloud.aiplatform.models:model = aiplatform.Model('projects/654544512569/locations/us-central1/models/8700118851242688512')
jax_model-cpu
projects/654544512569/locations/us-central1/models/8700118851242688512


In [12]:
endpoint = model.deploy(
    machine_type=MACHINE_TYPE,
    accelerator_type=ACCELERATOR_TYPE,
    accelerator_count=ACCELERATOR_COUNT,
)

INFO:google.cloud.aiplatform.models:Creating Endpoint
INFO:google.cloud.aiplatform.models:Create Endpoint backing LRO: projects/654544512569/locations/us-central1/endpoints/7109266263339171840/operations/1215000205988986880
INFO:google.cloud.aiplatform.models:Endpoint created. Resource name: projects/654544512569/locations/us-central1/endpoints/7109266263339171840
INFO:google.cloud.aiplatform.models:To use this Endpoint in another session:
INFO:google.cloud.aiplatform.models:endpoint = aiplatform.Endpoint('projects/654544512569/locations/us-central1/endpoints/7109266263339171840')
INFO:google.cloud.aiplatform.models:Deploying model to Endpoint : projects/654544512569/locations/us-central1/endpoints/7109266263339171840
INFO:google.cloud.aiplatform.models:Deploy Endpoint model backing LRO: projects/654544512569/locations/us-central1/endpoints/7109266263339171840/operations/5826686224416374784
INFO:google.cloud.aiplatform.models:Endpoint model deployed. Resource name: projects/65454451256

### Get predictions

In [13]:
import json
import requests
import tensorflow_datasets as tfds

from absl import flags
from jax.experimental.jax2tf.examples.mnist_lib import load_mnist

In [14]:
flags.FLAGS(['e']) # need to initialize flags somehow to avoid errors in load_mnist

image_to_predict, _ = next(iter(load_mnist(tfds.Split.TEST, 1)))
print(image_to_predict.shape)
image_json = json.dumps(image_to_predict.numpy().tolist())

INFO:absl:Load dataset info from /home/jupyter/tensorflow_datasets/mnist/3.0.1
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.splits from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.module_name from disk and from code do not match. Keeping the one from code.
INFO:absl:Reusing dataset mnist (/home/jupyter/tensorflow_datasets/mnist/3.0.1)
INFO:absl:Constructing tf.data.Dataset mnist for split test, from /home/jupyter/tensorflow_datasets/mnist/3.0.1
(1, 28, 28, 1)


In [15]:
instances = image_to_predict.numpy().tolist()
prediction = endpoint.predict(instances)
print(prediction.predictions)

[[-10.2773552, -8.94001102, -8.18527, -6.2940197, -12.1361103, -4.93151569, -12.9197159, -11.7381096, -0.00995016098, -7.88032]]
