In [2]:
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI Model Garden - Model Co-hosting Serving

<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/notebooks/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/vertex-ai-samples/main/notebooks/community/model_garden/model_garden_model_cohost.ipynb">
      <img alt="Workbench logo" src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" width="32px"><br> Run in Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fvertex-ai-samples%2Fmain%2Fnotebooks%2Fcommunity%2Fmodel_garden%2Fmodel_garden_model_cohost.ipynb">
      <img alt="Google Cloud Colab Enterprise logo" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" width="32px"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_model_cohost.ipynb">
      <img alt="GitHub logo" src="https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png" width="32px"><br> View on GitHub
    </a>
  </td>
</tr></tbody></table>

## Overview

This notebook provides a step-by-step guide to (1) single-model multi-replica serving, and (2) multi-model serving. For single-model multi-replica serving, the notebook demonstrates a container-level solution using the Model Garden vLLM model co-hosting container and an infrastructure-level solution using pod co-scheduling and NVIDIA Multi-Instance GPU (MIG). For multi-model serving, the notebook demonstrates a container-level solution using the Model Garden vLLM model co-hosting container. The notebook additionally demonstrates finding the optimal serving recipe for the Model Garden vLLM model co-hosting container using a benchmark utility.

### Objective

The goal is to efficiently serve a single model with multiple replicas and serve multiple models on a full-shape VM, and to automate the process of testing various serving strategies (pipeline parallelism, tensor parallelism, and creating model replicas) to identify the recipe that provides the best performance (throughput and latency). We will then deploy the winning recipe to a Vertex AI Endpoint.

### Steps

#### Single-model Multi-replica Serving

1.  **Setup**: Install libraries, authenticate with Google Cloud, and configure your environment.
1.  **Prepare Benchmark Files**: Prepare benchmark files.
1.  **Run Benchmark**: Execute the benchmark utility to test different serving configurations under various concurrencies.
1.  **Review Reference Benchmark Results [Case Study]**: Review a reference set of benchmark results to learn how to interpret benchmark results and learn heuristics for optimal serving recipes.
1.  **Analyze Benchmark Results**: Analyze the generated benchmark outputs, visualize the performance metrics, and select the optimal serving recipe.
1.  **Deploy to Vertex AI and Test the Endpoint**: Upload the model to the Vertex AI Model Registry and deploy it to an Endpoint following the optimal serving recipe. Send a prediction request to the newly deployed endpoint.
1.  **[Alternative Solution: Pod Co-scheduling + MIG] Review Reference Benchmark Results**: Review a reference set of benchmark results to understand the performance scaling of pod co-scheduling + MIG.
1.  **[Alternative Solution: Pod Co-scheduling + MIG] Deploy to Vertex AI and Test the Endpoint**: Use the infrastructure-level solution, pod co-scheduling + MIG, to deploy the model with multiple replicas following the optimal serving recipe. Send a prediction request to the newly deployed endpoint.
1.  **Clean Up**: Delete the created Vertex AI resources.

#### Multi-model Serving

1.  **Setup**: Install libraries, authenticate with Google Cloud, and configure your environment.
1.  **Learn to Configure the Model Co-hosting Server**: Learn about the Model Garden model co-hosting server and how to use it to serve multiple models with the same container.
1.  **Deploy to Vertex AI and Test the Endpoint**: Upload the model to the Vertex AI Model Registry and deploy it to an Endpoint following the optimal serving recipe. Send a prediction request to the newly deployed endpoint.
1.  **Clean Up**: Delete the created Vertex AI resources.

### File a bug

File a bug on [GitHub](https://github.com/GoogleCloudPlatform/vertex-ai-samples/issues/new) if you encounter any issue with the notebook.

### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing), [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.

## Single-model Multi-replica Serving

## 1. Setup

First, let's install the necessary packages and set up your Google Cloud project environment.

In [None]:
# @title Setup Google Cloud project

# @markdown 1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

# @markdown 2. **[Optional]** Set region. If not set, the region will be set automatically according to Colab Enterprise environment.

REGION = ""  # @param {type:"string"}

# @markdown 3. If you want to run predictions with H100 GPUs or H200 GPUs, we recommend using the regions listed below. **NOTE:** Make sure you have associated quota in selected regions. Click the links to see your current quota for H100s: [`CustomModelServingH100GPUsPerProjectPerRegion`](https://console.cloud.google.com/iam-admin/quotas?metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_h100_gpus) and H200s: [`CustomModelServingH200GPUsPerProjectPerRegion`](https://console.cloud.google.com/iam-admin/quotas?metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_h200_gpus). You can request for quota following the instructions at ["Request a higher quota"](https://cloud.google.com/docs/quota/view-manage#requesting_higher_quota).

# @markdown | Machine Type | Accelerator Type | Recommended Regions |
# @markdown | ----------- | ----------- | ----------- |
# @markdown | a3-highgpu-8g | 8 NVIDIA_H100_80GB | asia-southeast1, europe-west4, us-central1, us-east5, us-west1 |
# @markdown | a3-ultragpu-8g | 8 NVIDIA_H200_141GB | asia-south2, us-south1 |

# Upgrade Vertex AI SDK.
! pip3 install --upgrade --quiet 'google-cloud-aiplatform==1.103.0'
! pip3 install --upgrade --quiet aiohttp matplotlib pandas seaborn

# Import the necessary packages
import importlib
import os
from typing import Tuple

import requests
from google import auth
from google.cloud import aiplatform

# Upgrade Vertex AI SDK.
if os.environ.get("VERTEX_PRODUCT") != "COLAB_ENTERPRISE":
    ! pip install --upgrade tensorflow
! git clone https://github.com/GoogleCloudPlatform/vertex-ai-samples.git

common_util = importlib.import_module(
    "vertex-ai-samples.notebooks.community.model_garden.docker_source_codes.notebook_util.common_util"
)

models, endpoints = {}, {}

# Get the default cloud project id.
PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"]

# Get the default region for launching jobs.
if not REGION:
    REGION = os.environ["GOOGLE_CLOUD_REGION"]

# Initialize Vertex AI API.
print("Initializing Vertex AI API.")
aiplatform.init(project=PROJECT_ID, location=REGION)

! gcloud config set project $PROJECT_ID

import vertexai

vertexai.init(
    project=PROJECT_ID,
    location=REGION,
)

## 2. Prepare Benchmark Files

The core of this workflow is a benchmark utility that automates the benchmark process for single-model multi-replica serving. The benchmark utility launches the vLLM model server through Docker and launches a benchmark client that sends prediction requests to the model server. This utility depends on a Python script that implements the benchmark client and a benchmark dataset. In this section, we prepare the necessary benchmark files.

### The Main Benchmark Utility

The utility takes as input (1) the vLLM container version and model for
launching the vLLM server and (2) the benchmark setup (benchmark script,
dataset, input length, output length, number of prompts, and concurrencies) for
launching the benchmark client. The utility launches the vLLM server with
docker, waits for the vLLM server to be ready, and then launches the benchmark
client. When launching the benchmark client, the utility iterates over
different possible combinations of tensor parallel size and number of model
replicas settings, under different concurrencies. In addition, the utility
allows the definition of maximum latency metrics. If the maximum latencies
are set, the utility checks whether each benchmark run satisfies the latencies
and marks it accordingly in the benchmark results. If set, the utility skips
larger concurrencies if one or more maximum latencies are not met at some
concurrency. The utility generates an analysis figure plotting metrics versus
concurrencies.

In [None]:
%%writefile benchmark_util.py
"""Utility for benchmarking vLLM under different setups and concurrencies.

The utility takes as input (1) the vLLM container version and model for
launching the vLLM server and (2) the benchmark setup (benchmark script,
dataset, input length, output length, number of prompts, and concurrencies) for
launching the benchmark client. The utility launches the vLLM server with
docker, waits for the vLLM server to be ready, and then launches the benchmark
client. When launching the benchmark client, the utility iterates over
different possible combinations of tensor parallel size and number of model
replicas settings, under different concurrencies. In addition, the utility
allows the definition of maximum latency metrics. If the maximum latencies
are set, the utility checks whether each benchmark run satisfies the latencies
and marks it accordingly in the benchmark results. If --no-skip-concurrencies-given-latency
is set, the utility skips larger concurrencies if one or more maximum latencies
are not met at some concurrency. The utility generates an analysis figure
plotting metrics versus concurrencies.

Sample command:

python benchmark_util.py \
  --total-gpus 8 \
  --input-length 1200 \
  --output-length 250 \
  --num-prompts 2000 \
  --sonnet-prefix-len 49 \
  --concurrencies 1 8 16 \
  --max-median-ttft-ms 1000 \
  --max-p99-ttft-ms 10000 \
  --max-median-tpot-ms 100 \
  --max-p99-tpot-ms 1000 \
  --model /path/to/model \
  --docker-uri us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20250808_0916_RC01_maas \
  --server-init-timeout 600 \
  --benchmark-script-path /path/to/benchmark_serving.py \
  --dataset-path /path/to/sonnet.txt \
  --results-output-path /path/to/benchmark_results.csv \
  --figure-output-path /path/to/benchmark_figure.png \
  --no-skip-concurrencies-given-latency
"""
import argparse
import json
import os
import subprocess
import time
from typing import List

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns


BENCHMARK_BACKEND = "chat_completions"
METRICS_TO_PLOT = [
    "request_throughput",
    "input_throughput",
    "output_throughput",
    "median_latency_ms",
    "median_ttft_ms",
    "median_tpot_ms",
]

parser = argparse.ArgumentParser(
    description="vLLM Docker benchmark script."
)
parser.add_argument(
    "--total-gpus",
    type=int,
    default=8,
    help="Total number of GPUs available on the machine.",
)
parser.add_argument(
    "--input-length",
    type=int,
    default=1200,
    help="Benchmark input length.",
)
parser.add_argument(
    "--output-length",
    type=int,
    default=250,
    help="Benchmark output length.",
)
parser.add_argument(
    "--num-prompts",
    type=int,
    default=100,
    help="Number of prompts to use in benchmark.",
)
parser.add_argument(
    "--sonnet-prefix-len",
    type=int,
    default=30,
    help="Number of prefix tokens per request, used for sonnet dataset.",
)
parser.add_argument(
    "--concurrencies",
    type=int,
    nargs="+",
    default=[1, 8, 16, 32, 64, 128],
    help="List of target concurrencies to test.",
)
parser.add_argument(
    "--max-median-ttft-ms",
    type=float,
    default=None,
    help="Maximum allowed median Time to First Token (TTFT) in milliseconds.",
)
parser.add_argument(
    "--max-p99-ttft-ms",
    type=float,
    default=None,
    help="Maximum allowed P99 Time to First Token (TTFT) in milliseconds.",
)
parser.add_argument(
    "--max-median-tpot-ms",
    type=float,
    default=None,
    help=(
        "Maximum allowed median Time per Output Token (TPOT) in milliseconds."
    ),
)
parser.add_argument(
    "--max-p99-tpot-ms",
    type=float,
    default=None,
    help="Maximum allowed P99 Time per Output Token (TPOT) in milliseconds.",
)
parser.add_argument(
    "--model",
    type=str,
    required=True,
    help="Local path to the model or HuggingFace model ID.",
)
parser.add_argument(
    "--docker-uri",
    type=str,
    default=(
        "us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20250808_0916_RC01_maas",
    ),
    help="Docker image URI for the vLLM server.",
)
parser.add_argument(
    "--server-init-timeout",
    type=int,
    default=600,
    help="Timeout limit (in seconds) for server initialization.",
)
parser.add_argument(
    "--benchmark-script-path",
    type=str,
    required=True,
    help="Path to the benchmark_serving.py script.",
)
parser.add_argument(
    "--dataset-path",
    type=str,
    required=True,
    help="Path to the benchmark dataset.",
)
parser.add_argument(
    "--results-output-path",
    type=str,
    required=True,
    help="Path to output benchmark results.",
)
parser.add_argument(
    "--figure-output-path",
    type=str,
    required=True,
    help="Path to output the analysis figure.",
)
parser.add_argument(
    "--skip-concurrencies-given-latency",
    action=argparse.BooleanOptionalAction,
    default=False,
    help=(
        "Skip larger concurrencies when one of more latency requirements are "
        "not met at a concurrency."
    ),
)
args = parser.parse_args()


def wait_for_server(container_id: str, timeout: int = 1200) -> bool:
    """
    Polls the Docker container's logs to wait for the server startup message.
    """
    start_time = time.time()
    while True:
        if time.time() - start_time > timeout:
            print(f"Error: Server did not start within {timeout} seconds.")
            return False

        try:
            # Use 'docker logs --tail 1' to check the last log line
            output = subprocess.check_output(
                ["docker", "logs", "--tail", "1", container_id],
                text=True,
                stderr=subprocess.STDOUT
            ).strip()

            if "Application startup complete" in output:
                time.sleep(5)  # Wait for 5 seconds for all model servers
                print("vLLM server is ready! 🚀")
                return True
        except subprocess.CalledProcessError as e:
            print(
                f"Error checking logs for container {container_id}: {e.output}"
            )
            return False

        time.sleep(5)  # Wait for 5 seconds before checking again


def run_benchmark(
    pp_size: int,
    tp_size: int,
    model_replicas: int,
    concurrency_list: List[int],
    input_length: int,
    output_length: int,
    num_prompts: int,
    sonnet_prefix_len: int,
    model: str,
    server_init_timeout: int,
    benchmark_script: str,
    dataset: str,
    vllm_host: str = "0.0.0.0",
    vllm_port: int = 7080,
    max_median_ttft_ms: float = None,
    max_p99_ttft_ms: float = None,
    max_median_tpot_ms: float = None,
    max_p99_tpot_ms: float = None,
    skip_concurrencies_given_latency: bool = False,
) -> pd.DataFrame:
    """Launches the vLLM server and benchmark client."""

    print(
        f"Starting vLLM server with PP={pp_size}, TP={tp_size} and Replicas="
        f"{model_replicas}..."
    )

    if model_replicas > 1:
        api_server = "vllm.entrypoints.nginx_server"
    else:
        api_server = "vllm.entrypoints.api_server"

    # Prepare vLLM server command
    vllm_server_cmd = [
        "python", "-m", api_server,
        f"--host={vllm_host}",
        f"--port={vllm_port}",
        f"--model={model}",
        f"--pipeline-parallel-size={pp_size}",
        f"--tensor-parallel-size={tp_size}",
        f"--data-parallel-size=1",
        "--swap-space=16",
        "--gpu-memory-utilization=0.9",
        "--no-enable-prefix-caching",
    ]

    if model_replicas > 1:
        vllm_server_cmd.extend([
            f"--num_instances={model_replicas}",
            f"--total_gpus={int(tp_size * model_replicas)}"
        ])

    # Prepare Docker command
    docker_uri = args.docker_uri
    gpu_devices = ",".join([str(i) for i in range(args.total_gpus)])
    docker_cmd = [
        "docker", "run",
        "--entrypoint", "bash",
        "-e", f"NVIDIA_VISIBLE_DEVICES={gpu_devices}",
        "--gpus", "all",
        "--network=host",
        "-v", f"{os.path.expanduser('~')}:{os.path.expanduser('~')}",
        "--shm-size", "19.2gb",
        "-itd",  # Run in detached mode
        docker_uri,
        "-c", " ".join(vllm_server_cmd)
    ]

    # Start the vLLM server inside Docker
    try:
        print("Running Docker command:", " ".join(docker_cmd))
        container_id = subprocess.check_output(docker_cmd, text=True).strip()
        print(f"vLLM server starting in container: {container_id}")

        # Wait for the server to be ready using the new function
        if not wait_for_server(
            container_id=container_id,
            timeout=server_init_timeout,
        ):
            return pd.DataFrame()

        all_results_df = pd.DataFrame()

        for concurrency in concurrency_list:
            print(f"Benchmarking with concurrency: {concurrency}")

            benchmark_cmd = [
                "python", benchmark_script,
                f"--backend={BENCHMARK_BACKEND}",
                f"--model={model}",
                f"--tokenizer={model}",
                f"--host={vllm_host}",
                f"--port={vllm_port}",
                f"--dataset={dataset}",
                f"--max-input-length={input_length}",
                f"--max-output-length={output_length}",
                f"--num-prompts={num_prompts}",
                f"--sonnet-prefix-len={sonnet_prefix_len}",
                f"--c={concurrency}",
                f"--output-dir={os.getcwd()}",
                f"--name={pp_size}_{tp_size}_{model_replicas}_{concurrency}",
            ]

            # Execute benchmark script
            result = subprocess.run(
                benchmark_cmd, capture_output=True, text=True, check=True
            )
            print("Benchmark command executed successfully.")

            full_results_filename = f"{pp_size}_{tp_size}_{model_replicas}_{concurrency}_aggregated_results.json"
            full_results_path = os.path.join(os.getcwd(), full_results_filename)

            if os.path.exists(full_results_path):
                df = pd.read_json(full_results_path, lines=True)

                # Add configuration columns
                df["pp_size"] = pp_size
                df["tp_size"] = tp_size
                df["model_replicas"] = model_replicas
                df["docker_cmd"] = " ".join(docker_cmd)
                df["benchmark_cmd"] = " ".join(benchmark_cmd)

                # Compare latency metrics against optional max latency requirements
                missed_latency_requirement = False
                if max_median_ttft_ms is not None:
                    median_ttft_ms = df['median_ttft_ms'].iloc[0] if not df['median_ttft_ms'].isnull().all() else float('inf')
                    df["median_ttft_ok"] = median_ttft_ms <= max_median_ttft_ms
                    if median_ttft_ms > max_median_ttft_ms:
                        missed_latency_requirement = True
                if max_p99_ttft_ms is not None:
                    p99_ttft_ms = df['p99_ttft_ms'].iloc[0] if not df['p99_ttft_ms'].isnull().all() else float('inf')
                    df["p99_ttft_ok"] = p99_ttft_ms <= max_p99_ttft_ms
                    if p99_ttft_ms > max_p99_ttft_ms:
                        missed_latency_requirement = True
                if max_median_tpot_ms is not None:
                    median_tpot_ms = df['median_tpot_ms'].iloc[0] if not df['median_tpot_ms'].isnull().all() else float('inf')
                    df["median_tpot_ok"] = median_tpot_ms <= max_median_tpot_ms
                    if median_tpot_ms > max_median_tpot_ms:
                        missed_latency_requirement = True
                if max_p99_tpot_ms is not None:
                    p99_tpot_ms = df['p99_tpot_ms'].iloc[0] if not df['p99_tpot_ms'].isnull().all() else float('inf')
                    df["p99_tpot_ok"] = p99_tpot_ms <= max_p99_tpot_ms
                    if p99_tpot_ms > max_p99_tpot_ms:
                        missed_latency_requirement = True

                all_results_df = pd.concat([all_results_df, df], ignore_index=True)
            else:
                print(f"Warning: Full results file not found at {full_results_path}")

            print(f"Benchmark for PP={pp_size}, TP={tp_size}, Replicas={model_replicas}, Concurrency={concurrency} complete.")

            if skip_concurrencies_given_latency and missed_latency_requirement:
                print(f"Latency requirement(s) not met at concurrency={concurrency}. Skip larger concurrencies.")
                break

    except subprocess.CalledProcessError as e:
        print(f"Error during benchmark: {e.stderr}")
        return pd.DataFrame()
    finally:
        # Stop and remove the Docker container
        subprocess.run(["docker", "stop", container_id], check=False, text=True)
        subprocess.run(["docker", "rm", container_id], check=False, text=True)
        print(f"Container {container_id} stopped and removed.")

    return all_results_df


def plot_metric_by_concurrency(
    results_path: str,
    target_metrics: List[str] = ["request_throughput"],
    figure_path: str = "benchmark_figure.png",
):
    """Creates analysis figures based on benchmark results."""
    # Load benchmark results
    all_results_df = pd.read_csv(results_path)

    # Create a new column to represent each model server setup
    all_results_df["Server Config"] = all_results_df.apply(
        lambda row: (
            f"PP={int(row['pp_size'])}, TP={int(row['tp_size'])}, Replicas="
            f"{int(row['model_replicas'])}"
        ),
        axis=1,
    )

    # Melt the DataFrame to a long format for easier plotting
    all_results_df_melted = all_results_df.melt(
        id_vars=["concurrent_requests", "Server Config"],
        value_vars=target_metrics,
        var_name="metric",
        value_name="value",
    )

    sns.set_theme(style="whitegrid")

    # Create the multi-plot grid using relplot
    g = sns.relplot(
        data=all_results_df_melted,
        x="concurrent_requests",
        y="value",
        hue="Server Config",
        col="metric",
        col_wrap=3,
        kind="line",
        marker="o",
        height=4,
        aspect=1.2,
        facet_kws={'sharey': False},
    )

    def get_formatted_name(metric_name):
        name_map = {
            'request_throughput': 'Request Throughput (req/s)',
            'output_throughput': 'Output Throughput (tok/s)',
            'input_throughput': 'Input Throughput (tok/s)',
            'median_latency_ms': 'Median E2E Latency (ms)',
            'p99_latency_ms': 'P99 E2E Latency (ms)',
            'median_ttft_ms': 'Median TTFT (ms)',
            'p99_ttft_ms': 'P99 TTFT (ms)',
            'median_tpot_ms': 'Median TPOT (ms)',
            'p99_tpot_ms': 'P99 TPOT (ms)'
        }
        return name_map.get(metric_name, metric_name)

    for ax in g.axes.flatten():
        ax.set_title(get_formatted_name(ax.get_title().replace("metric = ", "")))
        ax.set_xlabel("Concurrency")
        ax.set_ylabel("")

    g.figure.subplots_adjust(top=0.9)
    plt.suptitle(
        "Performance Metrics vs. Concurrency Across Server Configs",
        fontsize=20,
    )

    plt.savefig(figure_path, dpi=300)
    print(f"Analysis figure saved to {figure_path}.")


def main():
    total_gpus = args.total_gpus
    benchmark_script = args.benchmark_script_path

    # Check if the benchmark script exists
    if not os.path.exists(benchmark_script):
        print(f"Error: The benchmark script at {benchmark_script} does not exist.")
        return

    # Pull the target Docker image
    print("Pulling the Docker image...")
    try:
        subprocess.run(["docker", "pull", args.docker_uri], check=True, text=True)
        print("Docker image pulled successfully.")
    except subprocess.CalledProcessError as e:
        print(f"Error pulling Docker image: {e.stderr}")
        return

    final_results = pd.DataFrame()

    filtered_concurrency_list = list(set(
        [c for c in args.concurrencies if c > 0]
    ))
    filtered_concurrency_list = sorted(filtered_concurrency_list)

    # Iterate through different PP, TP and replica settings
    for pp_size in range(1, total_gpus + 1):
        for tp_size in range(1, total_gpus + 1):
            if total_gpus % (pp_size * tp_size) == 0:
                model_replicas = total_gpus // (pp_size * tp_size)

                results_for_config = run_benchmark(
                    pp_size=pp_size,
                    tp_size=tp_size,
                    model_replicas=model_replicas,
                    concurrency_list=filtered_concurrency_list,
                    input_length=args.input_length,
                    output_length=args.output_length,
                    num_prompts=args.num_prompts,
                    sonnet_prefix_len=args.sonnet_prefix_len,
                    model=args.model,
                    server_init_timeout=args.server_init_timeout,
                    benchmark_script=benchmark_script,
                    dataset=args.dataset_path,
                    max_median_ttft_ms=args.max_median_ttft_ms,
                    max_p99_ttft_ms=args.max_p99_ttft_ms,
                    max_median_tpot_ms=args.max_median_tpot_ms,
                    max_p99_tpot_ms=args.max_p99_tpot_ms,
                    skip_concurrencies_given_latency=args.skip_concurrencies_given_latency,
                )
                final_results = pd.concat([final_results, results_for_config], ignore_index=True)
                final_results.to_csv(args.results_output_path)
                print(f"Intermediate benchmark results saved to {args.results_output_path}")

    # Print the final results table
    if not final_results.empty:
        print("\n" + "="*80)
        print("Final Benchmark Results Summary")
        print("="*80)
        print(final_results.head().to_markdown(index=False)) # print first 5 rows
        print("... and so on")
        print("Total rows collected:", len(final_results))
        print("="*80)

        final_results.to_csv(args.results_output_path)
        print(f"Benchmark results saved to {args.results_output_path}")

    # Create analysis figure
    plot_metric_by_concurrency(
        results_path=args.results_output_path,
        target_metrics=METRICS_TO_PLOT,
        figure_path=args.figure_output_path,
    )


if __name__ == "__main__":
    main()


### Benchmark Client (`benchmark_serving.py`)

This script, called by the main utility, is responsible for sending concurrent requests to the vLLM server and measuring performance.

In [None]:
%%writefile benchmark_serving.py
"""Benchmark client for LLM serving."""

# pylint: disable=g-multiple-import
# pylint: disable=g-importing-member
# pylint: disable=logging-fstring-interpolation
# pylint: disable=f-string-without-interpolation

from abc import ABC
from abc import abstractmethod
import argparse
import asyncio
from collections.abc import AsyncGenerator
import dataclasses
from dataclasses import dataclass
from dataclasses import field
from datetime import datetime
import json
import logging
import os
import random
import sys
import time
import traceback
from typing import Any, Optional

import aiohttp
import numpy as np
import pandas as pd
from tenacity import RetryCallState, retry, stop_after_attempt, wait_exponential
from tqdm.asyncio import tqdm
from transformers import AutoTokenizer


CLIENT_TIMEOUT_SEC = 3 * 60 * 60
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=CLIENT_TIMEOUT_SEC)


class BaseTokenizer(ABC):
    """Abstract class for tokenizers."""

    @abstractmethod
    def encode(self, text: str, add_special_tokens: bool = True) -> list[int]:
        pass

    @abstractmethod
    def decode(self, token_ids: list[int]) -> str:
        pass

    @abstractmethod
    def apply_chat_template(
        self,
        message: list[dict[str, Any]],
        add_generation_prompt: bool = True,
        tokenize: bool = False,
    ) -> str:
        pass

    @abstractmethod
    def all_special_ids(self) -> list[int]:
        pass

    @abstractmethod
    def get_vocab(self) -> dict[str, int]:
        pass

    @abstractmethod
    def bos_token(self) -> str:
        pass


class Llama3Tokenizer(BaseTokenizer):
    """Llama3 specific tokenizer, based on Tiktoken.
    """

    def __init__(self, tokenizer_path: str):
        from saxml.server.pax.lm import vocabularies  # pylint: disable=g-import-not-at-top

        self._tokenizer = vocabularies.LLama3Vocabulary(tokenizer_path)

    def encode(self, text: str, add_special_tokens: bool = True) -> list[int]:
        del add_special_tokens
        return list(self._tokenizer.encode(text))

    def decode(self, token_ids: list[int]) -> str:
        return self._tokenizer.decode(token_ids)

    def apply_chat_template(
        self,
        message: list[dict[str, Any]],
        add_generation_prompt: bool = True,
        tokenize: bool = False,
    ) -> str:
        del add_generation_prompt, tokenize, message
        # This is not required for the servomatic backend.
        # The formatted prompt is ignored and regular prompt is used.
        logging.debug("apply_chat_template is not supported for Llama3Tokenizer.")
        return ""

    def all_special_ids(self) -> list[int]:
        raise NotImplementedError("Not implemented for Llama3Tokenizer.")

    def get_vocab(self) -> dict[str, int]:
        raise NotImplementedError("Not implemented for Llama3Tokenizer.")

    def bos_token(self) -> str:
        raise NotImplementedError("Not implemented for Llama3Tokenizer.")


class GeneralTokenizer(BaseTokenizer):
    """General tokenizer, based on transformers.AutoTokenizer, used for OSS runs."""

    def __init__(self, tokenizer_path: str, trust_remote_code: bool = False):
        logging.info("GeneralTokenizer: tokenizer_path: %s", tokenizer_path)
        self._tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_path, trust_remote_code=trust_remote_code
        )

    def encode(self, text: str, add_special_tokens: bool = True) -> list[int]:
        return list(
            self._tokenizer.encode(text, add_special_tokens=add_special_tokens)
        )

    def decode(self, token_ids: list[int]) -> str:
        return self._tokenizer.decode(token_ids)

    def apply_chat_template(
        self,
        message: list[dict[str, Any]],
        add_generation_prompt: bool = True,
        tokenize: bool = False,
    ) -> str:
        return self._tokenizer.apply_chat_template(
            message, add_generation_prompt=add_generation_prompt, tokenize=tokenize
        )

    def all_special_ids(self) -> list[int]:
        return self._tokenizer.all_special_ids

    def get_vocab(self) -> dict[str, int]:
        return self._tokenizer.get_vocab()

    def bos_token(self) -> str:
        return self._tokenizer.bos_token


def str2bool(v: str) -> Optional[bool]:
    if v is None:
        return None
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("Boolean value expected.")


def sample_sonnet_requests(
    dataset_path: str,
    num_requests: int,
    min_input_len: int,
    max_input_len: int,
    min_output_len: int,
    max_output_len: int,
    prefix_len: int,
    tokenizer: BaseTokenizer,
    fixed_input_length: Optional[int] = None,
    fixed_output_length: Optional[int] = None,
) -> list[tuple[str, str, int, int, int]]:
    """Samples requests from the Sonnet dataset.

    Args:
        dataset_path: Path to the Sonnet dataset.
        num_requests: Number of requests to sample.
        min_input_len: Minimum input length.
        max_input_len: Maximum input length.
        min_output_len: Minimum output length.
        max_output_len: Maximum output length.
        prefix_len: Number of prefix tokens per request.
        tokenizer: Tokenizer to use.
        fixed_input_length: If specified, forces input_len to be fixed_input_length.
        fixed_output_length: If specified, forces output_len to be
            fixed_output_length.

    Returns:
        A list of tuples containing the prompt, formatted prompt, prompt length,
        formatted prompt length, and output length.
    """

    # Load the dataset.
    with open(dataset_path) as f:
        poem_lines = f.readlines()
    poem_lines = poem_lines * 100

    # Tokenize the poem lines.
    poem_token_ids = [tokenizer.encode(poem_line) for poem_line in poem_lines]
    average_poem_len = sum(len(token_ids) for token_ids in poem_token_ids) / len(
        poem_token_ids
    )

    # Base prefix for all requests.
    if dataset_path.endswith("code-sonnet.txt"):
        base_prompt = (
            "Repeated pick as many questions from each line and write the answer to"
            " each question infinitly.\n"
        )
    else:
        base_prompt = "Pick as many lines as you can from these poem lines:\n"
    base_message = [{
        "role": "user",
        "content": base_prompt,
    }]
    base_prompt_formatted = tokenizer.apply_chat_template(
        base_message, add_generation_prompt=True, tokenize=False
    )
    base_prompt_offset = len(tokenizer.encode(base_prompt_formatted))

    logging.info("prefix_len: %s", prefix_len)
    logging.info("base_prompt_offset: %s", base_prompt_offset)
    logging.info("base_prompt_formatted: %s", base_prompt_formatted)
    logging.info(
        "base_prompt_formatted.input_ids: %s",
        tokenizer.encode(base_prompt_formatted),
    )

    # First approximately `prefix_len` number of tokens in the
    # prompt are fixed poem lines.
    assert (
        prefix_len > base_prompt_offset
    ), f"Set 'args.sonnet-prefix-len' higher than {base_prompt_offset}."

    num_prefix_lines = round((prefix_len - base_prompt_offset) / average_poem_len)
    prefix_lines = poem_lines[:num_prefix_lines]

    # Sample the rest of lines per request.
    sampled_requests: list[tuple[str, str, int, int, int]] = []
    for _ in range(num_requests):
        if fixed_input_length:
            input_len = fixed_input_length
        else:
            input_len = (
                random.randrange(min_input_len, max_input_len)
                if max_input_len > min_input_len
                else min_input_len
            )
        assert (
            input_len > prefix_len
        ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
        assert (
            input_len > base_prompt_offset
        ), f"Set 'args.sonnet-input-len' higher than {base_prompt_offset}."
        num_input_lines = round((input_len - base_prompt_offset) / average_poem_len)

        if fixed_output_length:
            output_len = fixed_output_length
        else:
            output_len = (
                random.randrange(min_output_len, max_output_len)
                if max_output_len > min_output_len
                else min_output_len
            )

        sampled_lines = "".join(
            prefix_lines
            + random.sample(poem_lines, num_input_lines - num_prefix_lines)
        )

        prompt = f"{base_prompt}{sampled_lines}"
        message = [
            {
                "role": "user",
                "content": prompt,
            },
        ]
        prompt_formatted = tokenizer.apply_chat_template(
            message, add_generation_prompt=True, tokenize=False
        )

        prompt_len = len(tokenizer.encode(prompt))
        prompt_formatted_len = len(tokenizer.encode(prompt_formatted))
        sampled_requests.append(
            (prompt, prompt_formatted, prompt_len, prompt_formatted_len, output_len)
        )

    return sampled_requests


async def get_request(
    input_requests: list[tuple[str, int, int]],
) -> AsyncGenerator[tuple[str, int, int], None]:
    """Gets request async."""
    input_requests = iter(input_requests)
    for request in input_requests:
        yield request


@dataclass
class RequestFuncInput:
    """Input to the request function.

    Attributes:
        backend: Backend to benchmark.
        api_url: The API URL to send the request to.
        prompt: The prompt to send to the model.
        prompt_len: The length of the prompt.
        output_len: Expected output length.
        enable_retry: Whether to enable retry on failure.
        model: Model name.
        extra_body: Extra body to send in the request.
        max_context_length: Maximum context length.
    """

    backend: str = ""
    api_url: str = ""
    prompt: str = ""
    prompt_len: int = 0
    output_len: int = 0
    enable_retry: bool = False
    model: str = ""
    extra_body: str | dict[str, Any] | None = None
    max_context_length: Optional[int] = None


@dataclass
class RequestFuncOutput:
    """Output of the request function.

    Attributes:
        backend: Backend to benchmark.
        model: Model name.
        generated_text: Generated text in case of non-servomatic.
        generated_token_ids: List of generated token ids in case of servomatic and
            evergreen.
        success: Whether the request was successful.
        start_time: Timestamp when the request was sent.
        latency: total request latency
        prompt_len: input prompt length
        error: Error message if any
        ttft: Time to first token
        itl: Inter-token latencies
        requested_output_len:
    """

    backend: str = ""
    model: str = ""
    generated_text: str = ""
    generated_token_ids: Optional[list[int]] = None
    success: bool = False
    start_time: float = 0.0
    latency: float = 0.0
    prompt_len: int = 0
    error: str = ""
    ttft: Optional[float] = None  # Time to first token
    itl: list[float] = field(
        default_factory=list
    )  # List of inter-token latencies
    requested_output_len: Optional[int] = None


def get_api_key() -> str:
    """Get the API key for the given request_input."""
    api_key = os.environ.get("OPENAI_API_KEY", os.environ.get("API_KEY", ""))
    return api_key


def create_retry_predicate(enable_retry: bool):
    """Create a retry gate."""

    def retry_if_status_is_429(retry_state: RetryCallState) -> bool:
        """Retry if the status is 429."""
        assert retry_state.outcome is not None

        if not enable_retry:
            return False
        exception = retry_state.outcome.exception()
        return (
            isinstance(exception, aiohttp.ClientResponseError)
            and exception.status == 429  # pytype: disable=attribute-error
        )

    return retry_if_status_is_429


async def make_chat_completions_request(
    session: aiohttp.ClientSession,
    headers: dict[str, str],
    request_input: RequestFuncInput,
    payload: dict[str, Any],
    stream: Optional[bool] = True,
    ttft: float = 0.0,
    most_recent_timestamp: float = 0.0,
    generated_text: str = "",
    output: RequestFuncOutput = RequestFuncOutput(),
) -> RequestFuncOutput:
    """Make a chat completions request."""
    st = time.perf_counter()  # Reset st for each retry.
    async with session.post(
        url=request_input.api_url, json=payload, headers=headers
    ) as response:
        if response.status == 200:
            output.success = True
            async for chunk_bytes in response.content:
                chunk_bytes = chunk_bytes.strip()
                if not chunk_bytes:
                    continue

                chunk = chunk_bytes.decode("utf-8").removeprefix("data:").strip()
                logging.debug("chunk: %s", chunk)
                if chunk != "[DONE]":
                    try:
                        data = json.loads(chunk)
                    except json.decoder.JSONDecodeError:
                        logging.error(f"Failed to parse response chunk: {chunk}")
                        output.success = False
                        continue
                    timestamp = time.perf_counter()
                    if "choices" not in data or not data["choices"]:
                        logging.info("empty chunk: %s", chunk)
                        continue
                    if stream:
                        if "delta" not in data["choices"][0]:
                            logging.info("empty delta in chunk: %s", chunk)
                            continue
                        delta = data["choices"][0]["delta"]
                        tag = "content"
                        if not delta.get("content", None):
                            tag = "reasoning_content"
                        if delta.get(tag, None):
                            # First token
                            if ttft == 0.0:
                                ttft = time.perf_counter() - st
                                output.ttft = ttft

                            # Decoding phase
                            else:
                                output.itl.append(timestamp - most_recent_timestamp)
                            generated_text += delta[tag]
                    else:
                        assert not generated_text
                        if "message" not in data["choices"][0]:
                            logging.info("empty message in chunk: %s", chunk)
                            continue
                        if "content" not in data["choices"][0]["message"]:
                            logging.info("empty message.content in chunk: %s", chunk)
                            continue
                        generated_text = data["choices"][0]["message"]["content"]

                    most_recent_timestamp = timestamp

            if not generated_text:
                logging.error("Received empty response")
                output.success = False
            output.generated_text = generated_text
            output.latency = time.perf_counter() - st
        else:
            if response.content_type == "application/json":
                try:
                    response_json = await response.json()
                    logging.error(
                        "Error from Server (JSON):\n"
                        f"{json.dumps(response_json, indent=2)}"
                    )
                except aiohttp.ContentTypeError:
                    logging.error("Response body expected JSON but failed to parse.")
                    logging.error(f"Raw response text: {await response.text()}")
            else:
                logging.error(
                    f"Response Content-Type is {response.content_type}. Reading"
                    " as text."
                )
                logging.error(f"Raw response text: {await response.text()}")
            response.raise_for_status()

    return output


async def send_chat_completions_request(
    request_input: RequestFuncInput,
    sem: asyncio.Semaphore,
    pbar: Optional[tqdm] = None,
    stream: Optional[bool] = True,
    ignore_eos: bool = True,
) -> RequestFuncOutput:
    """Sends a streaming request to OpenAI Chat Completions API."""
    assert request_input.api_url.endswith(
        "chat/completions"
    ), "OpenAI Chat Completions API URL must end with 'chat/completions'."

    if stream is None:
        stream = True  # defaults to True

    async with sem:
        async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
            content = request_input.prompt

            payload = {
                "model": request_input.model,
                "messages": [
                    {
                        "role": "user",
                        "content": content,
                    },
                ],
                "temperature": 0.0,
                "max_tokens": request_input.output_len,
                "stream": stream,
                "ignore_eos": ignore_eos,
            }

            output = RequestFuncOutput()
            output.backend = request_input.backend
            output.model = request_input.model
            output.prompt_len = request_input.prompt_len
            output.requested_output_len = request_input.output_len

            if request_input.extra_body:
                payload["extra_body"] = request_input.extra_body
            api_key = get_api_key()
            headers = {
                "Content-Type": "application/json",
            }

            if api_key:
                headers["Authorization"] = f"Bearer {api_key}"

            generated_text = ""
            ttft = 0.0
            st = time.perf_counter()
            most_recent_timestamp = st
            output.start_time = time.time()
            try:
                logging.debug("request: %s", json.dumps(payload, indent=2))
                retry_decorator = retry(
                    stop=stop_after_attempt(8),
                    wait=wait_exponential(
                        multiplier=1, min=2, max=1000
                    ),  # Wait 2s, then 4s, 8s, ...
                    retry=create_retry_predicate(request_input.enable_retry),
                )
                output = await retry_decorator(make_chat_completions_request)(
                    session,
                    headers,
                    request_input,
                    payload,
                    stream,
                    ttft,
                    most_recent_timestamp,
                    generated_text,
                    output,
                )
            except Exception:  # pylint: disable=broad-except
                output.success = False
                exc_info = sys.exc_info()
                output.error = "".join(traceback.format_exception(*exc_info))
                logging.warning(output.error)

            if pbar:
                pbar.update(1)
            return output


@dataclass
class BenchmarkMetrics:
    """Aggregated metrics for a benchmark run."""

    requested: int
    completed: int
    total_input: int
    total_output: int
    request_throughput: float
    input_throughput: float
    output_throughput: float
    mean_ttft_ms: Optional[float]
    median_ttft_ms: Optional[float]
    p99_ttft_ms: Optional[float]
    mean_tpot_ms: Optional[float]
    median_tpot_ms: Optional[float]
    p99_tpot_ms: Optional[float]
    mean_latency_ms: Optional[float]
    median_latency_ms: Optional[float]
    p99_latency_ms: Optional[float]
    accept_length: Optional[float]


def calculate_metrics(
    outputs: list[RequestFuncOutput],
    duration_sec: float,
    tokenizer: BaseTokenizer,
) -> tuple[BenchmarkMetrics, pd.DataFrame]:
    """Calculates the aggregated metrics for a benchmark run.

    Args:
        outputs: Benchmark outputs.
        duration_sec: Duration of the benchmark run.
        tokenizer: Tokenizer used for the benchmark.

    Returns:
        A BenchmarkMetrics.
        A dataframe with the detailed per-request benchmark results.
    """
    actual_output_lens = []
    total_input = 0
    completed = 0
    results = []
    tpots = []
    ttfts = []
    latencies = []
    accept_lens = []

    for i in range(len(outputs)):
        dt = dataclasses.asdict(outputs[i])
        if outputs[i].success:
            dt.pop("generated_text")
            dt.pop("generated_token_ids")

            if outputs[i].generated_token_ids:
                output_len = len(outputs[i].generated_token_ids)
            else:
                output_len = len(tokenizer.encode(outputs[i].generated_text))
            if output_len != outputs[i].requested_output_len:
                logging.debug(
                    "Output length mismatch: requested len: %d vs actual len:%d",
                    outputs[i].requested_output_len,
                    output_len,
                )
            if "itl" in dt and len(dt["itl"]) != output_len and dt["itl"]:
                accept_lens.append(output_len / len(dt["itl"]))

            if outputs[i].backend == "vllm" or outputs[i].backend == "vllm_stream":
                output_len -= outputs[i].prompt_len
            dt["output_len"] = output_len
            actual_output_lens.append(output_len)
            total_input += outputs[i].prompt_len
            completed += 1
            latencies.append(outputs[i].latency)
            if outputs[i].ttft:
                if output_len > 1:
                    tpots.append(
                        (outputs[i].latency - outputs[i].ttft) / (output_len - 1)
                    )
                ttfts.append(outputs[i].ttft)
        else:
            dt["output_len"] = 0
            actual_output_lens.append(0)
        results.append(dt)

    metrics = BenchmarkMetrics(
        # number of requested requests
        requested=len(outputs),
        # number of successful requests
        completed=completed,
        # sum of input prompts length
        total_input=total_input,
        # sum of output length
        total_output=sum(actual_output_lens),
        # throughput requests / sec
        request_throughput=completed / duration_sec,
        # input throughput input tokens / sec
        input_throughput=total_input / duration_sec,
        # output throughtput output tokens / sec
        output_throughput=sum(actual_output_lens) / duration_sec,
        mean_ttft_ms=np.mean(ttfts or 0) * 1000 if ttfts else None,
        median_ttft_ms=np.median(ttfts or 0) * 1000 if ttfts else None,
        p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000 if ttfts else None,
        mean_tpot_ms=np.mean(tpots) * 1000 if tpots else None,
        median_tpot_ms=np.median(tpots) * 1000 if tpots else None,
        p99_tpot_ms=np.percentile(tpots, 99) * 1000 if tpots else None,
        mean_latency_ms=np.mean(latencies or 0) * 1000 if latencies else None,
        median_latency_ms=np.median(latencies or 0) * 1000 if latencies else None,
        p99_latency_ms=np.percentile(latencies or 0, 99) * 1000
        if latencies
        else None,
        accept_length=np.mean(accept_lens) if accept_lens else None,
    )

    return metrics, pd.DataFrame.from_dict(results)  # pytype: disable=wrong-arg-types


async def benchmark(
    args: argparse.Namespace,
    api_urls: list[str],
    input_requests: list[tuple[str, int, int]],
    tokenizer: BaseTokenizer,
    prefix: str,
    max_input: int,
    max_output: int,
    concurrent_requests: Optional[int] = None,
):
    """Runs benchmark with asynchronous requests."""
    print(
        f"Running benchmark for {args.backend}, max input: {max_input}, max"
        f" output: {max_output}, concurrent requests: {concurrent_requests},"
        f" request rate: {args.request_rate}, fixed qps: {args.fixed_qps}"
    )

    tasks: list[asyncio.Task] = []
    pbar = tqdm(total=len(input_requests))

    benchmark_start_time = time.perf_counter()
    start_time = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
    sem = (
        asyncio.Semaphore(concurrent_requests)
        if concurrent_requests
        else asyncio.Semaphore(len(input_requests))
    )
    async for request in get_request(
        input_requests,
    ):
        prompt, prompt_len, output_len = request
        request_extra_body = None
        if args.request_extra_body is not None:
            try:
                request_extra_body = json.loads(args.request_extra_body)
            except json.decoder.JSONDecodeError:
                request_extra_body = args.request_extra_body
        api_url = random.choice(api_urls)
        logging.debug("api url: %s", api_url)
        request_input = RequestFuncInput(
            backend=args.backend,
            api_url=api_url,
            prompt=prompt,
            prompt_len=prompt_len,
            output_len=output_len,
            enable_retry=args.enable_retry,
            model=args.model,
            extra_body=request_extra_body,
            max_context_length=args.max_context_length,
        )
        if args.backend == "chat_completions":
            request_func = send_chat_completions_request
        else:
            raise ValueError(f"Unsupported backend: {args.backend}")
        task = asyncio.create_task(
            request_func(
                request_input,
                sem,
                pbar,
                args.stream,
                args.fixed_output_length,
            )
        )
        if args.fixed_qps is not None:
            # await here would force task to start when running in fixed_qps mode.
            await asyncio.sleep(1.0 / args.fixed_qps)
        tasks.append(task)
    outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
    duration_sec = time.perf_counter() - benchmark_start_time
    end_time = datetime.now().strftime("%d/%m/%Y %H:%M:%S")

    if pbar is not None:
        pbar.close()

    metrics, full_results = calculate_metrics(outputs, duration_sec, tokenizer)
    if concurrent_requests:
        full_results = full_results.assign(
            concurrent_requests=concurrent_requests,
        )
    elif args.fixed_qps:
        full_results = full_results.assign(
            fixed_qps=args.fixed_qps,
        )
    else:
        full_results = full_results.assign(
            request_rate=args.request_rate,
        )

    full_results = full_results.assign(
        max_input=max_input,
        max_output=max_output,
    )

    if args.save_full_results:
        f = open(
            os.path.join(
                args.output_dir if args.output_dir else os.getcwd(),
                f"{prefix}_full_results.json",
            ),
            mode="a",
        )
        f.write(full_results.to_json(orient="records", lines=True))
        f.close()

    print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
    print("{:<40} {:<10}".format("Total requests:", metrics.requested))
    print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
    print("{:<40} {:<10.2f}".format("Benchmark duration (s):", duration_sec))
    print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
    print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
    print(
        "{:<40} {:<10}".format(
            "Average input length:", metrics.total_input / metrics.completed
        )
    )
    print(
        "{:<40} {:<10}".format(
            "Average output length:", metrics.total_output / metrics.completed
        )
    )
    print(
        "{:<40} {:<10.3f}".format(
            "Request throughput (req/s):", metrics.request_throughput
        )
    )
    print(
        "{:<40} {:<10.2f}".format(
            "Input token throughput (tok/s):", metrics.input_throughput
        )
    )
    print(
        "{:<40} {:<10.2f}".format(
            "Output token throughput (tok/s):", metrics.output_throughput
        )
    )
    if metrics.mean_ttft_ms:
        print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-"))
        print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
        print(
            "{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms)
        )
        print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
        if metrics.mean_tpot_ms:
            print(
                "{s:{c}^{n}}".format(
                    s="Time per Output Token (excl. 1st token)", n=50, c="-"
                )
            )
            print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
            print(
                "{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms)
            )
            print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
    if metrics.mean_latency_ms:
        print("{s:{c}^{n}}".format(s="Latencies", n=50, c="-"))
        print(
            "{:<40} {:<10.2f}".format("Mean Latency (ms):", metrics.mean_latency_ms)
        )
        print(
            "{:<40} {:<10.2f}".format(
                "Median Latency (ms):", metrics.median_latency_ms
            )
        )
        print(
            "{:<40} {:<10.2f}".format("P99 Latency (ms):", metrics.p99_latency_ms)
        )
    if metrics.accept_length:
        print("{s:{c}^{n}}".format(s="Accept Length", n=50, c="-"))
        print(
            "{:<40} {:<10.2f}".format(
                "Mean Accept Length (tokens):", metrics.accept_length
            )
        )
    print("=" * 50)

    result = {
        "backend": args.backend,
        "start": start_time,
        "end": end_time,
        "duration": duration_sec,
        "completed": metrics.completed,
        "total_input_tokens": metrics.total_input,
        "total_output_tokens": metrics.total_output,
        "request_throughput": metrics.request_throughput,
        "input_throughput": metrics.input_throughput,
        "output_throughput": metrics.output_throughput,
        "mean_latency_ms": metrics.mean_latency_ms,
        "median_latency_ms": metrics.median_latency_ms,
        "p99_latency_ms": metrics.p99_latency_ms,
    }
    if metrics.mean_ttft_ms:
        result |= {
            "mean_ttft_ms": metrics.mean_ttft_ms,
            "median_ttft_ms": metrics.median_ttft_ms,
            "p99_ttft_ms": metrics.p99_ttft_ms,
            "mean_tpot_ms": metrics.mean_tpot_ms,
            "median_tpot_ms": metrics.median_tpot_ms,
            "p99_tpot_ms": metrics.p99_tpot_ms,
        }
    if metrics.accept_length:
        result |= {
            "accept_length": metrics.accept_length,
        }
    return result


def main(args: argparse.Namespace):
    random.seed(args.seed)
    np.random.seed(args.seed)

    log_levels = {
        "debug": logging.DEBUG,
        "info": logging.INFO,
        "warning": logging.WARNING,
        "error": logging.ERROR,
        "critical": logging.CRITICAL,
    }

    # Configure the logging
    logging.basicConfig(level=log_levels[args.verbosity])

    endpoint = args.endpoint
    if not args.endpoint:
        if args.backend == "chat_completions":
            endpoint = "v1/chat/completions"
        else:
            raise ValueError(f"Unsupported backend: {args.backend}")

    port_str = ":" + str(args.port) if args.port else ""
    protocol = "" if args.host.startswith("http") else "http://"
    base_api_url = f"{protocol}{args.host}{port_str}"
    api_url = f"{base_api_url}/{endpoint}"

    api_urls = []
    if args.endpoints:
        with open(args.endpoints, "r") as f:
            endpoints = f.readlines()
            for endpoint in endpoints:
                endpoint = endpoint.strip()
                api_url = f"{base_api_url}/{endpoint}"
                logging.debug("api url added to list: %s", api_url)
                api_urls.append(f"{api_url}")
    else:
        logging.debug("api url added to list: %s", api_url)
        api_urls.append(api_url)

    if args.tokenizer_type == "llama3":
        tokenizer = Llama3Tokenizer(args.tokenizer)
    else:
        tokenizer = GeneralTokenizer(args.tokenizer, args.trust_remote_code)

    prefix = args.name if args.name else args.backend
    fname = os.path.join(
        args.output_dir if args.output_dir else os.getcwd(),
        f"{prefix}_aggregated_results.json",
    )

    logging.info("preparing requests")
    for max_input in args.max_input_length:
        for max_output in args.max_output_length:
            if args.dataset.endswith("sonnet.txt"):
                min_input_len = int(max_input / 2)
                max_input_len = max_input + min_input_len
                min_output_len = int(max_output / 2)
                max_output_len = max_output + min_output_len
                input_requests = sample_sonnet_requests(
                    dataset_path=args.dataset,
                    num_requests=args.num_prompts,
                    min_input_len=min_input_len,
                    max_input_len=max_input_len,
                    min_output_len=min_output_len,
                    max_output_len=max_output_len,
                    prefix_len=args.sonnet_prefix_len,
                    tokenizer=tokenizer,
                    fixed_input_length=(max_input if args.fixed_input_length else None),
                    fixed_output_length=(
                        max_output if args.fixed_output_length else None
                    ),
                )
                if args.backend == "chat_completions":
                    input_requests = [
                        (prompt, prompt_len, output_len)
                        for prompt, _, prompt_len, _, output_len in input_requests
                    ]
                else:
                    raise ValueError("Unsupported backend: %s" % args.backend)
            else:
                raise ValueError(
                    f"Unsupported dataset: {args.dataset}. Expected sonnet.txt."
                )

            logging.info("staring benchmark")
            c_list = args.c
            if c_list is None:
                c_list = [None]
            for concurrent_requests in c_list:
                results = asyncio.run(
                    benchmark(
                        args,
                        api_urls,
                        input_requests,
                        tokenizer,
                        prefix,
                        max_input,
                        max_output,
                        concurrent_requests,
                    )
                )
                print(f"results: {results}")

                bm_configs = dict(vars(args).copy())
                bm_configs.pop("save_full_results")
                bm_configs.pop("c")
                bm_configs.pop("max_input_length")
                bm_configs.pop("max_output_length")
                bm_configs["max_input_len"] = max_input
                bm_configs["max_output_len"] = max_output
                if concurrent_requests is not None:
                    bm_configs["concurrent_requests"] = concurrent_requests
                    bm_configs.pop("request_rate")
                    bm_configs.pop("fixed_qps")
                results = results | bm_configs
                df = pd.DataFrame([results])
                f = open(fname, mode="a")
                f.write(df.to_json(orient="records", lines=True))
                f.close()
    print(f"Saved results to {fname}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Benchmark the online serving throughput."
    )
    parser.add_argument(
        "--backend",
        type=str,
        default="chat_completions",
        choices=["chat_completions"],
    )
    parser.add_argument(
        "--model",
        type=str,
        default="",
        help="Model name to send request to at API server.",
    )
    parser.add_argument("--endpoint", type=str, default=None)
    parser.add_argument("--host", type=str, default="localhost")
    parser.add_argument("--port", type=int, default=None)
    parser.add_argument("--dataset", type=str, help="Path to the dataset.")
    parser.add_argument(
        "--endpoints",
        type=str,
        default=None,
        help="Path to a file containing a list of endpoints.",
    )
    parser.add_argument(
        "--tokenizer",
        type=str,
        required=True,
        help="Name or path of the tokenizer.",
    )
    parser.add_argument(
        "--tokenizer-type",
        type=str,
        required=False,
        choices=[
            "general",
            "llama3",
        ],
        help=(
            "If provided, use the specified tokenizer type rather than relying on"
            " implicit logic."
        ),
    )
    parser.add_argument(
        "--stream",
        type=str2bool,
        default=None,
        help="Whether to uses streaming API.",
    )
    parser.add_argument(
        "--save-full-results",
        type=bool,
        default=False,
        help="Whether to save the full (per request) results.",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default=None,
        help=(
            "Directory to the output result file otherwise current directory is"
            " used."
        ),
    )
    parser.add_argument(
        "--num-prompts",
        type=int,
        default=1000,
        help="Number of prompts to process.",
    )

    def _list_of_ints(arg: str) -> list[int]:
        return list(map(int, arg.split(",")))

    parser.add_argument(
        "--max-input-length",
        type=_list_of_ints,
        default=[1024],
        help=(
            "Maximum number of input tokens for filtering the benchmark dataset."
            " This argument can be a list of integers separated by ','."
        ),
    )
    parser.add_argument(
        "--fixed-input-length",
        type=str2bool,
        default=False,
        help="If true, force the input length to be --max-input-length.",
    )
    parser.add_argument(
        "--max-output-length",
        type=_list_of_ints,
        default=[1024],
        help=(
            "Maximum number of input tokens for filtering the benchmark dataset."
            " This argument can be a list of integers separated by ','"
        ),
    )
    parser.add_argument(
        "--fixed-output-length",
        type=str2bool,
        default=False,
        help="If true, force the output length to be --max-output-length.",
    )
    parser.add_argument(
        "--max-context-length",
        type=int,
        default=32768,
        help=(
            "The maximum context length for the model. Some serving dockers"
            " support overriding this value, such as Ollama."
        ),
    )
    parser.add_argument(
        "--sonnet-prefix-len",
        type=int,
        default=30,
        help="Number of prefix tokens per request, used only for sonnet dataset.",
    )
    parser.add_argument(
        "--c",
        "--concurrent-requests",
        type=_list_of_ints,
        default=None,
        help=(
            "The number of concurrent requests to send., This argument can be a"
            " list of integers separated by ','"
        ),
    )
    parser.add_argument(
        "--enable-retry",
        action="store_true",
        default=False,
        help="Whether to enable retry on retriable errors.",
    )
    parser.add_argument(
        "--request-rate",
        type=float,
        default=float("inf"),
        help=(
            "If this is inf, all requests are sent at time 0. Otherwise, we take"
            " 1 divided by this argument value to be the parameter of the Poisson"
            " distribution for modeling the request arrival times. Ignored if"
            " --concurrent-requests is set."
        ),
    )
    parser.add_argument(
        "--fixed-qps",
        type=float,
        help=(
            "Number of requests per second sent with equal intervals. If this"
            " argument is set, we ignore request_rate and use a fixed QPS for"
            " sending the requests. Ignored if --concurrent-requests is set."
        ),
    )
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument(
        "--trust-remote-code",
        action="store_true",
        help="trust remote code from huggingface",
    )
    parser.add_argument(
        "--name",
        type=str,
        default="",
        help=(
            "The name of the benchmark. Will be used as the prefix of the saved"
            " results files."
        ),
    )
    # pylint: disable=line-too-long
    parser.add_argument(
        "--request-extra-body",
        type=str,
        default="",
        help=(
            "Extra body to send with request. To disable LLamaGuard, set it to:"
            ' \'{"google": { "model_safety_settings": {"enabled": False,'
            ' "llama_guard_settings": {}}}}\''
        ),
    )
    parser.add_argument(
        "-v",
        "--verbosity",
        help="Set the logging level (default: %(default)s)",
        default="warning",
        choices=["debug", "info", "warning", "error", "critical"],
    )

    cmd_args = parser.parse_args()
    main(cmd_args)


### Benchmark Dataset

The benchmark client needs a dataset for constructing prompts. We will download the [sonnet.txt](https://github.com/vllm-project/vllm/blob/main/benchmarks/sonnet.txt) dataset directly from the official vLLM project repository.

In [None]:
# Download the dataset from the vLLM GitHub repository
!wget https://raw.githubusercontent.com/vllm-project/vllm/main/benchmarks/sonnet.txt

print("sonnet.txt downloaded successfully.")

## 3. Run Benchmark

Now we're ready to run the benchmark.

⚠️ **Important**: This step requires a local Docker environment and access to NVIDIA GPUs. If you are running this notebook on a machine without GPUs (like a standard Colab instance), this command will fail. You should run this on a **GPU-enabled environment**, such as a Vertex AI Workbench instance or a GCE VM equipped with GPUs.

### Why Different Setups Need Different Recipes 🧠

Before we run the command, it's crucial to understand *why* the optimal serving configuration isn't a one-size-fits-all solution. Finding the best "recipe" is a complex balancing act between a model's memory requirements, your hardware, and your expected user traffic. Let's break down the strategies you'll be testing.

#### Key Serving Strategies Explained

##### **Pipeline Parallelism (PP): The Assembly Line**
**Pipeline Parallelism** splits a model's layers into sequential stages, placing each stage on a different GPU.

* **Analogy 🚗:** Think of a car manufacturing assembly line. GPU 1 (Station 1) installs the engine (processes layers 1-16), then passes the car to GPU 2 (Station 2) to add the chassis (processes layers 17-32), and so on. A request "moves" from one GPU to the next until it's complete.
* **Use Case:** Useful for serving a model that doesn't fit onto a single GPU and especially multi-host serving of large models. Its main drawback is potential GPU idle time, as later stages must wait for the first ones to finish (known as the "pipeline bubble"), and potential uneven allocation of layers across GPUs.

##### **Tensor Parallelism (TP): The Chef Team**
**Tensor Parallelism** splits individual layers—and the mathematical operations within them—across multiple GPUs.

* **Analogy 🍕:** Imagine several chefs (GPUs) working on a single, enormous pizza (a single model layer) at the same time. Each chef is responsible for a slice, and they must communicate constantly to ensure the toppings are distributed perfectly.
* **Use Case:** Useful for serving a model that doesn't fit onto a single GPU and can reduce inference latency by splitting large tensor computations into smaller components. It introduces communication overhead during per-layer computations.

##### **Model Replicas (vLLM Instances): Multiple Kitchens**
A **Model Replica** is a full, independent copy of the vLLM server instance. If a model can fit on a subset of your available GPUs, you can run multiple replicas to handle more users at once.

* **Analogy 👨‍🍳:** Instead of building one giant, complex kitchen (a single large model instance), you open several smaller, independent kitchens (replicas). Each kitchen can take a customer's order and fulfill it from start to finish without waiting for the others.
* **Use Case:** This is the primary strategy to maximize throughput.

✨ **Special Feature**: The Vertex AI Model Garden vLLM container used in this section allows you to run **multiple independent vLLM server instances (replicas) as one container**. This feature is specifically designed to maximize throughput for high-concurrency applications.

**Note**: The Model Garden vLLM container used in this tutorial builds upon vLLM's version: `vllm/vllm-openai@sha256:43892706699a4a390dab480e6a3b2f144203de11e0caebdbcb0c29ca1bce63c6`. It doesn't modify the kernel or engine implementation.

#### How Your Use Case Affects the Optimal Recipe

Now, let's understand how these strategies apply to different use cases:

1.  **Model Size**: Larger models require a minimum setting of **TP** and/or **PP** to fit the weights into memory. In addition, duplication of their model weights in model replicas creates larger overhead in memory usage. Smaller models can fit onto a single GPU, allowing more possible combinations of **PP**, **TP** and **Model Replica** settings, and duplication of their model weights creates less overhead.

2.  **Input and Output Length**: The input length and output length of requests affect the prefill and decode time it takes to process the requests, and the server's compute and memory usage patterns.

3.  **Concurrency**: The concurrency of requests, the number of requests running or queued at the server, also affects the server's compute and memory usage patterns. It is a key factor to the latency and throughput tradeoff of the server.

By running the benchmark utility, you are empirically testing these trade-offs to find the sweet spot for *your specific model, hardware, and expected traffic*.

First, let's define the parameters for our benchmark run.

In [None]:
# The model we want to benchmark
MODEL_PATH = "/path/to/Llama-3.3-70B-Instruct"  # @param {type:"string"}

# The MG vLLM serving container supporting model replicas.
VLLM_DOCKER_URI = "us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20250808_0916_RC01_maas"  # @param {type:"string"}
SERVER_INIT_TIMEOUT = 300  # @param {type:"integer"}

# The total number of GPUs available on the machine.
# This should match your hardware setup (e.g., 8 for an a3-highgpu-8g machine).
TOTAL_GPUS = 8  # @param {type:"integer"}

# Benchmark settings
INPUT_LENGTH = 1200  # @param {type:"integer"}
OUTPUT_LENGTH = 250  # @param {type:"integer"}
NUM_PROMPTS = 10  # @param {type:"integer"}
CONCURRENCIES = "1 8 64"  # @param {type:"string"}

# Latency requirements (optional)
# The utility will flag runs that don't meet the latency requirements in the
# benchmark results. If SKIP_CONCURRENCIES_GIVEN_LATENCY is set to True, the
# utility skips larger runs with larger concurrencies if the current run
# doesn't satisfy any of the specified latency requirements.
MAX_MEDIAN_TTFT_MS = 1000  # @param {type:"number"}
MAX_MEDIAN_TPOT_MS = 200  # @param {type:"number"}
SKIP_CONCURRENCIES_GIVEN_LATENCY = False  # @param {type:"boolean"}
SKIP_CONCURRENCIES_GIVEN_LATENCY_ARG = (
    "--skip-concurrencies-given-latency" if SKIP_CONCURRENCIES_GIVEN_LATENCY else ""
)

# Output file paths
RESULTS_CSV_PATH = "benchmark_results.csv"
FIGURE_PATH = "benchmark_figure.png"

Next, construct and execute the command. The utility will run benchmarks iterating through all valid combinations of pipeline parallelism (`PP`), tensor parallelism (`TP`), and model replicas that can be formed with the `TOTAL_GPUS`.

In [None]:
!python benchmark_util.py \
  --total-gpus $TOTAL_GPUS \
  --input-length $INPUT_LENGTH \
  --output-length $OUTPUT_LENGTH \
  --num-prompts $NUM_PROMPTS \
  --concurrencies $CONCURRENCIES \
  --max-median-ttft-ms $MAX_MEDIAN_TTFT_MS \
  --max-median-tpot-ms $MAX_MEDIAN_TPOT_MS \
  --model $MODEL_PATH \
  --docker-uri $VLLM_DOCKER_URI \
  --server-init-timeout $SERVER_INIT_TIMEOUT \
  --benchmark-script-path benchmark_serving.py \
  --dataset-path sonnet.txt \
  --results-output-path $RESULTS_CSV_PATH \
  --figure-output-path $FIGURE_PATH \
  $SKIP_CONCURRENCIES_GIVEN_LATENCY_ARG

## 4. Review Reference Benchmark Results [Case Study] 

To provide general recommendations and an example of what the benchmark results look like and how to interpret them, this section is a case study on reference benchmark results.

**Note**: The reference benchmark results and recommendations shared in this section are specific to Vertex AI's offering of 8 x H100 VMs and a certain vLLM configuration and benchmark methodology. They intend to only serve as general guidance.

### Part 1: General Recommendations for 8 x H100 Setups

While the benchmark utility is the best way to find the precise optimal recipe for your specific needs, we would like to provide general recommendations on efficiently serving a model on 8 x H100 setups.

| Model Size | Sample Model | Recommended Recipe |
|------------|--------------|--------------------|
| Small | [google/gemma-2-9b-it](https://huggingface.co/google/gemma-2-9b-it) | At smaller concurrecies, `TP=8` gives the best performance. At larger concurrencies, having multiple model replicas gives the best performance. As the concurrency increases, the optimal recipe shifts from `TP=8` to having gradually smaller `TP=4, TP=2, TP=1` and gradually more model replicas of 2, 4, and 8. |
| Medium | [Qwen/Qwen2.5-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-32B-Instruct) | Similar to for a small model, as concurrency increases, the optimal recipe goes from `TP=8` to having more model replicas. Meanwhile, with a larger model, the concurrency thresholds at which we should have more model replicas increase. In other words, the optimal recipe has comparatively fewer model replicas. With a larger model, there is more memory overhead with having copies of the model weights. |
| Large | [meta-llama/Llama-3.3-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct) | The same trend holds that as concurrency increases, the optimal recipe goes from `TP=8` to having more model replicas. With a large model that doesn't fit onto one GPU, there are fewer possible combinations of serving strategies. In addition, empirically we find that using tensor parallelism and model replicas offer better performance than pipeline parallelism. |

Note
- The model size is denoted relative to the 8 x H100 VM.
- The recommendations apply specifically to Vertex AI's 8 x H100 VMs. Different infrastructures and accelerators can require different recipes. For instance, for accelerators without efficient cross-GPU communication, pipeline parallelism (PP) can perform more favorably.

Even though the hardware and models are specific, the underlying principles and trade-offs are broadly applicable. Reference these examples to build an intuition for how to approach your unique setup.

### Part 2: Scalability of Model Replicas (Server Instances)

The Vertex AI Model Garden vLLM container offers the feature of co-hosting multiple model replicas within a single container, by running mutliple vLLM server instances. To demonstrate the scalability of this feature, we run a benchmark with H100 GPUs using [google/gemma-2-9b-it](https://huggingface.co/google/gemma-2-9b-it), a certain server configuration, and an approximate input length of 1200 and output length of 250.

| Number of Model Replicas | GPUs | Concurrency | Request Throughput (req/s) | Median Request Latency (ms) | Interpretation |
|--------------------------|------|-------------|----------------------------|-----------------------------|----------------|
| 1 | 1 x H100 | 2048 per GPU* | 11.245 | 180977 | A single model replica and vLLM server instance. Baseline. |
| 8 | 8 x H100 | 2048 per GPU* | 88.006 (**7.8x**) | 182453 (**1.0x**) | With 8 model replicas, implemented with 8 vLLM server instances, we obtain linear improvement in request throughput, with no regression in request latency. |

\* A concurrency of 2048 per GPU is heavy traffic that saturates the server. This represents a maximum-throughput scenario which approaches offline inference.

The results show that at saturating traffic, the model replicas implementation can achieve **linear throughput scaling with virtually no latency overhead**.

### Part 3: Concurrency Crossover: The Optimal Recipe Changes

The optimal serving recipe depends on the traffic. A critical insight from benchmarking is understanding the "crossover point" where the optimal serving recipe changes.

To illustrate this, let's compare two configurations on an 8 x H100 serving [meta-llama/Llama-3.3-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct), benchmarked with a certain server configuration and an approximate input length of 1200 and output length of 250:

| Setup | Concurrency | Request Throughput (req/s) | Median TTFT (ms) | Median TPOT (ms) |
|-------|-------------|----------------------------|----------------------|----------------|
| Setup A: `TP=8, Replicas=1` | 8 | 2.329 | 72.590 | 13.515 |
| Setup B: `TP=4, Replicas=2` | 8 | 1.760 | 100.157 | 17.866 |

The winner is setup A: `TP=8, Replicas=1`.

| Setup | Concurrency | Request Throughput (req/s) | Median TTFT (ms) | Median TPOT (ms) |
|-------|-------------|----------------------------|----------------------|----------------|
| Setup A: `TP=8, Replicas=1` | 256 | 13.534 | 206.216 | 73.741 |
| Setup B: `TP=4, Replicas=2` | 256 | 15.636 | 167.483 | 61.023 |

The winner is setup B: `TP=4, Replicas=2`.

**Key Takeaway**: There is no single best configuration--it depends on your traffic. Setup A, with a larger TP, is better for lower concurrency setups. As user traffic increases, the crossover happens. Setup B, with 2 server instances, is better for higher concurrency setups. This demonstrates the value of the benchmark utility and the importance of running benchmark experiments to **tailor the optimal serving recipe to your specific use case**.

## 5. Analyze Benchmark Results

Now, we will analyze the results from the benchmark run. This process involves three steps:

1. Visualize: Display the summary figure generated by the benchmark utility to get a high-level visual understanding of the performance trade-offs.
2. Filter: Load the raw data and filter out configurations that failed to meet any specified latency requirements.
3. Optimize: From the valid configurations, compile the optimal recipe for each concurrency in a table, and identify the optimal recipe for a target concurrency.

In [None]:
import pandas as pd
from IPython.display import Image, display

# --- Configuration ---
RESULTS_CSV_PATH = "benchmark_results.csv"  # @param {type:"string"}
FIGURE_PATH = "benchmark_figure.png"  # @param {type:"string"}
TARGET_CONCURRENCY = 8  # @param {type:"integer"}

# --- 1. Visualize Results ---
print("--- 1. Visualizing Performance Chart ---")
if os.path.exists(FIGURE_PATH):
    print(f"Displaying benchmark summary from: {FIGURE_PATH}")
    display(Image(filename=FIGURE_PATH))
else:
    print(
        f"Warning: Figure file not found at '{FIGURE_PATH}'. Ensure the benchmark ran successfully."
    )

# --- 2. Load and Analyze Data ---
print("\n\n--- 2. Analyzing Optimal Configuration per Concurrency ---")
try:
    results_df = pd.read_csv(RESULTS_CSV_PATH)

    # Filter for configurations that meet all latency requirements
    valid_configs = results_df.copy()
    latency_checks = ["median_ttft_ok", "p99_ttft_ok", "median_tpot_ok", "p99_tpot_ok"]
    for check in latency_checks:
        if check in valid_configs.columns:
            valid_configs = valid_configs[valid_configs[check]]

    if not valid_configs.empty:
        # Find the best configuration for each concurrency level
        # Group by concurrency and find the index of the max throughput in each group
        optimal_indices = valid_configs.groupby("concurrent_requests")[
            "request_throughput"
        ].idxmax()
        optimal_per_concurrency = valid_configs.loc[optimal_indices]

        print("🏆 Optimal Configuration per Concurrency (Meeting Latency Goals) 🏆\n")

        display_cols = [
            "concurrent_requests",
            "tp_size",
            "model_replicas",
            "request_throughput",
            "median_ttft_ms",
            "median_tpot_ms",
        ]
        print(optimal_per_concurrency[display_cols].to_markdown(index=False))

        # --- 3. Find Best Configuration for Target Concurrency ---
        print(
            f"\n\n--- 3. Selecting Best Configuration for Target Concurrency: {TARGET_CONCURRENCY} ---"
        )

        target_config_df = optimal_per_concurrency[
            optimal_per_concurrency["concurrent_requests"] == TARGET_CONCURRENCY
        ]

        if not target_config_df.empty:
            best_overall_config = target_config_df.iloc[0]
            print(
                f"✅ Found optimal configuration for target concurrency {TARGET_CONCURRENCY}."
            )
        else:
            print(
                f"⚠️ Warning: No valid configuration found for target concurrency {TARGET_CONCURRENCY}."
            )
            best_overall_config = optimal_per_concurrency.loc[
                optimal_per_concurrency["request_throughput"].idxmax()
            ]
            print(
                f"Falling back to the configuration with the highest overall throughput at concurrency {int(best_overall_config['concurrent_requests'])}."
            )

        OPTIMAL_PP_SIZE = int(best_overall_config["pp_size"])
        OPTIMAL_TP_SIZE = int(best_overall_config["tp_size"])
        OPTIMAL_REPLICAS = int(best_overall_config["model_replicas"])

        print("=" * 50)
        print("🔧 Configuration for Deployment 🔧")
        print("=" * 50)
        print(f"Pipeline Parallel Size (PP): {OPTIMAL_PP_SIZE}")
        print(f"Tensor Parallel Size (TP): {OPTIMAL_TP_SIZE}")
        print(f"Model Replicas: {OPTIMAL_REPLICAS}")
        print("=" * 50)
    else:
        print("\nCould not find any configuration that met all latency requirements.")
        OPTIMAL_PP_SIZE = 1
        OPTIMAL_TP_SIZE = 8
        OPTIMAL_REPLICAS = 1
        print("Defaulting to PP=1, TP=8, Replicas=1 for deployment.")
except FileNotFoundError:
    print(
        f"Results file '{RESULTS_CSV_PATH}' not found. Defaulting to PP=1, TP=8, Replicas=1 for deployment."
    )
    OPTIMAL_PP_SIZE = 1
    OPTIMAL_TP_SIZE = 8
    OPTIMAL_REPLICAS = 1

## 6. Deploy to Vertex AI and Test the Endpoint

With the optimal recipe identified, we now deploy the model to a Vertex AI Endpoint with 8 x H100 VM.

In [None]:
# Set model to deploy
base_model_name = "Llama-3.3-70B-Instruct"  # @param {type:"string"}
model_id = "meta-llama/Llama-3.3-70B-Instruct"  # @param {type:"string"}
HF_TOKEN = ""  # @param {type:"string"}
hf_model_id = model_id
publisher = "meta"
publisher_model_id = "llama3-3"

# Find Vertex AI prediction supported accelerators and regions at https://cloud.google.com/vertex-ai/docs/predictions/configure-compute.
accelerator_type = "NVIDIA_H100_80GB"
accelerator_count = 8
machine_type = "a3-highgpu-8g"
multihost_gpu_node_count = 1
gpu_memory_utilization = 0.9

In [None]:
# @title Deploy with customized configs

# @markdown Set use_dedicated_endpoint to False if you don't want to use [dedicated endpoint](https://cloud.google.com/vertex-ai/docs/general/deployment#create-dedicated-endpoint). Note that [dedicated endpoint does not support VPC Service Controls](https://cloud.google.com/vertex-ai/docs/predictions/choose-endpoint-type), uncheck the box if you are using VPC-SC.
use_dedicated_endpoint = True  # @param {type:"boolean"}

# @markdown Choose whether to use a [Spot VM](https://cloud.google.com/compute/docs/instances/spot) for the deployment.
is_spot = False  # @param {type:"boolean"}

common_util.check_quota(
    project_id=PROJECT_ID,
    region=REGION,
    accelerator_type=accelerator_type,
    accelerator_count=int(accelerator_count * multihost_gpu_node_count),
    is_for_training=False,
    is_spot=is_spot,
)

# @markdown To enable the auto-scaling in deployment, you can set the following options:

min_replica_count = 1  # @param {type:"integer"}
max_replica_count = 1  # @param {type:"integer"}
required_replica_count = 1  # @param {type:"integer"}

# @markdown Set the target of GPU duty cycle or CPU usage between 1 and 100 for auto-scaling.
autoscale_by_gpu_duty_cycle_target = 0  # @param {type:"integer"}
autoscale_by_cpu_usage_target = 0  # @param {type:"integer"}

# @markdown Note: GPU duty cycle is not the most accurate metric for scaling workloads. More advanced auto-scaling metrics are coming soon. See [the public doc](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/DedicatedResources#AutoscalingMetricSpec) for more details.


def deploy_model_vllm_single_model_cohost(
    model_name: str,
    model_id: str,
    publisher: str,
    publisher_model_id: str,
    service_account: str = None,
    base_model_id: str = None,
    machine_type: str = "a3-highgpu-8g",
    accelerator_type: str = "NVIDIA_H100_80GB",
    accelerator_count: int = 8,
    gpu_partition_size: str = "",
    multihost_gpu_node_count: int = 1,
    pipeline_parallel_size: int = 1,
    tensor_parallel_size: int = 8,
    model_replicas: int = 1,
    gpu_memory_utilization: float = 0.95,
    enable_trust_remote_code: bool = False,
    use_dedicated_endpoint: bool = False,
    min_replica_count: int = 1,
    max_replica_count: int = 1,
    required_replica_count: int = 1,
    autoscale_by_gpu_duty_cycle_target: int = 0,
    autoscale_by_cpu_usage_target: int = 0,
    is_spot: bool = True,
    model_cohost_feature: str = "single-model-cohost",
) -> Tuple[aiplatform.Model, aiplatform.Endpoint]:
    """Deploys models with vLLM to Vertex AI."""
    endpoint = aiplatform.Endpoint.create(
        display_name=f"{model_name}-endpoint",
        dedicated_endpoint_enabled=use_dedicated_endpoint,
    )

    if not base_model_id:
        base_model_id = model_id

    if model_replicas > 1:
        api_server = "vllm.entrypoints.nginx_server"
    else:
        api_server = "vllm.entrypoints.api_server"

    # See https://docs.vllm.ai/en/latest/models/engine_args.html for a list of possible arguments with descriptions.
    vllm_args = [
        "python",
        "-m",
        api_server,
        "--host=0.0.0.0",
        "--port=8080",
        f"--model={model_id}",
        f"--pipeline-parallel-size={pipeline_parallel_size}",
        f"--tensor-parallel-size={tensor_parallel_size}",
        "--data-parallel-size=1",
        "--swap-space=16",
    ]

    if multihost_gpu_node_count > 1:
        vllm_args = ["/vllm-workspace/ray_launcher.sh"] + vllm_args

    if gpu_memory_utilization:
        vllm_args.append(f"--gpu-memory-utilization={gpu_memory_utilization}")

    if enable_trust_remote_code:
        vllm_args.append("--trust-remote-code")

    if model_replicas > 1:
        vllm_args.extend(
            [
                f"--num_instances={model_replicas}",
                f"--total_gpus={accelerator_count}",
            ]
        )

    env_vars = {
        "MODEL_ID": base_model_id,
        "DEPLOY_SOURCE": "notebook",
    }

    # HF_TOKEN is not a compulsory field and may not be defined.
    try:
        if HF_TOKEN:
            env_vars["HF_TOKEN"] = HF_TOKEN
    except NameError:
        pass

    model = aiplatform.Model.upload(
        display_name=model_name,
        serving_container_image_uri=VLLM_DOCKER_URI,
        serving_container_args=vllm_args,
        serving_container_ports=[8080],
        serving_container_predict_route="/generate",
        serving_container_health_route="/ping",
        serving_container_environment_variables=env_vars,
        serving_container_shared_memory_size_mb=(16 * 1024),  # 16 GB
        serving_container_deployment_timeout=7200,
        model_garden_source_model_name=(
            f"publishers/{publisher}/models/{publisher_model_id}"
        ),
    )
    print(
        f"Deploying {model_name} on {multihost_gpu_node_count} host(s) of {machine_type} with {accelerator_type} GPU(s)."
    )

    creds, _ = auth.default()
    auth_req = auth.transport.requests.Request()
    creds.refresh(auth_req)

    url = f"https://{REGION}-aiplatform.googleapis.com/ui/projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint.name}:deployModel"
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {creds.token}",
    }
    data = {
        "deployedModel": {
            "model": model.resource_name,
            "displayName": model_name,
            "dedicatedResources": {
                "machineSpec": {
                    "machineType": machine_type,
                    "multihostGpuNodeCount": multihost_gpu_node_count,
                    "acceleratorType": accelerator_type,
                    "acceleratorCount": accelerator_count,
                    "gpuPartitionSize": gpu_partition_size,
                },
                "minReplicaCount": min_replica_count,
                "requiredReplicaCount": required_replica_count,
                "maxReplicaCount": max_replica_count,
            },
            "system_labels": {
                "NOTEBOOK_NAME": "model_garden_model_cohost.ipynb",
                "NOTEBOOK_ENVIRONMENT": common_util.get_deploy_source(),
                "mg-serving-feature-model-cohost": model_cohost_feature,
            },
        },
    }
    if service_account:
        data["deployedModel"]["serviceAccount"] = service_account
    if is_spot:
        data["deployedModel"]["dedicatedResources"]["spot"] = True
    if autoscale_by_gpu_duty_cycle_target > 0 or autoscale_by_cpu_usage_target > 0:
        data["deployedModel"]["dedicatedResources"]["autoscalingMetricSpecs"] = []
        if autoscale_by_gpu_duty_cycle_target > 0:
            data["deployedModel"]["dedicatedResources"][
                "autoscalingMetricSpecs"
            ].append(
                {
                    "metricName": "aiplatform.googleapis.com/prediction/online/accelerator/duty_cycle",
                    "target": autoscale_by_gpu_duty_cycle_target,
                }
            )
        if autoscale_by_cpu_usage_target > 0:
            data["deployedModel"]["dedicatedResources"][
                "autoscalingMetricSpecs"
            ].append(
                {
                    "metricName": "aiplatform.googleapis.com/prediction/online/cpu/utilization",
                    "target": autoscale_by_cpu_usage_target,
                }
            )
    response = requests.post(url, headers=headers, json=data)
    print(f"Deploy Model response: {response.json()}")
    if response.status_code != 200 or "name" not in response.json():
        raise ValueError(f"Failed to deploy model: {response.text}")
    common_util.poll_and_wait(response.json()["name"], REGION, 7200)
    print("endpoint_name:", endpoint.name)

    return model, endpoint


(
    models["vllm_gpu_single_model"],
    endpoints["vllm_gpu_single_model"],
) = deploy_model_vllm_single_model_cohost(
    model_name=common_util.get_job_name_with_datetime(prefix="single-model-cohost"),
    model_id=model_id,
    publisher=publisher,
    publisher_model_id=publisher_model_id,
    base_model_id=hf_model_id,
    machine_type=machine_type,
    accelerator_type=accelerator_type,
    accelerator_count=accelerator_count,
    multihost_gpu_node_count=multihost_gpu_node_count,
    pipeline_parallel_size=OPTIMAL_PP_SIZE,
    tensor_parallel_size=OPTIMAL_TP_SIZE,
    model_replicas=OPTIMAL_REPLICAS,
    gpu_memory_utilization=gpu_memory_utilization,
    use_dedicated_endpoint=use_dedicated_endpoint,
    min_replica_count=min_replica_count,
    max_replica_count=max_replica_count,
    required_replica_count=required_replica_count,
    autoscale_by_gpu_duty_cycle_target=autoscale_by_gpu_duty_cycle_target,
    autoscale_by_cpu_usage_target=autoscale_by_cpu_usage_target,
    is_spot=is_spot,
    model_cohost_feature="single-model-cohost",
)
# @markdown Click "Show Code" to see more details.

In [None]:
# @title Raw predict

# @markdown Once deployment succeeds, you can send requests to the endpoint with text prompts. Sampling parameters supported by vLLM can be found [here](https://docs.vllm.ai/en/latest/dev/sampling_params.html).

# @markdown Example:

# @markdown ```
# @markdown Human: What is a car?
# @markdown Assistant:  A car, or a motor car, is a road-connected human-transportation system used to move people or goods from one place to another. The term also encompasses a wide range of vehicles, including motorboats, trains, and aircrafts. Cars typically have four wheels, a cabin for passengers, and an engine or motor. They have been around since the early 19th century and are now one of the most popular forms of transportation, used for daily commuting, shopping, and other purposes.
# @markdown ```
# @markdown Additionally, you can moderate the generated text with Vertex AI. See [Moderate text documentation](https://cloud.google.com/natural-language/docs/moderating-text) for more details.

# Loads an existing endpoint instance using the endpoint name:
# - Using `endpoint_name = endpoint.name` allows us to get the
#   endpoint name of the endpoint `endpoint` created in the cell
#   above.
# - Alternatively, you can set `endpoint_name = "1234567890123456789"` to load
#   an existing endpoint with the ID 1234567890123456789.
# You may uncomment the code below to load an existing endpoint.

# endpoint_name = ""  # @param {type:"string"}
# aip_endpoint_name = (
#     f"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint_name}"
# )
# endpoint = aiplatform.Endpoint(aip_endpoint_name)

prompt = "What is a car?"  # @param {type: "string"}
# @markdown If you encounter an issue like `ServiceUnavailable: 503 Took too long to respond when processing`, you can reduce the maximum number of output tokens, by lowering `max_tokens`.
max_tokens = 50  # @param {type:"integer"}
temperature = 1.0  # @param {type:"number"}
top_p = 1.0  # @param {type:"number"}
top_k = 1  # @param {type:"integer"}
# @markdown Set `raw_response` to `True` to obtain the raw model output. Set `raw_response` to `False` to apply additional formatting in the structure of `"Prompt:\n{prompt.strip()}\nOutput:\n{output}"`.
raw_response = False  # @param {type:"boolean"}

# Overrides parameters for inferences.
instances = [
    {
        "prompt": prompt,
        "max_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
        "raw_response": raw_response,
    },
]
response = endpoints["vllm_gpu_single_model"].predict(
    instances=instances, use_dedicated_endpoint=use_dedicated_endpoint
)

for prediction in response.predictions:
    print(prediction)

# @markdown Click "Show Code" to see more details.

In [None]:
# @title Chat completion

if use_dedicated_endpoint:
    DEDICATED_ENDPOINT_DNS = endpoints[
        "vllm_gpu_single_model"
    ].gca_resource.dedicated_endpoint_dns
ENDPOINT_RESOURCE_NAME = endpoints["vllm_gpu_single_model"].resource_name

# @title Chat Completions Inference

# @markdown Once deployment succeeds, you can send requests to the endpoint using the OpenAI SDK.

# @markdown First you will need to install the SDK and some auth-related dependencies.

! pip install -qU openai google-auth requests

# @markdown Next fill out some request parameters:

user_message = "How is your day going?"  # @param {type: "string"}
# @markdown If you encounter the issue like `ServiceUnavailable: 503 Took too long to respond when processing`, you can reduce the maximum number of output tokens, such as set `max_tokens` as 20.
max_tokens = 50  # @param {type: "integer"}
temperature = 1.0  # @param {type: "number"}
stream = False  # @param {type: "boolean"}

# @markdown Now we can send a request.

import google.auth
import openai

creds, project = google.auth.default()
auth_req = google.auth.transport.requests.Request()
creds.refresh(auth_req)

BASE_URL = (
    f"https://{REGION}-aiplatform.googleapis.com/v1beta1/{ENDPOINT_RESOURCE_NAME}"
)
try:
    if use_dedicated_endpoint:
        BASE_URL = f"https://{DEDICATED_ENDPOINT_DNS}/v1beta1/{ENDPOINT_RESOURCE_NAME}"
except NameError:
    pass

client = openai.OpenAI(base_url=BASE_URL, api_key=creds.token)

model_response = client.chat.completions.create(
    model="",
    messages=[{"role": "user", "content": user_message}],
    temperature=temperature,
    max_tokens=max_tokens,
    stream=stream,
)

if stream:
    usage = None
    contents = []
    for chunk in model_response:
        if chunk.usage is not None:
            usage = chunk.usage
            continue
        print(chunk.choices[0].delta.content, end="")
        contents.append(chunk.choices[0].delta.content)
    print(f"\n\n{usage}")
else:
    print(model_response)

# @markdown Click "Show Code" to see more details.

## 7. [Alternative Solution: Pod Co-scheduling + MIG] Review Reference Benchmark Results

An alternative to the container-level solution is the infrastructure-level solution with pod co-scheduling + NVIDIA Multi-Instance GPU (MIG), which is available in [Preview](https://cloud.google.com/products?e=48754805&hl=en#product-launch-stages).

- Co-scheduling with whole GPUs: We can assign one or more full hardware accelerators to each model replica. For example, we can deploy up to eight replicas on an 8 x H100 VM.

- Partitioning with NVIDIA Multi-Instance GPU (MIG): For even greater efficiency with smaller workloads, we can partition a single physical GPU into multiple, smaller, fully isolated instances using NVIDIA MIG. This allows us to assign resources at a sub-GPU level, maximizing the utilization of each accelerator.

### Comparable Performance

We run a benchmark with 8 x H100 Vertex Endpoints using [google/gemma-2-9b-it](https://huggingface.co/google/gemma-2-9b-it), a certain server configuration, and an approximate input length of 1200 and output length of 250.

| Solution | Number of Model Replicas | Concurrency | Request Throughput (req/s) | Median Request Latency (ms) |
|----------|--------------------------|-------------|----------------------------|-----------------------------|
| Model Co-hosting Container | 8 | 128 | 34.367 | 3512 |
| Pod Co-scheduling | 8 | 128 | 32.404 | 3656 |

The results show that the two solutions offer **comparable performance with serving multiple replicas of a model**.

## 8. [Alternative Solution: Pod Co-scheduling + MIG] Deploy to Vertex AI and Test the Endpoint

With the optimal recipe previously, we now deploy the model to a Vertex AI Endpoint with pod co-scheduling. Note that we set the accelerator count per replica at `accelerator_count`. We serve one replica per container (pod).

When MIG is enabled, we can't use GPU sharing (each replica is limited to consuming MIG in a single GPU). Consequently, when a `gpu_partition_size` is specified, the `accelerator_count` must be set to 1. An exampe MIG setting of `gpu_partition_size` is `"1g.10gb"`.

In [None]:
# Set model to deploy
base_model_name = "Llama-3.3-70B-Instruct"  # @param {type:"string"}
model_id = "meta-llama/Llama-3.3-70B-Instruct"  # @param {type:"string"}
HF_TOKEN = ""  # @param {type:"string"}
hf_model_id = model_id
publisher = "meta"
publisher_model_id = "llama3-3"

# Find Vertex AI prediction supported accelerators and regions at https://cloud.google.com/vertex-ai/docs/predictions/configure-compute.
accelerator_type = "NVIDIA_H100_80GB"
accelerator_count = 8
accelerator_count_per_replica = 4
gpu_partition_size = ""
machine_type = "a3-highgpu-8g"
multihost_gpu_node_count = 1
pipeline_parallel_size = 1
tensor_parallel_size = accelerator_count_per_replica
model_replicas = 1
gpu_memory_utilization = 0.9

In [None]:
# @title Deploy with customized configs

# @markdown Set use_dedicated_endpoint to False if you don't want to use [dedicated endpoint](https://cloud.google.com/vertex-ai/docs/general/deployment#create-dedicated-endpoint). Note that [dedicated endpoint does not support VPC Service Controls](https://cloud.google.com/vertex-ai/docs/predictions/choose-endpoint-type), uncheck the box if you are using VPC-SC.
use_dedicated_endpoint = True  # @param {type:"boolean"}

# @markdown Choose whether to use a [Spot VM](https://cloud.google.com/compute/docs/instances/spot) for the deployment.
is_spot = False  # @param {type:"boolean"}

common_util.check_quota(
    project_id=PROJECT_ID,
    region=REGION,
    accelerator_type=accelerator_type,
    accelerator_count=int(accelerator_count * multihost_gpu_node_count),
    is_for_training=False,
    is_spot=is_spot,
)

# @markdown To enable the auto-scaling in deployment, you can set the following options:

min_replica_count = 1  # @param {type:"integer"}
max_replica_count = 1  # @param {type:"integer"}
required_replica_count = 1  # @param {type:"integer"}

# @markdown Set the target of GPU duty cycle or CPU usage between 1 and 100 for auto-scaling.
autoscale_by_gpu_duty_cycle_target = 0  # @param {type:"integer"}
autoscale_by_cpu_usage_target = 0  # @param {type:"integer"}

# @markdown Note: GPU duty cycle is not the most accurate metric for scaling workloads. More advanced auto-scaling metrics are coming soon. See [the public doc](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/DedicatedResources#AutoscalingMetricSpec) for more details.


def deploy_model_vllm_single_model_cohost(
    model_name: str,
    model_id: str,
    publisher: str,
    publisher_model_id: str,
    service_account: str = None,
    base_model_id: str = None,
    machine_type: str = "a3-highgpu-8g",
    accelerator_type: str = "NVIDIA_H100_80GB",
    accelerator_count: int = 8,
    gpu_partition_size: str = "",
    multihost_gpu_node_count: int = 1,
    pipeline_parallel_size: int = 1,
    tensor_parallel_size: int = 8,
    model_replicas: int = 1,
    gpu_memory_utilization: float = 0.95,
    enable_trust_remote_code: bool = False,
    use_dedicated_endpoint: bool = False,
    min_replica_count: int = 1,
    max_replica_count: int = 1,
    required_replica_count: int = 1,
    autoscale_by_gpu_duty_cycle_target: int = 0,
    autoscale_by_cpu_usage_target: int = 0,
    is_spot: bool = True,
    model_cohost_feature: str = "single-model-cohost",
) -> Tuple[aiplatform.Model, aiplatform.Endpoint]:
    """Deploys models with vLLM to Vertex AI."""
    endpoint = aiplatform.Endpoint.create(
        display_name=f"{model_name}-endpoint",
        dedicated_endpoint_enabled=use_dedicated_endpoint,
    )

    if not base_model_id:
        base_model_id = model_id

    if model_replicas > 1:
        api_server = "vllm.entrypoints.nginx_server"
    else:
        api_server = "vllm.entrypoints.api_server"

    # See https://docs.vllm.ai/en/latest/models/engine_args.html for a list of possible arguments with descriptions.
    vllm_args = [
        "python",
        "-m",
        api_server,
        "--host=0.0.0.0",
        "--port=8080",
        f"--model={model_id}",
        f"--pipeline-parallel-size={pipeline_parallel_size}",
        f"--tensor-parallel-size={tensor_parallel_size}",
        "--data-parallel-size=1",
        "--swap-space=16",
    ]

    if multihost_gpu_node_count > 1:
        vllm_args = ["/vllm-workspace/ray_launcher.sh"] + vllm_args

    if gpu_memory_utilization:
        vllm_args.append(f"--gpu-memory-utilization={gpu_memory_utilization}")

    if enable_trust_remote_code:
        vllm_args.append("--trust-remote-code")

    if model_replicas > 1:
        vllm_args.extend(
            [
                f"--num_instances={model_replicas}",
                f"--total_gpus={accelerator_count}",
            ]
        )

    env_vars = {
        "MODEL_ID": base_model_id,
        "DEPLOY_SOURCE": "notebook",
    }

    # HF_TOKEN is not a compulsory field and may not be defined.
    try:
        if HF_TOKEN:
            env_vars["HF_TOKEN"] = HF_TOKEN
    except NameError:
        pass

    model = aiplatform.Model.upload(
        display_name=model_name,
        serving_container_image_uri=VLLM_DOCKER_URI,
        serving_container_args=vllm_args,
        serving_container_ports=[8080],
        serving_container_predict_route="/generate",
        serving_container_health_route="/ping",
        serving_container_environment_variables=env_vars,
        serving_container_shared_memory_size_mb=(16 * 1024),  # 16 GB
        serving_container_deployment_timeout=7200,
        model_garden_source_model_name=(
            f"publishers/{publisher}/models/{publisher_model_id}"
        ),
    )
    print(
        f"Deploying {model_name} on {multihost_gpu_node_count} host(s) of {machine_type} with {accelerator_type} GPU(s)."
    )

    creds, _ = auth.default()
    auth_req = auth.transport.requests.Request()
    creds.refresh(auth_req)

    url = f"https://{REGION}-aiplatform.googleapis.com/ui/projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint.name}:deployModel"
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {creds.token}",
    }
    data = {
        "deployedModel": {
            "model": model.resource_name,
            "displayName": model_name,
            "dedicatedResources": {
                "machineSpec": {
                    "machineType": machine_type,
                    "multihostGpuNodeCount": multihost_gpu_node_count,
                    "acceleratorType": accelerator_type,
                    "acceleratorCount": accelerator_count,
                    "gpuPartitionSize": gpu_partition_size,
                },
                "minReplicaCount": min_replica_count,
                "requiredReplicaCount": required_replica_count,
                "maxReplicaCount": max_replica_count,
            },
            "system_labels": {
                "NOTEBOOK_NAME": "model_garden_model_cohost.ipynb",
                "NOTEBOOK_ENVIRONMENT": common_util.get_deploy_source(),
                "mg-serving-feature-model-cohost": model_cohost_feature,
            },
        },
    }
    if service_account:
        data["deployedModel"]["serviceAccount"] = service_account
    if is_spot:
        data["deployedModel"]["dedicatedResources"]["spot"] = True
    if autoscale_by_gpu_duty_cycle_target > 0 or autoscale_by_cpu_usage_target > 0:
        data["deployedModel"]["dedicatedResources"]["autoscalingMetricSpecs"] = []
        if autoscale_by_gpu_duty_cycle_target > 0:
            data["deployedModel"]["dedicatedResources"][
                "autoscalingMetricSpecs"
            ].append(
                {
                    "metricName": "aiplatform.googleapis.com/prediction/online/accelerator/duty_cycle",
                    "target": autoscale_by_gpu_duty_cycle_target,
                }
            )
        if autoscale_by_cpu_usage_target > 0:
            data["deployedModel"]["dedicatedResources"][
                "autoscalingMetricSpecs"
            ].append(
                {
                    "metricName": "aiplatform.googleapis.com/prediction/online/cpu/utilization",
                    "target": autoscale_by_cpu_usage_target,
                }
            )
    response = requests.post(url, headers=headers, json=data)
    print(f"Deploy Model response: {response.json()}")
    if response.status_code != 200 or "name" not in response.json():
        raise ValueError(f"Failed to deploy model: {response.text}")
    common_util.poll_and_wait(response.json()["name"], REGION, 7200)
    print("endpoint_name:", endpoint.name)

    return model, endpoint


(
    models["vllm_gpu_pod_coschedule_mig"],
    endpoints["vllm_gpu_pod_coschedule_mig"],
) = deploy_model_vllm_single_model_cohost(
    model_name=common_util.get_job_name_with_datetime(
        prefix="single-model-cohost-pod-coschedule-mig"
    ),
    model_id=model_id,
    publisher=publisher,
    publisher_model_id=publisher_model_id,
    base_model_id=hf_model_id,
    machine_type=machine_type,
    accelerator_type=accelerator_type,
    accelerator_count=accelerator_count_per_replica,
    gpu_partition_size=gpu_partition_size,
    multihost_gpu_node_count=multihost_gpu_node_count,
    pipeline_parallel_size=pipeline_parallel_size,
    tensor_parallel_size=tensor_parallel_size,
    model_replicas=model_replicas,
    gpu_memory_utilization=gpu_memory_utilization,
    use_dedicated_endpoint=use_dedicated_endpoint,
    min_replica_count=min_replica_count,
    max_replica_count=max_replica_count,
    required_replica_count=required_replica_count,
    autoscale_by_gpu_duty_cycle_target=autoscale_by_gpu_duty_cycle_target,
    autoscale_by_cpu_usage_target=autoscale_by_cpu_usage_target,
    is_spot=is_spot,
    model_cohost_feature="single-model-cohost-pod-coschedule-mig",
)
# @markdown Click "Show Code" to see more details.

In [None]:
# @title Raw predict

# @markdown Once deployment succeeds, you can send requests to the endpoint with text prompts. Sampling parameters supported by vLLM can be found [here](https://docs.vllm.ai/en/latest/dev/sampling_params.html).

# @markdown Example:

# @markdown ```
# @markdown Human: What is a car?
# @markdown Assistant:  A car, or a motor car, is a road-connected human-transportation system used to move people or goods from one place to another. The term also encompasses a wide range of vehicles, including motorboats, trains, and aircrafts. Cars typically have four wheels, a cabin for passengers, and an engine or motor. They have been around since the early 19th century and are now one of the most popular forms of transportation, used for daily commuting, shopping, and other purposes.
# @markdown ```
# @markdown Additionally, you can moderate the generated text with Vertex AI. See [Moderate text documentation](https://cloud.google.com/natural-language/docs/moderating-text) for more details.

# Loads an existing endpoint instance using the endpoint name:
# - Using `endpoint_name = endpoint.name` allows us to get the
#   endpoint name of the endpoint `endpoint` created in the cell
#   above.
# - Alternatively, you can set `endpoint_name = "1234567890123456789"` to load
#   an existing endpoint with the ID 1234567890123456789.
# You may uncomment the code below to load an existing endpoint.

# endpoint_name = ""  # @param {type:"string"}
# aip_endpoint_name = (
#     f"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint_name}"
# )
# endpoint = aiplatform.Endpoint(aip_endpoint_name)

prompt = "What is a car?"  # @param {type: "string"}
# @markdown If you encounter an issue like `ServiceUnavailable: 503 Took too long to respond when processing`, you can reduce the maximum number of output tokens, by lowering `max_tokens`.
max_tokens = 50  # @param {type:"integer"}
temperature = 1.0  # @param {type:"number"}
top_p = 1.0  # @param {type:"number"}
top_k = 1  # @param {type:"integer"}
# @markdown Set `raw_response` to `True` to obtain the raw model output. Set `raw_response` to `False` to apply additional formatting in the structure of `"Prompt:\n{prompt.strip()}\nOutput:\n{output}"`.
raw_response = False  # @param {type:"boolean"}

# Overrides parameters for inferences.
instances = [
    {
        "prompt": prompt,
        "max_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
        "raw_response": raw_response,
    },
]
response = endpoints["vllm_gpu_pod_coschedule_mig"].predict(
    instances=instances, use_dedicated_endpoint=use_dedicated_endpoint
)

for prediction in response.predictions:
    print(prediction)

# @markdown Click "Show Code" to see more details.

In [None]:
# @title Chat completion

if use_dedicated_endpoint:
    DEDICATED_ENDPOINT_DNS = endpoints[
        "vllm_gpu_pod_coschedule_mig"
    ].gca_resource.dedicated_endpoint_dns
ENDPOINT_RESOURCE_NAME = endpoints["vllm_gpu_pod_coschedule_mig"].resource_name

# @title Chat Completions Inference

# @markdown Once deployment succeeds, you can send requests to the endpoint using the OpenAI SDK.

# @markdown First you will need to install the SDK and some auth-related dependencies.

! pip install -qU openai google-auth requests

# @markdown Next fill out some request parameters:

user_message = "How is your day going?"  # @param {type: "string"}
# @markdown If you encounter the issue like `ServiceUnavailable: 503 Took too long to respond when processing`, you can reduce the maximum number of output tokens, such as set `max_tokens` as 20.
max_tokens = 50  # @param {type: "integer"}
temperature = 1.0  # @param {type: "number"}
stream = False  # @param {type: "boolean"}

# @markdown Now we can send a request.

import google.auth
import openai

creds, project = google.auth.default()
auth_req = google.auth.transport.requests.Request()
creds.refresh(auth_req)

BASE_URL = (
    f"https://{REGION}-aiplatform.googleapis.com/v1beta1/{ENDPOINT_RESOURCE_NAME}"
)
try:
    if use_dedicated_endpoint:
        BASE_URL = f"https://{DEDICATED_ENDPOINT_DNS}/v1beta1/{ENDPOINT_RESOURCE_NAME}"
except NameError:
    pass

client = openai.OpenAI(base_url=BASE_URL, api_key=creds.token)

model_response = client.chat.completions.create(
    model="",
    messages=[{"role": "user", "content": user_message}],
    temperature=temperature,
    max_tokens=max_tokens,
    stream=stream,
)

if stream:
    usage = None
    contents = []
    for chunk in model_response:
        if chunk.usage is not None:
            usage = chunk.usage
            continue
        print(chunk.choices[0].delta.content, end="")
        contents.append(chunk.choices[0].delta.content)
    print(f"\n\n{usage}")
else:
    print(model_response)

# @markdown Click "Show Code" to see more details.

## 9. Clean Up

To avoid incurring ongoing charges, it's important to clean up the resources you've created.

In [None]:
# @title Delete the models and endpoints

# @markdown  Delete the experiment models and endpoints to recycle the resources
# @markdown  and avoid unnecessary continuous charges that may incur.

# Undeploy model and delete endpoint.
for endpoint in endpoints.values():
    endpoint.delete(force=True)

# Delete models.
for model in models.values():
    model.delete()

## Multi-model Serving

## 1. Setup (Same as Before)

First, let's install the necessary packages and set up your Google Cloud project environment.

In [None]:
# @title Setup Google Cloud project

# @markdown 1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

# @markdown 2. **[Optional]** Set region. If not set, the region will be set automatically according to Colab Enterprise environment.

REGION = ""  # @param {type:"string"}

# @markdown 3. If you want to run predictions with H100 GPUs or H200 GPUs, we recommend using the regions listed below. **NOTE:** Make sure you have associated quota in selected regions. Click the links to see your current quota for H100s: [`CustomModelServingH100GPUsPerProjectPerRegion`](https://console.cloud.google.com/iam-admin/quotas?metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_h100_gpus) and H200s: [`CustomModelServingH200GPUsPerProjectPerRegion`](https://console.cloud.google.com/iam-admin/quotas?metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_h200_gpus). You can request for quota following the instructions at ["Request a higher quota"](https://cloud.google.com/docs/quota/view-manage#requesting_higher_quota).

# @markdown | Machine Type | Accelerator Type | Recommended Regions |
# @markdown | ----------- | ----------- | ----------- |
# @markdown | a3-highgpu-8g | 8 NVIDIA_H100_80GB | asia-southeast1, europe-west4, us-central1, us-east5, us-west1 |
# @markdown | a3-ultragpu-8g | 8 NVIDIA_H200_141GB | asia-south2, us-south1 |

# Upgrade Vertex AI SDK.
! pip3 install --upgrade --quiet 'google-cloud-aiplatform==1.103.0'
! pip3 install --upgrade --quiet aiohttp matplotlib pandas seaborn

# Import the necessary packages
import importlib  # noqa: F811
import os  # noqa: F811
from typing import Tuple  # noqa: F811

import requests  # noqa: F811
from google import auth  # noqa: F811
from google.cloud import aiplatform  # noqa: F811

# Upgrade Vertex AI SDK.
if os.environ.get("VERTEX_PRODUCT") != "COLAB_ENTERPRISE":
    ! pip install --upgrade tensorflow
! git clone https://github.com/GoogleCloudPlatform/vertex-ai-samples.git

common_util = importlib.import_module(
    "vertex-ai-samples.notebooks.community.model_garden.docker_source_codes.notebook_util.common_util"
)

models, endpoints = {}, {}

# Get the default cloud project id.
PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"]

# Get the default region for launching jobs.
if not REGION:
    REGION = os.environ["GOOGLE_CLOUD_REGION"]

# Initialize Vertex AI API.
print("Initializing Vertex AI API.")
aiplatform.init(project=PROJECT_ID, location=REGION)

! gcloud config set project $PROJECT_ID

import vertexai

vertexai.init(
    project=PROJECT_ID,
    location=REGION,
)

## 2. Learn to Configure the Model Co-hosting Server

✨ **Special Feature**: The Vertex AI Model Garden vLLM container used in this section allows you to **co-host multiple models as one container, with each model having its dedicated pipeline parallelism (PP), tensor parallelism (TP) and model replicas strategies**.

### Launch Arguments

Below lists the key launch arguments.

---

#### 🛠️ Model Specification and Memory Allocation

| Argument | Requirement | Example(s) | Description |
| :--- | :--- | :--- | :--- |
| **`--model`** | Required | `model_a,model_b,model_c` | Comma-separated list of HuggingFace model IDs or GCS paths to load. |
| **`--served-model-name`** | Optional | `model_x,model_y,model_z` | Comma-separated list of **model identifiers** to use in the API for each model. If not set, the value of `--model` is used. |
| **`--gpu-memory-partition`** | Required | `0.5,0.25,0.25` <br> `0.25` | Comma-separated list of **GPU memory ratios** to reserve for each model out of the full VM (e.g., first model gets 50%, second gets 25%, etc.). A single value applies to all. *(New argument)* |
| **`--model-replicas`** | Optional | `4,1,2` <br> `1` | Comma-separated list of the **number of model replicas** to create for each model, or a single number for all. If not set, all models have one replica. *(New argument)* |
| **`--max-model-len`** | Optional | `1024,8192,8192` <br> `1024` | Comma-separated list of **maximum context lengths** for each model, or a single length for all. If unset, the length is derived from the model's configuration. |

---

#### ⚙️ Parallelism

| Argument | Requirement | Example(s) | Description |
| :--- | :--- | :--- | :--- |
| **`--tensor-parallel-size`** | Optional | `1,2,2` <br> `8` | Comma-separated list of **Tensor Parallelism (TP) sizes** for each model, or a single size for all. If unset, TP size defaults to the number of available GPUs. |
| **`--pipeline-parallel-size`** | Optional | `1,2,2` <br> `1` | Comma-separated list of **Pipeline Parallelism (PP) sizes** for each model, or a single size for all. If unset, PP size defaults to 1. |

---

**Note**: The Model Garden vLLM container used in this section builds upon vLLM's version: https://github.com/vllm-project/vllm/commit/c8851a47235f5dfd3da3abf6c89453b3bdb41ad1. It doesn't modify the kernel or engine implementation.

In [None]:
# The MG vLLM model co-hosting serving container.
VLLM_DOCKER_URI = "us-docker.pkg.dev/vertex-imageplatform/vertex-model-garden/vllm-inference-restricted-ubuntu22.04-py3.12:model-garden.vllm-restricted-x86-release_20251028.02_p0"  # @param {type:"string"}

## 3. Deploy to Vertex AI and Test the Endpoint

With the optimal recipe identified, we now deploy the model to a Vertex AI Endpoint with 8 x H100 VM.

In [None]:
# Set model to deploy
base_model_name_a = "gemma-3n-E2B-it"  # @param {type:"string"}
model_id_a = "google/gemma-3n-E2B-it"  # @param {type:"string"}
hf_model_id_a = model_id_a
publisher_a = "google"
publisher_model_id_a = "gemma3n"

base_model_name_b = "Llama-3.1-8B-Instruct"  # @param {type:"string"}
model_id_b = "meta-llama/Llama-3.1-8B-Instruct"  # @param {type:"string"}
hf_model_id_b = model_id_b
publisher_b = "meta"
publisher_model_id_b = "llama3_1"

base_model_name = ",".join([base_model_name_a, base_model_name_b])
served_model_name = base_model_name
model_id = ",".join([model_id_a, model_id_b])

gpu_memory_partition = "0.4,0.4"  # @param {type:"string"}
pipeline_parallel_size = "1,1"  # @param {type:"string"}
tensor_parallel_size = "1,4"  # @param {type:"string"}
model_replicas = "4,1"  # @param {type:"string"}

HF_TOKEN = ""  # @param {type:"string"}

# Find Vertex AI prediction supported accelerators and regions at https://cloud.google.com/vertex-ai/docs/predictions/configure-compute.
accelerator_type = "NVIDIA_H100_80GB"
accelerator_count = 8
machine_type = "a3-highgpu-8g"
multihost_gpu_node_count = 1

In [None]:
# @title Deploy with customized configs

# @markdown Set use_dedicated_endpoint to False if you don't want to use [dedicated endpoint](https://cloud.google.com/vertex-ai/docs/general/deployment#create-dedicated-endpoint). Note that [dedicated endpoint does not support VPC Service Controls](https://cloud.google.com/vertex-ai/docs/predictions/choose-endpoint-type), uncheck the box if you are using VPC-SC.
use_dedicated_endpoint = True  # @param {type:"boolean"}

# @markdown Choose whether to use a [Spot VM](https://cloud.google.com/compute/docs/instances/spot) for the deployment.
is_spot = False  # @param {type:"boolean"}

common_util.check_quota(
    project_id=PROJECT_ID,
    region=REGION,
    accelerator_type=accelerator_type,
    accelerator_count=int(accelerator_count * multihost_gpu_node_count),
    is_for_training=False,
    is_spot=is_spot,
)

# @markdown To enable the auto-scaling in deployment, you can set the following options:

min_replica_count = 1  # @param {type:"integer"}
max_replica_count = 1  # @param {type:"integer"}
required_replica_count = 1  # @param {type:"integer"}

# @markdown Set the target of GPU duty cycle or CPU usage between 1 and 100 for auto-scaling.
autoscale_by_gpu_duty_cycle_target = 0  # @param {type:"integer"}
autoscale_by_cpu_usage_target = 0  # @param {type:"integer"}

# @markdown Note: GPU duty cycle is not the most accurate metric for scaling workloads. More advanced auto-scaling metrics are coming soon. See [the public doc](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/DedicatedResources#AutoscalingMetricSpec) for more details.


def deploy_model_vllm_multi_model_cohost(
    model_name: str,
    model_id: str,
    gpu_memory_partition: str,
    publisher: str,
    publisher_model_id: str,
    service_account: str = None,
    base_model_id: str = None,
    served_model_name: str = "",
    machine_type: str = "a3-highgpu-8g",
    accelerator_type: str = "NVIDIA_H100_80GB",
    accelerator_count: int = 8,
    multihost_gpu_node_count: int = 1,
    pipeline_parallel_size: int = 1,
    tensor_parallel_size: int = 8,
    model_replicas: int = 1,
    enable_trust_remote_code: bool = False,
    use_dedicated_endpoint: bool = False,
    min_replica_count: int = 1,
    max_replica_count: int = 1,
    required_replica_count: int = 1,
    autoscale_by_gpu_duty_cycle_target: int = 0,
    autoscale_by_cpu_usage_target: int = 0,
    is_spot: bool = True,
    model_cohost_feature: str = "multi-model-cohost",
) -> Tuple[aiplatform.Model, aiplatform.Endpoint]:
    """Deploys models with vLLM to Vertex AI."""
    endpoint = aiplatform.Endpoint.create(
        display_name=f"{model_name}-endpoint",
        dedicated_endpoint_enabled=use_dedicated_endpoint,
    )

    if not base_model_id:
        base_model_id = model_id

    api_server = "vllm.entrypoints.model_cohost_server"

    # See https://docs.vllm.ai/en/latest/models/engine_args.html for a list of possible arguments with descriptions.
    vllm_args = [
        "python",
        "-m",
        api_server,
        "--host=0.0.0.0",
        "--port=8080",
        f"--model={model_id}",
        f"--gpu-memory-partition={gpu_memory_partition}",
        f"--pipeline-parallel-size={pipeline_parallel_size}",
        f"--tensor-parallel-size={tensor_parallel_size}",
        f"--model-replicas={model_replicas}",
        "--data-parallel-size=1",
        "--swap-space=16",
    ]

    if multihost_gpu_node_count > 1:
        vllm_args = ["/vllm-workspace/ray_launcher.sh"] + vllm_args

    if served_model_name:
        vllm_args.append(f"--served-model-name={served_model_name}")

    if enable_trust_remote_code:
        vllm_args.append("--trust-remote-code")

    env_vars = {
        "MODEL_ID": base_model_id,
        "DEPLOY_SOURCE": "notebook",
    }

    # HF_TOKEN is not a compulsory field and may not be defined.
    try:
        if HF_TOKEN:
            env_vars["HF_TOKEN"] = HF_TOKEN
    except NameError:
        pass

    model = aiplatform.Model.upload(
        display_name=model_name,
        serving_container_image_uri=VLLM_DOCKER_URI,
        serving_container_args=vllm_args,
        serving_container_ports=[8080],
        serving_container_predict_route="/generate",
        serving_container_health_route="/ping",
        serving_container_environment_variables=env_vars,
        serving_container_shared_memory_size_mb=(16 * 1024),  # 16 GB
        serving_container_deployment_timeout=7200,
        model_garden_source_model_name=(
            f"publishers/{publisher}/models/{publisher_model_id}"
        ),
    )
    print(
        f"Deploying {model_name} on {multihost_gpu_node_count} host(s) of {machine_type} with {accelerator_type} GPU(s)."
    )

    creds, _ = auth.default()
    auth_req = auth.transport.requests.Request()
    creds.refresh(auth_req)

    url = f"https://{REGION}-aiplatform.googleapis.com/ui/projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint.name}:deployModel"
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {creds.token}",
    }
    data = {
        "deployedModel": {
            "model": model.resource_name,
            "displayName": model_name,
            "dedicatedResources": {
                "machineSpec": {
                    "machineType": machine_type,
                    "multihostGpuNodeCount": multihost_gpu_node_count,
                    "acceleratorType": accelerator_type,
                    "acceleratorCount": accelerator_count,
                },
                "minReplicaCount": min_replica_count,
                "requiredReplicaCount": required_replica_count,
                "maxReplicaCount": max_replica_count,
            },
            "system_labels": {
                "NOTEBOOK_NAME": "model_garden_model_cohost.ipynb",
                "NOTEBOOK_ENVIRONMENT": common_util.get_deploy_source(),
                "mg-serving-feature-model-cohost": model_cohost_feature,
            },
        },
    }
    if service_account:
        data["deployedModel"]["serviceAccount"] = service_account
    if is_spot:
        data["deployedModel"]["dedicatedResources"]["spot"] = True
    if autoscale_by_gpu_duty_cycle_target > 0 or autoscale_by_cpu_usage_target > 0:
        data["deployedModel"]["dedicatedResources"]["autoscalingMetricSpecs"] = []
        if autoscale_by_gpu_duty_cycle_target > 0:
            data["deployedModel"]["dedicatedResources"][
                "autoscalingMetricSpecs"
            ].append(
                {
                    "metricName": "aiplatform.googleapis.com/prediction/online/accelerator/duty_cycle",
                    "target": autoscale_by_gpu_duty_cycle_target,
                }
            )
        if autoscale_by_cpu_usage_target > 0:
            data["deployedModel"]["dedicatedResources"][
                "autoscalingMetricSpecs"
            ].append(
                {
                    "metricName": "aiplatform.googleapis.com/prediction/online/cpu/utilization",
                    "target": autoscale_by_cpu_usage_target,
                }
            )
    response = requests.post(url, headers=headers, json=data)
    print(f"Deploy Model response: {response.json()}")
    if response.status_code != 200 or "name" not in response.json():
        raise ValueError(f"Failed to deploy model: {response.text}")
    common_util.poll_and_wait(response.json()["name"], REGION, 7200)
    print("endpoint_name:", endpoint.name)

    return model, endpoint


(
    models["vllm_gpu_multi_model"],
    endpoints["vllm_gpu_multi_model"],
) = deploy_model_vllm_multi_model_cohost(
    model_name=common_util.get_job_name_with_datetime(prefix="multi-model-cohost"),
    model_id=model_id,
    gpu_memory_partition=gpu_memory_partition,
    served_model_name=served_model_name,
    publisher=publisher_a,
    publisher_model_id=publisher_model_id_a,
    base_model_id=hf_model_id_a,
    machine_type=machine_type,
    accelerator_type=accelerator_type,
    accelerator_count=accelerator_count,
    multihost_gpu_node_count=multihost_gpu_node_count,
    pipeline_parallel_size=pipeline_parallel_size,
    tensor_parallel_size=tensor_parallel_size,
    model_replicas=model_replicas,
    use_dedicated_endpoint=use_dedicated_endpoint,
    min_replica_count=min_replica_count,
    max_replica_count=max_replica_count,
    required_replica_count=required_replica_count,
    autoscale_by_gpu_duty_cycle_target=autoscale_by_gpu_duty_cycle_target,
    autoscale_by_cpu_usage_target=autoscale_by_cpu_usage_target,
    is_spot=is_spot,
    model_cohost_feature="multi-model-cohost",
)
# @markdown Click "Show Code" to see more details.

In [None]:
# @title Chat completion

if use_dedicated_endpoint:
    DEDICATED_ENDPOINT_DNS = endpoints[
        "vllm_gpu_multi_model"
    ].gca_resource.dedicated_endpoint_dns
ENDPOINT_RESOURCE_NAME = endpoints["vllm_gpu_multi_model"].resource_name

# @title Chat Completions Inference

# @markdown Once deployment succeeds, you can send requests to the endpoint using the OpenAI SDK.

# @markdown First you will need to install the SDK and some auth-related dependencies.

! pip install -qU openai google-auth requests

# @markdown Next fill out some request parameters:

model = ""  # @param {type: "string"}
user_message = "How is your day going?"  # @param {type: "string"}
# @markdown If you encounter the issue like `ServiceUnavailable: 503 Took too long to respond when processing`, you can reduce the maximum number of output tokens, such as set `max_tokens` as 20.
max_tokens = 50  # @param {type: "integer"}
temperature = 1.0  # @param {type: "number"}
stream = False  # @param {type: "boolean"}

# @markdown Now we can send a request.

import google.auth
import openai

creds, project = google.auth.default()
auth_req = google.auth.transport.requests.Request()
creds.refresh(auth_req)

BASE_URL = (
    f"https://{REGION}-aiplatform.googleapis.com/v1beta1/{ENDPOINT_RESOURCE_NAME}"
)
try:
    if use_dedicated_endpoint:
        BASE_URL = f"https://{DEDICATED_ENDPOINT_DNS}/v1beta1/{ENDPOINT_RESOURCE_NAME}"
except NameError:
    pass

client = openai.OpenAI(base_url=BASE_URL, api_key=creds.token)

model_response = client.chat.completions.create(
    model=model,
    messages=[{"role": "user", "content": user_message}],
    temperature=temperature,
    max_tokens=max_tokens,
    stream=stream,
)

if stream:
    usage = None
    contents = []
    for chunk in model_response:
        if chunk.usage is not None:
            usage = chunk.usage
            continue
        print(chunk.choices[0].delta.content, end="")
        contents.append(chunk.choices[0].delta.content)
    print(f"\n\n{usage}")
else:
    print(model_response)

# @markdown Click "Show Code" to see more details.

## 4. Clean Up

To avoid incurring ongoing charges, it's important to clean up the resources you've created.

In [None]:
# @title Delete the models and endpoints

# @markdown  Delete the experiment models and endpoints to recycle the resources
# @markdown  and avoid unnecessary continuous charges that may incur.

# Undeploy model and delete endpoint.
for endpoint in endpoints.values():
    endpoint.delete(force=True)

# Delete models.
for model in models.values():
    model.delete()