In [None]:
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Get started with Vertex Explainable AI Example Based API - Custom training image classification model

<table align="left">

  <td>
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/official/explainable_ai/get_started_with_vertex_xai_example_based_images.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Colab logo"> Run in Colab
    </a>
  </td>
  <td>
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/official/explainable_ai/get_started_with_vertex_xai_example_based_images.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo">
      View on GitHub
    </a>
  </td>
  <td>
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/vertex-ai-samples/main/notebooks/official/explainable_ai/get_started_with_vertex_xai_example_based_images.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo">
      Open in Vertex AI Workbench
    </a>
  </td>
</table>

**_NOTE_**: This notebook has been tested in the following environment:

* Python version = 3.9

## Overview

This notebook demonstrates how to get example-based explanations for an image classification model. With example-based explanations, Vertex AI uses nearest neighbor search to return a list of examples (typically from the training set) that are most similar to the input.

Learn more about [Vertex Explainable AI](https://cloud.google.com/vertex-ai/docs/explainable-ai/overview).

### Objective

In this tutorial, you learn how to get Example-Based explanations from Vertex Explainable AI services.

This tutorial uses the following Google Cloud ML services and resources:

- `Vertex AI Model Registry`
- `Vertex Explainable AI`
- `Vertex AI Prediction`


The steps performed include:

- Prepare training data
- Fine tune a image classication model to get embeddings
- Register the model in Vertex AI Model Registry
- Deploy the model in Vertex AI Endpoint
- Request explanations using Example-Based Explanation API
- Analyze the results

### Dataset

For this notebook, you use the [beans dataset](https://github.com/AI-Lab-Makerere/ibean/) downloaded through [TF Datasets](https://www.tensorflow.org/datasets/catalog/beans).

### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing)
and [Cloud Storage pricing](https://cloud.google.com/storage/pricing),
and use the [Pricing Calculator](https://cloud.google.com/products/calculator/)
to generate a cost estimate based on your projected usage.

## Installation

Install the following packages required to execute this notebook.

In [None]:
# Install the packages.
# Notice Tensorflow version has to be alligned with Vertex AI prebuild serving container.

USER = ""
! pip3 install {USER} --upgrade numpy tensorflow_datasets tensorflow==2.11.0  -q --no-warn-conflicts
! pip3 install {USER} --upgrade google-cloud-aiplatform -q --no-warn-conflicts

### Colab only: Uncomment the following cell to restart the kernel.

In [None]:
# Automatically restart kernel after installs so that your environment can access the new packages
# import IPython

# app = IPython.Application.instance()
# app.kernel.do_shutdown(True)

## Before you begin

### Set up your Google Cloud project

**The following steps are required, regardless of your notebook environment.**

1. [Select or create a Google Cloud project](https://console.cloud.google.com/cloud-resource-manager). When you first create an account, you get a $300 free credit towards your compute/storage costs.

2. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

3. [Enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).

4. If you are running this notebook locally, you need to install the [Cloud SDK](https://cloud.google.com/sdk).

#### Set your project ID

**If you don't know your project ID**, try the following:
* Run `gcloud config list`.
* Run `gcloud projects list`.
* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)

In [None]:
PROJECT_ID = "[your-project-id]"  # @param {type:"string"}

# Set the project id
! gcloud config set project {PROJECT_ID}

#### Region

You can also change the `REGION` variable used by Vertex AI. Learn more about [Vertex AI regions](https://cloud.google.com/vertex-ai/docs/general/locations).

In [None]:
REGION = "us-central1"  # @param {type: "string"}

### Authenticate your Google Cloud account

Depending on your Jupyter environment, you may have to manually authenticate. Follow the relevant instructions below.

**1. Vertex AI Workbench**
* Do nothing as you are already authenticated.

**2. Local JupyterLab instance, uncomment and run:**

In [None]:
# ! gcloud auth login

**3. Colab, uncomment and run:**

In [None]:
# from google.colab import auth
# auth.authenticate_user()

**4. Service account or other**
* See how to grant Cloud Storage permissions to your service account at https://cloud.google.com/storage/docs/gsutil/commands/iam#ch-examples.

### Create a Cloud Storage bucket

Create a storage bucket to store intermediate artifacts such as datasets.

In [None]:
BUCKET_NAME = f"your-bucket-name-{PROJECT_ID}-unique" # @param {type:"string"}
BUCKET_URI = f"gs://{BUCKET_NAME}"

**Only if your bucket doesn't already exist**: Run the following cell to create your Cloud Storage bucket.

In [None]:
! gsutil mb -l $REGION -p $PROJECT_ID $BUCKET_URI

### Set up project template
Set the folder you use in this tutorial.

In [None]:
import os

TUTORIAL_DIR = os.path.join(os.getcwd(), "sdk_xai_example_based_tutorial")
DATA_DIR = os.path.join(TUTORIAL_DIR, "data")
MODEL_DIR = os.path.join(TUTORIAL_DIR, "model")

for path in TUTORIAL_DIR, DATA_DIR, MODEL_DIR:
    os.makedirs(path, exist_ok=True)

### Set the dataset and the model to explain

Indicate the dataset and the bucket uri of the model to explain

In [None]:
DATASET_NAME = "beans"
MODEL_FILE_NAME = f"mobilenetv2-{DATASET_NAME}.tar.gz"
SOURCE_MODEL_URI = BUCKET_URI + "/model/" + MODEL_FILE_NAME

### Import libraries

In [None]:
# General
import io
from PIL import Image
import base64
from io import BytesIO
from google.cloud import storage
import tarfile
import json
import numpy as np
from matplotlib import pyplot as plt
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
import tensorflow as tf
from tensorflow import keras

# Vertex AI
from google.cloud import aiplatform as vertex_ai
from google.cloud import aiplatform_v1beta1 as vertex_ai_v1beta1
from google.cloud.aiplatform_v1beta1.types import io as io_pb2
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value

### Constants

In [None]:
# API service endpoint
API_ENDPOINT = "{}-aiplatform.googleapis.com".format(REGION)

# Vertex location root path for your dataset, model and endpoint resources
PARENT = "projects/" + PROJECT_ID + "/locations/" + REGION

# Training
CHANNELS = 3
SIZE = (224, 224)
BATCH_SIZE = 32
NUM_BATCHES = -1

# Serving
MODEL_SOURCE_FILE_NAME = "model/" + MODEL_FILE_NAME
MODEL_DESTINATION_FILE_NAME = os.path.join(MODEL_DIR, MODEL_FILE_NAME)
MODEL_FOLDER_DIR = os.path.join(MODEL_DIR, f"mobilenetv2-{DATASET_NAME}")
ENBEDDINGS_URI = BUCKET_URI + f"/model/mobilenetv2-{DATASET_NAME}"
TRAIN_DATASET_FILE = DATASET_NAME + "train-images.jsonl"
TRAIN_SOURCE_JSON_PATH = os.path.join(DATA_DIR, TRAIN_DATASET_FILE)
TRAIN_DESTINATION_JSON_PATH = 'data/' + TRAIN_DATASET_FILE
TRAIN_DATASET_URI = BUCKET_URI + '/' + TRAIN_DESTINATION_JSON_PATH

### Helpers

In [None]:
def create_index_to_name_map(ds_info):
    """
    Creates a map from label name to numerical index.
    Args:
        ds_info: DatasetInfo object.
    Returns:
        index_to_name_map: dict. Map from name to index.
    """
    index_to_name = {}
    num_classes = ds_info.features["label"].num_classes
    names = ds_info.features["label"].names
    for i in range(num_classes):
        index_to_name[i] = names[i]
    return index_to_name


def extract_images_and_labels(ds, num_batches):
    """
    Extract images and labels from a dataset.
    Args:
        ds: A dataset.
        num_batches: The number of batches to extract. -1 uses the whole dataset
    Returns:
        images: A numpy structure of images.
        labels: A numpy structure of labels.
    """
    data_slice = ds.take(num_batches)
    images = []
    labels = []
    for image, label in data_slice:
        images.append(image)
        labels.append(label)
    images = tf.concat(images, 0)
    labels = tf.concat(labels, 0)
    return images.numpy(), labels.numpy()

def get_instance(index, image):
    """
    Get the instance to send to the model
    Args:
        index: The index associated with image
        image: The image to send to the model
    Returns:
        The instance to send to the model
    """
    img_bytes = io.BytesIO()
    img = Image.fromarray(image.astype(np.uint8))
    img.save(img_bytes, format="PNG")
    instance = {
                "id": str(index),
                "bytes_inputs": {
                    "b64": base64.b64encode(img_bytes.getvalue()).decode("utf-8")
                },
              }
    return instance

def write_jsonl(saved_jsonl_path, images):
    """
    Write the jsonl file to send to the model
    Args:
        saved_jsonl_path: The path to save the jsonl file
        images: The images to send to the model
    Returns:
        None
    """
    with open(saved_jsonl_path, "w") as f:
        for i, im in enumerate(images):
            instance = get_instance(i, im)
            json.dump(
                instance,
                f,
            )
            f.write("\n")

def upload_model(model_configuration):
    """
    Upload the model to Vertex AI
    Args:
        model_configuration: The model configuration
    Returns:
        The uploaded model
    """

    model = vertex_ai_v1beta1.Model(
        display_name=model_configuration["display_name"],
        artifact_uri=model_configuration["artifact_uri"],
        metadata_schema_uri=model_configuration["metadata_schema_uri"],
        explanation_spec=model_configuration["explanation_spec"],
        container_spec=model_configuration["container_spec"],
    )

    response = clients["model"].upload_model(parent=PARENT, model=model)
    print("Long running operation:", response.operation.name)
    uploaded_model = response.result(timeout=10000)
    print("upload_model_response")
    print(" model:", uploaded_model)
    return uploaded_model

def create_endpoint(endpoint_config):
    """
    Create an endpoint
    Args:
        endpoint_config: The endpoint configuration
    Returns:
        The created endpoint
    """
    response = clients["endpoint"].create_endpoint(parent=PARENT, endpoint=endpoint_config)
    print("Long running operation:", response.operation.name)
    endpoint = response.result()
    print("create_endpoint_response")
    print(" endpoint:", endpoint)
    return endpoint

def deploy_model(
        model, endpoint, deploy_config
):
    """
    Deploy a model to an endpoint
    Args:
        model: The model to deploy
        endpoint: The endpoint to deploy the model
        deploy_config: The model deployment configuration
    Returns:
        The deployed model
    """


    if deploy_config["deploy_gpu"]:
        machine_spec = {
            "machine_type": deploy_config["deploy_compute"],
            "accelerator_type": deploy_config["deploy_gpu"],
            "accelerator_count": deploy_config["deploy_ngpu"],
        }
    else:
        machine_spec = {
            "machine_type": deploy_config["deploy_compute"],
        }

    deployed_model = {
        "model": model,
        "display_name": deploy_config["deployed_model_display_name"],
        "dedicated_resources": {
            "min_replica_count": deploy_config["min_nodes"],
            "max_replica_count": deploy_config["max_nodes"],
            "machine_spec": machine_spec,
        },
        "enable_container_logging": False,
    }

    response = clients["endpoint"].deploy_model(
        endpoint=endpoint, deployed_model=deployed_model, traffic_split=deploy_config["traffic_split"]
    )

    print("Long running operation:", response.operation.name)
    deployed_model = response.result(timeout=10000)
    print("deploy_model_response")
    print(" deployed_model:", deployed_model)

    return deployed_model

def explain_image(formatted_data, endpoint, parameters, deployed_model_id):
    """
    Get example based explanations an image
    Args:
        formatted_data: The data to send to the model
        endpoint: The endpoint to send the data to
        parameters: The parameters to send to the model
        deployed_model_id: The deployed model id
    Returns:
        The response from the model
    """

    # The format of each instance should conform to the deployed model's prediction input schema.
    instances_list = formatted_data
    instances = [
        json_format.ParseDict(instance, Value()) for instance in instances_list
    ]

    response = clients["prediction"].explain(
        endpoint=endpoint,
        instances=instances,
        parameters=parameters,
        deployed_model_id=deployed_model_id,
    )
    print("response")
    print(" deployed_model_id:", response.deployed_model_id)
    predictions = response.predictions
    print("predictions")
    for prediction in predictions:
        print(" prediction:", prediction)

    explanations = response.explanations
    print("explanations")
    for explanation in explanations:
        print(" explanation:", explanation)

    return response

def plot_input_and_neighbors(
    val_img_idx,
    all_train_images,
    val_images,
    all_train_labels,
    val_labels,
    label_index_to_name,
    data_with_neighbors,
):
    """
    Plot the input image and its neighbors.
    Args:
        val_img_idx: Index of the input image.
        all_train_images: All training images.
        val_images: Validation images.
        all_train_labels: All training labels.
        val_labels: Validation labels.
        label_index_to_name: Dictionary mapping label indices to names.
        data_with_neighbors: Data with neighbors.
    Returns:
        None
    """
    image = val_images[val_img_idx]
    fig = plt.figure(figsize=(24, 12))
    ax_list = fig.subplots(3, 5)
    ax_list[0, 0].axis("off")
    ax_list[0, 1].axis("off")
    ax_list[0, 3].axis("off")
    ax_list[0, 4].axis("off")
    ax = ax_list[0, 2]
    class_label = val_labels[val_img_idx]
    ax.set_title(
        f"{class_label}:{label_index_to_name[class_label]} (example index: {val_img_idx})",
        fontsize=15,
    )
    ax.axis("off")
    ax.imshow(image.astype("uint8"))

    neighbor_list = data_with_neighbors[val_img_idx]["neighbors"]
    num_neighbors = len(neighbor_list)
    for n in range(num_neighbors):
        neighbor = neighbor_list[n]
        neighbor_idx = int(neighbor["neighborId"])
        neighbor_dist = neighbor["neighborDistance"]
        ax = ax_list[1 + n // 5, n % 5]
        class_label = all_train_labels[neighbor_idx]
        ax.set_title(
            f"{class_label}:{label_index_to_name[class_label]} (dist: {neighbor_dist:.3f})",
            fontsize=15,
        )
        ax.axis("off")
        ax.imshow(all_train_images[neighbor_idx].astype("uint8"))

def undeploy_model(deployed_model_id, endpoint):
    """
    Undeploy a model from an endpoint
    Args:
        deployed_model_id: The deployed model id
        endpoint: The endpoint to undeploy the model from
    Returns:
        None
    """
    response = clients["endpoint"].undeploy_model(
        endpoint=endpoint, deployed_model_id=deployed_model_id, traffic_split={}
    )
    print(response)

def delete_endpoint(endpoint_id):
    """
    Delete an endpoint
    Args:
        endpoint_id: The name of endpoint to delete
    Returns:
        None
    """
    response = clients["endpoint"].delete_endpoint(name=endpoint_id)
    print(response)

### Initialize Vertex AI SDK for Python

Initialize the Vertex AI SDK for Python for your project.

In [None]:
vertex_ai.init(project=PROJECT_ID, location=REGION, staging_bucket=BUCKET_URI)

### Set up clients

The Vertex AI client library works as a client/server model. Then you need to set clients to use different services.

You will use different clients in this tutorial for different steps in the workflow. So set them all up upfront.

- Model Service for `Model` resources.
- Endpoint Service for deployment.
- Job Service for batch jobs and custom training.
- Prediction Service for serving.

In [None]:
# client options same for all services
client_options = {"api_endpoint": API_ENDPOINT}


def create_model_client():
    client = vertex_ai_v1beta1.ModelServiceClient(client_options=client_options)
    return client


def create_endpoint_client():
    client = vertex_ai_v1beta1.EndpointServiceClient(client_options=client_options)
    return client


def create_prediction_client():
    client = vertex_ai_v1beta1.PredictionServiceClient(client_options=client_options)
    return client


clients = {}
clients["model"] = create_model_client()
clients["endpoint"] = create_endpoint_client()
clients["prediction"] = create_prediction_client()

for client in clients.items():
    print(client)

## Working with Vertex Explainable Example-based API

Vertex Explainable Example-based API provides an highly performant ANN service for returning similar examples to new predictions/instances.

To leverage Vertex AI Example-based explanations, you need to cover the following steps:

- Index the entire dataset: It requires to provide a path to an embedding model in a GCS bucket, training data stored in a GCS bucket and the config file for example-based explanation

- Deploy index and model: You need to specify the machine to use and the model identifier from the model upload set

- Query for similar examples: You need to make the explain query and model will return similar examples

Below you use a `MobileNetV2` [Keras Application](https://keras.io/api/applications/) deep learning model that is available alongside pre-trained weights for fine-tuning the model which will be used to create embeddings.

### Download and visualize the data

Download the data.

In [None]:
split_ds, ds_info = tfds.load(
    DATASET_NAME,
    split=["train", "test"],
    as_supervised=True,
    with_info=True,
    shuffle_files=False,
    data_dir=DATA_DIR,
)
train_ds, validation_ds = split_ds

Visualize the dataset.

In [None]:
tfds.show_examples(ds=train_ds, ds_info=ds_info)

### Download the model to explain

For your convenience, you copy and extract a pretrained MobileNetV2 model for this exercise.

In [None]:
! gsutil cp {SOURCE_MODEL_URI} {MODEL_DESTINATION_FILE_NAME}

In [None]:
with tarfile.open(MODEL_DESTINATION_FILE_NAME) as file:
  file.extractall(MODEL_DIR)
  file.close()

Print a summary of the model to the console, including the number of layers and the number of neurons in each layer

In [None]:
model = tf.keras.models.load_model(MODEL_FOLDER_DIR)
model.summary()

### Index the entire dataset

To index the dataset you will use to get similar examples, you provide:

- embedding model
- training dataset
- config file for example-based explanation

#### Extract embeddings

To generate example-based explanations, you need to extract embedding model from the model you want to evaluate.

In this case, you skip the data augmentation layer and drop the softmax layer to get to the embeddings from the model you previously trained.

In [None]:
embedding_model = keras.Sequential()
# Loop over layers
for layer in model.layers[:-1]:
    # Skip data augmentation layer
    if "sequential" not in layer.name:
        embedding_model.add(layer)
embedding_model.summary()

####  Prepare embeddings for Vertex Explainable AI

Next, you need to upload enbeddings layer of the TF.Keras model to Vertex AI Model Registry as Vertex `Model` resource.

During the index deployment process, the model will be served to transform images into embeddings and create the index.

As you can imagine, images need some common preprocessing. When you use a TensorFlow pre-built container to serve the model and you want to include preprocessing, you need to define a serving function to convert data to the format your embeddings expects. You specify the input layer of the serving function as the signature `serving_default` and saved it back with the underlying model using `tf.saved_model.save`.

In [None]:
CONCRETE_INPUT = "numpy_inputs"

def _preprocess(bytes_input):
    """
    The preprocess function.
    Args:
        bytes_input: The input image in bytes.
    Returns:
        The preprocessed image in numpy array.
    """
    decoded = tf.io.decode_jpeg(bytes_input, channels=CHANNELS)
    decoded = tf.image.convert_image_dtype(decoded, tf.float32)
    resized = tf.image.resize(decoded, size=SIZE)
    rescale = tf.cast(resized, tf.float32)
    return rescale


@tf.function(input_signature=[tf.TensorSpec([None], tf.string)])
def preprocess_fn(bytes_inputs):
    """
    Preprocess the input image.
    Args:
        bytes_inputs: A list of raw image bytes.
    Returns:
        A list of preprocessed images.
    """
    decoded_images = tf.nest.map_structure(
        tf.stop_gradient, tf.map_fn(_preprocess, bytes_inputs, dtype=tf.float32)
    )
    return {CONCRETE_INPUT: decoded_images}


@tf.function(
    input_signature=[tf.TensorSpec([None], tf.string), tf.TensorSpec([None], tf.string)]
)
def serving_fn(id, bytes_inputs):
    """
    This function is used to serve the embeddings.
    Args:
        id: The id of the input.
        bytes_inputs: The input image.
    Returns:
        The output of the model.
    """
    images = preprocess_fn(bytes_inputs)
    embedding = m_call(**images)
    return {"id": id, "embedding": embedding}

#### Export Embeddings for Vertex Explainable AI

After you specify input signatures, you export embeddings as a SavedModel.

In [None]:
SIZE_3D = (None, 224, 224, 3)

m_call = tf.function(embedding_model.call).get_concrete_function(
    [tf.TensorSpec(shape=SIZE_3D, dtype=tf.float32, name=CONCRETE_INPUT)]
)

tf.saved_model.save(
    embedding_model,
    ENBEDDINGS_URI,
    signatures={
        "serving_default": serving_fn,
    },
)

In this tutorial, you will use a TensorFlow pre-built container on Vertex AI to serve example-based explanation. When you use a TensorFlow pre-built container to serve predictions, you need to provide the names of the input tensors and the output tensor of your model. These names will be part of an ExplanationMetadata message when you configure a Model for Vertex Explainable AI.

In [None]:
embedding_model_loaded = tf.saved_model.load(ENBEDDINGS_URI)

serving_input = list(
    embedding_model_loaded.signatures["serving_default"]
    .structured_input_signature[1]
    .keys()
)[0]

In [None]:
print("Serving function input:", serving_input)
serving_output = list(
    embedding_model_loaded.signatures["serving_default"].structured_outputs.keys()
)[0]
print("Serving function output:", serving_output)

input_name = model.input.name
print("Model input name:", input_name)
output_name = model.output.name
print("Model output name:", output_name)

#### Prepare the training dataset

Now that you get embeddings, you need to prepare the training dataset by converting images into jsonl file.

In [None]:
train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, SIZE), y))
train_ds = train_ds.batch(BATCH_SIZE).prefetch(buffer_size=10)
train_images, train_labels = extract_images_and_labels(
    train_ds, num_batches=NUM_BATCHES
)
write_jsonl(TRAIN_SOURCE_JSON_PATH, train_images)

In [None]:
! gsutil cp {TRAIN_SOURCE_JSON_PATH} {TRAIN_DATASET_URI}

#### Example-based explanation configuration

Finally, you need to define the example-based explanation configuration.

In particular, you need to specify:

- `parameters` which indicate the explainability algorithm to use for explanations on your model. In this tutorial, you will use `Examples`

- `metadata` which indicate how the algorithm is applied on your custom model.

##### Parameters

About `Parameters` of example based explanations, you need to provide `examples` which define conditions to return the nearest neighbors from the provided dataset.

With Example-based explanations, you have a new explanation method with associated parameter configuration. Below you have the list of the main properties you have to define.

- `dimensions` : The dimension of the embedding.
- `approximateNeighborsCount` : Number of neighbors to return.
- `distanceMeasureType` : The distance metric by which to measure nearness of examples. You can choose between ``SQUARED_L2_DISTANCE,  L1_DISTANCE, COSINE_DISTANCE and DOT_PRODUCT_DISTANCE``.
- `featureNormType` : Normalize the embeddings so that it has a unit length. You can choose between ``UNIT_L2_NORM or NONE``.
- `treeAhConfig`: Parameters controlling the trade-off between quality of approximation and speed. See the paper for technical details. Under the hood, it creates a shallow tree where the number of leaves is controlled by leafNodeEmbeddingCount and the search recall/speed tradeoff is controlled by leafNodesToSearchPercent.


In [None]:
dimensions = embedding_model.output.shape[1]

NUM_NEIGHBORS_TO_RETURN = 10

nearest_neighbor_search_config = {
    "contentsDeltaUri": "",
    "config": {
        "dimensions": dimensions,
        "approximateNeighborsCount": NUM_NEIGHBORS_TO_RETURN,
        "distanceMeasureType": "SQUARED_L2_DISTANCE",
        "featureNormType": "NONE",
        "algorithmConfig": {
            "treeAhConfig": {
                "leafNodeEmbeddingCount": 1000,
                "leafNodesToSearchPercent": 100,
            }
        },
    },
}

examples = vertex_ai_v1beta1.Examples(
    nearest_neighbor_search_config=nearest_neighbor_search_config,
    gcs_source=io_pb2.GcsSource(uris=[TRAIN_DATASET_URI]),
    neighbor_count=NUM_NEIGHBORS_TO_RETURN,
)

parameters = vertex_ai_v1beta1.ExplanationParameters(examples=examples)

##### Explanation Metadata

About metadata, you need to indicate

- `outputs`: It is represented by Map from output names to output metadata. In this case you expect embeddings.

- `inputs`: It is represented by Metadata of the input of a feature. In this case you have the encoded image and the id associated to it.

In [None]:
# for encoding parameter, 1 stands for 'IDENTITY'
IMAGE_INPUT_TENSOR_NAME = "bytes_inputs"
ID_INPUT_TENSOR_NAME = "id"
OUTPUT_TENSOR_NAME = "embedding"

explanation_inputs = {
    "my_input": vertex_ai_v1beta1.ExplanationMetadata.InputMetadata(
        {
            "input_tensor_name": IMAGE_INPUT_TENSOR_NAME,
            "encoding": vertex_ai_v1beta1.ExplanationMetadata.InputMetadata.Encoding(1),
            "modality": "image",
        }
    ),
    "id": vertex_ai_v1beta1.ExplanationMetadata.InputMetadata(
        {
            "input_tensor_name": ID_INPUT_TENSOR_NAME,
            "encoding": vertex_ai_v1beta1.ExplanationMetadata.InputMetadata.Encoding(1),
        }
    ),
}

explanation_outputs = {
    "embedding": vertex_ai_v1beta1.ExplanationMetadata.OutputMetadata(
        {"output_tensor_name": OUTPUT_TENSOR_NAME}
    )
}

explanation_meta_config = vertex_ai_v1beta1.ExplanationMetadata(
    inputs=explanation_inputs, outputs=explanation_outputs
)

explanation_spec = vertex_ai_v1beta1.ExplanationSpec(
    parameters=parameters, metadata=explanation_meta_config
)

### Deploy model and index

Now you are ready to deploy your model.

To deploy the model on Vertex AI, you need to create a `Model` resource. Then deploy the model to a `Endpoint` resource.

#### Upload the model

You can use `upload_model` helper function to upload your model, stored in SavedModel format, up to the `Model` service, which will instantiate a Vertex `Model` resource instance for your model. Below the parameters you need to define:

- `display_name`: A human readable name for the `Model` resource.
- `metadata_schema_uri`: Since your model was built without an Vertex `Dataset` resource, you will leave this blank (`''`).
- `artificat_uri`: The Cloud Storage path where the embeddings is stored in SavedModel format.
- `container_spec`: This is the specification for the Docker container that will be installed on the `Endpoint` resource, from which the `Model` resource will serve predictions.
- `explanation_spec`: This is the specification for enabling explainability for your model.

Uploading a model into a Vertex Model resource returns a long running operation which would take time. With example-based explanations, uploading a model triggers a batch prediction job to calculate embeddings and index them.


##### Define serving container configuration

In [None]:
DEPLOY_IMAGE_URI = "gcr.io/cloud-aiplatform/prediction/tf2-cpu.2-11:latest"

container_config = {"image_uri": DEPLOY_IMAGE_URI}
container_spec = vertex_ai_v1beta1.ModelContainerSpec(container_config)

##### Define Model configuration

In [None]:
MODEL_NAME = f"mobilenetv2-{DATASET_NAME}-similarity"
model_config = {
    "display_name": MODEL_NAME,
    "artifact_uri": ENBEDDINGS_URI,
    "metadata_schema_uri": "",
    "container_spec": container_spec,
    "explanation_spec": explanation_spec,
}

##### Upload the model

Upload the model would take more than 1 hour.


In [None]:
uploaded_model = upload_model(model_config)

#### Deploy the `Model` resource

To deploy the registered Vertex `Model` resource, you need to create an `Endpoint` resource. And then you deploy the `Model` resource to the `Endpoint` resource.

##### Create an `Endpoint` resource

You use `create_endpoint` to create an endpoint for serving the model. Below the configuration you have to specify with the name of the `Endpoint` resource and some additional information.

Creating an `Endpoint` resource returns a long running operation, since it may take a few moments to provision the `Endpoint` resource for serving.

In [None]:
ENDPOINT_NAME = f"{MODEL_NAME}-similarity-endpoint"
DESCRIPTION = "An endpoint for the similarity model"
LABELS = {"env": "prod", "status": "online"}

endpoint_config = {
        "display_name": ENDPOINT_NAME,
        "description": DESCRIPTION,
        "labels": LABELS,
    }

endpoint = create_endpoint(endpoint_config)

##### Deploy model to endpoint

You use `deploy_model` helper function to deploy the model to the endpoint you created. Below the parameters you have to define:

- `model`: The Vertex fully qualified identifier of the `Model` resource to upload (deploy) from the training pipeline.
- `endpoint`: The Vertex fully qualified `Endpoint` resource identifier to deploy the `Model` resource to.
- `deploy_config`: The deployment configuration to define the deployment resources (GPUs, machine type) and some other conditions such as traffic split policy.

In [None]:
DEPLOYED_MODEL_NAME = f"{MODEL_NAME}-deployed"
uploaded_model_id = uploaded_model.model
endpoint_id = endpoint.name

deploy_config = {
        "deployed_model_display_name": DEPLOYED_MODEL_NAME,
        "deploy_gpu": None,
        "deploy_ngpu": 0,
        "deploy_compute": 'n1-standard-4',
        "min_nodes" : 1,
        "max_nodes" : 1,
        "traffic_split" : {"0": 100}
        }

deployed_model = deploy_model(uploaded_model_id, endpoint_id, deploy_config)

### Query for similar examples

Lastly you can run an online prediction request to your deployed model to get your similar examples using a sample of validation dataset.

In this tutorial, you send STL10 images as compressed and encoded PNG image into base 64, instead of the raw uncompressed bytes that has been previously created.

Each instance in the prediction request is a dictionary entry of the form:

                        {`id`:, `bytes_inputs`: {'b64': content}}

- `id`: the unique identifier associated to the image.
- `bytes_inputs` : A map to contain decoded inputs.
- `'b64'`: A key that indicates the content is base 64 encoded.
- `content`: The compressed JPG image bytes as a base 64 encoded string.

You use `get_instance` helper function to create the prediction instances for the prediction request.

In [None]:
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, SIZE), y))
validation_ds = validation_ds.batch(BATCH_SIZE).prefetch(buffer_size=10)
val_images, val_labels = extract_images_and_labels(
    validation_ds, num_batches=NUM_BATCHES
)

In [None]:
val_data = []

for i, im in enumerate(val_images):
  val_instance = get_instance(i, im)
  val_data.append(val_instance)

#### Send the prediction with explanation request

To send the prediction with explanation request you use `explain_image` helper function, which takes the parameters:

- `image`: A list of test image data as a numpy array.
- `endpoint`: The Vertex fully qualified identifier for the `Endpoint` resource where the `Model` resource was deployed.
- `parameters_dict`: Additional parameters for serving.
- `deployed_model_id`: The Vertex fully qualified identifier for the deployed model, when more than one model is deployed at the endpoint. Otherwise, if only one model deployed, can be set to `None`.

In [None]:
INSTANCE_SIZE = 8
NUM_VAL_DATA = 16
deployed_model_id = deployed_model.deployed_model.id

all_neighbors = []

for data_idx in range(0, NUM_VAL_DATA, INSTANCE_SIZE):
    end_idx = min(data_idx + INSTANCE_SIZE, NUM_VAL_DATA)
    formatted_data = val_data[data_idx:end_idx]
    response = explain_image(formatted_data, endpoint_id, None, deployed_model_id)
    all_neighbors = (
        all_neighbors + json_format.MessageToDict(response._pb)["explanations"]
    )

print(f"\nExamples processed: {len(all_neighbors)}")

#### Save input ids and the corresponding neighbors

For each input image you sent, we create a dictionary with corrisponding neighbors.

In [None]:
# Save input ids and the corresponding neighbors
data_with_neighbors = []
input_data_list = val_data[:NUM_VAL_DATA]

for i, input_data in enumerate(input_data_list):
    neighbor_dict = all_neighbors[i]
    neighbor_dict["input"] = input_data["id"]
    data_with_neighbors.append(neighbor_dict)

DEBUG = False
if DEBUG:
    val_idx = 0
    print(data_with_neighbors[val_idx])
    print(data_with_neighbors[val_idx]["neighbors"])
    print(data_with_neighbors[val_idx]["input"])
    print(len(data_with_neighbors[val_idx]["neighbors"]))

#### Visualize the images with explanations

In the following representation, you will see for each image sent the ten closer examples the API generated according the distance you define.

As you can verify, although the `example index` results closed to image classified in the same category, in some cases the model wrongly indenfies the category. And you can easily visualize them by leveraging distances.

In [None]:
label_index_to_name = create_index_to_name_map(ds_info)

In [None]:
VAL_IMG_INDICES = [1, 2, 10]  # images to visually explore
for val_img_idx in VAL_IMG_INDICES:
    if val_img_idx > NUM_VAL_DATA - 1:
        raise ValueError(
            f"Data index {val_img_idx} does not exist in the requested explanations"
        )
    plot_input_and_neighbors(
        val_img_idx,
        train_images,
        val_images,
        train_labels,
        val_labels,
        label_index_to_name,
        data_with_neighbors,
    )

### Further exploration
If you want to continue exploring, here are some ideas:
1.   Isolate test points where the model is making mistakes (cat mislabed as bird), and visualize the example-based explanations to see if you can find any common patterns.
2.   If through this analysis, you find your training data is lacking in some representative cases (overhead images of cats), you can try adding such images to your dataset to see if that improves model performance.
3.   [Fine-tune](https://keras.io/guides/transfer_learning/) the lower layers of the model to see if you can improve the quality of example-based explanations by enabling the model to learn a better latent representation.


## Cleaning up

To clean up all Google Cloud resources used in this project, you can [delete the Google Cloud
project](https://cloud.google.com/resource-manager/docs/creating-managing-projects#shutting_down_projects) you used for the tutorial.

Otherwise, you can delete the individual resources you created in this tutorial


You can undeploy your `Model` resource from the serving `Endpoint` resoure with `undeploy_model` helper function, which takes the following parameters:

- `deployed_model_id`: The model deployment identifier returned by the endpoint service when the `Model` resource was deployed to.
- `endpoint`: The Vertex fully qualified identifier for the `Endpoint` resource where the `Model` is deployed to.

In [None]:
# delete flags
undeploy_model_flag = False
delete_endpoint_flag = False
delete_bucket_flag = False

# Undeploy model resource
if undeploy_model_flag:
  undeploy_model(deployed_model_id, endpoint_id)

# Delete endpoint resource
if delete_endpoint_flag:
  delete_endpoint(endpoint_id)

# Delete Cloud Storage objects that were created
if delete_bucket_flag or os.getenv("IS_TESTING"):
    ! gsutil -m rm -r $BUCKET_URI