In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI Model Garden GenAI Workshop for Instant ID

<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fvertex-ai-samples%2Fmain%2Fnotebooks%2Fcommunity%2Fmodel_garden%2Fmodel_garden_pytorch_instant_id_gradio.ipynb">
      <img alt="Google Cloud Colab Enterprise logo" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" width="32px"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_pytorch_instant_id_gradio.ipynb">
      <img alt="GitHub logo" src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" width="32px"><br> View on GitHub
    </a>
  </td>
</tr></tbody></table>

## Overview

This notebook demonstrates starting a playground for model [InstantX/InstantID](https://www.gradio.app/InstantX/InstantID) based on [Gradio UI](https://www.gradio.app/), which allows users to interact with the identity-preserving image generation model more easily and intuitively.

### Objective

- Deploy model to a [Vertex AI Endpoint resource](https://cloud.google.com/vertex-ai/docs/predictions/using-private-endpoints).
- Run online predictions for `instant-id` tasks, from the UI.
- Adjust the parameters, such as prompt, negative_prompt, num_inference_steps, and check out the generated images for best image quality.

### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing) and [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.

## Run the playground

In [None]:
# @title Setup Google Cloud project and prepare the dependencies

# @markdown [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

! pip3 install --upgrade gradio==4.29.0 opencv-python
# Uninstall nest-asyncio and uvloop as a workaround to https://github.com/gradio-app/gradio/issues/8238#issuecomment-2101066984
! pip3 uninstall --yes nest-asyncio uvloop

import os
import sys

from google.cloud import aiplatform

# Get the default cloud project id.
PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"]

# Get the default region for launching jobs.
REGION = os.environ["GOOGLE_CLOUD_REGION"]

aiplatform.init(project=PROJECT_ID, location=REGION)

# Enable the Vertex AI API and Compute Engine API, if not already.
! gcloud services enable aiplatform.googleapis.com compute.googleapis.com

# Set up the default SERVICE_ACCOUNT.
SERVICE_ACCOUNT = None
shell_output = ! gcloud projects describe $PROJECT_ID
project_number = shell_output[-1].split(":")[1].strip().replace("'", "")
SERVICE_ACCOUNT = f"{project_number}-compute@developer.gserviceaccount.com"

print("Using this default Service Account:", SERVICE_ACCOUNT)

# The pre-built serving docker image. It contains serving scripts and models.
SERVE_DOCKER_URI = "us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-diffusers-serve-opt:20240605_1400_RC00"

if "google.colab" in sys.modules:
    from google.colab import auth

    auth.authenticate_user(project_id=PROJECT_ID)

In [None]:
# @title Start the playground

# @markdown This is a playground for generating identify-preserving model like [InstantX/InstantID](https://www.gradio.app/InstantX/InstantID).
# @markdown After the cell runs, this playground is avaible in a separate browser tab if you click the public URL.
# @markdown Sometsomething similar to ["https://####.gradio.live"](#) in the output of the cell.

# @markdown **How to use:**
# @markdown 1. Important: Notebook cell reruns create new public URLs. Previous URLs will stop working.
# @markdown 1. Before you start, you need to select a Vertex prediction endpoint, with a matching model deployed to the endpoint
# @markdown from the endpoint dropdown list, that has been deployed in the project and region;
# @markdown 1. If no models were deployed in the past, you can create a new Vertex prediction
# @markdown endpoint by selecting your favorite model and click "Deploy".
# @markdown 1. New model deployment takes ~20 minutes. You can check the progress at [Vertex Online Prediction](https://console.cloud.google.com/vertex-ai/online-prediction/endpoints).
# @markdown 1. Adjust the prompt/negative-prompt, image-dimension, inference steps, guidance-scale to achieve the optimum image quality and inference latency.
# @markdown 1. Don't forget to undeploy the models after all the experiment to avoid continuous charges to the project.

# @markdown Note: this workshop/notebook is specially built for the [InstantX/InstantID] model.
# @markdown Other models may work, but they are not tested please use with caution.

import base64
from datetime import datetime
from io import BytesIO

import gradio as gr
from google.cloud import aiplatform
from PIL import Image

style_list = [
    {
        "name": "(No style)",
        "prompt": "{prompt}",
        "negative_prompt": "",
    },
    {
        "name": "Cinematic",
        "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
        "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
    },
    {
        "name": "Photographic",
        "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
        "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
    },
    {
        "name": "Anime",
        "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime,  highly detailed",
        "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
    },
    {
        "name": "Manga",
        "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
        "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
    },
    {
        "name": "Digital Art",
        "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
        "negative_prompt": "photo, photorealistic, realism, ugly",
    },
    {
        "name": "Pixel art",
        "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
        "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
    },
    {
        "name": "Fantasy art",
        "prompt": "ethereal fantasy concept art of  {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
        "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
    },
    {
        "name": "Neonpunk",
        "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
        "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
    },
    {
        "name": "3D Model",
        "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
        "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
    },
]

styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "(No style)"


def create_job_name(prefix):
    now = datetime.now().strftime("%y%m%d-%H%M%S")
    job_name = f"{prefix}-gradio-{now}"
    return job_name


def base64_to_image(image_str: str) -> Image:
    """Convert base64 encoded string to an image."""
    image = Image.open(BytesIO(base64.b64decode(image_str)))
    return image


def image_to_base64(image: Image, format="JPEG") -> str:
    buffer = BytesIO()
    image.save(buffer, format=format)
    image_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
    return image_str


def is_instantid_endpoint(endpoint: aiplatform.Endpoint) -> bool:
    """Returns True if the endpoint is an Instant ID endpoint."""
    return (
        "instant_id" in endpoint.display_name.lower()
        or "instant-id" in endpoint.display_name.lower()
        or "instant" in endpoint.display_name.lower()
    )


def list_endpoints() -> list[str]:
    """Returns all valid prediction endpoints for in the project and region."""
    # Gets all the valid endpoints in the project and region.
    endpoints = aiplatform.Endpoint.list(order_by="create_time desc")
    # Filters out the endpoints which do not have a deployed model, and the endpoint is for image generation
    endpoints = list(
        filter(
            lambda endpoint: endpoint.traffic_split and is_instantid_endpoint(endpoint),
            endpoints,
        )
    )

    endpoint_names = list(
        map(
            lambda endpoint: f"{endpoint.name} - {endpoint.display_name[:40]}",
            endpoints,
        )
    )

    return endpoint_names


def get_endpoint(endpoint_name: str) -> aiplatform.Endpoint:
    """Returns a Vertex endpoint for the given endpoint_name."""

    endpoint_id = endpoint_name.split(" - ")[0]
    endpoint = aiplatform.Endpoint(
        f"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint_id}"
    )

    return endpoint


def get_task_name(model_name: str) -> str:
    """Returns the corresponding task name for the given model_name."""

    model_to_task_dict = {
        "instantx/instantid": "instant-id",
    }

    if model_name.lower() not in model_to_task_dict.keys():
        print(model_name)
        raise gr.Error("Please select a valid model name for Endpoint creation.")

    return model_to_task_dict[model_name.lower()]


def deploy_model(model_name: str) -> aiplatform.Endpoint:
    """Creates a new Vertex prediction endpoint and deploys a model to it."""

    if not model_name:
        raise gr.Error("Please select a valid model name for model list.")
        return

    gr.Info("Model deployment started. Please wait...")

    model_id = model_name.split(": ")[1]
    task_name = get_task_name(model_id)

    display_name = create_job_name(model_id)
    endpoint = aiplatform.Endpoint.create(display_name=display_name)
    serving_env = {
        "MODEL_ID": model_id,
        "TASK": task_name,
        "DEPLOY_SOURCE": "notebook_gradio",
    }

    display_name = create_job_name(model_id)
    model = aiplatform.Model.upload(
        display_name=model_id,
        serving_container_image_uri=SERVE_DOCKER_URI,
        serving_container_ports=[7080],
        serving_container_predict_route="/predictions/diffusers_serving",
        serving_container_health_route="/ping",
        serving_container_environment_variables=serving_env,
    )
    machine_type = "g2-standard-8"
    accelerator_type = "NVIDIA_L4"

    model.deploy(
        endpoint=endpoint,
        machine_type=machine_type,
        accelerator_type=accelerator_type,
        accelerator_count=1,
        deploy_request_timeout=1800,
        service_account=SERVICE_ACCOUNT,
        sync=False,
    )

    gr.Info(
        f"Model {display_name} is being deployed. It may take ~20 minutes to complete."
    )

    return endpoint


def get_default_dimension(model_name: str) -> int:
    """Returns the default dimension for the given model_name."""

    return 1024


def get_default_guidance_scale(model_name: str) -> int:
    """Returns the default guidance scale for the given model_name."""

    return 1.2


def get_default_num_inference_steps(model_name: str) -> int:
    """Returns the default num_inference_steps for the given model_name."""

    return 5


def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
    p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
    return p.replace("{prompt}", positive), n + negative


def generate_images(
    endpoint_name,
    style_name=None,
    prompt="",
    negative_prompt="",
    guidance_scale=1.2,
    num_inference_steps=5,
    image_dimension=1024,
    face_image=None,
    pose_image=None,
) -> list[Image.Image]:
    if not endpoint_name:
        raise gr.Error("Please select (or deploy) a model first!")

    prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
    payload = {
        "prompt": prompt,
        "negative_prompt": negative_prompt,
        "height": image_dimension,
        "width": image_dimension,
        "guidance_scale": guidance_scale,
        "num_inference_steps": num_inference_steps,
        "face_image": image_to_base64(face_image),
    }

    if pose_image:
        payload = {
            **payload,
            "pose_image": image_to_base64(pose_image),
        }

    instances = [
        payload,
    ]

    response = get_endpoint(endpoint_name).predict(instances=instances)
    images = [base64_to_image(image) for image in response.predictions]

    return images


def update_default_parameters(model_name: str):
    """Updates the default inference parameters based on the selected model."""
    return {
        guidance_scale: gr.update(value=get_default_guidance_scale(model_name)),
        num_inference_steps: gr.update(
            value=get_default_num_inference_steps(model_name)
        ),
        image_dimension: gr.update(value=get_default_dimension(model_name)),
    }


tip_text = r"""
1. Select a Vertex prediction endpoint with a model deployed for your chosen task. Mismatched models can lead to unreliable outcomes.
2. New model deployment takes ~20 minutes. You can check the progress at [Vertex Online Prediction](https://console.cloud.google.com/vertex-ai/online-prediction/endpoints).
3. After the model deployment is complete, restart the playground in Colab to see the updated endpoint list.
"""

css = """
.gradio-container {
  width: 90% !important
}
"""
with gr.Blocks(
    css=css, theme=gr.themes.Default(primary_hue="orange", secondary_hue="blue")
) as demo:
    gr.Markdown("# Model Garden Playground for InstantID")

    with gr.Accordion("How To Use", open=False):
        tip = gr.Markdown(tip_text)

    with gr.Row(equal_height=True):
        with gr.Column(scale=3):
            prompt = gr.Textbox(label="Prompt", lines=1)
            negative_prompt = gr.Textbox(label="Negative Prompt", lines=1)
        with gr.Column(scale=1):
            endpoint_name = gr.Dropdown(
                label="Select a model previously deployed on Vertex",
                choices=list_endpoints(),
                value=None,
            )
            with gr.Row():
                selected_model = gr.Dropdown(
                    scale=7,
                    label="Deploy a new model to Vertex",
                    choices=[
                        "instant-id: instantx/instantid",
                    ],
                    value=None,
                )
                deploy_model_button = gr.Button(
                    "Deploy", scale=1, variant="primary", min_width=10
                )

    with gr.Row(equal_height=True):
        with gr.Column(scale=1):
            generate_button = gr.Button("Generate", variant="primary")

            image_dimension = gr.Slider(
                label="Image dimension", value=1024, step=128, minimum=512, maximum=1024
            )
            num_inference_steps = gr.Slider(
                label="Sampling steps", value=5, step=1, minimum=1, maximum=25
            )
            guidance_scale = gr.Slider(
                label="Guidance scale", value=1.2, step=0.1, minimum=0, maximum=10.0
            )
            with gr.Accordion("Styles", open=False):
                style_selection = gr.Radio(
                    show_label=True,
                    container=True,
                    interactive=True,
                    choices=STYLE_NAMES,
                    value=DEFAULT_STYLE_NAME,
                    label="Image Style",
                )

        with gr.Column(scale=4):
            with gr.Row(equal_height=True):
                with gr.Column(scale=2):
                    face_image_input = gr.Image(
                        type="pil",
                        label="Upload a photo of your face",
                        sources="upload",
                        height=350,
                        interactive=True,
                    )
                    pose_image_input = gr.Image(
                        type="pil",
                        label="Upload a reference pose image (Optional)",
                        sources="upload",
                        height=350,
                        interactive=True,
                    )
                with gr.Column(scale=3):
                    image_output = gr.Gallery(
                        label="Generated Images",
                        rows=1,
                        height=715,
                        preview=True,
                    )

    endpoint_name.change(
        update_default_parameters,
        endpoint_name,
        [
            guidance_scale,
            num_inference_steps,
            image_dimension,
        ],
    )

    deploy_model_button.click(
        deploy_model,
        inputs=[selected_model],
        outputs=[],
    )

    generate_button.click(
        generate_images,
        inputs=[
            endpoint_name,
            style_selection,
            prompt,
            negative_prompt,
            guidance_scale,
            num_inference_steps,
            image_dimension,
            face_image_input,
            pose_image_input,
        ],
        outputs=image_output,
    )

show_debug_logs = True  # @param {type: "boolean"}
demo.queue()
demo.launch(share=True, inline=False, debug=show_debug_logs, show_error=True)