# Serving Models via TF Serving
## Learning Objectives
1. Learn how to export Keras models in SavedModel format
2. Learn how to load and use SavedModel
3. Learn how to customize signatures using TensorFlow
4. Learn how to deploy models to Vertex AI
5. Learn how to use a deployed model in online and batch prediction

In this lab, you will learn how to serve models after training. <br>

Serving machine learning models requires infrastructure. Vertex AI makes this simple by providing autoscaling services that reduce setup and maintenance effort.

To use Vertex AI, we will look at how to export a Keras model in SavedModel format and deploy it into Vertex AI. Along the way, you learn about signatures, how to customize them, and how to get predictions out of a deployed model.

## Setup

In [None]:
import os
import warnings

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
warnings.filterwarnings("ignore")

# Set `PATH` to include the directory containing saved_model_cli
PATH = %env PATH
%env PATH=/home/jupyter/.local/bin:{PATH}

In [None]:
import base64
import json
import shutil
from datetime import datetime

import keras
import keras_hub
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import requests
import tensorflow as tf
from google.cloud import aiplatform
from oauth2client.client import GoogleCredentials

## Build and Train a Model
Model training is not a focus in this lab, so let's create a simple Mobilenet-based model and use transfer learning to train quickly.

In [None]:
PROJECT = !(gcloud config get-value core/project)
PROJECT = PROJECT[0]
BUCKET = PROJECT + "-flowers"
FILE_DIR = f"gs://{BUCKET}/data"

CLASSES = ["daisy", "dandelion", "roses", "sunflowers", "tulips"]

IMG_HEIGHT = 224
IMG_WIDTH = 224
IMG_CHANNELS = 3

BATCH_SIZE = 32

If you haven't run [create_tfrecords_at_scale.ipynb](https://github.com/GoogleCloudPlatform/asl-ml-immersion/blob/master/notebooks/image_models/solutions/create_tfrecords_at_scale.ipynb) notebook, please uncomment the cell below and copy the data from `gs://asl-public` bucket.

In [None]:
# !gsutil mb gs://{BUCKET}
# !gsutil cp gs://asl-public/data/flowers/tfrecords/* {FILE_DIR}

This dataset contains images of flowers that have been serialized to TFRecords. Use the tf.data API to read and parse the data.

In [None]:
!gsutil ls {FILE_DIR}

In [None]:
TRAIN_PATTERN = FILE_DIR + "/train*"
EVAL_PATTERN = FILE_DIR + "/eval*"


def parse_example(example):
    feature_description = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "label": tf.io.FixedLenFeature([], tf.int64),
    }
    example = tf.io.parse_single_example(example, feature_description)
    example["image"] = tf.io.decode_jpeg(example["image"], channels=3)
    example["image"] = tf.image.resize(
        example["image"], [IMG_HEIGHT, IMG_WIDTH]
    )
    return example["image"], example["label"]


train_ds = (
    tf.data.TFRecordDataset(tf.io.gfile.glob(TRAIN_PATTERN))
    .map(parse_example)
    .batch(BATCH_SIZE)
)
eval_ds = (
    tf.data.TFRecordDataset(tf.io.gfile.glob(EVAL_PATTERN))
    .map(parse_example)
    .batch(10)
)

In [None]:
backbone = keras_hub.models.Backbone.from_preset(
    "mobilenet_v3_large_100_imagenet_21k",
)

transfer_model = tf.keras.Sequential(
    [
        keras.Input(
            shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), name="mobilenet_input"
        ),
        keras.layers.Rescaling(scale=1.0 / 255),
        backbone,
        keras.layers.GlobalMaxPooling2D(),
        keras.layers.Dropout(rate=0.2),
        keras.layers.Dense(
            len(CLASSES),
            activation="softmax",
            kernel_regularizer=tf.keras.regularizers.l2(0.0001),
        ),
    ]
)

transfer_model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)

In [None]:
transfer_model.fit(
    train_ds,
    epochs=5,
    validation_data=eval_ds,
)

## Export & load SavedModel
Now we have a trained model. In this section, we will export the model in SavedModel format.

Also, we will look at how to extend our model with an additional serving function. Additional serving functions allow you to provide preprocessing and/or postprocessing logic to a model's prediction.

### Export in SavedModel 
A SavedModel contains a complete model program, including trained parameters  (i.e., tf.Variables) and the computation graph. We don't need the original model building code to run a model exported with SavedModel, making it useful for sharing or deploying with LiteRT, TensorFlow.js, TensorFlow Serving.

You can save and load a model in the SavedModel format using the following APIs:

- Keras Model API: Keras supports SavedModel format export via [keras.Model.export](https://keras.io/api/models/model_saving_apis/export/) method using `format="tf_saved_model"` option. It also supports "onnx" format.
- Keras ExportAchive: [ExportArchive](https://keras.io/api/models/model_saving_apis/export/#exportarchive-class) module supports more finer control like configuring different serving endpoints as well as their signatures.

First, let's export our trained model to a SavedModel with a Keras API.

In [None]:
shutil.rmtree("export", ignore_errors=True)
os.mkdir("export")
# For a normal model serialization you can use .save() with .keras extension
transfer_model.save("export/flowers_model.keras")

transfer_model.export(
    "export/saved_model", format="tf_saved_model", verbose=False
)

Let's take a look at the directory.

In [None]:
!ls export/saved_model

We can see multiple files in the directory.

- `saved_model.pb` is the SavedModel main file which contains the actual TensorFlow program, or model, and a set of named signatures, each identifying a function that accepts tensor inputs and produces tensor outputs.
- `keras_metadata.pb` file is created only with tf.keras.Model.save() function. It contains metadata regarding the Keras model.
- `variables` directory contains all the variables of the model.
- `assets` directory contains arbitrary files, called assets, that are needed for SavedModel. For example, a vocabulary file used to initialize a lookup table. Upon loading, the assets and the serialized functions that depend on them will refer to the correct file paths inside the SavedModel directory.


### Investigate a SavedModel with `saved_model_cli` command

If you installed TensorFlow through a pre-built TensorFlow binary, then the SavedModel CLI is already installed on your system at pathname `bin/saved_model_cli`.

The SavedModel CLI supports the following two commands on a SavedModel:

- `show`, which shows the computations available from a SavedModel.
- `run`, which runs a computation from a SavedModel.

Here let's investigate the SavedModel with `saved_model_cli show` command specifying the default tag (`serve`) and signature (`serving_default`).

In [None]:
!saved_model_cli show --dir export/saved_model --tag_set serve --signature_def serving_default

Now we can see the concrete function's descriptions.

- Its input shape is a batch(`-1`) of  224x224 images with 3 channels, named `'mobilenet_input'`.
- Its output shape is a batched 5 float values (that represent probabilities of 5 flowers), named `'output_*'`

### load and predict

Once exported as SavedModel, we can load and use the model in a program.

If it is a Python program, you can just call `keras.models.load_model` and load the `.keras.` file instead of SavedModel.

Also, in the prediction phase, we sometimes (like in the Web API case) cannot expect that our model always receives preprocessed TFRecords data or batched `(224, 224, 3)` Tensors that we used in the training phase.<br>
Let's say the model receives file paths to image data. Then we need to add preprocessing operations to handle the image paths before calling the model. 

In [None]:
filenames = [
    "gs://asl-public/data/flowers/jpegs/10172567486_2748826a8b.jpg",
    "gs://asl-public/data/flowers/jpegs/10386503264_e05387e1f7_m.jpg",
    "gs://asl-public/data/flowers/jpegs/10391248763_1d16681106_n.jpg",
    "gs://asl-public/data/flowers/jpegs/10712722853_5632165b04.jpg",
    "gs://asl-public/data/flowers/jpegs/10778387133_9141024b10.jpg",
    "gs://asl-public/data/flowers/jpegs/112334842_3ecf7585dd.jpg",
]


def preprocess(img_bytes):
    img = tf.image.decode_jpeg(img_bytes, channels=IMG_CHANNELS)
    img = keras.ops.image.resize(img, [IMG_HEIGHT, IMG_WIDTH])
    return img


def read_from_jpegfile(filename):
    img = tf.io.read_file(filename)
    img = preprocess(img)
    return img


serving_model = keras.models.load_model("export/flowers_model.keras")
input_images = [read_from_jpegfile(f) for f in filenames]

f, ax = plt.subplots(1, 6, figsize=(15, 15))
for idx, img in enumerate(input_images):
    ax[idx].imshow(img.numpy())
    batch_image = keras.ops.expand_dims(img, axis=0)
    batch_pred = serving_model.predict(batch_image)
    pred = batch_pred[0]
    pred_label_index = keras.ops.argmax(pred).numpy()
    pred_label = CLASSES[pred_label_index]
    prob = pred[pred_label_index]
    ax[idx].set_title(f"{pred_label} ({prob:.2f})")

### Defining Additional Serving Function

So do we need to write this boiler-plate preprocessing code each time?<br>
The answer is no! By adding a custom serving signature, we can incorporate these additional preprocessing or postprocessing functions into the SavedModel itself.

Once compiled into a SavedModel, the same graph can be executed in various environments, such as on edge devices, within C++ code, or in Javascript programs.

Let's assume we deploy this model to a web server and provide prediction via a web API.

What kind of signature would be easiest for API clients to use? Instead of asking them to send us tensors of the image contents, we can simply ask them for a GCS JPEG file path, for example. <br>
And instead of returning a tensor of 5 probabilities, we can send back easy-to-understand information extracted from the probabilities.

In [None]:
def postprocess(pred):
    top_prob = keras.ops.max(pred, axis=[1])
    pred_label_index = keras.ops.argmax(pred, axis=1)
    pred_label = keras.ops.take(
        keras.ops.convert_to_tensor(CLASSES), pred_label_index
    )

    # custom output
    return {
        "probability": top_prob,
        "flower_type_int": pred_label_index,
        "flower_type_str": pred_label,
    }


def predict_from_filename(filenames):

    # custom pre-process
    input_images = keras.ops.map(read_from_jpegfile, filenames)

    # model
    batch_pred = transfer_model(input_images)  # same as model.predict()

    # custom post-process
    processed = postprocess(batch_pred)
    return processed

Additionally, let's define another serving function that can receive and preprocess base64 encoded image data.<br>
This is useful when we want to send raw image data in online prediction.

In [None]:
def predict_from_b64(img_bytes):

    # custom pre-process
    input_images = keras.ops.map(preprocess, img_bytes)

    # model
    batch_pred = transfer_model(input_images)

    # custom post-process
    processed = postprocess(batch_pred)
    return processed

Here, let's export these graphs in the SavedModel format under different key names using `add_endpoint()`.
- `serving_default` is the default key name for SavedModel. We override this key with `predict_from_filename` serving function.
- We add a new serving signature key `predict_from_b64` with `predict_from_b64` that takes an image byte string.

In [None]:
export_archive = keras.export.ExportArchive()

export_archive.track(transfer_model)

export_archive.add_endpoint(
    name="serving_default",
    fn=predict_from_filename,
    # this function receives 1 string value (filename).
    input_signature=[keras.InputSpec(shape=(None,), dtype="string")],
)

export_archive.add_endpoint(
    name="predict_base64",
    fn=predict_from_b64,
    # this function receives 1 string value (byte string).
    input_signature=[keras.InputSpec(shape=(None,), dtype="string")],
)

export_archive.write_out("export/flowers_model_with_signature", verbose=False)

Let's take a look at the new SavedModel metadata.

In [None]:
!saved_model_cli show --dir export/flowers_model_with_signature --tag_set serve --signature_def serving_default

Now we can see our new `serving_default` gets file paths and return dictionaries with three keys (`flower_type_int`, `flower_type_str`, and `probability`).

Let's try to load and use this SavedModel.<br>
In Keras, you can load a compiled SavedModelusing `TFSMLayer` layer. (In TensorFlow, you can use `tf.saved_model.load` to do the same)

Notice that now we don't need to write additional preprocessing or postprocessing codes.

In [None]:
serving_fn = keras.layers.TFSMLayer(
    "export/flowers_model_with_signature", call_endpoint="serving_default"
)

pred = serving_fn(keras.ops.convert_to_tensor(filenames))

# print custom outputs
for k in pred.keys():
    print(f"{k:15}: {pred[k].numpy()}")

These outputs look more useful for API clients than a vector of probabilities.

And let's check `predict_base64` signature as well.

In [None]:
!saved_model_cli show --dir export/flowers_model_with_signature --tag_set serve --signature_def predict_base64

## Vertex AI Prediction

Now our model is ready for deployment!

In this notebook, we deploy our model to the scalable Vertex AI service.
Vertex AI supports both Batch Prediction and Online Prediction. 

First, let's upload the SavedModel to Vertex AI.

### Upload model to Vertex AI Prediction service

In [None]:
TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")

REGION = "us-central1"
MODEL_DISPLAYNAME = f"flower_classifier-{TIMESTAMP}"

print(f"MODEL_DISPLAYNAME: {MODEL_DISPLAYNAME}")

# from https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers
SERVING_CONTAINER_IMAGE_URI = (
    "us-docker.pkg.dev/vertex-ai-restricted/prediction/tf_opt-gpu.2-17:latest"
)

We upload the SavedModel to a GCS bucket at first.

In [None]:
!gsutil cp -R export/flowers_model_with_signature gs://{BUCKET}/{MODEL_DISPLAYNAME}

We can use Python SDK to upload models.

Here we are specifying `display_name`, `artifact_uri`, which is the path of SavedModel, and `serving_container_image_uri`, which is a container environment on which our model runs (pre-build container is selected in this case, but you can use a custom container if needed).

For more detail, please refer to [the SDK document](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.Model#google_cloud_aiplatform_Model_upload).

In [None]:
uploaded_model = aiplatform.Model.upload(
    display_name=MODEL_DISPLAYNAME,
    artifact_uri=f"gs://{BUCKET}/{MODEL_DISPLAYNAME}",
    serving_container_image_uri=SERVING_CONTAINER_IMAGE_URI,
)

After uploading it, you can check your model on the console UI by clicking Vertex AI -> Models

### Batch Prediction

In batch prediction, we can pass a large dataset to our model and predict as a batch.<br>

#### Create a prediction file
[Batch Prediction](https://cloud.google.com/vertex-ai/docs/predictions/batch-predictions) service accepts JSON Lines, TF Records, CSV, or simple text file list format.

Here we create a simple dataset containing many image file paths in [JSON Lines](https://jsonlines.org/) format.

In [None]:
files = !gsutil ls -r gs://asl-public/data/flowers/jpegs/*.jpg
print(len(files))

In [None]:
JSON_FILE = "batch_prediction.jsonl"

with open(JSON_FILE, "w") as f:
    for file in files:
        f.write(json.dumps({"filenames": file}) + "\n")

!head {JSON_FILE}

In [None]:
!gsutil cp {JSON_FILE} {FILE_DIR}

#### Send a Batch Prediction Job
Let's call a batch prediction job with [`aiplatform.batch_predict()`](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.Model#google_cloud_aiplatform_Model_batch_predict) function.

Note that we can specify machine type and accelerator as needed.<br>
This is very useful when we want to process a large amount of data in a limited time.

In [None]:
TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")

JOB_DISPLAY_NAME = "flower_classification_batch"
MACHINE_TYPE = "n1-standard-4"

batch_pred_job = uploaded_model.batch_predict(
    job_display_name=JOB_DISPLAY_NAME,
    gcs_source=f"{FILE_DIR}/{JSON_FILE}",
    gcs_destination_prefix=f"{FILE_DIR}/batch_prediction_result/{TIMESTAMP}",
    machine_type=MACHINE_TYPE,
    accelerator_type="NVIDIA_TESLA_T4",
    accelerator_count=1,
    sync=False,
)

**Notice it takes around 20 minutes. Please wait until that or move forward to the Online Prediction section and return to the next cell later. You can check the status on Vertex AI -> Batch Predictions page**

In [None]:
if batch_pred_job.output_info:
    output_dir = batch_pred_job.output_info.gcs_output_directory
    results = !gsutil cat {output_dir}/prediction.results*
    for r in results[:5]:
        r = json.loads(r)
        print(f"filename       : {r['instance']['filenames']}")
        for k in r["prediction"].keys():
            print(f"{k:15}: {r['prediction'][k]}")
        print("*" * 30)
else:
    print(f"This job is still running.")

### Online Prediction

In the Online Prediction option, you can create a dedicated endpoint for your model, and use it as a web API.

Let's create an endpoint and link your model to it by [`aiplatform.Model.deploy`](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.Model#google_cloud_aiplatform_Model_deploy) function. Here you can also specify the machine type and the accelerators.

**The command below takes around 10 minutes.**

In [None]:
endpoint = uploaded_model.deploy(
    machine_type=MACHINE_TYPE,
    accelerator_type="NVIDIA_TESLA_T4",
    accelerator_count=1,
)

After the deployment, we can simply call the endpoint and retrieve the result. <br>
You can check the endpoint details by visiting the Vertex AI -> Endpoints page.

Here we stick with the [Python SDK](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.Endpoint#google_cloud_aiplatform_Endpoint_predict), but note that you can call the endpoint from any environment.

In [None]:
instances = [{"filenames": f} for f in filenames]

In [None]:
pred = endpoint.predict(instances=instances)

# print custom outputs
for p in pred.predictions:
    for k in p.keys():
        print(f"{k:15}: {p[k]}")
    print("*" * 30)

### Online Prediction with raw images

Next, let's call `predict_base64` signature and pass raw image data. <br>
In order to submit raw image data via API, you must Base64 encode the data and encapsulate it in a JSON object having b64 as the key as follows:

```python
{ "b64": <base64 encoded string> }
```

In [None]:
# Download a sample image from GCS.
!gsutil cp gs://asl-public/data/flowers/jpegs/10172567486_2748826a8b.jpg sample.jpg


def b64encode(filename):
    with open(filename, "rb") as ifp:
        img_bytes = ifp.read()
        return base64.b64encode(img_bytes).decode()


data = {
    "signature_name": "predict_base64",
    "instances": [{"img_bytes": {"b64": b64encode("./sample.jpg")}}],
}

Since Python SDK `endpoint.predict` supports only `serving_default` signature, here let's define an API call directly with general Python `request` module, and use `rawPredict` API to spefify other signatures instead.

In order to do so, we need to define an authorization token and wrap it in a request header.

In [None]:
token = (
    GoogleCredentials.get_application_default().get_access_token().access_token
)
headers = {"Authorization": "Bearer " + token}

The endpoint URL is `https://<region>-aiplatform.googleapis.com/v1/projects/<project id>/locations/<region>/endpoints/<endpoint id>:rawPredict`. <br>
Let's define accordingly and send the encoded raw image to the API.

In [None]:
api = "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/endpoints/{}:rawPredict".format(
    REGION, PROJECT, REGION, endpoint.name
)

response = requests.post(api, json=data, headers=headers)
json.loads(response.content)

Now we could get the result from the Vertex AI online prediction service.

## Summary
We learned how to:
- export and load a SavedModel
- customize the serving function to control a SavedModel behavior
- deploy a SavedModel to Vertex AI
- Use deployed model both for Batch and Online predictions.


Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.