In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI Model Garden - Stable Diffusion Gradio UI

<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fvertex-ai-samples%2Fmain%2Fnotebooks%2Fcommunity%2Fmodel_garden%2Fmodel_garden_pytorch_stable_diffusion_gradio.ipynb">
      <img alt="Google Cloud Colab Enterprise logo" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" width="32px"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_pytorch_stable_diffusion_gradio.ipynb">
      <img alt="GitHub logo" src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" width="32px"><br> View on GitHub
    </a>
  </td>
</tr></tbody></table>

## Overview

This notebook demonstrates starting a playground based on [Gradio UI](https://www.gradio.app/), which allows users to interact with the stable diffusion models more easily and intuitively. The playground now support `text-to-image`, `image-to-image` and `image-inpainting` tasks.

### Objective

- Deploy model to a [Vertex AI Endpoint resource](https://cloud.google.com/vertex-ai/docs/predictions/using-private-endpoints).
- Run online predictions for `text-to-image`, `image-inpainting`, `instruct-pix2pix`, and `SD 4x upscaler` tasks, from the UI.
- Adjust the parameters, such as prompt, negative_prompt, num_inference_steps, and check out the generated images.

### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing) and [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.

In [None]:
# @title Setup Google Cloud project

# @markdown 1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

! pip3 install --upgrade gradio==3.48.0

import os
import sys

from google.cloud import aiplatform

# Get the default cloud project id.
PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"]

# Get the default region for launching jobs.
REGION = os.environ["GOOGLE_CLOUD_REGION"]

aiplatform.init(project=PROJECT_ID, location=REGION)

# Enable the Vertex AI API and Compute Engine API, if not already.
! gcloud services enable aiplatform.googleapis.com compute.googleapis.com

# Set up the default SERVICE_ACCOUNT.
SERVICE_ACCOUNT = None
shell_output = ! gcloud projects describe $PROJECT_ID
project_number = shell_output[-1].split(":")[1].strip().replace("'", "")
SERVICE_ACCOUNT = f"{project_number}-compute@developer.gserviceaccount.com"

print("Using this default Service Account:", SERVICE_ACCOUNT)

# The pre-built serving docker image. It contains serving scripts and models.
SERVE_DOCKER_URI = "us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-diffusers-serve-opt:20240223_1230_RC00"

if "google.colab" in sys.modules:
    from google.colab import auth

    auth.authenticate_user(project_id=PROJECT_ID)

In [None]:
# @title Start the SD Playground on Gradio UI

# @markdown This is a simple playground similar to the popular [stable diffusion webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
# @markdown This UI is avaible in a separate browser tab if you click the public URL after the cell runs.
# @markdown The public URL is something similar to "https://####.gradio.live". Click the URL to open the playground.

# @markdown Before you start, you need to select an existing Vertex prediction endpoint from the dropdown list
# @markdown which has been deployed in the project and region; If no models were deployed in the past, you can
# @markdown create a new Vertex prediction endpoint by selecting your favorite model and click "Deploy".

# @markdown Four tasks `text-to-image`, `image-inpainting`, `instruct-pix2pix` and `SD 4x upscaler` are currently supported.

import base64
from io import BytesIO

import gradio as gr
from google.cloud import aiplatform
from PIL import Image

# The pre-built serving docker image. It contains serving scripts and models.
SERVE_DOCKER_URI = "us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-diffusers-serve-opt:20240306_1230_RC00"


def base64_to_image(image_str: str) -> Image:
    """Convert base64 encoded string to an image."""
    image = Image.open(BytesIO(base64.b64decode(image_str)))
    return image


def image_to_base64(image: Image, format="JPEG") -> str:
    buffer = BytesIO()
    image.save(buffer, format=format)
    image_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
    return image_str


def list_endpoints() -> list[str]:
    """Returns all valid prediction endpoints for in the project and region."""
    # Gets all the valid endpoints in the project and region.
    endpoints = aiplatform.Endpoint.list(order_by="create_time desc")
    # Filters out the endpoints which do not have a deployed model
    endpoints = list(filter(lambda endpoint: endpoint.traffic_split, endpoints))

    endpoint_names = list(
        map(
            lambda endpoint: f"{endpoint.name} - {endpoint.display_name[:40]}",
            endpoints,
        )
    )

    if not endpoint_names:
        raise gr.Warning(
            "No prediction endpoints were found. Please create an Endpoint first."
        )

    return endpoint_names


def get_endpoint(endpoint_name: str) -> aiplatform.Endpoint:
    """Returns a Vertex endpoint for the given endpoint_name."""

    endpoint_id = endpoint_name.split(" - ")[0]
    endpoint = aiplatform.Endpoint(
        f"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint_id}"
    )

    return endpoint


def get_task_name(model_name: str) -> str:
    """Returns the corresponding task name for the given model_name."""

    model_to_task_dict = {
        "runwayml/stable-diffusion-v1-5": "text-to-image",
        "stabilityai/stable-diffusion-2-1": "text-to-image",
        "stabilityai/stable-diffusion-xl-base-1.0": "text-to-image-sdxl",
        "stabilityai/stable-diffusion-xl-base-1.0 - refiner": "text-to-image-refiner",
        "latent-consistency/lcm-sdxl": "text-to-image-sdxl-lcm",
        "latent-consistency/lcm-lora-sdxl": "text-to-image-sdxl-lcm-lora",
        "stabilityai/sdxl-turbo": "text-to-image-sdxl-turbo",
        "runwayml/stable-diffusion-inpainting": "image-inpainting",
        "kandinsky-community/kandinsky-2-2-decoder-inpaint": "image-inpainting",
        "diffusers/stable-diffusion-xl-1.0-inpainting-0.1": "image-inpainting",
        "timbrooks/instruct-pix2pix": "instruct-pix2pix",
        "stabilityai/stable-diffusion-x4-upscaler": "conditioned-super-res",
    }

    if model_name not in model_to_task_dict.keys():
        raise gr.Error("Please select a valid model name for Endpoint creation.")

    return model_to_task_dict[model_name]


def deploy_model(model_name: str) -> aiplatform.Endpoint:
    """Creates a new Vertex prediction endpoint and deploys a model to it."""
    refiner_model_id = "stabilityai/stable-diffusion-xl-refiner-1.0"

    if not model_name:
        raise gr.Error("Please select a valid model name for model list.")
        return

    gr.Info("Model is being deployed. It may take ~20 minutes to complete.")

    task_name = get_task_name(model_name)
    model_id = model_name
    if (
        model_name == "stabilityai/stable-diffusion-xl-base-1.0 - refiner"
        or model_name == "latent-consistency/lcm-sdxl"
        or model_name == "latent-consistency/lcm-lora-sdxl"
    ):
        model_id = "stabilityai/stable-diffusion-xl-base-1.0"

    endpoint = aiplatform.Endpoint.create(display_name=model_name)
    serving_env = {
        "MODEL_ID": model_id,
        "TASK": task_name,
    }
    if model_name == "stabilityai/stable-diffusion-xl-base-1.0 - refiner":
        serving_env = {
            **serving_env,
            "REFINER_MODEL_ID": refiner_model_id,
        }

    model = aiplatform.Model.upload(
        display_name=model_name,
        serving_container_image_uri=SERVE_DOCKER_URI,
        serving_container_ports=[7080],
        serving_container_predict_route="/predictions/diffusers_serving",
        serving_container_health_route="/ping",
        serving_container_environment_variables=serving_env,
    )
    machine_type = "g2-standard-8"
    accelerator_type = "NVIDIA_L4"

    model.deploy(
        endpoint=endpoint,
        machine_type=machine_type,
        accelerator_type=accelerator_type,
        accelerator_count=1,
        deploy_request_timeout=1800,
        service_account=SERVICE_ACCOUNT,
    )

    gr.Info("Model have been deployed successfully.")

    return endpoint


def get_default_dimension(model_name: str) -> int:
    """Returns the default dimension for the given model_name."""

    dimension = 512
    if not model_name:
        return dimension

    if "stable-diffusion-xl" in model_name or "sdxl" in model_name:
        dimension = 1024
    elif "stable-diffusion-2-1" in model_name:
        dimension = 768

    return dimension


def get_default_guidance_scale(model_name: str) -> int:
    """Returns the default guidance scale for the given model_name."""

    guidance_scale = 7.5
    if not model_name:
        return guidance_scale

    if "lcm" in model_name or "sdxl-turbo" in model_name:
        guidance_scale = 0

    return guidance_scale


def get_default_num_inference_steps(model_name: str) -> int:
    """Returns the default num_inference_steps for the given model_name."""

    num_inference_steps = 25
    if not model_name:
        return num_inference_steps

    if "lcm" in model_name:
        num_inference_steps = 8
    elif "sdxl-turbo" in model_name:
        num_inference_steps = 2

    return num_inference_steps


def generate_images(
    endpoint_name,
    prompt,
    negative_prompt="",
    num_samples=1,
    guidance_scale=7.5,
    num_inference_steps=25,
    height=512,
    width=512,
) -> list[Image.Image]:
    if not endpoint_name:
        raise gr.Error("Please select (or deploy) a model first!")

    instances = [
        {
            "prompt": prompt,
            "negative_prompt": negative_prompt,
            "height": height,
            "width": width,
            "guidance_scale": guidance_scale,
            "num_inference_steps": num_inference_steps,
        },
    ]

    if len(instances) == 1 and num_samples > 1:
        instances = instances * num_samples

    response = get_endpoint(endpoint_name).predict(instances=instances)
    images = [base64_to_image(image) for image in response.predictions]

    return images


def inpaint_generate_images(
    endpoint_name: str,
    prompt="",
    negative_prompt="",
    num_samples=1,
    guidance_scale=7.5,
    num_inference_steps=25,
    dict=None,
    height=512,
    width=512,
) -> list[Image.Image]:
    if not endpoint_name:
        raise gr.Error("Please select (or deploy) a model first!")

    default_dimension = 512
    # Set the default_dimension=1024 if the model is `stable-diffusion-xl-1.0-inpainting-0.1`.
    if "stable-diffusion-xl" in endpoint_name:
        default_dimension = 1024

    init_image = (
        dict["image"].convert("RGB").resize((default_dimension, default_dimension))
    )
    mask = dict["mask"].convert("RGB").resize((default_dimension, default_dimension))

    instances = [
        {
            "prompt": prompt,
            "negative_prompt": negative_prompt,
            "image": image_to_base64(init_image),
            "mask_image": image_to_base64(mask),
            "height": height,
            "width": width,
            "guidance_scale": guidance_scale,
            "num_inference_steps": num_inference_steps,
        },
    ]

    if len(instances) == 1 and num_samples > 1:
        instances = instances * num_samples

    response = get_endpoint(endpoint_name).predict(instances=instances)
    images = [base64_to_image(image) for image in response.predictions]

    return images


def instruct_pix2pix_generate_images(
    endpoint_name: str,
    prompt="",
    negative_prompt="",
    num_samples=1,
    guidance_scale=7.5,
    num_inference_steps=25,
    init_image=None,
    height=512,
    width=512,
) -> list[Image.Image]:
    if not endpoint_name:
        raise gr.Error("Please select (or deploy) a model first!")

    instances = [
        {
            "prompt": prompt,
            "negative_prompt": negative_prompt,
            "image": image_to_base64(init_image),
            "height": height,
            "width": width,
            "guidance_scale": guidance_scale,
            "num_inference_steps": num_inference_steps,
        },
    ]

    if len(instances) == 1 and num_samples > 1:
        instances = instances * num_samples

    response = get_endpoint(endpoint_name).predict(instances=instances)
    images = [base64_to_image(image) for image in response.predictions]

    return images


def upscaler_generate_images(
    endpoint_name: str,
    prompt="",
    negative_prompt="",
    num_samples=1,
    guidance_scale=7.5,
    num_inference_steps=25,
    init_image=None,
    height=512,
    width=512,
) -> list[Image.Image]:
    if not endpoint_name:
        raise gr.Error("Please select (or deploy) a model first!")

    default_dimension = 256

    init_image = init_image.convert("RGB").resize(
        (default_dimension, default_dimension)
    )
    instances = [
        {
            "prompt": prompt,
            "negative_prompt": negative_prompt,
            "image": image_to_base64(init_image),
            "height": height,
            "width": width,
            "guidance_scale": guidance_scale,
            "num_inference_steps": num_inference_steps,
        },
    ]

    if len(instances) == 1 and num_samples > 1:
        instances = instances * num_samples

    response = get_endpoint(endpoint_name).predict(instances=instances)
    images = [base64_to_image(image) for image in response.predictions]

    return images


def select_interface(interface_name: str):
    if interface_name == "Text2Image pipeline":
        return {
            endpoint_name: gr.update(visible=True, value=None),
            prompt: gr.update(visible=True, value=None),
            negative_prompt: gr.update(visible=True, value=None),
            image_input: gr.update(visible=False, value=None),
            generate_button: gr.update(visible=True),
            inpaint_generate_button: gr.update(visible=False),
            instruct_pix2pix_generate_button: gr.update(visible=False),
            upscaler_generate_button: gr.update(visible=False),
        }

    elif interface_name == "Inpaint Pipeline":
        return {
            endpoint_name: gr.update(visible=True, value=None),
            prompt: gr.update(visible=True, value=None),
            negative_prompt: gr.update(visible=True, value=None),
            image_input: gr.update(visible=True, value=None, tool="sketch"),
            generate_button: gr.update(visible=False),
            inpaint_generate_button: gr.update(visible=True),
            instruct_pix2pix_generate_button: gr.update(visible=False),
            upscaler_generate_button: gr.update(visible=False),
        }

    elif interface_name == "Instruct pix2pix Pipeline":
        return {
            endpoint_name: gr.update(visible=True, value=None),
            prompt: gr.update(visible=True, value=None),
            negative_prompt: gr.update(visible=True, value=None),
            image_input: gr.update(visible=True, value=None, tool="None"),
            generate_button: gr.update(visible=False),
            inpaint_generate_button: gr.update(visible=False),
            instruct_pix2pix_generate_button: gr.update(visible=True),
            upscaler_generate_button: gr.update(visible=False),
        }

    elif interface_name == "SD 4x Upscaler Pipeline":
        return {
            endpoint_name: gr.update(visible=True, value=None),
            prompt: gr.update(visible=True, value=None),
            negative_prompt: gr.update(visible=True, value=None),
            image_input: gr.update(visible=True, value=None, tool="None"),
            generate_button: gr.update(visible=False),
            inpaint_generate_button: gr.update(visible=False),
            instruct_pix2pix_generate_button: gr.update(visible=False),
            upscaler_generate_button: gr.update(visible=True),
        }


def update_default_parameters(model_name: str):
    """Updates the default inference parameters based on the selected model."""
    return {
        guidance_scale: gr.update(value=get_default_guidance_scale(model_name)),
        num_inference_steps: gr.update(
            value=get_default_num_inference_steps(model_name)
        ),
        height: gr.update(value=get_default_dimension(model_name)),
        width: gr.update(value=get_default_dimension(model_name)),
    }


with gr.Blocks(
    theme=gr.themes.Default(primary_hue="orange", secondary_hue="blue")
) as demo:
    gr.Markdown("# Stable Diffusion Playground ")

    with gr.Tab("Tasks"):
        interfaces_box = gr.Radio(
            show_label=False,
            choices=[
                "Text2Image pipeline",
                "Inpaint Pipeline",
                "Instruct pix2pix Pipeline",
                "SD 4x Upscaler Pipeline",
            ],
            value="Text2Image pipeline",
        )

    with gr.Row(equal_height=True):
        with gr.Column(scale=3):
            prompt = gr.Textbox(label="Prompt", lines=1)
            negative_prompt = gr.Textbox(label="Negative Prompt", lines=1)
        with gr.Column(scale=1):
            endpoint_name = gr.Dropdown(
                label="Select a model previously deployed on Vertex",
                choices=list_endpoints(),
                value=None,
            )
            with gr.Row():
                selected_model = gr.Dropdown(
                    scale=7,
                    label="Deploy a new model to Vertex",
                    choices=[
                        "runwayml/stable-diffusion-v1-5",
                        "stabilityai/stable-diffusion-2-1",
                        "stabilityai/stable-diffusion-xl-base-1.0",
                        "stabilityai/stable-diffusion-xl-base-1.0 - refiner",
                        "latent-consistency/lcm-sdxl",
                        "latent-consistency/lcm-lora-sdxl",
                        "stabilityai/sdxl-turbo",
                        "runwayml/stable-diffusion-inpainting",
                        "kandinsky-community/kandinsky-2-2-decoder-inpaint",
                        "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
                        "timbrooks/instruct-pix2pix",
                        "stabilityai/stable-diffusion-x4-upscaler",
                    ],
                    value=None,
                )
                deploy_model_button = gr.Button(
                    "Deploy", scale=1, variant="primary", min_width=10
                )

    with gr.Row(equal_height=True):
        with gr.Column(scale=1):
            generate_button = gr.Button("Generate", variant="primary")
            inpaint_generate_button = gr.Button(
                "Generate", variant="primary", visible=False
            )
            instruct_pix2pix_generate_button = gr.Button(
                "Generate", variant="primary", visible=False
            )
            upscaler_generate_button = gr.Button(
                "Generate", variant="primary", visible=False
            )

            num_samples = gr.Slider(
                label="Number of samples", value=1, step=1, minimum=1, maximum=4
            )
            height = gr.Slider(
                label="Height", value=768, step=256, minimum=512, maximum=1024
            )
            width = gr.Slider(
                label="Width", value=768, step=256, minimum=512, maximum=1024
            )
            num_inference_steps = gr.Slider(
                label="Sampling steps", value=25, step=1, minimum=1, maximum=100
            )
            guidance_scale = gr.Slider(
                label="Guidance scale", value=7.5, step=0.5, minimum=0, maximum=20.0
            )

        with gr.Column(scale=3):
            with gr.Row(equal_height=True):
                image_input = gr.Image(
                    source="upload",
                    tool="sketch",
                    type="pil",
                    label="Upload",
                    visible=False,
                    height=400,
                )
                image_output = gr.Gallery(
                    show_label=False, rows=1, height=400, preview=True
                )

    interfaces_box.change(
        select_interface,
        interfaces_box,
        [
            endpoint_name,
            prompt,
            negative_prompt,
            image_input,
            generate_button,
            inpaint_generate_button,
            instruct_pix2pix_generate_button,
            upscaler_generate_button,
        ],
    )

    endpoint_name.change(
        update_default_parameters,
        endpoint_name,
        [
            guidance_scale,
            num_inference_steps,
            height,
            width,
        ],
    )

    deploy_model_button.click(
        deploy_model,
        inputs=[selected_model],
        outputs=[],
    )

    generate_button.click(
        generate_images,
        inputs=[
            endpoint_name,
            prompt,
            negative_prompt,
            num_samples,
            guidance_scale,
            num_inference_steps,
            height,
            width,
        ],
        outputs=image_output,
    )

    inpaint_generate_button.click(
        inpaint_generate_images,
        inputs=[
            endpoint_name,
            prompt,
            negative_prompt,
            num_samples,
            guidance_scale,
            num_inference_steps,
            image_input,
            height,
            width,
        ],
        outputs=image_output,
    )

    instruct_pix2pix_generate_button.click(
        instruct_pix2pix_generate_images,
        inputs=[
            endpoint_name,
            prompt,
            negative_prompt,
            num_samples,
            guidance_scale,
            num_inference_steps,
            image_input,
            height,
            width,
        ],
        outputs=image_output,
    )

    upscaler_generate_button.click(
        upscaler_generate_images,
        inputs=[
            endpoint_name,
            prompt,
            negative_prompt,
            num_samples,
            guidance_scale,
            num_inference_steps,
            image_input,
            height,
            width,
        ],
        outputs=image_output,
    )

demo.queue(concurrency_count=5, max_size=10)
demo.launch(share=True, inline=False, inbrowser=True, debug=True, show_error=True)