In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI Model Garden - PaliGemma (Deployment)

<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fvertex-ai-samples%2Fmain%2Fnotebooks%2Fcommunity%2Fmodel_garden%2Fmodel_garden_jax_paligemma_deployment.ipynb">
      <img alt="Google Cloud Colab Enterprise logo" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" width="32px"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_jax_paligemma_deployment.ipynb">
      <img alt="GitHub logo" src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" width="32px"><br> View on GitHub
    </a>
  </td>
</tr></tbody></table>

## Overview

This notebook demonstrates deploying PaliGemma to a Vertex AI Endpoint and making online predictions for tasks listed below. The notebook also demonstrates creating a shareable link to a web interface that allows querying with the deployed PaliGemma model using [Gradio](https://www.gradio.app/).


### Objective

- Deploy PaliGemma to a Vertex AI Endpoint.
- Make predictions to the endpoint including:
  - Answering questions about a given image.
  - Captioning images.
  - Extracting texts.
  - Detecting objects.
- Create a playground website to use with the PaliGemma Vertex AI Endpoint.

### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing), [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.

## Before you begin

In [None]:
# @title Setup Google Cloud project
# @markdown ### Prerequisites
# @markdown 1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

# @markdown 2. [Optional] [Create a Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets) for storing experiment outputs. Set the BUCKET_URI for the experiment environment. The specified Cloud Storage bucket (`BUCKET_URI`) should be located in the same region as where the notebook was launched. Note that a multi-region bucket (eg. "us") is not considered a match for a single region covered by the multi-region range (eg. "us-central1"). If not set, a unique GCS bucket will be created instead.

# Import the necessary packages
! pip install -q gradio==4.21.0
import base64
import enum
import io
import json
import os
import re
from datetime import datetime
from io import BytesIO
from typing import List, Sequence, Tuple

import gradio as gr
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import requests
from google.cloud import aiplatform
from PIL import Image

# Get the default cloud project id.
PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"]

# Get the default region for launching jobs.
REGION = os.environ["GOOGLE_CLOUD_REGION"]

# Enable the Vertex AI API and Compute Engine API, if not already.
print("Enabling Vertex AI API and Compute Engine API.")
! gcloud services enable aiplatform.googleapis.com compute.googleapis.com

# Cloud Storage bucket for storing the experiment artifacts.
# A unique GCS bucket will be created for the purpose of this notebook. If you
# prefer using your own GCS bucket, change the value yourself below.
now = datetime.now().strftime("%Y%m%d%H%M%S")
BUCKET_URI = "gs://"  # @param {type:"string"}
assert BUCKET_URI.startswith("gs://"), "BUCKET_URI must start with `gs://`."
# Create a unique GCS bucket for this notebook, if not specified by the user
if BUCKET_URI is None or BUCKET_URI.strip() == "" or BUCKET_URI == "gs://":
    BUCKET_URI = f"gs://{PROJECT_ID}-tmp-{now}"
    ! gsutil mb -l {REGION} {BUCKET_URI}
else:
    BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])
    shell_output = ! gsutil ls -Lb {BUCKET_NAME} | grep "Location constraint:" | sed "s/Location constraint://"
    bucket_region = shell_output[0].strip().lower()
    if bucket_region != REGION:
        raise ValueError(
            f"Bucket region {bucket_region} is different from notebook region {REGION}"
        )
print(f"Using this GCS Bucket: {BUCKET_URI}")

STAGING_BUCKET = os.path.join(BUCKET_URI, "temporal")
MODEL_BUCKET = os.path.join(BUCKET_URI, "paligemma")

# Initialize Vertex AI API.
print("Initializing Vertex AI API.")
aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=STAGING_BUCKET)

# Set up default SERVICE_ACCOUNT
SERVICE_ACCOUNT = None
shell_output = ! gcloud projects describe $PROJECT_ID
project_number = shell_output[-1].split(":")[1].strip().replace("'", "")
SERVICE_ACCOUNT = f"{project_number}-compute@developer.gserviceaccount.com"
print("Using this default Service Account:", SERVICE_ACCOUNT)

# Provision permissions to the SERVICE_ACCOUNT with the GCS bucket
BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])
! gsutil iam ch serviceAccount:{SERVICE_ACCOUNT}:roles/storage.admin $BUCKET_NAME


# @markdown ### Access PaliGemma models on Vertex AI for GPU based serving
# @markdown Accept the model agreement to access the models:
# @markdown 1. Open the [PaliGemma model card](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/363) from [Vertex AI Model Garden](https://cloud.google.com/model-garden).
# @markdown 1. Review and accept the agreement in the pop-up window on the model card page. If you have previously accepted the model agreement, there will not be a pop-up window on the model card page and this step is not needed.
# @markdown 1. After accepting the agreement of PaliGemma, a `gs://` URI containing PaliGemma pretrained models will be shared.
# @markdown 1. Paste the link in the `VERTEX_AI_MODEL_GARDEN_PALIGEMMA` field below.
# @markdown 1. The PaliGemma models will be copied into `BUCKET_URI`.
# @markdown The file transfer can take anywhere from 15 minutes to 30 minutes.
VERTEX_AI_MODEL_GARDEN_PALIGEMMA = "gs://"  # @param {type:"string", isTemplate:true}
assert (
    VERTEX_AI_MODEL_GARDEN_PALIGEMMA
), "Click the agreement of PaliGemma in Vertex AI Model Garden, and get the GCS path of PaliGemma model artifacts."
print(
    "Copying PaliGemma model artifacts from",
    VERTEX_AI_MODEL_GARDEN_PALIGEMMA,
    "to ",
    MODEL_BUCKET,
)

! gsutil -m cp -R $VERTEX_AI_MODEL_GARDEN_PALIGEMMA/* $MODEL_BUCKET

model_path_prefix = MODEL_BUCKET

# The pre-built serving docker images.
SERVE_DOCKER_URI = "us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/jax-paligemma-serve-gpu:20240513_0916_RC00"

pretrained_filename_lookup = {
    "paligemma-224-float32": "pt_224.npz",
    "paligemma-224-float16": "pt_224.f16.npz",
    "paligemma-224-bfloat16": "pt_224.bf16.npz",
    "paligemma-448-float32": "pt_448.npz",
    "paligemma-448-float16": "pt_448.f16.npz",
    "paligemma-448-bfloat16": "pt_448.bf16.npz",
    "paligemma-896-float32": "pt_896.npz",
    "paligemma-896-float16": "pt_896.f16.npz",
    "paligemma-896-bfloat16": "pt_896.bf16.npz",
    "paligemma-mix-224-float32": "mix_224.npz",
    "paligemma-mix-224-float16": "mix_224.f16.npz",
    "paligemma-mix-224-bfloat16": "mix_224.bf16.npz",
    "paligemma-mix-448-float32": "mix_448.npz",
    "paligemma-mix-448-float16": "mix_448.f16.npz",
    "paligemma-mix-448-bfloat16": "mix_448.bf16.npz",
}


def get_job_name_with_datetime(prefix: str) -> str:
    """Gets the job name with date time when triggering training or deployment
    jobs in Vertex AI.
    """
    return prefix + datetime.now().strftime("_%Y%m%d_%H%M%S")


def deploy_model(
    model_name: str,
    checkpoint_path: str,
    machine_type: str = "g2-standard-32",
    accelerator_type: str = "NVIDIA_L4",
    accelerator_count: int = 1,
    resolution: int = 224,
) -> Tuple[aiplatform.Model, aiplatform.Endpoint]:
    """Create a Vertex AI Endpoint and deploy the specified model to the endpoint."""
    model_name_with_time = get_job_name_with_datetime(model_name)
    endpoint = aiplatform.Endpoint.create(
        display_name=f"{model_name_with_time}-endpoint"
    )
    model = aiplatform.Model.upload(
        display_name=model_name_with_time,
        serving_container_image_uri=SERVE_DOCKER_URI,
        serving_container_ports=[8080],
        serving_container_predict_route="/predict",
        serving_container_health_route="/health",
        serving_container_environment_variables={
            "CKPT_PATH": checkpoint_path,
            "RESOLUTION": resolution,
            "MODEL_ID": model_name,
        },
    )
    print(
        f"Deploying {model_name_with_time} on {machine_type} with {accelerator_count} {accelerator_type} GPU(s)."
    )
    model.deploy(
        endpoint=endpoint,
        machine_type=machine_type,
        accelerator_type=accelerator_type,
        accelerator_count=accelerator_count,
        deploy_request_timeout=1800,
        service_account=SERVICE_ACCOUNT,
        enable_access_logging=True,
        min_replica_count=1,
        sync=True,
    )
    return model, endpoint


def download_image(url: str) -> Image.Image:
    """Downloads an image from the specified URL."""
    response = requests.get(url)
    return Image.open(BytesIO(response.content))


def resize_image(image: Image.Image, new_width: int = 1000) -> Image.Image:
    width, height = image.size
    print(f"original input image size: {width}, {height}")
    new_height = int(height * new_width / width)
    new_img = image.resize((new_width, new_height))
    print(f"resized input image size: {new_width}, {new_height}")
    return new_img


def image_to_base64(image: Image.Image, format="JPEG") -> str:
    """Converts an image to a base64 string."""
    buffer = BytesIO()
    image.save(buffer, format=format)
    image_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
    return image_str


def vqa_predict(
    endpoint: aiplatform.Endpoint,
    image: Image.Image,
    prompts: List[str],
    new_width: int = 1000,
) -> List[str]:
    """Predicts the answer to a question about an image using an Endpoint."""
    # Resize and convert image to base64 string.
    resized_image = resize_image(image, new_width)
    resized_image_base64 = image_to_base64(resized_image)

    # Format question prompt
    question_prompt_format = "answer en {}\n"

    instances = []
    for question_prompt in prompts:
        if question_prompt:
            instances.append(
                {
                    "prompt": question_prompt_format.format(question_prompt),
                    "image": resized_image_base64,
                }
            )

    response = endpoint.predict(instances=instances)
    return [pred.get("response") for pred in response.predictions]


def caption_predict(
    endpoint: aiplatform.Endpoint,
    image: Image.Image = None,
    language_code: str = "en",
    new_width: int = 1000,
) -> str:
    """Predicts a caption for a given image using an Endpoint."""
    # Resize and convert image to base64 string.
    resized_image = resize_image(image, new_width)
    resized_image_base64 = image_to_base64(resized_image)

    # Format caption prompt
    caption_prompt = f"caption {language_code}\n"

    instances = [
        {
            "prompt": caption_prompt,
            "image": resized_image_base64,
        },
    ]
    response = endpoint.predict(instances=instances)
    return response.predictions[0].get("response")


def ocr_predict(
    endpoint: aiplatform.Endpoint,
    image: Image.Image = None,
    new_width: int = 1000,
) -> str:
    """Extracts text from a given image using an Endpoint."""
    # Resize and convert image to base64 string.
    resized_image = resize_image(image, new_width)
    resized_image_base64 = image_to_base64(resized_image)

    instances = [
        {
            "prompt": "ocr",
            "image": resized_image_base64,
        },
    ]
    response = endpoint.predict(instances=instances)
    return response.predictions[0].get("response")


def detect_predict(
    endpoint: aiplatform.Endpoint,
    image: Image.Image,
    prompt: str,
    new_width: int = 1000,
):
    """Predicts the answer to a question about an image using an Endpoint."""
    # Resize and convert image to base64 string.
    resized_image = resize_image(image, new_width)
    resized_image_base64 = image_to_base64(resized_image)

    instances = [
        {
            "prompt": f"detect {prompt}",
            "image": resized_image_base64,
        }
    ]

    response = endpoint.predict(instances=instances)
    return response.predictions[0].get("response")


def parse_detections(txt):
    """Parses bounding boxes from a detection string."""
    bboxes = []
    for loc_text in txt.split(" ; "):
        m = re.match(
            r"<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>.*",
            loc_text,
        )
        if m is not None:
            d = m.groupdict()
        else:
            raise ValueError(f"{txt} is not a value detection string.")

        def fmt_box(x):
            return float(x) / 1024.0

        box = np.array(
            [fmt_box(d["y0"]), fmt_box(d["x0"]), fmt_box(d["y1"]), fmt_box(d["x1"])]
        )
        bboxes.append(box)
    return bboxes


def plot_bounding_boxes(im: Image.Image, bboxes: Sequence[np.ndarray]) -> Image.Image:
    fig, ax = plt.subplots(figsize=(5, 5))
    ax.imshow(im, zorder=-1)
    ax.set_xlim(*ax.get_xlim())
    ax.set_ylim(*ax.get_ylim())

    for y0, x0, y1, x1 in bboxes:
        box = np.array([y0, x0, y1, x1])
        w, h = im.size
        y1, x1, y2, x2 = box * [h, w, h, w]
        ax.add_patch(
            mpl.patches.Rectangle(
                (x1, y1), x2 - x1, y2 - y1, linewidth=1, edgecolor="r", facecolor="none"
            )
        )
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    return Image.open(buf)


model = None


def get_quota(project_id: str, region: str, resource_id: str) -> int:
  """Returns the quota for a resource in a region. Returns -1 if can not figure out the quota."""
  service_endpoint = "aiplatform.googleapis.com"
  quota_list_output = !gcloud alpha services quota list --service=$service_endpoint  --consumer=projects/$project_id --filter="$service_endpoint/$resource_id" --format=json
  # Use '.s' on the command output because it is an SList type.
  quota_data = json.loads(quota_list_output.s)
  if len(quota_data) == 0 or "consumerQuotaLimits" not in quota_data[0]:
    return -1
  if len(quota_data[0]["consumerQuotaLimits"]) == 0 or "quotaBuckets" not in quota_data[0]["consumerQuotaLimits"][0]:
    return -1
  all_regions_data = quota_data[0]["consumerQuotaLimits"][0]["quotaBuckets"]
  for region_data in all_regions_data:
    if region_data.get('dimensions') and region_data['dimensions']['region'] == region:
      if 'effectiveLimit' in region_data:
        return int(region_data['effectiveLimit'])
      else:
        return 0
  return -1


def get_resource_id(accelerator_type: str, is_for_training: bool) -> str:
  """Returns the resource id for a given accelerator type and the use case.
  Args:
    accelerator_type: The accelerator type.
    is_for_training: Whether the resource is used for training. Set false
    for serving use case.
  Returns:
    The resource id.
  """
  training_accelerator_map = {
      "NVIDIA_TESLA_V100": "custom_model_training_nvidia_v100_gpus",
      "NVIDIA_L4": "custom_model_training_nvidia_l4_gpus",
      "NVIDIA_TESLA_A100": "custom_model_training_nvidia_a100_gpus",
  }
  serving_accelerator_map = {
      "NVIDIA_TESLA_V100": "custom_model_serving_nvidia_v100_gpus",
      "NVIDIA_L4": "custom_model_serving_nvidia_l4_gpus",
      "NVIDIA_TESLA_A100": "custom_model_serving_nvidia_a100_gpus",
  }
  if is_for_training:
    if accelerator_type in training_accelerator_map:
      return training_accelerator_map[accelerator_type]
    else:
      raise ValueError(
          f"Could not find accelerator type: {accelerator_type} for training."
      )
  else:
    if accelerator_type in serving_accelerator_map:
      return serving_accelerator_map[accelerator_type]
    else:
      raise ValueError(
          f"Could not find accelerator type: {accelerator_type} for serving."
      )


def check_quota(project_id:str, region: str, accelerator_type: str,
                accelerator_count: int, is_for_training: bool):
  """Checks if the project and the region has the required quota."""
  resource_id = get_resource_id(accelerator_type, is_for_training)
  quota = get_quota(project_id, region, resource_id)
  quota_request_instruction = ("Either use "
            "a different region or request additional quota. Follow "
            "instructions here "
            "https://cloud.google.com/docs/quotas/view-manage#requesting_higher_quota"
            " to check quota in a region or request additional quota for "
            "your project.")
  if quota == -1:
    raise ValueError(
            f"""Quota not found for: {resource_id} in {region}.
            {quota_request_instruction}"""
        )
  if quota < accelerator_count:
    raise ValueError(
            f"""Quota not enough for {resource_id} in {region}:
            {quota} < {accelerator_count}.
            {quota_request_instruction}"""
        )

## Deploy PaliGemma to a Vertex AI Endpoint

In [None]:
# @title Deploy

# @markdown This section uploads the prebuilt PaliGemma model to Model Registry and deploys it to a Vertex AI Endpoint. It takes approximately 15 minutes to finish.

# @markdown Select the desired resolution and precision of prebuilt model to deploy, leaving the optional `custom_paligemma_model_uri` as is. Higher resolution and precision_type can result in better inference results, but may require additional GPU.

# @markdown You can also serve a finetuned PaliGemma model by setting `resolution` and `precision_type` to the resolution and precision type of the original base model and then setting `custom_paligemma_model_uri` to the GCS URI containing the model.

# @markdown **Note**: You cannot use accelerator type `NVIDIA_TESLA_V100` to serve prebuilt or finetuned PaliGemma models with resolution `896` and precision_type `float32`.

model_variant = "mix"  # @param ["mix", "pt"]
resolution = 224  # @param [224, 448, 896]
precision_type = "float32"  # @param ["float32", "float16", "bfloat16"]
custom_paligemma_model_uri = "gs://"  # @param {type: "string"}

if model_variant == "mix":
    model_name_prefix = "paligemma-mix"
else:
    model_name_prefix = "paligemma"

if custom_paligemma_model_uri == "gs://" or not custom_paligemma_model_uri:
    print("Deploying prebuilt PaliGemma model.")
    model_name = f"{model_name_prefix}-{resolution}-{precision_type}"
    checkpoint_filename = pretrained_filename_lookup[model_name]
    checkpoint_path = os.path.join(model_path_prefix, checkpoint_filename)
else:
    print("Deploying custom PaliGemma model.")
    model_name = f"{model_name_prefix}-{resolution}-{precision_type}-custom"
    checkpoint_path = custom_paligemma_model_uri

# @markdown Select the accelerator type to use to deploy the model:
accelerator_type = "NVIDIA_L4"  # @param ["NVIDIA_L4", "NVIDIA_TESLA_V100"]
if accelerator_type == "NVIDIA_L4":
    machine_type = "g2-standard-16"
    accelerator_count = 1
elif accelerator_type == "NVIDIA_TESLA_V100":
    if resolution == 896 and precision_type == "float32":
        raise ValueError(
            "NVIDIA_TESLA_V100 is not sufficient. Multi-gpu is not supported for PaLIGemma."
        )
    else:
        machine_type = "n1-highmem-8"
        accelerator_count = 1
else:
    raise ValueError(
        f"Recommended machine settings not found for: {accelerator_type}. To use another another accelerator, edit this code block to pass in an appropriate `machine_type`, `accelerator_type`, and `accelerator_count` to the deploy_model function by clicking `Show Code` and then modifying the code."
    )

check_quota(project_id=PROJECT_ID,
            region=REGION,
            accelerator_type=accelerator_type,
            accelerator_count=accelerator_count,
            is_for_training=False)
# @markdown If you want to use other accelerator types not listed above, then check other Vertex AI prediction supported accelerators and regions at https://cloud.google.com/vertex-ai/docs/predictions/configure-compute. You may need to manually set the `machine_type`, `accelerator_type`, and `accelerator_count` in the code by clicking `Show code` first.

model, endpoint = deploy_model(
    model_name=model_name,
    checkpoint_path=checkpoint_path,
    machine_type=machine_type,
    accelerator_type=accelerator_type,
    accelerator_count=accelerator_count,
    resolution=resolution,
)

In [None]:
# @title [Optional] Loading an existing Endpoint
# @markdown If you've already deployed an Endpoint, you can load it by filling in the Endpoint's ID below.
# @markdown You can view deployed Endpoints at [Vertex Online Prediction](https://console.cloud.google.com/vertex-ai/online-prediction/endpoints).
endpoint_id = ""  # @param {type: "string"}

if endpoint_id:
  endpoint = aiplatform.Endpoint(
    endpoint_name=endpoint_id,
    project=PROJECT_ID,
    location=REGION,
  )

### Predict

The following sections will use images from [pexels.com](https://www.pexels.com/) for demoing purposes. All the images have the following license: https://www.pexels.com/license/.

Images will be resized to a width of 1000 pixels by default since requests made to a Vertex Endpoint are limited to 1.500MB.

In [None]:
# @title Visual Question Answering

# @markdown This section uses the deployed PaliGemma model to answer questions about a given image.

# @markdown ![](https://images.pexels.com/photos/4012966/pexels-photo-4012966.jpeg?w=1260&h=750)
image_url = "https://images.pexels.com/photos/4012966/pexels-photo-4012966.jpeg"  # @param {type:"string"}

image = download_image(image_url)
display(image)

# @markdown You may leave question prompts empty and they will be ignored.
question_prompt_1 = "Which of laptop, book, pencil, clock, flower are in the image?"  # @param {type: "string"}
question_prompt_2 = "Do the book and the cup have the same color?"  # @param {type: "string"}
question_prompt_3 = "Is there a person in the image?"  # @param {type: "string"}
question_prompt_4 = "How many laptop are in the image?"  # @param {type: "string"}
question_prompt_5 = "桌子是什么颜色的?"  # @param {type: "string"}


# @markdown The question prompt can be non-English languages.
questions_list = [
    question_prompt_1,
    question_prompt_2,
    question_prompt_3,
    question_prompt_4,
    question_prompt_5,
]
questions_list = [question for question in questions_list if question]


answers = vqa_predict(endpoint, image, questions_list)

for question, answer in zip(questions_list, answers):
    print(f"Question: {question}")
    print(f"Answer: {answer}")
# @markdown Click "Show Code" to see more details.

In [None]:
# @title Image Captioning

# @markdown This section uses the deployed PaliGemma model to caption and describe an image in a chosen language.

# @markdown ![](https://images.pexels.com/photos/20427316/pexels-photo-20427316/free-photo-of-a-moped-parked-in-front-of-a-blue-door.jpeg?auto=compress&cs=tinysrgb&w=630&h=375&dpr=2)

image_url = "https://images.pexels.com/photos/20427316/pexels-photo-20427316/free-photo-of-a-moped-parked-in-front-of-a-blue-door.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=2"  # @param {type:"string"}

image = download_image(image_url)
display(image)

# Make a prediction.
image_base64 = image_to_base64(image)
language_code = "en"  # @param {type: "string"}
caption = caption_predict(endpoint, image, language_code)

print("Caption: ", caption)
# @markdown Click "Show Code" to see more details.

In [None]:
# @title OCR
# @markdown This section uses the deployed PaliGemma model to extract text from an image, starting from the top left.

# @markdown ![](https://images.pexels.com/photos/8919535/pexels-photo-8919535.jpeg?auto=compress&cs=tinysrgb&w=630&h=375&dpr=2)
image_url = "https://images.pexels.com/photos/8919535/pexels-photo-8919535.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=2"  # @param {type:"string"}

image = download_image(image_url)
display(image)
text_found = ocr_predict(endpoint, image)

print(f"Text found: {text_found}")
# @markdown Click "Show Code" to see more details.

In [None]:
# @title Object Detection
# @markdown This section uses the deployed PaliGemma model to output bounding boxes for specified object image in a given image.
# @markdown The text output will be parsed into bounding boxes and overlaid on the original image.

# @markdown ![](https://images.pexels.com/photos/1006293/pexels-photo-1006293.jpeg?auto=compress&cs=tinysrgb&w=630&h=375&dpr=2)

image_url = "https://images.pexels.com/photos/1006293/pexels-photo-1006293.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=2"  # @param {type:"string"}

# @markdown Specify what object to detect. To specify multiple objects, enter them as a semicolon separated list as shown below.

objects = "plant ; pineapple ; glasses"  # @param {type:"string"}
image = download_image(image_url)
display(image)

# Make a prediction.
detection_response = detect_predict(endpoint, image, objects)

bboxes = parse_detections(detection_response)
plot_bounding_boxes(image, bboxes)
print("Output: ", detection_response)
# @markdown Click "Show Code" to see more details.

## Creating a webpage playground with Gradio

In [None]:
# @title How to use
# @markdown This is a playground similar to the popular [Stable Diffusion WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).

# @markdown **Prerequisites**
# @markdown -  Before you can upload an image to make a prediction, you need to select a Vertex prediction endpoint serving PaliGemma
# @markdown from the endpoint dropdown list that has been deployed in the current project and region.
# @markdown -  If no models have been deployed, you can create a new Vertex prediction
# @markdown endpoint by clicking "Deploy to Vertex" in the playground or running the `Deploy` cell above.
# @markdown   * New model deployment takes approximately 15 minutes. You can check the progress at [Vertex Online Prediction](https://console.cloud.google.com/vertex-ai/online-prediction/endpoints).

# @markdown **How to use**

# @markdown Just run this cell and a link to the playground formatted as `https://####.gradio.live` will be outputted.
# @markdown This link will take you to the playground in a separate browser tab.


class Task(enum.Enum):
    VQA = "Visual Question Answering"
    CAPTION = "Image Captioning"
    OCR = "OCR"
    DETECT = "Object Detection"


def list_paligemma_endpoints() -> list[str]:
    """Returns all valid prediction endpoints for in the project and region."""
    # Gets all the valid endpoints in the project and region.
    endpoints = aiplatform.Endpoint.list(order_by="create_time desc")
    # Filters out the endpoints which do not have a deployed model, and the endpoint is for image generation
    endpoints = list(
        filter(
            lambda endpoint: endpoint.traffic_split
            and "pali" in endpoint.display_name.lower(),
            endpoints,
        )
    )

    endpoint_names = list(
        map(
            lambda endpoint: f"{endpoint.name} - {endpoint.display_name[:40]}",
            endpoints,
        )
    )

    if not endpoint_names:
        gr.Warning(
            "No prediction endpoints were found. Create an Endpoint first."
        )

    return endpoint_names


def get_endpoint(endpoint_name: str) -> aiplatform.Endpoint:
    """Returns a Vertex endpoint for the given endpoint_name."""
    endpoint_id = endpoint_name.split(" - ")[0]
    endpoint = aiplatform.Endpoint(
        f"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint_id}"
    )
    return endpoint


def select_interface(interface_name: str):
    if interface_name == Task.VQA.value:
        return {
            text_input_box: gr.update(label="Question", value=None, visible=True),
            language_code_box: gr.update(visible=False),
            submit_button: gr.update(value="Answer"),
            text_output: gr.update(value=None),
            image_output: gr.update(value=None, visible=False),
        }
    elif interface_name == Task.CAPTION.value:
        return {
            text_input_box: gr.update(value=None, visible=False),
            language_code_box: gr.update(visible=True),
            submit_button: gr.update(value="Caption"),
            text_output: gr.update(value=None),
            image_output: gr.update(value=None, visible=False),
        }
    elif interface_name == Task.OCR.value:
        return {
            text_input_box: gr.update(value=None, visible=False),
            language_code_box: gr.update(visible=False),
            submit_button: gr.update(value="Extract text"),
            text_output: gr.update(value=None),
            image_output: gr.update(value=None, visible=False),
        }
    elif interface_name == Task.DETECT.value:
        return {
            text_input_box: gr.update(label="Object(s)", value=None, visible=True),
            language_code_box: gr.update(visible=False),
            submit_button: gr.update(value="Detect"),
            text_output: gr.update(value=None),
            image_output: gr.update(value=None, visible=True),
        }
    else:
        raise gr.Error(f"Invalid interface name: {interface_name}")


def deploy_model_handler(model_choice: str) -> None:
    gr.Info("Starting model deployment.")
    model_name = model_choice.replace("-pt-", "-")
    checkpoint_filename = pretrained_filename_lookup[model_name]
    _, _, resolution, _ = model_choice.split("-")
    resolution = int(resolution)
    model, endpoint = deploy_model(
        model_name=model_choice,
        checkpoint_path=os.path.join(model_path_prefix, checkpoint_filename),
        machine_type="g2-standard-16",
        accelerator_type="NVIDIA_L4",
        accelerator_count=1,
        resolution=resolution,
    )
    gr.Info(f"Deploying model ID: {model.name}, endpoint ID: {endpoint.name}")


def predict_handler(
    interface_name: str,
    endpoint_name: str,
    image: Image.Image,
    prompt: str,
    language_code: str,
) -> Tuple[str, Image.Image]:
    if not endpoint_name:
        raise gr.Error("Select (or deploy) a model first!")
    if not image:
        raise gr.Error("You must upload an image!")
    endpoint = get_endpoint(endpoint_name)
    if interface_name == Task.VQA.value:
        return vqa_predict(endpoint, image, [prompt])[0], None
    elif interface_name == Task.CAPTION.value:
        return caption_predict(endpoint, image, language_code), None
    elif interface_name == Task.OCR.value:
        return ocr_predict(endpoint, image), None
    elif interface_name == Task.DETECT.value:
        text_output = detect_predict(endpoint, image, prompt)
        bboxes = parse_detections(text_output)
        return text_output, plot_bounding_boxes(image, bboxes)
    else:
        raise gr.Error("Select an interface first!")


tip_text = r"""
<b> Tips: </b>
1. Select a Vertex prediction endpoint with a deployed PaLIGemma model or click `Deploy to Vertex` to deploy PaLIGemma to Vertex.
2. New model deployment takes approximately 15 minutes. You can check the progress at [Vertex Online Prediction](https://console.cloud.google.com/vertex-ai/online-prediction/endpoints).
3. After the model deployment is complete, click `Refresh Endpoints list` to view the new endpoint in the dropdown list.
"""

css = """
.gradio-container {
  width: 85% !important
}
"""
with gr.Blocks(
    css=css, theme=gr.themes.Default(primary_hue="orange", secondary_hue="blue")
) as demo:
    gr.Markdown("# Model Garden Playground for PaliGemma")
    with gr.Row(equal_height=True):
        with gr.Column(scale=3):
            gr.Markdown(tip_text)
        with gr.Column(scale=2):
            with gr.Row():
                endpoint_name = gr.Dropdown(
                    scale=7,
                    label="Select a model previously deployed on Vertex",
                    choices=list_paligemma_endpoints(),
                    value=None,
                )
                refresh_button = gr.Button(
                    "Refresh Endpoints list",
                    scale=1,
                    variant="primary",
                    min_width=10,
                )
            with gr.Row():
                selected_model = gr.Dropdown(
                    scale=7,
                    label="Deploy a new model to Vertex",
                    choices=[
                        "paligemma-mix-224-float32",
                        "paligemma-mix-224-float16",
                        "paligemma-mix-224-bfloat16",
                        "paligemma-mix-448-float32",
                        "paligemma-mix-448-float16",
                        "paligemma-mix-448-bfloat16",
                        "paligemma-pt-224-float32",
                        "paligemma-pt-224-float16",
                        "paligemma-pt-224-bfloat16",
                        "paligemma-pt-448-float32",
                        "paligemma-pt-448-float16",
                        "paligemma-pt-448-bfloat16",
                        "paligemma-pt-896-float32",
                        "paligemma-pt-896-float16",
                        "paligemma-pt-896-bfloat16",
                    ],
                    value=None,
                )
                deploy_model_button = gr.Button(
                    "Deploy a new model",
                    scale=1,
                    variant="primary",
                    min_width=10,
                )
    with gr.Row(equal_height=True):
        with gr.Column(scale=1):
            image_input = gr.Image(
                show_label=True,
                type="pil",
                label="Upload",
                visible=True,
                height=400,
            )
            with gr.Group():
                with gr.Tab("Task"):
                    interfaces_box = gr.Radio(
                        show_label=False,
                        choices=[
                            Task.VQA.value,
                            Task.CAPTION.value,
                            Task.OCR.value,
                            Task.DETECT.value,
                        ],
                        value=Task.VQA.value,
                    )
                text_input_box = gr.Textbox(label="Question", lines=1)
                language_code_box = gr.Textbox(
                    value="en", label="Language code", lines=1, visible=False
                )
                submit_button = gr.Button("Answer", variant="primary")
        with gr.Column(scale=1):
            image_output = gr.Image(label="Image response:", visible=False)
            text_output = gr.Textbox(label="Text response:")

    refresh_button.click(
        fn=lambda: gr.update(choices=list_paligemma_endpoints()),
        outputs=[endpoint_name],
    )
    deploy_model_button.click(
        deploy_model_handler,
        inputs=[selected_model],
        outputs=[],
    )
    interfaces_box.change(
        fn=select_interface,
        inputs=interfaces_box,
        outputs=[
            text_input_box,
            language_code_box,
            submit_button,
            text_output,
            image_output,
        ],
    )
    submit_button.click(
        fn=predict_handler,
        inputs=[
            interfaces_box,
            endpoint_name,
            image_input,
            text_input_box,
            language_code_box,
        ],
        outputs=[text_output, image_output],
    )
show_debug_logs = True  # @param {type: "boolean"}
demo.queue()
demo.launch(share=True, inline=False, inbrowser=True, debug=show_debug_logs, show_error=True)

# @markdown Click "Show Code" to see more details.

## Clean up resources

In [None]:
# @title Run
# @markdown Delete the experiment models and endpoints to recycle the resources
# @markdown and avoid unnecessary continouous charges that may incur.

# Undeploy model and delete endpoint.
endpoint.delete(force=True)

# Delete models.
if model:
    model.delete()

delete_bucket = False  # @param {type:"boolean"}
if delete_bucket:
    ! gsutil -m rm -r $BUCKET_URI