In [1]:
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Batch prediction

Batch prediction is commonly used when you have thousands to millions of predictions. It will create a Vertex AI batch prediction job. We will put our prediction request JSONL file (multiple lines of JSON records) to GCS, and use the Python API to request the job.

In [2]:
import json
import os

import tensorflow as tf
import tensorflow_datasets as tfds
from absl import flags
from google.cloud import aiplatform
from jax.experimental.jax2tf.examples import mnist_lib

In [3]:
PROJECT_ID = !(gcloud config get-value project)
PROJECT_ID = PROJECT_ID[0]

REGION = "us-central1"

BUCKET_NAME = PROJECT_ID
# Use a regional bucket in the above region you have rights to.
# Create if needed:
# !gsutil mb -l $REGION gs://$BUCKET_NAME

MODEL_RESOURCENAME = None
MODEL_DISPLAYNAME = "jax_model_customcontainer"

USE_GPU_SERVING = False

SERVING_BATCH_SIZE = 3

## Fetch model

On Vertex AI, only the "resource name" (ending in a numerical ID) of a model is unique, not its "display name". Therefore while you can look up your model(s) by display_name, if you have multiple ones, you need to know the resource_name of the one you need (this is returned and printed when you upload the model).

In [4]:
# specify either a unique MODEL_DISPLAY_NAME, or a MODEL_RESOURCE_NAME
if MODEL_RESOURCENAME:
    model = aiplatform.Model(MODEL_RESOURCENAME)
else:
    models = aiplatform.Model.list(filter=f"display_name={MODEL_DISPLAYNAME}")
    if len(models) > 1:
        for model in models:
            print(model.resource_name, model.display_name)
        raise Exception(
            f"multiple models with display_name=={MODEL_DISPLAYNAME} "
            "(see above), please delete all but one, or use a resource_name"
        )
    model = models[0]

print(model.display_name, model.resource_name)

MODEL_DISPLAYNAME = model.display_name
MODEL_RESOURCENAME = model.resource_name

jax_model_customcontainer projects/654544512569/locations/us-central1/models/653734429503520768


In [5]:
# need to initialize flags somehow to avoid errors in load_mnist
flags.FLAGS([""])

images_to_predict, _ = next(
    iter(mnist_lib.load_mnist(tfds.Split.TEST, batch_size=SERVING_BATCH_SIZE))
)
with open("inputs.jsonl", "w") as file:
    for image in images_to_predict:
        json.dump(dict(inputs=image.numpy().tolist()), file)
        file.write("\n")

INFO:absl:Load dataset info from /home/jupyter/tensorflow_datasets/mnist/3.0.1
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.splits from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.module_name from disk and from code do not match. Keeping the one from code.
INFO:absl:Reusing dataset mnist (/home/jupyter/tensorflow_datasets/mnist/3.0.1)
INFO:absl:Constructing tf.data.Dataset mnist for split test, from /home/jupyter/tensorflow_datasets/mnist/3.0.1


In [6]:
!gsutil cp inputs.jsonl \
    gs://$BUCKET_NAME/batchpred/$MODEL_DISPLAYNAME/inputs.jsonl

Copying file://inputs.jsonl [Content-Type=application/octet-stream]...
/ [1 files][ 21.4 KiB/ 21.4 KiB]                                                
Operation completed over 1 objects/21.4 KiB.                                     


In [7]:
MACHINE_TYPE = "n1-standard-2"

batch_prediction_job = aiplatform.Model(MODEL_RESOURCENAME).batch_predict(
    job_display_name=f"{MODEL_DISPLAYNAME}_batchprediction",
    gcs_source=(
        f"gs://{BUCKET_NAME}/batchpred/" f"{MODEL_DISPLAYNAME}/inputs.jsonl"
    ),
    gcs_destination_prefix=f"gs://{BUCKET_NAME}/batchpred/{MODEL_DISPLAYNAME}",
    machine_type=MACHINE_TYPE,
)

batch_prediction_job.wait()

print(batch_prediction_job.display_name)
print(batch_prediction_job.resource_name)
print(batch_prediction_job.state)

INFO:google.cloud.aiplatform.jobs:Creating BatchPredictionJob
INFO:google.cloud.aiplatform.jobs:BatchPredictionJob created. Resource name: projects/654544512569/locations/us-central1/batchPredictionJobs/4116523155881721856
INFO:google.cloud.aiplatform.jobs:To use this BatchPredictionJob in another session:
INFO:google.cloud.aiplatform.jobs:bpj = aiplatform.BatchPredictionJob('projects/654544512569/locations/us-central1/batchPredictionJobs/4116523155881721856')
INFO:google.cloud.aiplatform.jobs:View Batch Prediction Job:
https://console.cloud.google.com/ai/platform/locations/us-central1/batch-predictions/4116523155881721856?project=654544512569
INFO:google.cloud.aiplatform.jobs:BatchPredictionJob projects/654544512569/locations/us-central1/batchPredictionJobs/4116523155881721856 current state:
JobState.JOB_STATE_RUNNING
INFO:google.cloud.aiplatform.jobs:BatchPredictionJob projects/654544512569/locations/us-central1/batchPredictionJobs/4116523155881721856 current state:
JobState.JOB_STAT

In [8]:
latest_prediction_dir = !gsutil ls gs://$BUCKET_NAME/batchpred/$MODEL_DISPLAYNAME/outputs | grep prediction-$MODEL_DISPLAYNAME- | tail -1
latest_prediction_dir = os.path.dirname(latest_prediction_dir[0])

The below error file might be empty, if all goes well:

In [9]:
# empty if no errors
!gsutil cat $latest_prediction_dir/prediction.errors_stats-*

In [10]:
!gsutil ls $latest_prediction_dir/prediction.results-*

gs://dsparing-sandbox/batchpred/jax_model_customcontainer/outputs/prediction-jax_model_customcontainer-2021_07_01T08_40_52_145Z/prediction.results-00000-of-00001


There might be multiple output files. We now only investigate the first one:

In [11]:
!gsutil cat $latest_prediction_dir/prediction.results-00000-of-00001 | jq -c '.["prediction"]'

[1;39m[[0;39m-11.9759874[0m[1;39m,[0;39m-14.1695709[0m[1;39m,[0;39m-11.7008791[0m[1;39m,[0;39m-7.97484255[0m[1;39m,[0;39m-5.06094456[0m[1;39m,[0;39m-9.23107433[0m[1;39m,[0;39m-13.9212132[0m[1;39m,[0;39m-3.24065161[0m[1;39m,[0;39m-6.47828436[0m[1;39m,[0;39m-0.0486364365[0m[1;39m[1;39m][0m
[1;39m[[0;39m-10.7667494[0m[1;39m,[0;39m-16.7841396[0m[1;39m,[0;39m-12.4569149[0m[1;39m,[0;39m-5.81106091[0m[1;39m,[0;39m-12.0771446[0m[1;39m,[0;39m-8.92639[0m[1;39m,[0;39m-17.6782799[0m[1;39m,[0;39m-0.0128059387[0m[1;39m,[0;39m-9.55736[0m[1;39m,[0;39m-4.65686941[0m[1;39m[1;39m][0m
[1;39m[[0;39m-10.7817202[0m[1;39m,[0;39m-13.3755789[0m[1;39m,[0;39m-8.78866291[0m[1;39m,[0;39m-11.231391[0m[1;39m,[0;39m-0.0163898468[0m[1;39m,[0;39m-7.05897188[0m[1;39m,[0;39m-7.22624874[0m[1;39m,[0;39m-5.87213516[0m[1;39m,[0;39m-5.62200928[0m[1;39m,[0;39m-4.82241917[0m[1;39m[1;39m][0m
