In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Hugging Face DLCs: Serving Gemma 2 with multiple LoRA adapters with Text Generation Inference (TGI) on Vertex AI

<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_tgi_gemma_multi_lora_adapters_deployment.ipynb">
      <img width="32px" src="https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg" alt="Google Colaboratory logo"><br> Open in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fopen-models%2Fserving%2Fvertex_ai_tgi_gemma_multi_lora_adapters_deployment.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Open in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/open-models/serving/vertex_ai_tgi_gemma_multi_lora_adapters_deployment.ipynb">
      <img src="https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_tgi_gemma_multi_lora_adapters_deployment.ipynb">
      <img width="32px" src="https://www.svgrepo.com/download/217753/github.svg" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
</table>

<div style="clear: both;"></div>

<b>Share to:</b>

<a href="https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_tgi_gemma_multi_lora_adapters_deployment.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg" alt="LinkedIn logo">
</a>

<a href="https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_tgi_gemma_multi_lora_adapters_deployment.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg" alt="Bluesky logo">
</a>

<a href="https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_tgi_gemma_multi_lora_adapters_deployment.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg" alt="X logo">
</a>

<a href="https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_tgi_gemma_multi_lora_adapters_deployment.ipynb" target="_blank">
  <img width="20px" src="https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png" alt="Reddit logo">
</a>

<a href="https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_tgi_gemma_multi_lora_adapters_deployment.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg" alt="Facebook logo">
</a>

| | |
|-|-|
| Author(s) |  [Ivan Nardini](https://github.com/inardini)

## Overview

> [**Gemma**](https://ai.google.dev/gemma) is a family of lightweight, state-of-the-art open models built from the same research and technology used to create the Gemini models, developed by Google DeepMind and other teams across Google.

> [**Hugging Face DLCs**](https://github.com/huggingface/Google-Cloud-Containers) are pre-built and optimized Deep Learning Containers (DLCs) maintained by Hugging Face and Google Cloud teams to simplify environment configuration for your ML workloads.

> [**Google Cloud Vertex AI**](https://cloud.google.com/vertex-ai) is a Machine Learning (ML) platform that lets you train and deploy ML models and AI applications, and customize large language models (LLMs) for use in your AI-powered applications.

This notebook showcases how to deploy Gemma 2 2B from Hugging Face Hub with multiple LoRA adapters fine-tuned for different purposes such as coding, or SQL using Hugging Face's Text Generation Inference (TGI) Deep Learning Container (DLC) in combination with a [custom handler](https://huggingface.co/docs/inference-endpoints/en/guides/custom_handler#create-custom-inference-handler) on Vertex AI.

By the end of this notebook, you will learn how to:

- Create a custom handler and test it
- Register any LLM from the Hugging Face Hub on Vertex AI
- Deploy an LLM on Vertex AI
- Send online predictions on Vertex AI

## Get started

### (Optional) Set the runtime

Depending on your notebook enviroment, consider to set a GPU runtime.


### Install Vertex AI SDK and other required packages


In [None]:
%pip install --upgrade --user --quiet 'torch' 'torchvision' 'torchaudio'
%pip install --upgrade --user --quiet 'huggingface_hub[hf_transfer]' 'transformers' 'accelerate>=0.26.0'
%pip install --upgrade --user --quiet 'google-cloud-aiplatform' 'crcmod' 'etils'

### Restart runtime

To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel.

The restart might take a minute or longer. After it's restarted, continue to the next step.

In [None]:
import IPython

app = IPython.Application.instance()
app.kernel.do_shutdown(True)

<div class="alert alert-block alert-warning">
<b>⚠️ The kernel is going to restart. In Colab or Colab Enterprise, you might see an error message that says "Your session crashed for an unknown reason." This is expected. Wait until it's finished before continuing to the next step. ⚠️</b>
</div>


### Authenticate your notebook environment (Colab only)

If you're running this notebook on Google Colab, run the cell below to authenticate your environment.

In [None]:
import sys

if "google.colab" in sys.modules:
    from google.colab import auth

    auth.authenticate_user()

### Set Hugging Face variables

In [None]:
import os

from etils.epath import Path

ROOT_PATH = Path(".")
TUTORIAL_PATH = ROOT_PATH / "deploy_gemma_with_multi_lora_adapters_tutorial"

os.environ["HF_HOME"] = str(TUTORIAL_PATH)
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

### Authenticate your Hugging Face account

As [`google/gemma-2b`](https://huggingface.co/google/gemma-2-2b) is a gated model, you need to have a Hugging Face Hub account, and accept the Google's usage license for Gemma. Once that's done, you need to generate a new user access token with read-only access so that the weights can be downloaded from the Hub in the Hugging Face DLC for TGI.

> Note that the user access token can only be generated via [the Hugging Face Hub UI](https://huggingface.co/settings/tokens/new), where you can either select read-only access to your account, or follow the recommendations and generate a fine-grained token with read-only access to [`google/gemma-2b`](https://huggingface.co/google/gemma-2-2b).

Then you can install the `huggingface_hub` that comes with a CLI that will be used for the authentication with the token generated in advance. So that then the token can be safely retrieved via `huggingface_hub.get_token`.

In [None]:
from huggingface_hub import interpreter_login

interpreter_login()

Read more about [Hugging Face Security](https://huggingface.co/docs/hub/en/security), specifically about [Hugging Face User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens).

### Requirements

#### Set Project ID and Location

To get started using Vertex AI, you must have an existing Google Cloud project and [enable these APIs](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com,artifactregistry.googleapis.com).

Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment).

In [None]:
# Use the environment variable if the user doesn't provide Project ID.
import os

PROJECT_ID = "[your-project-id]"  # @param {type: "string", placeholder: "[your-project-id]", isTemplate: true}

if not PROJECT_ID or PROJECT_ID == "[your-project-id]":
    PROJECT_ID = str(os.environ.get("GOOGLE_CLOUD_PROJECT"))

PROJECT_NUMBER = !gcloud projects describe {PROJECT_ID} --format="get(projectNumber)"[0]
PROJECT_NUMBER = PROJECT_NUMBER[0]

LOCATION = os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")

#### Create a bucket in Google Cloud Storage (GCS)

Create a storage bucket to store intermediate artifacts such as models.

In [None]:
BUCKET_NAME = "[your-bucket-name]"  # @param {type:"string", isTemplate: true}

if BUCKET_NAME == "[your-bucket-name]":
    raise ValueError("A valid BUCKET_NAME needs to be specified")

BUCKET_URI = f"gs://{BUCKET_NAME}"
os.environ["BUCKET_URI"] = BUCKET_URI

Uncomment the `gcloud storage buckets create` command below if you need to create a bucket on GCS.

In [None]:
!gcloud storage buckets create $BUCKET_URI --project $PROJECT_ID --location=$LOCATION --default-storage-class=STANDARD --uniform-bucket-level-access

#### Set Service Account and permissions

You will need to have the following IAM roles set:

- Vertex AI User (roles/aiplatform.user)
- Artifact Registry Reader (roles/artifactregistry.reader)
- Storage Object Admin (roles/storage.objectAdmin)

For more information about granting roles, see [Manage access](https://cloud.google.com/iam/docs/granting-changing-revoking-access).


> If you run following commands using Vertex AI Workbench, run directly in the terminal.


In [None]:
SERVICE_ACCOUNT = f"{PROJECT_NUMBER}-compute@developer.gserviceaccount.com"

In [None]:
for role in ['aiplatform.user', 'storage.objectAdmin', 'artifactregistry.reader']:

    ! gcloud projects add-iam-policy-binding {PROJECT_ID} \
      --member=serviceAccount:{SERVICE_ACCOUNT} \
      --role=roles/{role} --condition=None

### Initiate Vertex AI SDK

Initiate Vertex AI client session.

In [None]:
import vertexai

vertexai.init(project=PROJECT_ID, location=LOCATION, staging_bucket=BUCKET_URI)

### Import libraries

Import relevant libraries.

In [None]:
import gc
import json
from pprint import pprint as pp

from etils import epath
from google.cloud import aiplatform
from google.cloud.aiplatform import Endpoint, Model
from huggingface_hub import get_token, snapshot_download
import requests
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig as TGenerationConfig
from vertexai.generative_models import GenerationConfig, GenerativeModel

### Helpers

Define some helpers.

In [None]:
def hf_model_path(hf_home_path: str, model_name: str | None = None) -> str:
    """Copy model files from the Hugging Face cache."""
    # Convert and expand the input path
    base_path = Path(hf_home_path).expanduser()

    if not base_path.exists():
        raise FileNotFoundError(f"Cache path does not exist: {base_path}")

    if model_name:
        # Search in specific model directory
        model_dir_name = "models--" + model_name.replace("/", "--")
        model_dirs = [base_path / model_dir_name]
        if not model_dirs[0].exists():
            raise FileNotFoundError(f"Model directory not found: {model_dirs[0]}")
    else:
        # Search all directories (original behavior)
        model_dirs = [d for d in base_path.iterdir() if d.is_dir()]

    # Find the deepest snapshot directory
    snapshot_path = None
    for model_dir in model_dirs:
        # Look for 'snapshots' directory
        snapshots_dir = model_dir / "snapshots"
        if not snapshots_dir.exists():
            continue

        # Get the first snapshot folder (usually there's only one)
        for snapshot in snapshots_dir.iterdir():
            if snapshot.is_dir():
                snapshot_path = snapshot
                break

        if snapshot_path:
            break

    if not snapshot_path:
        error_msg = (
            f"No snapshot directory found for model: {model_name}"
            if model_name
            else "No snapshot directory found in the cache"
        )
        raise FileNotFoundError(error_msg)

    return str(snapshot_path)


def get_cuda_device_names():
    """A function to get the list of NVIDIA GPUs"""
    if not torch.cuda.is_available():
        return None

    return [str(i) for i in range(torch.cuda.device_count())]


def empty_gpu_ram():
    """A function to empty the GPU RAM"""
    gc.collect()
    torch.cuda.empty_cache()

## Prepare a custom handler to serve Lora Adapters on Vertex AI

Custom Handlers are custom classes in Python that define the pre-processing, inference, and post-processing steps required to run the inference using Hugging Face Pytorch Prediction container on Vertex AI.

Think of Custom Handlers as personalized instructions for Hugging Face models. They define how to prepare the input data, run the model, and handle the results. In this sense, Custom Handlers add flexibility. They let you customize how data is prepared and processed, add extra steps, and even build in custom measurements or logging. This means you can tailor the process to your exact needs when the standard setup isn't sufficient.

These instructions are stored in a file named `handler.py`. If you need additional dependencies, you can list it in a `requirements.txt` file. The PyTorch container automatically finds and uses these files if they're present.

Have a look at [🤗 Serve Anything with Inference Endpoints + Custom Handlers](https://huggingface.co/blog/alvarobartt/serve-anything-inference-endpoints-custom-code) to learn more.

### Test the `handler` locally

Before to build the handler module, you can test its coding locally.

#### Download Gemma 2 and adapters locally from Hugging Face Hub

In [None]:
# Download Gemma 2 and adapters
base_model_id = "google/gemma-2-2b-it"
sql_adapter_id = "google-cloud-partnership/gemma-2-2b-it-lora-sql"
magicoder_adapter_id = "google-cloud-partnership/gemma-2-2b-it-lora-magicoder"

snapshot_download(repo_id=base_model_id, token=get_token())
snapshot_download(repo_id=sql_adapter_id, token=get_token())
snapshot_download(repo_id=magicoder_adapter_id, token=get_token())

#### Load Gemma model

Load a pre-trained model called "gemma" for text generation. You first sets up the text preprocessor, which handles tasks like tokenization (breaking text into words or sub-words). Then you load the actual model itself, optimized for lower memory usage and automatic device placement (likely GPU if available).

In [None]:
gemma_path = hf_model_path(TUTORIAL_PATH / "hub", base_model_id)

tokenizer = AutoTokenizer.from_pretrained(gemma_path)

model = AutoModelForCausalLM.from_pretrained(
    gemma_path, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map="auto"
)

#### Load Gemma's Lora adapters


In [None]:
sql_adapter_path = hf_model_path(TUTORIAL_PATH / "hub", sql_adapter_id)
magicoder_adapter_path = hf_model_path(TUTORIAL_PATH / "hub", magicoder_adapter_id)

model.load_adapter(
    sql_adapter_path,
    adapter_name="sql_adapter",
    is_trainable=False,
    device_map="auto",
    offload_folder="/tmp/offload",
)
model.load_adapter(
    magicoder_adapter_path,
    adapter_name="magicoder_adapter",
    is_trainable=False,
    device_map="auto",
    offload_folder="/tmp/offload",
)

#### Prepare the prediction request

Prepare the text generation request by creating a dictionary with user prompt and generation parameters.


In [None]:
requests = []
user_prompts = [
    "I have a table called orders with columns order_id (INT), customer_id (INT), order_date (DATE), and total_amount (DECIMAL). I need to find the total revenue generated in the month of October 2023. How can I write a SQL query to achieve this?",
    "# Context: You have a list of numbers called `my_numbers`.\n# Question: How do I calculate the sum of all the numbers in `my_numbers` using a built-in function?\n# Example `my_numbers` list:\nmy_numbers = [1, 2, 3, 4, 5]",
]

for prompt in user_prompts:
    requests.append(
        {
            "inputs": prompt,
            "parameters": {"max_new_tokens": 10, "temperature": 0.7, "do_sample": True},
        }
    )

prediction_request = {
    "instances": requests,
}

In [None]:
pp(prediction_request, indent=3)

#### Define an LLM-based router

To send the right user request to the right adapted model, define a LLM-based router which takes the user input and determines if it's a SQL query generation task (SQL) or a code generation task (CODE).

In [None]:
router_model_id = "gemini-1.5-flash"
router_model = GenerativeModel(router_model_id)

In [None]:
def route_prompt(prompt, router_model):
    router_prompt = """Analyze the following prompt and determine if it's a SQL query generation task (SQL) or a code generation task (CODE).

    Guidelines:
    - SQL: Database queries, table operations, data retrieval, SQL explanations, database schema questions
    - CODE: Programming problems, algorithm implementations, software development tasks, code optimization

    Consider these examples:
    1. "Given a users table with columns (id, name, age), show me all users above 25" -> SQL
    2. "Write a function to reverse a linked list in Python" -> CODE
    3. "How can I join these two tables and filter by date?" -> SQL
    4. "Implement a binary search tree in Rust" -> CODE

    User prompt: "{prompt}"
    Response:
    """

    response_schema = {
        "type": "OBJECT",
        "properties": {
            "classification": {
                "type": "string",
                "enum": ["SQL", "CODE"],
            },
        },
    }

    try:
        response = router_model.generate_content(
            router_prompt.format(prompt=prompt),
            generation_config=GenerationConfig(
                response_mime_type="application/json", response_schema=response_schema
            ),
        ).text
        result = json.loads(response)
        return result
    except:
        return {"classification": "CODE"}

In [None]:
for instance in prediction_request["instances"]:
    prompt = instance["inputs"]
    print(prompt[:55], "-->", route_prompt(prompt, router_model))

#### Generate prediction

Process a list of prediction requests, each containing a text prompt.

> Note that depending on your runtime, generating predictions may require ~ 10 mins.

In [None]:
predictions = []

for instance in prediction_request["instances"]:
    # Check the prompt
    if "inputs" not in instance:
        raise ValueError("The request body must contain the `inputs` key.")

    # Get the adapter label
    prompt = instance["inputs"]
    prompt_classification = route_prompt(prompt, router_model)["classification"]

    # Set the adapter model
    if prompt_classification == "SQL":
        model.set_adapter("sql_adapter")
    else:
        model.set_adapter("magicoder_adapter")

    # Prepare input
    messages = [{"role": "user", "content": prompt}]
    input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(
        model.device
    )

    # Generate prediction
    input_len = input_ids.shape[-1]
    with torch.inference_mode():
        generation_config = instance.get(
            "parameters", {"max_new_tokens": 10, "temperature": 0.7, "do_sample": True}
        )
        generation = model.generate(
            input_ids=input_ids,
            generation_config=TGenerationConfig(**generation_config),
        )
        generation = generation[0][input_len:]
        response = tokenizer.decode(generation, skip_special_tokens=True).removeprefix(
            "model\n"
        )
        predictions.append(response)

Get the prediction.

In [None]:
for prediction in predictions:
    print("------- Prediction -------")
    print(prediction)
    print("--------------------------\n")

###  Define the `handler.py` module

After testing the handler code, you assemble the code in a Python module which defines the custom inference handler for Gemma.

You write this handler code to a file named handler.py within the model directory together with the requirements file which contains dependencies for executing handler code.

In [None]:
serve_uri = epath.Path(BUCKET_URI) / "serve"
serve_uri.mkdir(parents=True, exist_ok=True)

In [None]:
handler_module = '''
from typing import Any, Dict, List
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import GenerationConfig as TGenerationConfig
from vertexai.generative_models import GenerationConfig, GenerativeModel
import json
import logging
import sys
from huggingface_inference_toolkit.logging import logger

def route_prompt(prompt, router_model):
    router_prompt = """Analyze the following prompt and determine if it's a SQL query generation task (SQL) or a code generation task (CODE).

    Guidelines:
    - SQL: Database queries, table operations, data retrieval, SQL explanations, database schema questions
    - CODE: Programming problems, algorithm implementations, software development tasks, code optimization

    Consider these examples:
    1. "Given a users table with columns (id, name, age), show me all users above 25" -> SQL
    2. "Write a function to reverse a linked list in Python" -> CODE
    3. "How can I join these two tables and filter by date?" -> SQL
    4. "Implement a binary search tree in Rust" -> CODE

    User prompt: "{prompt}"
    Response:
    """

    response_schema = {
    "type": "OBJECT",
    "properties": {
        "classification": {
            "type": "string",
            "enum": ["SQL", "CODE"],
        },
    }}

    try:
      response = router_model.generate_content(
        router_prompt.format(prompt=prompt),
        generation_config=GenerationConfig(
        response_mime_type="application/json", response_schema=response_schema
        ),
      ).text
      result = json.loads(response)
      return result
    except:
      return {"classification": "CODE"}

class EndpointHandler:
    def __init__(
        self,
        model_dir: str = "google/gemma-2-2b-it",
        sql_adapter_id: str = "/tmp/model/google-cloud-partnership/gemma-2-2b-it-lora-sql"
        magicoder_adapter_id: str = "/tmp/model/google-cloud-partnership/gemma-2-2b-it-lora-magicoder"
        router_model_id: str = "gemini-1.5-flash",
        **kwargs: Any,
    ) -> None:

        self.processor = AutoTokenizer.from_pretrained(model_dir, token=os.getenv("HF_TOKEN"))
        self.model = AutoModelForCausalLM.from_pretrained(
                    model_dir,
                    low_cpu_mem_usage=True,
                    torch_dtype=torch.float16,
                    device_map="auto",
                    token=os.getenv("HF_TOKEN")
                )

        self.model.load_adapter(sql_adapter_id, adapter_name="sql_adapter", is_trainable=False, device_map="auto", offload_folder="/tmp/offload")
        self.model.load_adapter(magicoder_adapter_id, adapter_name="magicoder_adapter", is_trainable=False, device_map="auto", offload_folder="/tmp/offload")
        self.router_model = GenerativeModel(router_model_id)

    def __call__(self, data: Dict[str, Any]) -> Dict[str, List[Any]]:
        logger.info("Processing new request")
        predictions = []

        for instance in data["instances"]:
            # Check the prompt
            if "inputs" not in instance:
                raise ValueError("The request body must contain the `inputs` key.")

            # Get the adapter label
            logger.info(f'Getting the adapter label for the prompt: {instance["inputs"]}')
            prompt = instance["inputs"]
            prompt_classification = route_prompt(prompt, self.router_model)["classification"]

            # Set the adapter model
            logger.info(f'Setting the model to {prompt_classification} adapter')
            if prompt_classification == "SQL":
                self.model.set_adapter("sql_adapter")
            else:
                self.model.set_adapter("magicoder_adapter")

            # Prepare input
            logger.info('Preparing the input for the prompt')
            messages = [{"role": "user", "content": prompt}]
            input_ids = self.tokenizer.apply_chat_template(
                messages,
                return_tensors="pt"
            ).to(self.model.device)

            # Generate prediction
            logger.info('Generating the prediction')
            input_len = input_ids.shape[-1]
            with torch.inference_mode():
                generation_config = instance.get(
                    "parameters", {"temperature": 0.7, "do_sample": True}
                )
                generation = self.model.generate(
                    input_ids=input_ids,
                    generation_config=TGenerationConfig(**generation_config),
                )
                generation = generation[0][input_len:]
                response = self.tokenizer.decode(generation, skip_special_tokens=True)
                logger.info(f'Generated response: {response[:50]}...')
                predictions.append(response)

        logger.info(f"Successfully processed {len(predictions)} instances")
        return {"predictions": predictions}
'''

with serve_uri.joinpath("handler.py").open("w") as f:
    f.write(handler_module)
f.close()

### Provide requirements file

Save a `requirements.txt` file with handler's dependencies.

In [None]:
requirements_file = """
google-cloud-aiplatform
"""

with serve_uri.joinpath("requirements.txt").open("w") as f:
    f.write(requirements_file)
f.close()

### Copy model with custom handler on Cloud Bucket

Efficiently upload the model directory to Google Cloud Storage using `gsutil`.

Note that `-m` enables multi-threaded uploads for faster transfer, especially for large directories. `-o GSUtil:parallel_composite_upload_threshold=150M` optimizes large file uploads by splitting them into smaller parts for parallel transfer, significantly speeding up the process for files larger than 150MB.

In [None]:
! gsutil -o GSUtil:parallel_composite_upload_threshold=150M -m cp -r {str(gemma_path)}/* {str(serve_uri)}
! gsutil -o GSUtil:parallel_composite_upload_threshold=150M -m cp -r {str(sql_adapter_path)}/* {str(serve_uri)}/{sql_adapter_id}
! gsutil -o GSUtil:parallel_composite_upload_threshold=150M -m cp -r {str(magicoder_adapter_path)}/* {str(serve_uri)}/{magicoder_adapter_id}

## Register Google Gemma on Vertex AI

To serve Gemma with Pytorch Inference on Vertex AI, you start importing the model on Vertex AI Model Registry, a central repository where you can manage the lifecycle of your ML models on Vertex AI.

Before going into the code to upload or import a model on Vertex AI, let's quickly review the arguments provided to the `aiplatform.Model.upload` method:

* **`display_name`** is the name that will be shown in the Vertex AI Model Registry.

* **`serving_container_image_uri`** is the location of the Hugging Face DLC for Pytorch Inference that will be used for serving the model.

* (optional) **`serving_container_ports`** is the port where the Vertex AI endpoint |will be exposed, by default 8080.

For more information on the supported `aiplatform.Model.upload` arguments, check [its Python reference](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.Model#google_cloud_aiplatform_Model_upload).

In [None]:
model = Model.upload(
    display_name="google--gemma2-tgi-multi-lora-model",
    artifact_uri=str(serve_uri),
    serving_container_image_uri="us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-pytorch-inference-cu121.2-3.transformers.4-46.ubuntu2204.py311",
    serving_container_ports=[8080],
    serving_container_environment_variables={
        "HUGGING_FACE_HUB_TOKEN": get_token(),
    },
)
model.wait()

## Deploy Google Gemma on Vertex AI

After the model is registered on Vertex AI, you can deploy the model to an endpoint.

You need to first deploy a model to an endpoint before that model can be used to serve online predictions. Deploying a model associates physical resources with the model so it can serve online predictions with low latency.

Before going into the code to deploy a model to an endpoint, let's quickly review the arguments provided to the `aiplatform.Model.deploy` method:

- **`endpoint`** is the endpoint to deploy the model to, which is optional, and by default will be set to the model display name with the `_endpoint` suffix.
- **`machine_type`**, **`accelerator_type`** and **`accelerator_count`** are arguments that define which instance to use, and additionally, the accelerator to use and the number of accelerators, respectively. The `machine_type` and the `accelerator_type` are tied together, so you will need to select an instance that supports the accelerator that you are using and vice-versa. More information about the different instances at [Compute Engine Documentation - GPU machine types](https://cloud.google.com/compute/docs/gpus), and about the `accelerator_type` naming at [Vertex AI Documentation - MachineSpec](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec).

For more information on the supported `aiplatform.Model.deploy` arguments, you can check [its Python reference](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.Model#google_cloud_aiplatform_Model_deploy).

In [None]:
deployed_model = model.deploy(
    endpoint=Endpoint.create(display_name="google--gemma2-tgi-multi-lora-endpoint"),
    machine_type="g2-standard-4",
    accelerator_type="NVIDIA_L4",
    accelerator_count=1,
)

> Note that the model deployment on Vertex AI can take around 15 to 25 minutes; most of the time being the allocation / reservation of the resources, setting up the network and security, and such.

## Online predictions on Vertex AI

Once the model is deployed on Vertex AI, you can run the online predictions using the `aiplatform.Endpoint.predict` method, which will send the requests to the running endpoint in the `/predict` route specified within the container following Vertex AI I/O payload formatting.

### Via Python

#### Within the same session

To run the online prediction via the Vertex AI SDK, you can simply use the `predict` method.

In [None]:
output = deployed_model.predict(instances=prediction_request["instances"])
for prediction in output.predictions:
    print("------- Prediction -------")
    print(prediction)
    print("--------------------------\n")

#### From a different session

To run the online prediction from a different session, you can run the following snippet.

In [None]:
import os

from google.cloud import aiplatform
import requests

PROJECT_ID = "[your-project-id]"  # @param {type:"string", isTemplate: true}
if PROJECT_ID == "[your-project-id]":
    PROJECT_ID = str(os.environ.get("GOOGLE_CLOUD_PROJECT"))

LOCATION = os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")

aiplatform.init(project=PROJECT_ID, location=LOCATION)

ENDPOINT_DISPLAY_NAME = "google--gemma2-tgi-multi-lora-endpoint"

# Iterates over all the Vertex AI Endpoints within the current project and keeps the first match (if any), otherwise set to None
ENDPOINT_ID = next(
    (
        endpoint.name
        for endpoint in aiplatform.Endpoint.list()
        if endpoint.display_name == ENDPOINT_DISPLAY_NAME
    ),
    None,
)
assert ENDPOINT_ID, (
    "`ENDPOINT_ID` is not set, please make sure that the `endpoint_display_name` is correct at "
    f"https://console.cloud.google.com/vertex-ai/online-prediction/endpoints?project={os.getenv('PROJECT_ID')}"
)

# Initiate the endpoint
endpoint = aiplatform.Endpoint(
    f"projects/{PROJECT_ID}/locations/{LOCATION}/endpoints/{ENDPOINT_ID}"
)

# Set instances
instances = []
user_prompts = [
    "I have a table called orders with columns order_id (INT), customer_id (INT), order_date (DATE), and total_amount (DECIMAL). I need to find the total revenue generated in the month of October 2023. How can I write a SQL query to achieve this?",
    "# Context: You have a list of numbers called `my_numbers`.\n# Question: How do I calculate the sum of all the numbers in `my_numbers` using a built-in function?\n# Example `my_numbers` list:\nmy_numbers = [1, 2, 3, 4, 5]",
]

for prompt in user_prompts:
    instances.append(
        {
            "inputs": prompt,
            "parameters": {"temperature": 0.7, "do_sample": True},
        }
    )

# Generate the prediction
output = endpoint.predict(instances=instances)
for prediction in output.predictions:
    print("------- Prediction -------")
    print(prediction)
    print("--------------------------\n")

### Via gcloud

You can also send the requests using the `gcloud` CLI via the `gcloud ai endpoints` command.

> Note that, before proceeding, you should either replace the values or set the following environment variables in advance from the Python variables set in the example, as follows:
>
> ```python
> import os
> os.environ["PROJECT_ID"] = PROJECT_ID
> os.environ["LOCATION"] = LOCATION
> os.environ["ENDPOINT_NAME"] = "google--gemma2-tgi-multi-lora-endpoint"
> ```

In [None]:
%%bash
# Get Endpoint ID
ENDPOINT_ID=$(gcloud ai endpoints list \
  --project=$PROJECT_ID \
  --region=$LOCATION \
  --filter="display_name=$ENDPOINT_NAME" \
  --format="value(name)" \
  | cut -d'/' -f6)

# Generate the prediction
echo '{
  "instances": [
   {
      "inputs":"I have a table called orders with columns order_id (INT), customer_id (INT), order_date (DATE), and total_amount (DECIMAL). I need to find the total revenue generated in the month of October 2023. How can I write a SQL query to achieve this?",
      "parameters":{
         "temperature":0.7,
         "do_sample":true
      }
   },
   {
      "inputs":"# Context: You have a list of numbers called `my_numbers`.\n# Question: How do I calculate the sum of all the numbers in `my_numbers` using a built-in function?\n# Example `my_numbers` list:\nmy_numbers = [1, 2, 3, 4, 5]",
      "parameters":{
         "temperature":0.7,
         "do_sample":true
      }
   }]
}' | gcloud ai endpoints predict $ENDPOINT_ID \
  --project=$PROJECT_ID \
  --region=$LOCATION \
  --json-request="-"

### Via cURL

Alternatively, you can also send the requests via `cURL`.

> Note that, before proceeding, you should either replace the values or set the following environment variables in advance from the Python variables set in the example, as follows:
>
> ```python
> import os
> os.environ["PROJECT_ID"] = PROJECT_ID
> os.environ["LOCATION"] = LOCATION
> os.environ["ENDPOINT_NAME"] = "google--gemma2-tgi-multi-lora-endpoint"
> ```

In [None]:
%%bash
# Get Endpoint ID
ENDPOINT_ID=$(gcloud ai endpoints list \
  --project=$PROJECT_ID \
  --region=$LOCATION \
  --filter="display_name=$ENDPOINT_NAME" \
  --format="value(name)" \
  | cut -d'/' -f6)

# Generate the prediction
curl -X POST \
    -H "Authorization: Bearer $(gcloud auth print-access-token)" \
    -H "Content-Type: application/json" \
    "https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/endpoints/${ENDPOINT_ID}:predict" \
    -d @- <<EOF
{
    "instances": [
   {
      "inputs":"I have a table called orders with columns order_id (INT), customer_id (INT), order_date (DATE), and total_amount (DECIMAL). I need to find the total revenue generated in the month of October 2023. How can I write a SQL query to achieve this?",
      "parameters":{
         "temperature":0.7,
         "do_sample":true
      }
   },
   {
      "inputs":"# Context: You have a list of numbers called `my_numbers`.\n# Question: How do I calculate the sum of all the numbers in `my_numbers` using a built-in function?\n# Example `my_numbers` list:\nmy_numbers = [1, 2, 3, 4, 5]",
      "parameters":{
         "temperature":0.7,
         "do_sample":true
      }
   }
]
}

## Cleaning up

In [None]:
delete_endpoint = False
delete_model = False
delete_bucket = False

if delete_endpoint:
    deployed_model.undeploy_all()
    deployed_model.delete()

if delete_endpoint:
    delete_model.delete()

if delete_bucket:
    ! gsutil rm -r {BUCKET_URI}