In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI Model Garden - Gemma Finetuning

<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fvertex-ai-samples%2Fmain%2Fnotebooks%2Fcommunity%2Fmodel_garden%2Fmodel_garden_gemma_finetuning_on_vertex.ipynb">
      <img alt="Google Cloud Colab Enterprise logo" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" width="32px"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_finetuning_on_vertex.ipynb">
      <img alt="GitHub logo" src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" width="32px"><br> View on GitHub
    </a>
  </td>
</tr></tbody></table>

## Overview

This notebook demonstrates finetuning and deploying Gemma models with [Vertex AI Custom Training Job](https://cloud.google.com/vertex-ai/docs/training/create-custom-job). All of the examples in this notebook use parameter efficient finetuning methods [PEFT (LoRA)](https://github.com/huggingface/peft) to reduce training and storage costs. LoRA (Low-Rank Adaptation) is one approach of Parameter Efficient FineTuning (PEFT), where pretrained model weights are frozen and rank decomposition matrices representing the change in model weights are trained during finetuning. Read more about LoRA in the following publication: [Hu, E.J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S., Wang, L. and Chen, W., 2021. Lora: Low-rank adaptation of large language models. *arXiv preprint arXiv:2106.09685*](https://arxiv.org/abs/2106.09685).


After tuning, we can deploy models on Vertex with GPU or TPU.


### Objective

- Finetune and deploy Gemma models with Vertex AI Custom Training Jobs.
- Send prediction requests to your finetuned Gemma model.


### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing), [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.

## Before you begin

In [None]:
# @title Install Python Packages for Finetuning

# @markdown 1. Install google-cloud-aiplatform package and restart the session if instructed.
! pip install --upgrade --quiet 'google-cloud-aiplatform>=1.66.0'

# @markdown 2. Install packages to validate dataset with template.
! pip install --upgrade --quiet accelerate==0.31.0
! pip install --upgrade --quiet transformers==4.43.1
! pip install --upgrade --quiet datasets==2.19.2

# Load local tensorboard.
%load_ext tensorboard

In [None]:
# @title Setup Google Cloud project

# @markdown 1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

# @markdown 2. For finetuning, **[click here](https://console.cloud.google.com/iam-admin/quotas?location=us-central1&metric=aiplatform.googleapis.com%2Frestricted_image_training_nvidia_a100_80gb_gpus)** to check if your project already has the required 8 Nvidia A100 80 GB GPUs in the us-central1 region. If yes, then run this notebook in the us-central1 region. If you do not have 8 Nvidia A100 80 GPUs or have more GPU requirements than this, then schedule your job with Nvidia H100 GPUs via Dynamic Workload Scheduler using [these instructions](https://cloud.google.com/vertex-ai/docs/training/schedule-jobs-dws). For Dynamic Workload Scheduler, check the [us-central1](https://console.cloud.google.com/iam-admin/quotas?location=us-central1&metric=aiplatform.googleapis.com%2Fcustom_model_training_preemptible_nvidia_h100_gpus) or [europe-west4](https://console.cloud.google.com/iam-admin/quotas?location=europe-west4&metric=aiplatform.googleapis.com%2Fcustom_model_training_preemptible_nvidia_h100_gpus) quota for Nvidia H100 GPUs. If you do not have enough GPUs, then you can follow [these instructions](https://cloud.google.com/docs/quotas/view-manage#viewing_your_quota_console) to request quota.

# @markdown 3. For serving, **[click here](https://console.cloud.google.com/iam-admin/quotas?location=us-central1&metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_l4_gpus)** to check if your project already has the required 1 L4 GPU in the us-central1 region.  If yes, then run this notebook in the us-central1 region. If you need more L4 GPUs for your project, then you can follow [these instructions](https://cloud.google.com/docs/quotas/view-manage#viewing_your_quota_console) to request more. Alternatively, if you want to run predictions with A100 80GB or H100 GPUs, we recommend using the regions listed below. **NOTE:** Make sure you have associated quota in selected regions. Click the links to see your current quota for each GPU type: [Nvidia A100 80GB](https://console.cloud.google.com/iam-admin/quotas?metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_a100_80gb_gpus), [Nvidia H100 80GB](https://console.cloud.google.com/iam-admin/quotas?metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_h100_gpus).

# @markdown > | Machine Type | Accelerator Type | Recommended Regions |
# @markdown | ----------- | ----------- | ----------- |
# @markdown | a2-ultragpu-1g | 1 NVIDIA_A100_80GB | us-central1, us-east4, europe-west4, asia-southeast1, us-east4 |
# @markdown | a3-highgpu-2g | 2 NVIDIA_H100_80GB | us-west1, asia-southeast1, europe-west4 |
# @markdown | a3-highgpu-4g | 4 NVIDIA_H100_80GB | us-west1, asia-southeast1, europe-west4 |
# @markdown | a3-highgpu-8g | 8 NVIDIA_H100_80GB | us-central1, us-east5, europe-west4, us-west1, asia-southeast1 |

# @markdown 4. **[Optional]** [Create a Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets) for storing experiment outputs. Set the BUCKET_URI for the experiment environment. The specified Cloud Storage bucket (`BUCKET_URI`) should be located in the same region as where the notebook was launched. Note that a multi-region bucket (eg. "us") is not considered a match for a single region covered by the multi-region range (eg. "us-central1"). If not set, a unique GCS bucket will be created instead.

BUCKET_URI = "gs://"  # @param {type:"string"}

# @markdown 5. **[Optional]** Set region. If not set, the region will be set automatically according to Colab Enterprise environment.

REGION = ""  # @param {type:"string"}

# Import the necessary packages
! rm -rf vertex-ai-samples && git clone https://github.com/GoogleCloudPlatform/vertex-ai-samples.git
! cd vertex-ai-samples && git reset --hard 0727e19520cf7957bceb701c248221bd3dbe4f1f

import datetime
import importlib
import os
import uuid
from typing import Tuple

from google.cloud import aiplatform
from google.cloud.aiplatform.compat.types import \
    custom_job as gca_custom_job_compat

common_util = importlib.import_module(
    "vertex-ai-samples.community-content.vertex_model_garden.model_oss.notebook_util.common_util"
)

models, endpoints = {}, {}

# Get the default cloud project id.
PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"]

# Get the default region for launching jobs.
if not REGION:
    REGION = os.environ["GOOGLE_CLOUD_REGION"]

# Enable the Vertex AI API and Compute Engine API, if not already.
print("Enabling Vertex AI API and Compute Engine API.")
! gcloud services enable aiplatform.googleapis.com compute.googleapis.com

# Cloud Storage bucket for storing the experiment artifacts.
# A unique GCS bucket will be created for the purpose of this notebook. If you
# prefer using your own GCS bucket, change the value yourself below.
now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])

if BUCKET_URI is None or BUCKET_URI.strip() == "" or BUCKET_URI == "gs://":
    BUCKET_URI = f"gs://{PROJECT_ID}-tmp-{now}-{str(uuid.uuid4())[:4]}"
    BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])
    ! gsutil mb -l {REGION} {BUCKET_URI}
else:
    assert BUCKET_URI.startswith("gs://"), "BUCKET_URI must start with `gs://`."
    shell_output = ! gsutil ls -Lb {BUCKET_NAME} | grep "Location constraint:" | sed "s/Location constraint://"
    bucket_region = shell_output[0].strip().lower()
    if bucket_region != REGION:
        raise ValueError(
            "Bucket region %s is different from notebook region %s"
            % (bucket_region, REGION)
        )
print(f"Using this GCS Bucket: {BUCKET_URI}")

STAGING_BUCKET = os.path.join(BUCKET_URI, "temporal")
MODEL_BUCKET = os.path.join(BUCKET_URI, "gemma")


# Initialize Vertex AI API.
print("Initializing Vertex AI API.")
aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=STAGING_BUCKET)

# Gets the default SERVICE_ACCOUNT.
shell_output = ! gcloud projects describe $PROJECT_ID
project_number = shell_output[-1].split(":")[1].strip().replace("'", "")
SERVICE_ACCOUNT = f"{project_number}-compute@developer.gserviceaccount.com"
print("Using this default Service Account:", SERVICE_ACCOUNT)


# Provision permissions to the SERVICE_ACCOUNT with the GCS bucket
! gsutil iam ch serviceAccount:{SERVICE_ACCOUNT}:roles/storage.admin $BUCKET_NAME

! gcloud config set project $PROJECT_ID
! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role="roles/storage.admin"
! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role="roles/aiplatform.user"

# @markdown ## Access Gemma Models
# @markdown For GPU based finetuning and serving, choose between accessing Gemma models on [Hugging Face](https://huggingface.co/)
# @markdown or Vertex AI as described below.

# @markdown If you already obtained access to Gemma models on [Hugging Face](https://huggingface.co/), you can load models from there.
# @markdown Alternatively, you can also load the original Gemma models for finetuning and serving from Vertex AI after accepting the agreement.

# @markdown For TPU based finetuning and serving with KerasNLP, choose the Kaggle option.

# @markdown **Select and fill one of the three following sections.**
LOAD_MODEL_FROM = "Hugging Face"  # @param ["Hugging Face", "Google Cloud", "Kaggle"] {isTemplate:true}

# @markdown ---

# @markdown ### Access Gemma models on Hugging Face for GPU based finetuning and serving
# @markdown You must provide a Hugging Face User Access Token (read) to access the Gemma models. You can follow the [Hugging Face documentation](https://huggingface.co/docs/hub/en/security-tokens) to create a **read** access token and put it in the `HF_TOKEN` field below.

HF_TOKEN = ""  # @param {type:"string", isTemplate:true}
if LOAD_MODEL_FROM == "Hugging Face":
    assert (
        HF_TOKEN
    ), "Provide a read HF_TOKEN to load models from Hugging Face, or select a different model source."

# @markdown *--- Or ---*
# @markdown ### Access Gemma models on Vertex AI for GPU based finetuning and serving
# @markdown Accept the model agreement to access the models:
# @markdown 1. Open the [Gemma model card](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/335) from [Vertex AI Model Garden](https://cloud.google.com/model-garden).
# @markdown 2. Review the agreement on the model card page.
# @markdown 3. After accepting the agreement of Gemma, a `https://` link containing Gemma pretrained and finetuned models will be shared.
# @markdown 4. Paste the link in the `VERTEX_MODEL_GARDEN_GEMMA` field below.
# @markdown **Note:** This will unzip and copy the Gemma model artifacts to your Cloud Storage bucket, which will take around 1 hour.

VERTEX_AI_MODEL_GARDEN_GEMMA = ""  # @param {type:"string", isTemplate:true}

if LOAD_MODEL_FROM == "Google Cloud":
    assert (
        VERTEX_AI_MODEL_GARDEN_GEMMA
    ), "Accept the agreement of Gemma in Vertex AI Model Garden and get the URL to Gemma model artifacts, or select a different model source."

    # Only use the last part in case a full command is pasted.
    signed_url = VERTEX_AI_MODEL_GARDEN_GEMMA.split(" ")[-1].strip('"')

    ! mkdir -p ./gemma
    ! curl -X GET "{signed_url}" | tar -xzvf - -C ./gemma/
    ! gsutil -m cp -R ./gemma/* {MODEL_BUCKET}

    model_path_prefix = MODEL_BUCKET
    HF_TOKEN = ""
else:
    model_path_prefix = "google/"

conversion_job = None

# @markdown *--- Or ---*
# @markdown ### Access Gemma models from Kaggle for TPU based finetuning and serving
# @markdown Kaggle credentials are required for KerasNLP training and Hex-LLM deployment with TPUs.
# @markdown Generate the Kaggle username and key by following [these instructions](https://github.com/Kaggle/kaggle-api?tab=readme-ov-file#api-credentials).
# @markdown You will need to review and accept the model license.
KAGGLE_USERNAME = ""  # @param {type:"string", isTemplate:true}
KAGGLE_KEY = ""  # @param {type:"string", isTemplate:true}
if LOAD_MODEL_FROM == "Kaggle":
    assert (
        KAGGLE_USERNAME and KAGGLE_KEY
    ), "Provide Kaggle credentials to load models from Kaggle, or select a different model source."
# @markdown ---

## Finetune with HuggingFace PEFT and Deploy with vLLM on GPUs

In [None]:
# @title Set dataset

# @markdown Use the Vertex AI SDK to create and run the custom training jobs.

# @markdown This notebook uses [timdettmers/openassistant-guanaco](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) dataset as an example.
# @markdown You can set `dataset_name` to any existing [Hugging Face dataset](https://huggingface.co/datasets) name, and set `instruct_column_in_dataset` to the name of the dataset column containing training data. The [timdettmers/openassistant-guanaco](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) has only one column `text`, and therefore we set `instruct_column_in_dataset` to `text` in this notebook.

# @markdown ### (Optional) Prepare a custom JSONL dataset for finetuning

# @markdown You can prepare a JSONL file where each line is a valid JSON string as your custom training dataset. For example, here is one line from the [timdettmers/openassistant-guanaco](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) dataset:
# @markdown ```
# @markdown {"text": "### Human: Hola### Assistant: \u00a1Hola! \u00bfEn qu\u00e9 puedo ayudarte hoy?"}
# @markdown ```

# @markdown The JSON object has a key `text`, which should match `instruct_column_in_dataset`; The value should be one training data point, i.e. a string. After you prepared your JSONL file, you can either upload it to [Hugging Face datasets](https://huggingface.co/datasets) or [Google Cloud Storage](https://cloud.google.com/storage).

# @markdown - To upload a JSONL dataset to [Hugging Face datasets](https://huggingface.co/datasets), follow the instructions on [Uploading Datasets](https://huggingface.co/docs/hub/en/datasets-adding). Then, set `dataset_name` to the name of your newly created dataset on Hugging Face.

# @markdown - To upload a JSONL dataset to [Google Cloud Storage](https://cloud.google.com/storage), follow the instructions on [Upload objects from a filesystem](https://cloud.google.com/storage/docs/uploading-objects). Then, set `dataset_name` to the `gs://` URI to your JSONL file. For example: `gs://cloud-samples-data/vertex-ai/model-evaluation/peft_train_sample.jsonl`.

# @markdown Optionally update the `instruct_column_in_dataset` field below if your JSON objects use a key other than the default `text`.

# @markdown ### (Optional) Format your data with custom JSON template

# @markdown Sometimes, your dataset might have multiple text columns and you want to construct the training data with a template. You can prepare a JSON template in the following format:

# @markdown ```
# @markdown {
# @markdown   "description": "Template that accepts text-bison format.",
# @markdown   "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/models/tune-text-models-supervised#dataset-format",
# @markdown   "prompt_input": "\n\n<|start_header_id|>user<|end_header_id|>\n\n{input_text}<|eot_id|>\n\n<|start_header_id|>assistant<|end_header_id|>\n\n{output_text}<|eot_id|>",
# @markdown   "instruction_separator": "<|start_header_id|>user<|end_header_id|>\n\n",
# @markdown   "response_separator": "<|start_header_id|>assistant<|end_header_id|>\n\n"
# @markdown }
# @markdown ```

# @markdown As an example, the template above can be used to format the following training data (this line comes from `gs://cloud-samples-data/vertex-ai/model-evaluation/peft_train_sample.jsonl`):

# @markdown ```
# @markdown {"input_text":"TRANSCRIPT: \nREASON FOR EVALUATION:,\n\n LABEL:","output_text":"Chiropractic"}
# @markdown ```

# @markdown This example template simply concatenates `input_text` with `output_text` with some special tokens in between.
# @markdown
# @markdown To try such custom dataset, you can make the following changes:
# @markdown 1. Set `template` to `llama3-text-bison`
# @markdown 1. Set `train_dataset_name` to `gs://cloud-samples-data/vertex-ai/model-evaluation/peft_train_sample.jsonl`
# @markdown 1. Set `train_split_name` to `train`
# @markdown 1. Set `eval_dataset_name` to `gs://cloud-samples-data/vertex-ai/model-evaluation/peft_eval_sample.jsonl`
# @markdown 1. Set `eval_split_name` to `train` (**NOT** `test`)
# @markdown 1. Set `instruct_column_in_dataset` as `input_text`.

# Template name or gs:// URI to a custom template.
template = "openassistant-guanaco"  # @param {type:"string"}

# Hugging Face dataset name or gs:// URI to a custom JSONL dataset.
train_dataset_name = "timdettmers/openassistant-guanaco"  # @param {type:"string"}
train_split_name = "train"  # @param {type:"string"}
eval_dataset_name = "timdettmers/openassistant-guanaco"  # @param {type:"string"}
eval_split_name = "test"  # @param {type:"string"}

# Name of the dataset column containing training text input.
instruct_column_in_dataset = "text"  # @param {type:"string"}

In [None]:
# @title Set model

# @markdown Select a model variant of Gemma 2.
base_model_id = "gemma-2b"  # @param["gemma-2b", "gemma-2b-it", "gemma-7b", "gemma-7b-it", "gemma-1.1-2b-it", "gemma-1.1-7b-it"] {isTemplate:true}
pretrained_model_id = os.path.join(model_path_prefix, base_model_id)

In [None]:
# @title Validate Dataset with Template

# @markdown This section validates the train and eval datasets with the template before starting the fine tuning process.

import transformers

dataset_validation_util = importlib.import_module(
    "vertex-ai-samples.community-content.vertex_model_garden.model_oss.notebook_util.dataset_validation_util"
)

if dataset_validation_util.is_gcs_path(pretrained_model_id):
    # Download tokenizer.
    ! mkdir tokenizer
    ! gsutil cp {pretrained_model_id}/tokenizer.json ./tokenizer
    ! gsutil cp {pretrained_model_id}/config.json ./tokenizer
    tokenizer_path = "./tokenizer"
    access_token = ""
else:
    tokenizer_path = pretrained_model_id
    access_token = HF_TOKEN

tokenizer = transformers.AutoTokenizer.from_pretrained(
    tokenizer_path,
    trust_remote_code=False,
    use_fast=True,
    token=access_token,
)

# Validate the train dataset.
dataset_validation_util.validate_dataset_with_template(
    dataset_name=train_dataset_name,
    split=train_split_name,
    input_column=instruct_column_in_dataset,
    template=template,
    use_multiprocessing=False,
    tokenizer=tokenizer,
)

# Validate the eval dataset.
dataset_validation_util.validate_dataset_with_template(
    dataset_name=eval_dataset_name,
    split=eval_split_name,
    input_column=instruct_column_in_dataset,
    template=template,
    use_multiprocessing=False,
    tokenizer=tokenizer,
)

In [None]:
# @title Finetune
# @markdown This section demonstrates how to finetune the Gemma model and merge the finetuned LoRA adapter with the base model on Vertex AI. It uses the Vertex AI SDK to create and run the custom training jobs.

# @markdown The training job takes approximately between 10 to 20 mins to set-up. Once done, the training job is expected to take around 20 mins with the default configuration. To find the training time, throughput, and memory usage of your training job, you can go to the training logs and check the log line of the last training epoch.

# @markdown **Note**:
# @markdown 1. We recommend setting `finetuning_precision_mode` to `4bit` because it enables using fewer hardware resources for finetuning.
# @markdown 1. If `max_steps > 0`, it takes precedence over `epochs`. One can set a small `max_steps` value to quickly check the pipeline.

if LOAD_MODEL_FROM == "Kaggle":
    print(
        "Skipped: Cannot load model from Kaggle, which is only supported in the KerasNLP section."
    )
else:
    # @markdown Accelerator type to use for training.
    accelerator_type = "NVIDIA_A100_80GB"  # @param ["NVIDIA_A100_80GB", "NVIDIA_H100_80GB"]

    # The pre-built training docker image.
    if accelerator_type == "NVIDIA_A100_80GB":
        repo = "us-docker.pkg.dev/vertex-ai-restricted"
        is_restricted_image = True
        is_dynamic_workload_scheduler = False
        dws_kwargs = {}
    else:
        repo = "us-docker.pkg.dev/vertex-ai"
        is_restricted_image = False
        is_dynamic_workload_scheduler = True
        dws_kwargs = {
            "max_wait_duration": 1800,  # 30 minutes
            "scheduling_strategy": gca_custom_job_compat.Scheduling.Strategy.FLEX_START,
        }

    TRAIN_DOCKER_URI = (
        f"{repo}/vertex-vision-model-garden-dockers/pytorch-peft-train:stable_20240909"
    )

    # Worker pool spec.
    if accelerator_type == "NVIDIA_A100_80GB":
        per_node_accelerator_count = 8
        machine_type = "a2-ultragpu-8g"
    elif accelerator_type == "NVIDIA_H100_80GB":
        per_node_accelerator_count = 8
        machine_type = "a3-highgpu-8g"
    else:
        raise ValueError(
            f"Recommended machine settings not found for: {accelerator_type}. To use another accelerator type, edit this code block to pass in an appropriate `machine_type`, `accelerator_type`, and `per_node_accelerator_count` to the deploy_model_vllm function by clicking `Show Code` and then modifying the code."
        )

    # @markdown Batch size for finetuning.
    per_device_train_batch_size = 1  # @param{type:"integer"}
    # @markdown Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
    gradient_accumulation_steps = 4  # @param{type:"integer"}
    # @markdown Maximum sequence length.
    max_seq_length = 4096  # @param{type:"integer"}
    # @markdown Setting a positive `max_steps` here will override `num_epochs`.
    max_steps = -1  # @param{type:"integer"}
    num_epochs = 1.0  # @param{type:"number"}
    # @markdown Precision mode for finetuning.
    finetuning_precision_mode = "4bit"  # @param ["4bit", "8bit", "float16"]
    # @markdown Learning rate.
    learning_rate = 5e-5  # @param{type:"number"}
    # @markdown The scheduler type to use.
    lr_scheduler_type = "cosine"  # @param{type:"string"}
    # @markdown LoRA parameters.
    lora_rank = 16  # @param{type:"integer"}
    lora_alpha = 32  # @param{type:"integer"}
    lora_dropout = 0.05  # @param{type:"number"}
    # Activates gradient checkpointing for the current model (may be referred to as activation checkpointing or checkpoint activations in other frameworks).
    enable_gradient_checkpointing = True
    # Attention implementation to use in the model.
    attn_implementation = "eager"
    # The optimizer for which to schedule the learning rate.
    optimizer = "paged_adamw_32bit"
    # Define the proportion of training to be dedicated to a linear warmup where learning rate gradually increases.
    warmup_ratio = "0.01"
    # The list or string of integrations to report the results and logs to.
    report_to = "tensorboard"
    # Number of updates steps before two checkpoint saves.
    save_steps = 10
    # Number of update steps between two logs.
    logging_steps = save_steps
    # Train precision of the model.
    train_precision = "bfloat16"

    replica_count = 1

    common_util.check_quota(
        project_id=PROJECT_ID,
        region=REGION,
        accelerator_type=accelerator_type,
        accelerator_count=per_node_accelerator_count * replica_count,
        is_for_training=True,
        is_restricted_image=is_restricted_image,
        is_dynamic_workload_scheduler=is_dynamic_workload_scheduler,
    )

    job_name = common_util.get_job_name_with_datetime("gemma-lora-train")

    base_output_dir = os.path.join(STAGING_BUCKET, job_name)
    # Create a GCS folder to store the LORA adapter.
    lora_output_dir = os.path.join(base_output_dir, "adapter")
    # Create a GCS folder to store the merged model with the base model and the
    # finetuned LORA adapter.
    merged_model_output_dir = os.path.join(base_output_dir, "merged-model")

    eval_args = [
        f"--eval_dataset_path={eval_dataset_name}",
        f"--eval_column={instruct_column_in_dataset}",
        f"--eval_template={template}",
        f"--eval_split={eval_split_name}",
        f"--eval_steps={save_steps}",
        "--eval_tasks=builtin_eval",
        "--eval_metric_name=loss",
    ]

    train_job_args = [
        "--config_file=vertex_vision_model_garden_peft/deepspeed_zero2_8gpu.yaml",
        "--task=instruct-lora",
        "--completion_only=True",
        f"--pretrained_model_id={pretrained_model_id}",
        f"--dataset_name={train_dataset_name}",
        f"--train_split_name={train_split_name}",
        f"--instruct_column_in_dataset={instruct_column_in_dataset}",
        f"--output_dir={lora_output_dir}",
        f"--merge_base_and_lora_output_dir={merged_model_output_dir}",
        f"--per_device_train_batch_size={per_device_train_batch_size}",
        f"--gradient_accumulation_steps={gradient_accumulation_steps}",
        f"--lora_rank={lora_rank}",
        f"--lora_alpha={lora_alpha}",
        f"--lora_dropout={lora_dropout}",
        f"--max_steps={max_steps}",
        f"--max_seq_length={max_seq_length}",
        f"--learning_rate={learning_rate}",
        f"--lr_scheduler_type={lr_scheduler_type}",
        f"--precision_mode={finetuning_precision_mode}",
        f"--train_precision={train_precision}",
        f"--enable_gradient_checkpointing={enable_gradient_checkpointing}",
        f"--num_epochs={num_epochs}",
        f"--attn_implementation={attn_implementation}",
        f"--optimizer={optimizer}",
        f"--warmup_ratio={warmup_ratio}",
        f"--report_to={report_to}",
        f"--logging_output_dir={base_output_dir}",
        f"--save_steps={save_steps}",
        f"--logging_steps={logging_steps}",
        f"--template={template}",
        f"--huggingface_access_token={HF_TOKEN}",
    ] + eval_args

    # Pass training arguments and launch job.
    train_job = aiplatform.CustomContainerTrainingJob(
        display_name=job_name,
        container_uri=TRAIN_DOCKER_URI,
    )

    print("Running training job with args:")
    print(" \\\n".join(train_job_args))
    train_job.run(
        args=train_job_args,
        replica_count=replica_count,
        machine_type=machine_type,
        accelerator_type=accelerator_type,
        accelerator_count=per_node_accelerator_count,
        boot_disk_size_gb=500,
        service_account=SERVICE_ACCOUNT,
        base_output_dir=base_output_dir,
        sync=False,  # Non-blocking call to run.
        **dws_kwargs,
    )

    # Wait until resource has been created.
    train_job.wait_for_resource_creation()

    print("LoRA adapter will be saved in:", lora_output_dir)
    print("Trained and merged models will be saved in:", merged_model_output_dir)

# @markdown Click "Show Code" to see more details.

In [None]:
# @title Run TensorBoard
# @markdown This section shows how to launch TensorBoard in a [Cloud Shell](https://cloud.google.com/shell/docs).
# @markdown 1. Click the Cloud Shell icon(![terminal](https://github.com/google/material-design-icons/blob/master/png/action/terminal/materialicons/24dp/1x/baseline_terminal_black_24dp.png?raw=true)) on the top right to open the Cloud Shell.
# @markdown 2. Copy the `tensorboard` command shown below by running this cell.
# @markdown 3. Paste and run the command in the Cloud Shell to launch TensorBoard.
# @markdown 4. Once the command runs (You may have to click `Authorize` if prompted), click the link starting with `http://localhost`.

# @markdown Note: You may need to wait around 10 minutes after the job starts in order for the TensorBoard logs to be written to the GCS bucket.
print(f"Command to copy: tensorboard --logdir {base_output_dir}/logs")

In [None]:
# @title Deploy

# @markdown This section uploads the model to Model Registry and deploys it on the Endpoint. It takes 15 minutes to 1 hour to finish.

# The pre-built serving docker image for vLLM.
VLLM_DOCKER_URI = "us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20240815_1634_RC00"

# @markdown Set use_dedicated_endpoint to False if you don't want to use [dedicated endpoint](https://cloud.google.com/vertex-ai/docs/general/deployment#create-dedicated-endpoint).
use_dedicated_endpoint = True  # @param {type:"boolean"}


def deploy_model_vllm(
    model_name: str,
    model_id: str,
    service_account: str,
    base_model_id: str = None,
    machine_type: str = "g2-standard-8",
    accelerator_type: str = "NVIDIA_L4",
    accelerator_count: int = 1,
    gpu_memory_utilization: float = 0.9,
    max_model_len: int = 4096,
    dtype: str = "auto",
    enable_trust_remote_code: bool = False,
    enforce_eager: bool = False,
    enable_lora: bool = False,
    max_loras: int = 1,
    max_cpu_loras: int = 8,
    use_dedicated_endpoint: bool = False,
    max_num_seqs: int = 256,
    model_type: str = None,
) -> Tuple[aiplatform.Model, aiplatform.Endpoint]:
    """Deploys trained models with vLLM into Vertex AI."""
    endpoint = aiplatform.Endpoint.create(
        display_name=f"{model_name}-endpoint",
        dedicated_endpoint_enabled=use_dedicated_endpoint,
    )

    if not base_model_id:
        base_model_id = model_id

    # See https://docs.vllm.ai/en/latest/models/engine_args.html for a list of possible arguments with descriptions.
    vllm_args = [
        "python",
        "-m",
        "vllm.entrypoints.api_server",
        "--host=0.0.0.0",
        "--port=8080",
        f"--model={model_id}",
        f"--tensor-parallel-size={accelerator_count}",
        "--swap-space=16",
        f"--gpu-memory-utilization={gpu_memory_utilization}",
        f"--max-model-len={max_model_len}",
        f"--dtype={dtype}",
        f"--max-loras={max_loras}",
        f"--max-cpu-loras={max_cpu_loras}",
        f"--max-num-seqs={max_num_seqs}",
        "--disable-log-stats",
    ]

    if enable_trust_remote_code:
        vllm_args.append("--trust-remote-code")

    if enforce_eager:
        vllm_args.append("--enforce-eager")

    if enable_lora:
        vllm_args.append("--enable-lora")

    if model_type:
        vllm_args.append(f"--model-type={model_type}")

    env_vars = {
        "MODEL_ID": base_model_id,
        "DEPLOY_SOURCE": "notebook",
    }

    # HF_TOKEN is not a compulsory field and may not be defined.
    try:
        if HF_TOKEN:
            env_vars["HF_TOKEN"] = HF_TOKEN
    except NameError:
        pass

    model = aiplatform.Model.upload(
        display_name=model_name,
        serving_container_image_uri=VLLM_DOCKER_URI,
        serving_container_args=vllm_args,
        serving_container_ports=[8080],
        serving_container_predict_route="/generate",
        serving_container_health_route="/ping",
        serving_container_environment_variables=env_vars,
        serving_container_shared_memory_size_mb=(16 * 1024),  # 16 GB
        serving_container_deployment_timeout=7200,
    )
    print(
        f"Deploying {model_name} on {machine_type} with {accelerator_count} {accelerator_type} GPU(s)."
    )
    model.deploy(
        endpoint=endpoint,
        machine_type=machine_type,
        accelerator_type=accelerator_type,
        accelerator_count=accelerator_count,
        deploy_request_timeout=1800,
        service_account=service_account,
    )
    print("endpoint_name:", endpoint.name)

    return model, endpoint


if LOAD_MODEL_FROM == "Kaggle":
    print(
        "Skipped: Cannot load model from Kaggle, which is only supported in the KerasNLP section."
    )
else:
    if train_job.end_time is None:
        print("Waiting for the training job to finish...")
        train_job.wait()
        print("The training job has finished.")

    print("Deploying models in: ", merged_model_output_dir)

    # Find Vertex AI prediction supported accelerators and regions in [here](https://cloud.google.com/vertex-ai/docs/predictions/configure-compute).
    # Sets 1 L4 (24G) to deploy Gemma models.
    machine_type = "g2-standard-12"
    accelerator_type = "NVIDIA_L4"
    accelerator_count = 1

    common_util.check_quota(
        project_id=PROJECT_ID,
        region=REGION,
        accelerator_type=accelerator_type,
        accelerator_count=accelerator_count,
        is_for_training=False,
    )

    # Note that a larger max_model_len will require more GPU memory.
    max_model_len = 2048

    models["vllm_gpu"], endpoints["vllm_gpu"] = deploy_model_vllm(
        model_name=common_util.get_job_name_with_datetime(prefix="gemma-vllm-serve"),
        base_model_id=f"google/{base_model_id}",
        model_id=merged_model_output_dir,
        service_account=SERVICE_ACCOUNT,
        machine_type=machine_type,
        accelerator_type=accelerator_type,
        accelerator_count=accelerator_count,
        max_model_len=max_model_len,
        use_dedicated_endpoint=use_dedicated_endpoint,
    )
    print("endpoint_name:", endpoints["vllm_gpu"].name)

In [None]:
# @title Predict

# @markdown Once deployment succeeds, you can send requests to the endpoint with text prompts.

# @markdown Here we use an example from the [timdettmers/openassistant-guanaco](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) to show the finetuning outcome:

# @markdown ```
# @markdown ### Human: How would the Future of AI in 10 Years look?### Assistant: Predicting the future is always a challenging task, but here are some possible ways that AI could evolve over the next 10 years: Continued advancements in deep learning: Deep learning has been one of the main drivers of recent AI breakthroughs, and we can expect continued advancements in this area. This may include improvements to existing algorithms, as well as the development of new architectures that are better suited to specific types of data and tasks. Increased use of AI in healthcare: AI has the potential to revolutionize healthcare, by improving the accuracy of diagnoses, developing new treatments, and personalizing patient care. We can expect to see continued investment in this area, with more healthcare providers and researchers using AI to improve patient outcomes. Greater automation in the workplace: Automation is already transforming many industries, and AI is likely to play an increasingly important role in this process. We can expect to see more jobs being automated, as well as the development of new types of jobs that require a combination of human and machine skills. More natural and intuitive interactions with technology: As AI becomes more advanced, we can expect to see more natural and intuitive ways of interacting with technology. This may include voice and gesture recognition, as well as more sophisticated chatbots and virtual assistants. Increased focus on ethical considerations: As AI becomes more powerful, there will be a growing need to consider its ethical implications. This may include issues such as bias in AI algorithms, the impact of automation on employment, and the use of AI in surveillance and policing. Overall, the future of AI in 10 years is likely to be shaped by a combination of technological advancements, societal changes, and ethical considerations. While there are many exciting possibilities for AI in the future, it will be important to carefully consider its potential impact on society and to work towards ensuring that its benefits are shared fairly and equitably.
# @markdown ```

# @markdown Additionally, you can moderate the generated text with Vertex AI. See [Moderate text documentation](https://cloud.google.com/natural-language/docs/moderating-text) for more details.

# Loads an existing endpoint instance using the endpoint name:
# - Using `endpoint_name = endpoint.name` allows us to get the
#   endpoint name of the endpoint `endpoint` created in the cell
#   above.
# - Alternatively, you can set `endpoint_name = "1234567890123456789"` to load
#   an existing endpoint with the ID 1234567890123456789.
# You may uncomment the code below to load an existing endpoint.

# endpoint_name = ""  # @param {type:"string"}
# aip_endpoint_name = (
#     f"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint_name}"
# )
# endpoint = aiplatform.Endpoint(aip_endpoint_name)

if LOAD_MODEL_FROM == "Kaggle":
    print(
        "Skipped: Cannot load model from Kaggle, which is only supported in the KerasNLP section."
    )
else:
    prompt = "How would the Future of AI in 10 Years look?"  # @param {type: "string"}
    max_tokens = 128  # @param {type:"integer"}
    temperature = 1.0  # @param {type:"number"}
    top_p = 0.9  # @param {type:"number"}
    top_k = 1  # @param {type:"integer"}

    # Overrides max_tokens and top_k parameters during inferences.
    # If you encounter the issue like `ServiceUnavailable: 503 Took too long to respond when processing`,
    # you can reduce the max length, such as set max_tokens as 20.
    instances = [
        {
            "prompt": f"### Human: {prompt}### Assistant: ",
            "max_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "top_k": top_k,
        },
    ]
    response = endpoints["vllm_gpu"].predict(
        instances=instances, use_dedicated_endpoint=use_dedicated_endpoint
    )

    for prediction in response.predictions:
        print(prediction)

## Finetune with KerasNLP PEFT and Deploy with HexLLM on TPUs

In [None]:
# @title Set Dataset

# @markdown For the training dataset, you can
# @markdown  - either use a pre-built TensorFlow
# @markdown    dataset such as this [imdb_reviews](https://www.tensorflow.org/datasets/catalog/imdb_reviews) dataset
# @markdown  - or your own dataset in [JSONL format](https://jsonlines.org/examples/) such as this [databricks-dolly-15K](https://huggingface.co/datasets/databricks/databricks-dolly-15k/blob/main/databricks-dolly-15k.jsonl) JSONL file. [See license information here.](https://huggingface.co/datasets/databricks/databricks-dolly-15k#licenseattribution)

# @markdown Whether you use a TensorFlow dataset or a JSONL dataset, each data-item
# @markdown will be in the form of a dictionary containing multiple key-value pairs. For example,
# @markdown [the `imdb_reviews` data-item dictionary](https://www.tensorflow.org/datasets/catalog/imdb_reviews#imdb_reviewsplain_text_default_config)
# @markdown contains keys `text` and `label` and [the `databricks-dolly-15k` data-item dictionary](https://huggingface.co/datasets/databricks/databricks-dolly-15k) contains
# @markdown keys `instruction`, `context`, `response`, and  `category`.
# @markdown However, the Gemma model only takes a single string as a training example. To arbitrarily select and combine
# @markdown multiple key-values into a single training string, you can set a `template` variable in the next section.
# @markdown For example, for the [`databricks-dolly-15k`](https://huggingface.co/datasets/databricks/databricks-dolly-15k), you can set the
# @markdown `template` as `Instruction: {instruction} Response: {response}`
# @markdown which will then automatically fill each `instruction` and `response` key-values into this string template
# @markdown and generate a single training string which will look like:

# @markdown ```
# @markdown  Instruction: Why can camels survive for long without water? Response: Camels use the fat in their humps to keep them filled with energy and hydration for long periods of time.
# @markdown  ```

# @markdown And for the [imdb_reviews](https://www.tensorflow.org/datasets/catalog/imdb_reviews#imdb_reviewsplain_text_default_config) you can set the `template` as `{text}`
# @markdown which will then select each `text` key-value as a single training string which will look like:

# @markdown ```
# @markdown I have been known to fall asleep during films, but this is usually due to a combination of things including, really tired, being warm and comfortable on the sette and having just eaten a lot. However on this occasion I fell asleep because the film was rubbish. The plot development was constant. Constantly slow and boring. Things seemed to happen, but with no explanation of what was causing them or why. I admit, I may have missed part of the film, but i watched the majority of it and everything just seemed to happen of its own accord without any real concern for anything else. I cant recommend this film at all.
# @markdown ```

# @markdown Set a template suitable for the selected dataset whether TensorFlow Dataset or JSONL format. The following value is
# @markdown set for the `databricks-dolly-15k` dataset.
template = "Instruction: {instruction} Response: {response}"  # @param {type:"string"}

# @markdown ### Fill only one of the sections below:
# @markdown ---

# @markdown For example `imdb_reviews`.

# The TensorFlow dataset name.
tfds_dataset_name = ""  # @param {type:"string"}
# The dataset split to use.
tfds_dataset_split = "train"  # @param {type:"string"}
# @markdown ---

# @markdown or

# @markdown ---

# @markdown `jsonl_dataset_file` can be your Cloud Storage path
# @markdown such as `<BUCKET_URI>/<path-to-your-jsonl-file>` or a link to an online JSONL file
# @markdown in which case the code here will download and then copy the file to `BUCKET_URI`.
# @markdown If you want to upload a JSONL dataset to [Google Cloud Storage](https://cloud.google.com/storage) by yourself, then follow the instructions on [Upload objects from a filesystem](https://cloud.google.com/storage/docs/uploading-objects). Then, set `jsonl_dataset_file` to the `gs://` URI to your JSONL file such as `gs://cloud-samples-data/vertex-ai/model-evaluation/peft_train_sample.jsonl`.

jsonl_dataset_file = "https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl"  # @param {type:"string"}
# @markdown ---

# @markdown Click `Show code` to see more details.

assert (
    tfds_dataset_name or jsonl_dataset_file
), "Fill in either `tfds_dataset_name` or `jsonl_dataset_file`."
assert not (
    tfds_dataset_name and jsonl_dataset_file
), "Fill in only one of `tfds_dataset_name` or `jsonl_dataset_file`."

# Download the JSONL dataset.
jsonl_dataset_uri_gcsfuse = ""
if jsonl_dataset_file:
    if jsonl_dataset_file.startswith("gs://"):
        # Using cloud storage location.
        jsonl_dataset_uri = jsonl_dataset_file
    else:
        # Download the file and copy to cloud storage.
        !wget -O dataset.jsonl $jsonl_dataset_file
        jsonl_dataset_uri = f"{BUCKET_URI}/dataset.jsonl"
        print("Copying dataset.jsonl to ", jsonl_dataset_uri)
        !gsutil cp dataset.jsonl $jsonl_dataset_uri
        print("JSONL url copied to: ", jsonl_dataset_uri)
    jsonl_dataset_uri_gcsfuse = jsonl_dataset_uri.replace("gs://", "/gcs/")

In [None]:
# @title Finetune
# @markdown Use the Vertex AI SDK to create and run the custom training jobs.
# @markdown The training job uses TPU V3 8 cores and takes around 10 mins to
# @markdown finish once it starts running.
# @markdown Click `View backing custom job` link in the output of this cell to follow training job progress.
# @markdown **Note that to make the training run faster, only a subset of dataset (2000 examples) is used here during fine tuning and the fine tuning runs for just one epoch. To improve the performance of the model, use more training samples, fine tune for more epochs and experiment with increasing the LoRA rank.**
# @markdown Click `Show code` to see more details.

# The pre-built training docker images for KerasNLP training.
KERAS_TRAIN_DOCKER_URI = "us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/jax-keras-train-tpu:20240422_0939_RC00"
KERAS_MODEL_CONVERSION_DOCKER_URI = "us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/jax-keras-model-conversion:20240422_0949_RC00"

if LOAD_MODEL_FROM != "Kaggle":
    print("Skipped: Expect to load model from Kaggle, got", LOAD_MODEL_FROM)
else:
    # The Gemma base model.
    KAGGLE_MODEL_ID = "gemma_2b_en"  # @param["gemma_2b_en", "gemma_instruct_2b_en", "gemma_7b_en", "gemma_instruct_7b_en", "gemma_1.1_instruct_2b_en", "gemma_1.1_instruct_7b_en"] {isTemplate:true}
    # @markdown Set `num_train_subset_samples` as `-1` to use all the training samples.
    num_train_subset_samples = 2000  # @param {type:"integer"}
    # Number of train epochs.
    num_epochs = 1  # @param{type:"integer"}
    # Learning rate.
    learning_rate = 5e-5  # @param{type:"number"}
    # Weight decay.
    weight_decay = 0.01  # @param{type:"number"}
    # Input sequence length. It determines the memory required by the model.
    input_sequence_length = 512  # @param{type:"integer"}
    # LoRA rank.
    lora_rank = 4  # @param{type:"integer"}
    # Batch size for training.
    train_batch_size = 2  # @param{type:"integer"}
    # The KerasNLP checkpoint filename.
    # Note: Do not add folder name here.
    checkpoint_filename = "fine_tuned.weights.h5"  # @param{type:"string"}

    # Worker pool spec.
    machine_type = "cloud-tpu"
    # NOTE: The models have been test only with 8 cores.
    accelerator_type = "TPU_V3"
    # Number of TPU cores.
    accelerator_count = 8
    # Set model parallelism related parameters for 8 cores.
    model_parallel_batch_dim = 1
    model_parallel_model_dim = 8

    replica_count = 1

    common_util.check_quota(
        project_id=PROJECT_ID,
        region=REGION,
        accelerator_type=accelerator_type,
        accelerator_count=accelerator_count,
        is_for_training=True,
    )

    # Setup training job.
    job_name = common_util.get_job_name_with_datetime("gemma-keras-lora-train")

    # Pass training arguments and launch job.
    train_job = aiplatform.CustomContainerTrainingJob(
        display_name=job_name,
        container_uri=KERAS_TRAIN_DOCKER_URI,
    )

    # Create a GCS folder to save the finetuned model.
    output_folder = os.path.join(BUCKET_URI, job_name)
    output_folder_gcsfuse = output_folder.replace("gs://", "/gcs/")

    train_job.run(
        args=[
            f"--model_type={KAGGLE_MODEL_ID}",
            f"--num_epochs={num_epochs}",
            f"--learning_rate={learning_rate}",
            f"--weight_decay={weight_decay}",
            f"--input_sequence_length={input_sequence_length}",
            f"--lora_rank={lora_rank}",
            f"--model_parallel_batch_dim={model_parallel_batch_dim}",
            f"--model_parallel_model_dim={model_parallel_model_dim}",
            f"--tfds_dataset_name={tfds_dataset_name}",
            f"--tfds_dataset_split={tfds_dataset_split}",
            f"--jsonl_dataset_file={jsonl_dataset_uri_gcsfuse}",
            f"--template={template}",
            f"--train_batch_size={train_batch_size}",
            f"--num_train_subset_samples={num_train_subset_samples}",
            f"--output_folder={output_folder_gcsfuse}",
            f"--checkpoint_filename={checkpoint_filename}",
        ],
        environment_variables={
            "KAGGLE_USERNAME": KAGGLE_USERNAME,
            "KAGGLE_KEY": KAGGLE_KEY,
        },
        replica_count=replica_count,
        machine_type=machine_type,
        accelerator_type=accelerator_type,
        accelerator_count=accelerator_count,
        service_account=SERVICE_ACCOUNT,
    )

    print("Trained model is saved in: ", output_folder)

In [None]:
# @title Convert model
# @markdown Convert the KerasNLP model checkpoint to Hex-LLM format.
# @markdown  Use the Vertex AI SDK to create and run the custom job.
# @markdown Click `View backing custom job` link in the output of this cell to follow job progress.
# @markdown  The jobs takes around 6 mins to finish.
# @markdown  Click `Show code` to see more details.

if LOAD_MODEL_FROM != "Kaggle":
    print("Skipped: Expect to load model from Kaggle, got", LOAD_MODEL_FROM)
else:
    model_type_to_size = {
        "gemma_2b_en": "2b",
        "gemma_instruct_2b_en": "2b",
        "gemma_7b_en": "7b",
        "gemma_instruct_7b_en": "7b",
        "gemma_1.1_instruct_2b_en": "2b",
        "gemma_1.1_instruct_7b_en": "7b",
    }
    size = model_type_to_size[KAGGLE_MODEL_ID]
    # NOTE: The Hex-LLM serving code looks for model type tag in
    # the checkpoint filename.
    if "_2b_" in KAGGLE_MODEL_ID:
        model_type_to_file_suffix = "_gemma-2b.ckpt"
    else:
        assert (
            "_7b_" in KAGGLE_MODEL_ID
        ), "KAGGLE_MODEL_ID should contain `_2b_` or `_7b_`."
        model_type_to_file_suffix = "_gemma-7b.ckpt"
    hexllm_checkpoint_file = "finetuned_hexllm" + model_type_to_file_suffix

    # Worker pool spec.
    machine_type = "n1-highmem-16"

    replica_count = 1

    # Setup training job.
    job_name = common_util.get_job_name_with_datetime("gemma-keras-model-conversion")

    # Pass training arguments and launch job.
    conversion_job = aiplatform.CustomContainerTrainingJob(
        display_name=job_name,
        container_uri=KERAS_MODEL_CONVERSION_DOCKER_URI,
    )

    conversion_job.run(
        args=[
            f"--weights_file={output_folder_gcsfuse}/{checkpoint_filename}",
            f"--size={size}",
            f"--output_file={output_folder_gcsfuse}/{hexllm_checkpoint_file}",
        ],
        environment_variables={
            "KAGGLE_USERNAME": KAGGLE_USERNAME,
            "KAGGLE_KEY": KAGGLE_KEY,
        },
        replica_count=replica_count,
        machine_type=machine_type,
        service_account=SERVICE_ACCOUNT,
    )

    print(
        "Converted Hexllm checkpoint is saved in: ",
        output_folder + "/" + hexllm_checkpoint_file,
    )

In [None]:
# @title Deploy
# @markdown This section uploads the model to Model Registry and deploys it on the Endpoint. It takes 15 minutes to 1 hour to finish.
# @markdown **Hex-LLM** is a **H**igh-**E**fficiency **L**arge **L**anguage **M**odel (LLM) TPU serving solution built with **XLA**, which is being developed by Google Cloud. This notebook uses TPU v5e machines. Click `Show code` to see more details.

# @markdown Region to deploy the model on TPU.
TPU_DEPLOYMENT_REGION = "us-west1"  # @param ["us-west1"] {isTemplate:true}

# The pre-built serving docker image for Hex-LLM.
HEXLLM_DOCKER_URI = "us-docker.pkg.dev/vertex-ai-restricted/vertex-vision-model-garden-dockers/hex-llm-serve:deploy"

# @markdown Set use_dedicated_endpoint to False if you don't want to use [dedicated endpoint](https://cloud.google.com/vertex-ai/docs/general/deployment#create-dedicated-endpoint).
use_dedicated_endpoint = True  # @param {type:"boolean"}


def deploy_model_hexllm(
    model_name: str,
    model_id: str,
    service_account: str,
    base_model_id: str = None,
    data_parallel_size: int = 1,
    tensor_parallel_size: int = 1,
    machine_type: str = "ct5lp-hightpu-1t",
    tpu_topology: str = "1x1",
    hbm_utilization_factor: float = 0.6,
    max_running_seqs: int = 256,
    max_model_len: int = 4096,
    endpoint_id: str = "",
    min_replica_count: int = 1,
    max_replica_count: int = 1,
    use_dedicated_endpoint: bool = False,
) -> Tuple[aiplatform.Model, aiplatform.Endpoint]:
    """Deploys models with Hex-LLM on TPU in Vertex AI."""
    if endpoint_id:
        aip_endpoint_name = (
            f"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint_id}"
        )
        endpoint = aiplatform.Endpoint(aip_endpoint_name)
    else:
        endpoint = aiplatform.Endpoint.create(
            display_name=f"{model_name}-endpoint",
            location=TPU_DEPLOYMENT_REGION,
            dedicated_endpoint_enabled=use_dedicated_endpoint,
        )

    if not base_model_id:
        base_model_id = model_id

    if not tensor_parallel_size:
        tensor_parallel_size = int(machine_type[-2])

    num_hosts = int(tpu_topology.split("x")[0])

    # Learn more about the supported arguments and environment variables at https://cloud.google.com/vertex-ai/generative-ai/docs/open-models/use-hex-llm#config-server.
    hexllm_args = [
        "--host=0.0.0.0",
        "--port=7080",
        f"--model={model_id}",
        f"--data_parallel_size={data_parallel_size}",
        f"--tensor_parallel_size={tensor_parallel_size}",
        f"--num_hosts={num_hosts}",
        f"--hbm_utilization_factor={hbm_utilization_factor}",
        f"--max_running_seqs={max_running_seqs}",
        f"--max_model_len={max_model_len}",
    ]

    env_vars = {
        "MODEL_ID": base_model_id,
        "HEX_LLM_LOG_LEVEL": "info",
        "DEPLOY_SOURCE": "notebook",
    }

    # HF_TOKEN is not a compulsory field and may not be defined.
    try:
        if HF_TOKEN:
            env_vars.update({"HF_TOKEN": HF_TOKEN})
    except:
        pass

    model = aiplatform.Model.upload(
        display_name=model_name,
        serving_container_image_uri=HEXLLM_DOCKER_URI,
        serving_container_command=["python", "-m", "hex_llm.server.api_server"],
        serving_container_args=hexllm_args,
        serving_container_ports=[7080],
        serving_container_predict_route="/generate",
        serving_container_health_route="/ping",
        serving_container_environment_variables=env_vars,
        serving_container_shared_memory_size_mb=(16 * 1024),  # 16 GB
        serving_container_deployment_timeout=7200,
        location=TPU_DEPLOYMENT_REGION,
    )

    model.deploy(
        endpoint=endpoint,
        machine_type=machine_type,
        tpu_topology=tpu_topology if num_hosts > 1 else None,
        deploy_request_timeout=1800,
        service_account=service_account,
        min_replica_count=min_replica_count,
        max_replica_count=max_replica_count,
    )
    return model, endpoint


if LOAD_MODEL_FROM != "Kaggle":
    print("Skipped: Expect to load model from Kaggle, got", LOAD_MODEL_FROM)
else:
    if "2b" in KAGGLE_MODEL_ID:
        # Sets ct5lp-hightpu-1t (1 TPU chip) to deploy Gemma 2B models.
        machine_type = "ct5lp-hightpu-1t"
        accelerator_type = "TPU_V5e"
        # Note: 1 TPU-V5e chip has only 1 core.
        accelerator_count = 1
    else:
        # Sets ct5lp-hightpu-4t (4 TPU chips) to deploy Gemma 7B models.
        machine_type = "ct5lp-hightpu-4t"
        accelerator_type = "TPU_V5e"
        # Note: 1 TPU-V5e chip has only 1 core.
        accelerator_count = 4

    common_util.check_quota(
        project_id=PROJECT_ID,
        region=REGION,
        accelerator_type=accelerator_type,
        accelerator_count=accelerator_count,
        is_for_training=False,
    )

    # Note that a larger max_num_batched_tokens will require more TPU memory.
    max_num_batched_tokens = 11264
    # Multiple of tokens for padding alignment. A higher value can reduce
    # re-compilation but can also increase the waste in computation.
    tokens_pad_multiple = 1024
    # Multiple of sequences for padding alignment. A higher value can reduce
    # re-compilation but can also increase the waste in computation.
    seqs_pad_multiple = 32

    print("Using model from: ", output_folder)
    models["hexllm_tpu"], endpoints["hexllm_tpu"] = deploy_model_hexllm(
        model_name=common_util.get_job_name_with_datetime(prefix="gemma-serve-hexllm"),
        base_model_id=f"google/{KAGGLE_MODEL_ID}",
        model_id=output_folder,
        service_account=SERVICE_ACCOUNT,
        machine_type=machine_type,
        max_num_batched_tokens=max_num_batched_tokens,
        tokens_pad_multiple=tokens_pad_multiple,
        seqs_pad_multiple=seqs_pad_multiple,
        use_dedicated_endpoint=use_dedicated_endpoint,
    )
    print("endpoint_name:", endpoints["hexllm_tpu"].name)

In [None]:
# @title Predict
# @markdown Once deployment succeeds, you can send requests to the endpoint with text prompts based on your `template`.

# @markdown Additionally, you can moderate the generated text with Vertex AI. See [Moderate text documentation](https://cloud.google.com/natural-language/docs/moderating-text) for more details.

# Loads an existing endpoint instance using the endpoint name:
# - Using `endpoint_name = endpoint.name` allows us to get the
#   endpoint name of the endpoint `endpoint` created in the cell
#   above.
# - Alternatively, you can set `endpoint_name = "1234567890123456789"` to load
#   an existing endpoint with the ID 1234567890123456789.
# You may uncomment the code below to load an existing endpoint.

# endpoint_name = ""  # @param {type:"string"}
# aip_endpoint_name = (
#     f"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint_name}"
# )
# endpoint = aiplatform.Endpoint(aip_endpoint_name)

# If you encounter the issue like `ServiceUnavailable: 503 Took too long to respond when processing`,
# you can reduce the max length, such as set max_tokens as 20.

# @markdown **Note:** The following input corresponds to the default `template` set for the `databricks-dolly-15k` which uses `instruction` and `response` keys.
# @markdown   If you modify the `template` or use another dataset, then modify the `prompt` accordingly. For example for the `imdb_reviews` dataset  where `template = "{text}"`, set `prompt = "Inception is "`.

if LOAD_MODEL_FROM != "Kaggle":
    print("Skipped: Expect to load model from Kaggle, got", LOAD_MODEL_FROM)
else:
    prompt = "Instruction: What should I do on a trip to Europe? Response: "  # @param {type: "string"}
    # Overrides max_tokens and top_k parameters during inferences.
    # If you encounter the issue like `ServiceUnavailable: 503 Took too long to respond when processing`,
    # you can reduce the max length, such as set max_tokens as 20.
    max_tokens = 50  # @param {type:"integer"}
    temperature = 1.0  # @param {type:"number"}
    top_p = 1.0  # @param {type:"number"}
    top_k = 1.0  # @param {type:"number"}

    # @markdown **Note that the first few prompts will take longer to execute.**

    instances = [
        {
            "prompt": prompt,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "top_k": top_k,
        },
    ]
    prediction_response = endpoints["hexllm_tpu"].predict(
        instances=instances, use_dedicated_endpoint=use_dedicated_endpoint
    )

    for prediction in prediction_response.predictions:
        print(prediction)

## Clean up resources

In [None]:
# Delete the train job.
train_job.delete()

# Delete the conversion job.
if conversion_job:
    conversion_job.delete()

# @markdown  Delete the experiment models and endpoints to recycle the resources
# @markdown  and avoid unnecessary continuous charges that may incur.

# Undeploy model and delete endpoint.
for endpoint in endpoints.values():
    endpoint.delete(force=True)

# Delete models.
for model in models.values():
    model.delete()

delete_bucket = False  # @param {type:"boolean"}
if delete_bucket:
    ! gsutil -m rm -r $BUCKET_NAME