In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI Model Garden - Chat Completions With Streaming Playground

<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fvertex-ai-samples%2Fmain%2Fnotebooks%2Fcommunity%2Fmodel_garden%2Fmodel_garden_gradio_streaming_chat_completions.ipynb">
      <img alt="Google Cloud Colab Enterprise logo" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" width="32px"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gradio_streaming_chat_completions.ipynb">
      <img alt="GitHub logo" src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" width="32px"><br> View on GitHub
    </a>
  </td>
</tr></tbody></table>

## Overview

This notebook demonstrates starting a playground based on [Gradio UI](https://www.gradio.app/) that allows users to interact with the instruction-tuned text generation models via a chatbot UI more easily.

### Objective

- Chat with instruction-tuned text generation models deployed on the [Vertex Online Prediction](https://cloud.google.com/vertex-ai/docs/predictions/get-online-predictions) endpoints.
- (Optional) One-click deploy demo models to [Vertex Online Prediction](https://cloud.google.com/vertex-ai/docs/predictions/get-online-predictions) endpoints.

### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing), [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.

## Run the notebook

In [None]:
# @title Setup Google Cloud project and install dependencies
import os

from google.cloud import aiplatform

! pip3 install --upgrade gradio~=4.40.0

# Get the default cloud project id.
PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"]

# Get the default region for endpoints.
REGION = os.environ["GOOGLE_CLOUD_REGION"]

aiplatform.init(project=PROJECT_ID, location=REGION)

In [None]:
# @title Start the playground

# @markdown This is a chatbot playground for instruction-tuned text generation models.
# @markdown After the cell runs, this playground is available in a separate browser tab if you click the public URL,
# @markdown i.e. ["https://####.gradio.live"](#) in the output of the cell.

# @markdown **How to use:**
# @markdown 1. **Important**: Notebook cell reruns create new public URLs. Previous URLs will stop working.
# @markdown 1. Before you start, you need to select a Vertex prediction endpoint with a matching model
# @markdown from the endpoint dropdown list in the same project and region where you run this notebook.
# @markdown 1. This playground only supports new deployments with
# @markdown text-generation-inference (`us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-hf-tgi-serve`),
# @markdown vLLM (`us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve`),
# @markdown or HexLLM (`us-docker.pkg.dev/vertex-ai-restricted/vertex-vision-model-garden-dockers/hex-llm-serve`).
# @markdown
# @markdown    **Endpoints deployed with older serving containers or before August 20, 2024 might not work**. We recommend deploying a new endpoint from the listed demo models inside the Gradio app.
# @markdown 1. After experiments, do not forget to undeploy the models from [Vertex Online Prediction](https://console.cloud.google.com/vertex-ai/online-prediction/endpoints) to avoid continuous charges to the project.

import dataclasses
import json
from typing import Callable, Tuple

import gradio as gr
import requests
from google.cloud import aiplatform

MAX_TOKENS = 512
HF_TOKEN = ""

VLLM_DOCKER_URI = "us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20240819_0916_RC00"
TGI_DOCKER_URI = "us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-hf-tgi-serve:20240820_0936_RC01"

SERVER_TYPE_VLLM = "vllm"
SERVER_TYPE_HEXLLM = "hex-llm"
SERVER_TYPE_TGI = "tgi"
SERVER_TYPES = [
    SERVER_TYPE_VLLM,
    SERVER_TYPE_HEXLLM,
    SERVER_TYPE_TGI,
]


@dataclasses.dataclass
class Endpoint:
    display_name: str
    location: str
    resource_name: str
    server_type: str


PLAYGROUND_ENDPOINTS = []


@dataclasses.dataclass
class DeployConfig:
    display_name: str
    model_name: str
    func: Callable[[str], tuple[aiplatform.Model, aiplatform.Endpoint]]


def deploy_model_vllm(
    model_name: str,
    model_id: str,
    service_account: str,
    base_model_id: str = None,
    machine_type: str = "g2-standard-8",
    accelerator_type: str = "NVIDIA_L4",
    accelerator_count: int = 1,
    gpu_memory_utilization: float = 0.9,
    max_model_len: int = 4096,
    dtype: str = "auto",
) -> Tuple[aiplatform.Model, aiplatform.Endpoint]:
    """Deploys trained models with vLLM into Vertex AI."""
    endpoint = aiplatform.Endpoint.create(display_name=f"{model_name}-endpoint")

    if not base_model_id:
        base_model_id = model_id

    vllm_args = [
        "python",
        "-m",
        "vllm.entrypoints.api_server",
        "--host=0.0.0.0",
        "--port=7080",
        f"--model={model_id}",
        f"--tensor-parallel-size={accelerator_count}",
        "--swap-space=16",
        f"--gpu-memory-utilization={gpu_memory_utilization}",
        f"--max-model-len={max_model_len}",
        f"--dtype={dtype}",
        "--disable-log-stats",
    ]

    env_vars = {
        "MODEL_ID": base_model_id,
        "DEPLOY_SOURCE": "notebook",
    }

    # HF_TOKEN is not a compulsory field and may not be defined.
    try:
        if HF_TOKEN:
            env_vars["HF_TOKEN"] = HF_TOKEN
    except NameError:
        pass

    model = aiplatform.Model.upload(
        display_name=model_name,
        serving_container_image_uri=VLLM_DOCKER_URI,
        serving_container_args=vllm_args,
        serving_container_ports=[7080],
        serving_container_predict_route="/generate",
        serving_container_health_route="/ping",
        serving_container_environment_variables=env_vars,
        serving_container_shared_memory_size_mb=(16 * 1024),  # 16 GB
        serving_container_deployment_timeout=7200,
    )
    print(
        f"Deploying {model_name} on {machine_type} with {accelerator_count} {accelerator_type} GPU(s)."
    )
    model.deploy(
        endpoint=endpoint,
        machine_type=machine_type,
        accelerator_type=accelerator_type,
        accelerator_count=accelerator_count,
        deploy_request_timeout=1800,
        service_account=service_account,
    )
    print("endpoint_name:", endpoint.name)

    return model, endpoint


def deploy_model_tgi(
    model_name: str,
    model_id: str,
    service_account: str,
    machine_type: str = "g2-standard-8",
    accelerator_type: str = "NVIDIA_L4",
    accelerator_count: int = 1,
    max_input_length: int = 2047,
    max_total_tokens: int = 2048,
    max_batch_prefill_tokens: int = 2048,
) -> Tuple[aiplatform.Model, aiplatform.Endpoint]:
    """Deploys models with TGI on GPU in Vertex AI."""
    endpoint = aiplatform.Endpoint.create(display_name=f"{model_name}-endpoint")

    env_vars = {
        "MODEL_ID": model_id,
        "NUM_SHARD": f"{accelerator_count}",
        "MAX_INPUT_LENGTH": f"{max_input_length}",
        "MAX_TOTAL_TOKENS": f"{max_total_tokens}",
        "MAX_BATCH_PREFILL_TOKENS": f"{max_batch_prefill_tokens}",
        "DEPLOY_SOURCE": "notebook",
    }

    # HF_TOKEN is not a compulsory field and may not be defined.
    try:
        if HF_TOKEN:
            env_vars["HF_TOKEN"] = HF_TOKEN
    except NameError:
        pass

    model = aiplatform.Model.upload(
        display_name=model_name,
        serving_container_image_uri=TGI_DOCKER_URI,
        serving_container_ports=[80],
        serving_container_environment_variables=env_vars,
        serving_container_shared_memory_size_mb=(16 * 1024),  # 16 GB
    )

    model.deploy(
        endpoint=endpoint,
        machine_type=machine_type,
        accelerator_type=accelerator_type,
        accelerator_count=accelerator_count,
        deploy_request_timeout=1800,
        service_account=service_account,
    )
    return model, endpoint


DEPLOY_CONFIGS = [
    DeployConfig(
        display_name="microsoft/Phi-3-mini-4k-instruct (vLLM)",
        model_name="vllm-Phi-3-mini-4k-instruct-endpoint",
        func=lambda x: deploy_model_vllm(x, "microsoft/Phi-3-mini-4k-instruct", None),
    ),
    DeployConfig(
        display_name="Qwen/Qwen2-7B-Instruct (TGI)",
        model_name="tgi-Qwen2-7B-Instruct-endpoint",
        func=lambda x: deploy_model_tgi(x, "Qwen/Qwen2-7B-Instruct", None),
    ),
]


def get_server_type(endpoint: aiplatform.Endpoint) -> str | None:
    """Returns the model server type or None if not recognizable."""
    models = endpoint.list_models()
    models: list[aiplatform.Model] = [aiplatform.Model(m.model) for m in models]
    for server_type in SERVER_TYPES:
        if any(server_type in model.container_spec.image_uri for model in models):
            return server_type
    return None


def format_payload(messages: list[dict[str, str]]) -> dict[str, str]:
    return {
        "messages": messages,
        "max_tokens": MAX_TOKENS,
        "stream": True,
    }


def list_endpoints() -> list[tuple[str, str]]:
    """Returns all valid prediction endpoints for in the project and region."""
    endpoints = [
        endpoint
        for endpoint in aiplatform.Endpoint.list(order_by="create_time desc")
        if endpoint.traffic_split and get_server_type(endpoint)
    ]
    endpoints = [(e.display_name, e.resource_name) for e in endpoints]
    endpoints.extend((e.display_name, e.resource_name) for e in PLAYGROUND_ENDPOINTS)
    return endpoints


class StreamingClient:
    """A wrapper for a streaming client."""

    endpoint: Endpoint | None = None

    def set_endpoint(self, endpoint: str):
        """Sets the prediction endpoint."""
        playground_endpoint = [
            e for e in PLAYGROUND_ENDPOINTS if e.resource_name == endpoint
        ]
        if not playground_endpoint:
            vertex_endpoint = aiplatform.Endpoint(endpoint)
            server_type = get_server_type(vertex_endpoint)
            self.endpoint = Endpoint(
                display_name=vertex_endpoint.display_name,
                location=vertex_endpoint.location,
                resource_name=endpoint,
                server_type=server_type,
            )
        else:
            self.endpoint = playground_endpoint[0]
        print(
            "Selected endpoint:",
            self.endpoint.resource_name,
            "Server:",
            self.endpoint.server_type,
        )

    def predict(self, message: str, chat_history: list[tuple[str, str]]):
        if not self.endpoint:
            raise gr.Error("Select an endpoint first.")

        url = f"https://{self.endpoint.location}-aiplatform.googleapis.com/v1beta1/{self.endpoint.resource_name}/chat/completions"
        messages = []
        for u, a in chat_history:
            messages.append({"role": "user", "content": u})
            messages.append({"role": "assistant", "content": a})
        messages.append({"role": "user", "content": message})
        payload = format_payload(messages)
        access_token = ! gcloud auth print-access-token
        access_token = access_token[0]
        response = requests.post(
            url,
            headers={"Authorization": f"Bearer {access_token}"},
            json=payload,
            stream=True,
        )
        if not response.ok:
            raise gr.Error(response)
        prediction = ""
        for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False):
            if chunk:
                chunk = chunk.decode("utf-8").removeprefix("data:").strip()
                if chunk == "[DONE]":
                    break
                data = json.loads(chunk)
                if type(data) is not dict or "error" in data:
                    raise gr.Error(data)
                delta = data["choices"][0]["delta"].get("content")
                if delta:
                    prediction += delta
                    yield prediction


streaming_client = StreamingClient()


def create_endpoint_selector():
    """Creates a dropdown list of prediction endpoints."""

    with gr.Row():
        endpoints_dropdown = gr.Dropdown(
            list_endpoints(),
            label="Endpoint",
            scale=1,
            info="Only TGI, vLLM, and HexLLM endpoints deployed after August 20, 2024 with a new container image support chat completions and streaming features. "
            + "If you are not sure, you can deploy a demo endpoint directly from below. ",
        )
        endpoints_dropdown.input(
            streaming_client.set_endpoint, inputs=[endpoints_dropdown], outputs=[]
        )
        refresh_btn = gr.Button("Refresh", scale=0)
        refresh_btn.click(
            lambda: gr.Dropdown(choices=list_endpoints()),
            inputs=[],
            outputs=[endpoints_dropdown],
        )


def create_deploy_selector():
    """Creates a dropdown list of model deploy configs."""

    def find_deploy_config(display_name: str) -> DeployConfig:
        """Finds the deploy config from display name."""
        matches = [c for c in DEPLOY_CONFIGS if c.display_name == display_name]
        if not matches:
            raise gr.Error("Select a model to deploy first.")
        return matches[0]

    def deploy(endpoint_name: str, display_name: str):
        """Deploys the model."""
        config = find_deploy_config(display_name)
        gr.Info(f"Deploying to {endpoint_name}...")
        config.func(endpoint_name)
        gr.Info(f"Deployed to {endpoint_name}. Refresh the endpoints to see it.")

    with gr.Row():
        deploy_dropdown = gr.Dropdown(
            [x.display_name for x in DEPLOY_CONFIGS],
            label="Deploy Model",
            scale=1,
            info="Model deployment will take ~20 minutes. After you finish your experiments, "
            + "undeploy the endpoint from Vertex Online Prediction to avoid continuous charges to the project.",
        )
        model_name = gr.Textbox(
            label="Model Name",
            placeholder="Enter a custom model name for endpoint creation",
            interactive=True,
        )
        deploy_dropdown.change(
            lambda x: find_deploy_config(x).model_name,
            inputs=[deploy_dropdown],
            outputs=[model_name],
        )

        deploy_btn = gr.Button("Deploy", scale=0)
        deploy_btn.click(
            lambda: gr.Button("Deploying...", interactive=False),
            inputs=[],
            outputs=[deploy_btn],
        ).then(deploy, inputs=[model_name, deploy_dropdown], outputs=[]).then(
            lambda: gr.Button("Deploy", interactive=True), [], [deploy_btn]
        )


with gr.Blocks(title="Vertex Model Garden Chat", fill_height=True) as demo:
    create_endpoint_selector()
    create_deploy_selector()
    gr.ChatInterface(streaming_client.predict)


show_debug_logs = True  # @param {type: "boolean"}
demo.queue()
demo.launch(share=True, inline=False, debug=show_debug_logs, show_error=True)