In [1]:
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Deploy Vertex AI model for online prediction

In [2]:
import json

import tensorflow_datasets as tfds
from absl import flags
from google.cloud import aiplatform
from jax.experimental.jax2tf.examples import mnist_lib

## Fetch model

On Vertex AI, only the "resource name" (ending in a numerical ID) of a model is unique, not its "display name". Therefore while you can look up your model(s) by display_name, if you have multiple ones, you need to know the resource_name of the one you need (this is returned and printed when you upload the model).

In [3]:
# specify either a unique MODEL_DISPLAY_NAME, or a MODEL_RESOURCE_NAME
MODEL_RESOURCENAME = None
MODEL_DISPLAYNAME = "jax_model_customcontainer"

USE_GPU_SERVING = False

if MODEL_RESOURCENAME:
    model = aiplatform.Model(MODEL_RESOURCENAME)
else:
    models = aiplatform.Model.list(filter=f"display_name={MODEL_DISPLAYNAME}")
    if len(models) > 1:
        for model in models:
            print(model.resource_name, model.display_name)
        raise Exception(
            f"multiple models with display_name=={MODEL_DISPLAYNAME} "
            "(see above), please use a resource_name"
        )
    model = models[0]

print(model.display_name, model.resource_name)

jax_model_prebuilt projects/654544512569/locations/us-central1/models/996570951137099776


## Deploy model for Online Prediction on Vertex AI

In [4]:
MACHINE_TYPE = "n1-standard-2"

if USE_GPU_SERVING:
    ACCELERATOR_TYPE = "NVIDIA_TESLA_T4"
    ACCELERATOR_COUNT = 1
else:
    ACCELERATOR_TYPE = None
    ACCELERATOR_COUNT = None

endpoint = model.deploy(
    machine_type=MACHINE_TYPE,
    accelerator_type=ACCELERATOR_TYPE,
    accelerator_count=ACCELERATOR_COUNT,
)

INFO:google.cloud.aiplatform.models:Creating Endpoint
INFO:google.cloud.aiplatform.models:Create Endpoint backing LRO: projects/654544512569/locations/us-central1/endpoints/4429061535100305408/operations/2436977637889933312
INFO:google.cloud.aiplatform.models:Endpoint created. Resource name: projects/654544512569/locations/us-central1/endpoints/4429061535100305408
INFO:google.cloud.aiplatform.models:To use this Endpoint in another session:
INFO:google.cloud.aiplatform.models:endpoint = aiplatform.Endpoint('projects/654544512569/locations/us-central1/endpoints/4429061535100305408')
INFO:google.cloud.aiplatform.models:Deploying model to Endpoint : projects/654544512569/locations/us-central1/endpoints/4429061535100305408
INFO:google.cloud.aiplatform.models:Deploy Endpoint model backing LRO: projects/654544512569/locations/us-central1/endpoints/4429061535100305408/operations/7201786043647918080
INFO:google.cloud.aiplatform.models:Endpoint model deployed. Resource name: projects/65454451256

## Get Online Predictions from Vertex AI

In [5]:
# need to initialize flags somehow to avoid errors in load_mnist
flags.FLAGS([''])

image_to_predict, _ = next(iter(
    mnist_lib.load_mnist(tfds.Split.TEST, batch_size=1)
))
instances = image_to_predict.numpy().tolist()
image_json = json.dumps(instances)

INFO:absl:Load dataset info from /home/jupyter/tensorflow_datasets/mnist/3.0.1
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.splits from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.module_name from disk and from code do not match. Keeping the one from code.
INFO:absl:Reusing dataset mnist (/home/jupyter/tensorflow_datasets/mnist/3.0.1)
INFO:absl:Constructing tf.data.Dataset mnist for split test, from /home/jupyter/tensorflow_datasets/mnist/3.0.1


In [6]:
prediction = endpoint.predict(instances)
print(prediction.predictions)

[[-0.0142641068, -22.4597549, -10.4813251, -11.0796452, -14.5759716, -4.32176, -7.92723083, -13.9459896, -7.69770765, -10.5191851]]


## Cleanup

In [7]:
endpoint.undeploy_all().delete()

INFO:google.cloud.aiplatform.models:Undeploying Endpoint model: projects/654544512569/locations/us-central1/endpoints/4429061535100305408
INFO:google.cloud.aiplatform.models:Undeploy Endpoint model backing LRO: projects/654544512569/locations/us-central1/endpoints/4429061535100305408/operations/1572286509434798080
INFO:google.cloud.aiplatform.models:Endpoint model undeployed. Resource name: projects/654544512569/locations/us-central1/endpoints/4429061535100305408
INFO:google.cloud.aiplatform.base:Deleting Endpoint : projects/654544512569/locations/us-central1/endpoints/4429061535100305408
INFO:google.cloud.aiplatform.base:Delete Endpoint  backing LRO: projects/654544512569/locations/us-central1/operations/948537961043984384
INFO:google.cloud.aiplatform.base:Endpoint deleted. . Resource name: projects/654544512569/locations/us-central1/endpoints/4429061535100305408
