In [None]:
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI Model Garden - Fine-tune gpt-oss models with Axolotl

<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fvertex-ai-samples%2Fmain%2Fnotebooks%2Fcommunity%2Fmodel_garden%2Fmodel_garden_axolotl_gpt_oss_finetuning.ipynb">
      <img alt="Google Cloud Colab Enterprise logo" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" width="32px"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_axolotl_gpt_oss_finetuning.ipynb">
      <img alt="GitHub logo" src="https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png" width="32px"><br> View on GitHub
    </a>
  </td>
</tr></tbody></table>

## Overview
This notebook demonstrates fine-tuning gpt-oss model using [Axolotl](https://github.com/axolotl-ai-cloud/axolotl). Axolotl streamlines AI model fine-tuning by providing a wide range of training recipes and supporting multiple configurations and architectures.

### Objective
- Train gpt-oss model using Axolotl with Vertex AI Training.
- Deploy the trained model on Vertex AI and run predictions on Google Cloud.

### Resources required
The table below outlines the recommended machine specifications for different parts of the notebook to function correctly. Note that machine types with higher VRAM than recommended can also be used.
> | Model | Vertex AI Finetuning | Vertex AI Deployment |
| ----------- | ----------- | ----------- |
| openai/gpt-oss-20b | a2-ultragpu-8g or a3-highgpu-8g | a3-highgpu-1g |

Learn more about machine types by following [this doc](https://cloud.google.com/vertex-ai/docs/training/configure-compute#specifying_gpus).

### File a bug

File a bug on [GitHub](https://github.com/GoogleCloudPlatform/vertex-ai-samples/issues/new) if you encounter any issue with the notebook.

## Before you begin

In [None]:
# @title Import utility packages for fine-tuning

# Upgrade Vertex AI SDK.
! pip3 install --upgrade --quiet 'google-cloud-aiplatform==1.103.0'

# Import the necessary packages.
! rm -rf vertex-ai-samples && git clone https://github.com/GoogleCloudPlatform/vertex-ai-samples.git
! cd vertex-ai-samples

# Import the necessary packages.

import datetime
import importlib
import os
import pathlib
import time
import uuid
from typing import Tuple

import requests
import yaml
from google.cloud import aiplatform

common_util = importlib.import_module(
    "vertex-ai-samples.notebooks.community.model_garden.docker_source_codes.notebook_util.common_util"
)


def run_cmd_and_check_output(
    cmd: list[str], env: dict[str, str] = None, input: str = "", cwd: str = None
):
    """Runs the given command and raises exception if the command fails."""
    with subprocess.Popen(
        cmd,
        stdin=subprocess.PIPE,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1,
        env=env,
        cwd=cwd,
    ) as p:
        if input:
            p.stdin.write(input)
            p.stdin.flush()
        p.stdin.close()
        for line in p.stdout:
            print(line, end="", flush=True)
    if p.returncode:
        raise ValueError(
            f"Command '{' '.join(cmd)}' execution failed with return code {p.returncode}"
        )


train_job = None
models, endpoints = {}, {}
HF_TOKEN = ""
WORKING_DIR = os.getcwd()
print(f"Current working directory for notebook: {WORKING_DIR}")
from google import auth

In [None]:
# @title Setup Google Cloud project

# @markdown 1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

# @markdown 2. For finetuning using Vertex AI, we will use Dynamic Workload Scheduler. Learn more about Dynamic workload scheduler [here](https://cloud.google.com/vertex-ai/docs/training/schedule-jobs-dws). For Dynamic Workload Scheduler, check the [us-central1](https://console.cloud.google.com/iam-admin/quotas?location=us-central1&metric=aiplatform.googleapis.com%2Fcustom_model_training_preemptible_nvidia_h100_gpus) or [europe-west4](https://console.cloud.google.com/iam-admin/quotas?location=europe-west4&metric=aiplatform.googleapis.com%2Fcustom_model_training_preemptible_nvidia_h100_gpus) quota for Nvidia H100 GPUs, [us-central1](https://console.cloud.google.com/iam-admin/quotas?location=us-central1&metric=aiplatform.googleapis.com%2Fcustom_model_training_preemptible_nvidia_a100_80gb_gpus) quota for Nvidia A100 80GB GPUs. If you do not have enough GPUs, then you can follow [these instructions](https://cloud.google.com/docs/quotas/view-manage#viewing_your_quota_console) to request quota.

# @markdown 3. For serving, **[click here](https://console.cloud.google.com/iam-admin/quotas?location=us-central1&metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_l4_gpus)** to check if your project already has the required L4 GPUs in the us-central1 region. If you need more L4 GPUs for your project, then you can follow [these instructions](https://cloud.google.com/docs/quotas/view-manage#viewing_your_quota_console) to request more. Alternatively, if you want to run predictions with A100 80GB or H100 GPUs, we recommend using the regions listed below. **NOTE:** Make sure you have associated quota in selected regions. Click the links to see your current quota for each GPU type: [Nvidia A100 80GB](https://console.cloud.google.com/iam-admin/quotas?metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_a100_80gb_gpus), [Nvidia H100 80GB](https://console.cloud.google.com/iam-admin/quotas?metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_h100_gpus).

# @markdown > | Machine Type | Accelerator Type | Recommended Regions |
# @markdown | ----------- | ----------- | ----------- |
# @markdown | a2-ultragpu-1g | 1 NVIDIA_A100_80GB | us-central1, us-east4, europe-west4, asia-southeast1, us-east4 |
# @markdown | a3-highgpu-2g | 2 NVIDIA_H100_80GB | us-west1, asia-southeast1, europe-west4 |
# @markdown | a3-highgpu-4g | 4 NVIDIA_H100_80GB | us-west1, asia-southeast1, europe-west4 |
# @markdown | a3-highgpu-8g | 8 NVIDIA_H100_80GB | us-central1, us-east5, europe-west4, us-west1, asia-southeast1 |

# @markdown 4. **[Optional]** [Create a Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets) for storing experiment outputs. Set the BUCKET_URI for the experiment environment. The specified Cloud Storage bucket (`BUCKET_URI`) should be located in the same region as where the notebook was launched. Note that a multi-region bucket (eg. "us") is not considered a match for a single region covered by the multi-region range (eg. "us-central1"). If not set, a unique GCS bucket will be created instead.

BUCKET_URI = "gs://"  # @param {type:"string"}

# @markdown 5. **[Optional]** Set region. If not set, the region will be set automatically according to Colab Enterprise environment.

REGION = ""  # @param {type:"string"}


# Get the default cloud project id.
PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"]

# Get the default region for launching jobs.
if not REGION:
    if not os.environ.get("GOOGLE_CLOUD_REGION"):
        raise ValueError(
            "REGION must be set. See"
            " https://cloud.google.com/vertex-ai/docs/general/locations for"
            " available cloud locations."
        )
    REGION = os.environ["GOOGLE_CLOUD_REGION"]

# Enable the Vertex AI API and Compute Engine API, if not already.
print("Enabling Vertex AI API and Compute Engine API.")
! gcloud services enable aiplatform.googleapis.com compute.googleapis.com

# Cloud Storage bucket for storing the experiment artifacts.
# A unique GCS bucket will be created for the purpose of this notebook. If you
# prefer using your own GCS bucket, change the value yourself below.
now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])

if BUCKET_URI is None or BUCKET_URI.strip() == "" or BUCKET_URI == "gs://":
    BUCKET_URI = f"gs://{PROJECT_ID}-tmp-{now}-{str(uuid.uuid4())[:4]}"
    BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])
    ! gcloud storage buckets create --location={REGION} {BUCKET_URI}
else:
    assert BUCKET_URI.startswith("gs://"), "BUCKET_URI must start with `gs://`."
    shell_output = ! gcloud storage ls --full --buckets {BUCKET_NAME} | grep "Location constraint:" | sed "s/Location constraint://"
    bucket_region = shell_output[0].strip().lower()
    if bucket_region != REGION:
        raise ValueError(
            "Bucket region %s is different from notebook region %s"
            % (bucket_region, REGION)
        )
print(f"Using this GCS Bucket: {BUCKET_URI}")

STAGING_BUCKET = os.path.join(BUCKET_URI, "temporal")
MODEL_BUCKET = os.path.join(BUCKET_URI, "axolotl")


# Initialize Vertex AI API.
print("Initializing Vertex AI API.")
aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=STAGING_BUCKET)

# Gets the default SERVICE_ACCOUNT.
shell_output = ! gcloud projects describe $PROJECT_ID
project_number = shell_output[-1].split(":")[1].strip().replace("'", "")
SERVICE_ACCOUNT = f"{project_number}-compute@developer.gserviceaccount.com"
print("Using this default Service Account:", SERVICE_ACCOUNT)


# Provision permissions to the SERVICE_ACCOUNT with the GCS bucket
! # Note: Migrating scripts using gsutil iam ch is more complex than get or set. You need to replace the single iam ch command with a series of gcloud storage bucket add-iam-policy-binding and/or gcloud storage bucket remove-iam-policy-binding commands, or replicate the read-modify-write loop.
! gcloud storage buckets add-iam-policy-binding $BUCKET_NAME --member=serviceAccount:{SERVICE_ACCOUNT} --role=roles/storage.admin

! gcloud config set project $PROJECT_ID
! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role="roles/storage.admin"
! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role="roles/aiplatform.user"

## Finetune with Axolotl

In [None]:
# @title Set model to fine-tune
# @markdown Note: This overrides Axolotl's `base_model` flag.
HF_MODEL_ID = "openai/gpt-oss-20b"  # @param ["openai/gpt-oss-20b"]

In [None]:
# @title Set Axolotl config

# @markdown Axolotl is designed to work with YAML config files that contain everything you need to preprocess a dataset, train or fine-tune a model, run model inference or evaluation, and much more.
# @markdown The gpt-oss Axolotl configs are taken from [examples directory](https://github.com/axolotl-ai-cloud/axolotl/blob/ba3dba3e4f6fbe845b0249f517c3bff88d898e22/examples/gpt-oss).

# @markdown Suggestion for gpt-oss Axolotl configs:
# @markdown > | Model | Recommended Axolotl Config |
# @markdown | ----------- | ----------- |
# @markdown | openai/gpt-oss-20b | examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml |


# @markdown You can also customize the Axolotl config as per your requirements. To use a custom Axolotl config you can use `LOCAL` or `GCS` source option below.
# @markdown Alternatively, you can specify github axolotl config and override flags using `Setup Axolotl Flags` section below.

# @markdown 1. Set Axolotl config source.<br>
# @markdown For **GITHUB** as source, you can explore different Axolotl configurations in the [examples directory](https://github.com/axolotl-ai-cloud/axolotl/tree/6ba5c0ed2c42a0e069b28c83646ee5a2a6904430/examples). For `GITHUB` source, `AXOLOTL_CONFIG_PATH` should start with `examples/`. e.g. "examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml".<br>
# @markdown For **LOCAL** as source, create Axolotl config yaml file and specify correct path below. Note that, the local file will be copied to GCS bucket before running Vertex AI training job. For `LOCAL` source, `AXOLOTL_CONFIG_PATH` should be a absolute path of the config file, e.g. /content/lora.yml.<br>
# @markdown For **GCS** as source, specify the GCS URI to the Axolotl config file. Make sure the file is accessible to service account used in the notebook. For `GCS` source, `AXOLOTL_CONFIG_PATH` should be a complete GCS URI of the config file, e.g. gs://bucket/path/to/config/file.yml.

AXOLOTL_SOURCE = "GITHUB"  # @param ["GITHUB", "LOCAL", "GCS"]

# @markdown 2. Set the Axolotl config file path.
AXOLOTL_CONFIG_PATH = "examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml"  # @param ["examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml"] {allow-input: true}

assert AXOLOTL_CONFIG_PATH, "AXOLOTL_CONFIG_PATH must be set."

if AXOLOTL_SOURCE == "GITHUB":
    assert AXOLOTL_CONFIG_PATH.startswith(
        "examples/"
    ), "AXOLOTL_CONFIG_PATH must start with examples/ for GITHUB source."
    github_url = f"https://github.com/axolotl-ai-cloud/axolotl/raw/9d5c95db6f4d883252fdb1183e82d0b354ff76a2/{AXOLOTL_CONFIG_PATH}"
    r = requests.get(github_url)
    axolotl_config = r.content.decode("utf-8")
    axolotl_config = yaml.safe_load(axolotl_config)
elif AXOLOTL_SOURCE == "LOCAL":
    config_path = pathlib.Path(AXOLOTL_CONFIG_PATH)
    assert config_path.exists(), "AXOLOTL_CONFIG_PATH must exist for LOCAL source."
    file_content = config_path.read_text()
    axolotl_config = yaml.safe_load(file_content)
elif AXOLOTL_SOURCE == "GCS":
    local_path = pathlib.Path(f"{WORKING_DIR}/tmp/axolotl_config.yml")
    common_util.download_gcs_file_to_local(AXOLOTL_CONFIG_PATH, local_path.absolute())
    file_content = local_path.read_text()
    axolotl_config = yaml.safe_load(file_content)
    AXOLOTL_CONFIG_PATH = common_util.gcs_fuse_path(AXOLOTL_CONFIG_PATH)
else:
    raise ValueError(f"Unsupported AXOLOTL_SOURCE: {AXOLOTL_SOURCE}")

OUTPUT_GCS_URI = MODEL_BUCKET

if not OUTPUT_GCS_URI.startswith("gs://"):
    OUTPUT_GCS_URI = f"gs://{OUTPUT_GCS_URI}"

output_sub_dir = (
    AXOLOTL_CONFIG_PATH.replace("/", "_").replace(".yaml", "").replace(".yml", "")
)
BASE_AXOLOTL_OUTPUT_GCS_URI = f"{OUTPUT_GCS_URI}/{output_sub_dir}/axolotl_output"
BASE_AXOLOTL_OUTPUT_DIR = common_util.gcs_fuse_path(BASE_AXOLOTL_OUTPUT_GCS_URI)

# Placeholders for dataset settings.
datasets = []
test_datasets = []

In [None]:
# @title Setup HF token
HF_TOKEN = ""  # @param {type:"string"}

In [None]:
# @title **[Optional]** Setup dataset

# @markdown This section configures the dataset used for fine-tuning.

# @markdown **Note: If you don't fill any of the dataset options given below, then the dataset used will be the one defined in the Axolotl config file.** You have two options to configure the dataset:

# @markdown **1. Use a Hugging Face Dataset**
# @markdown   - Requires specifying the dataset name and type.

# @markdown **2. Load from Google Cloud Storage (GCS)**
# @markdown   - Requires specifying the bucket name, dataset type, file type, and paths to training/test splits.

# @markdown **Choose ONE of the following options:**

# @markdown ---
# @markdown **Option 1: Hugging Face**

# @markdown **Hugging Face Dataset Name:**
HF_DATASET = ""  # @param {type:"string", placeholder: "e.g. trl-lib/chatbot_arena_completions"}
# @markdown **Set the dataset type:** Refer to [Axolotl config file](https://github.com/axolotl-ai-cloud/axolotl/blob/6ba5c0ed2c42a0e069b28c83646ee5a2a6904430/docs/config.qmd#L140) for more details.
HF_DATASET_TYPE = ""  # @param {type:"string", placeholder: "e.g. chat_template"}
if HF_DATASET:
    assert HF_DATASET_TYPE, "HF_DATASET_TYPE must be set if HF_DATASET is set."

# @markdown ---
# @markdown **Option 2: GCS**

# @markdown **Path to Training Data :**

# @markdown E.g. `gs://cloud-samples-data/vertex-ai/model-garden/datasets/vertex-sample-chat-train.jsonl`
TRAIN_DATASET_PATH = ""  # @param {type:"string"}
# @markdown **File Type**. Refer to the [Axolotl config file](https://github.com/axolotl-ai-cloud/axolotl/blob/6ba5c0ed2c42a0e069b28c83646ee5a2a6904430/docs/config.qmd#L103).
FILE_TYPE = ""  # @param {type:"string", placeholder: "e.g. json"}
# @markdown **Messages Column**. Refer to the [Axolotl config file](https://github.com/axolotl-ai-cloud/axolotl/blob/6ba5c0ed2c42a0e069b28c83646ee5a2a6904430/docs/config.qmd#L155).
MESSAGES_COLUMN = "messages"  # @param {type:"string", placeholder: "e.g. messages"}

# @markdown **[Optional] Path to Test Data :**
# @markdown To use a dedicated validation set, provide the file path. Otherwise, the training data will be split to create a validation set.

# @markdown E.g. `gs://cloud-samples-data/vertex-ai/model-garden/datasets/vertex-sample-chat-validation.jsonl`
TEST_DATASET_PATH = ""  # @param {type:"string"}

if TRAIN_DATASET_PATH:
    assert FILE_TYPE, "FILE_TYPE must be set if TRAIN_DATASET_PATH is set."

if TEST_DATASET_PATH:
    assert (
        TRAIN_DATASET_PATH
    ), "TRAIN_DATASET_PATH must be set if TEST_DATASET_PATH is set."

assert not (
    HF_DATASET and TRAIN_DATASET_PATH
), "Only one of HF_DATASET or TRAIN_DATASET_PATH can be set."

datasets = []
test_datasets = []

if TRAIN_DATASET_PATH:
    dataset = {
        "path": TRAIN_DATASET_PATH,
        "ds_type": FILE_TYPE,
        "type": "chat_template",
        "field_messages": MESSAGES_COLUMN,
    }
    datasets.append(dataset)

if TEST_DATASET_PATH:
    dataset = {
        "path": TEST_DATASET_PATH,
        "ds_type": FILE_TYPE,
        "type": "chat_template",
        "split": "train",
        "field_messages": MESSAGES_COLUMN,
    }
    test_datasets.append(dataset)

if HF_DATASET:
    datasets.append({"path": HF_DATASET, "type": HF_DATASET_TYPE})

In [None]:
# @title Setup Axolotl Flags
# @markdown This section configures additional Axolotl flags. You can explore different Axolotl flags in the [Axolotl config file](https://github.com/axolotl-ai-cloud/axolotl/blob/6ba5c0ed2c42a0e069b28c83646ee5a2a6904430/docs/config.qmd).

# @markdown **To avoid OOM, you can reduce sequence length.** This can be done by setting `sequence_len` flag to some smaller value. But reducing sequence length might also reduce the fine-tuned model's quality.
# @markdown **Another alternative to avoid OOM is to use higher memory GPU.** It is recommended to use Vertex AI training for higher memory GPUs like A100 and H100. Vertex AI training offers greater availability of high-end GPUs.

# @markdown **Training can take a long time (20+ hours) to complete depending on the model, dataset and axololt config.** You can reduce the training time by reducing the max training steps. This can be done by setting `max_steps` flag to some smaller value. Note that, this might also reduce the fine-tuned model's quality.

# @markdown For example, if you want to run only single step of training, then you can set `["--use-tensorboard=True", "--max_steps=1"]` in the `axolotl_flag_overrides` to achieve that.

axolotl_flag_overrides = ["--use-tensorboard=True"]  # @param {type:"raw"}
assert type(axolotl_flag_overrides) is list, "axolotl_flag_overrides must be a list."

axolotl_flag_overrides.append(f"--base_model={HF_MODEL_ID}")


# Check if duplicate flags are passed.
flags_seen = set()
for flag in axolotl_flag_overrides:
    if flag in flags_seen:
        raise ValueError(f"Duplicate flag: {flag}")
    flags_seen.add(flag)

base_model = axolotl_config["base_model"]
for overrides in axolotl_flag_overrides:
    if overrides.startswith("--base_model="):
        base_model = overrides.split("=")[1]
        break
publisher = base_model.split("/")[0]
model_id = base_model.split("/")[1]
model_id = model_id.replace(".", "-")

### Finetune with Vertex AI Training

In [None]:
# @title Vertex AI fine-tuning job
# @markdown This section runs the Axolotl training using Vertex AI training job.
# @markdown **Note: This section can take a long time to run. You can reduce the training time by reducing the max training steps as mentioned in `Setup Axolotl Flags` section.**
# @markdown Refer to [Axolotl config](https://axolotl-ai-cloud.github.io/axolotl/docs/config.html) to override additional Axolotl flags.

from google.cloud.aiplatform.compat.types import \
    custom_job as gca_custom_job_compat

# @markdown Acceletor type to use for training.
training_accelerator_type = "NVIDIA_H100_80GB"  # @param ["NVIDIA_H100_80GB", "NVIDIA_A100_80GB"]


replica_count = 1
repo = "us-docker.pkg.dev/vertex-ai"
per_node_accelerator_count = 8
boot_disk_size_gb = 500
dws_kwargs = {
    "max_wait_duration": 5400,  # 90 minutes
    "scheduling_strategy": gca_custom_job_compat.Scheduling.Strategy.FLEX_START,
}
is_dynamic_workload_scheduler = True
if training_accelerator_type == "NVIDIA_A100_80GB":
    training_machine_type = "a2-ultragpu-8g"
elif training_accelerator_type == "NVIDIA_H100_80GB":
    training_machine_type = "a3-highgpu-8g"
    boot_disk_size_gb = 2000
else:
    raise ValueError(f"Unsupported accelerator type: {training_accelerator_type}")

TRAIN_DOCKER_URI = (
    f"{repo}/vertex-vision-model-garden-dockers/axolotl-train-dws:20250812-1800-rc1"
)

common_util.check_quota(
    project_id=PROJECT_ID,
    region=REGION,
    accelerator_type=training_accelerator_type,
    accelerator_count=per_node_accelerator_count * replica_count,
    is_for_training=True,
    is_restricted_image=False,
    is_dynamic_workload_scheduler=is_dynamic_workload_scheduler,
)

vertex_ai_config_path = AXOLOTL_CONFIG_PATH
# Copy the config file to the bucket.
if AXOLOTL_SOURCE == "LOCAL":
    ! gcloud storage cp $AXOLOTL_CONFIG_PATH $MODEL_BUCKET/config/
    vertex_ai_config_path = f"{common_util.gcs_fuse_path(MODEL_BUCKET)}/config/{pathlib.Path(AXOLOTL_CONFIG_PATH).name}"

job_name = common_util.get_job_name_with_datetime("axolotl-train")
AXOLOTL_OUTPUT_GCS_URI = f"{BASE_AXOLOTL_OUTPUT_GCS_URI}/{job_name}"
AXOLOTL_OUTPUT_DIR = f"{BASE_AXOLOTL_OUTPUT_DIR}/{job_name}"

TRAINING_JOB_OUTPUT_DIR = f"{AXOLOTL_OUTPUT_GCS_URI}/training_job_output"

# Set Axolotl flags.
axolotl_config_overwrites = []
axolotl_config_overwrites.append(f"--output_dir={AXOLOTL_OUTPUT_DIR}")
if len(datasets) > 0:
    axolotl_config_overwrites.append(f"--datasets={datasets}")
if len(test_datasets) > 0:
    axolotl_config_overwrites.append(f"--test_datasets={test_datasets}")
    axolotl_config_overwrites.append("--val_set_size=0")
axolotl_config_overwrites += axolotl_flag_overrides

train_job_args = []
train_job_args.append(f"--axolotl_config_path={vertex_ai_config_path}")
train_job_args += axolotl_config_overwrites
if HF_TOKEN:
    train_job_args.append(f"--huggingface_access_token={HF_TOKEN}")

job_name = common_util.get_job_name_with_datetime("axolotl-train")

# Add labels for the finetuning job.
labels = {
    "mg-source": "notebook",
    "mg-notebook-name": "model_garden_axolotl_gpt_oss_finetuning.ipynb".split(".")[0],
}

model_name = AXOLOTL_CONFIG_PATH.split("/")[1]
labels["mg-tune"] = f"publishers-{publisher}-models-{model_name}".lower()
labels["versioned-mg-tune"] = f"{labels['mg-tune']}-{model_id}".lower()
labels["versioned-mg-tune"] = labels["versioned-mg-tune"][
    : min(len(labels["versioned-mg-tune"]), 63)
]


# Pass training arguments and launch job.
train_job = aiplatform.CustomContainerTrainingJob(
    display_name=job_name,
    container_uri=TRAIN_DOCKER_URI,
    labels=labels,
)

# Run Vertex AI job.
print("Running training job with args:")
print(" \\\n".join(train_job_args))
train_job.run(
    args=train_job_args,
    replica_count=replica_count,
    machine_type=training_machine_type,
    accelerator_type=training_accelerator_type,
    accelerator_count=per_node_accelerator_count,
    boot_disk_size_gb=boot_disk_size_gb,
    service_account=SERVICE_ACCOUNT,
    base_output_dir=TRAINING_JOB_OUTPUT_DIR,
    sync=False,  # Non-blocking call to run.
    **dws_kwargs,
)

# Wait until resource has been created.
train_job.wait_for_resource_creation()

### Run TensorBoard

In [None]:
# @markdown This section shows how to launch TensorBoard in a [Cloud Shell](https://cloud.google.com/shell/docs).
# @markdown 1. Click the Cloud Shell icon(![terminal](https://github.com/google/material-design-icons/blob/master/png/action/terminal/materialicons/24dp/1x/baseline_terminal_black_24dp.png?raw=true)) on the top right to open the Cloud Shell.
# @markdown 2. Copy the `tensorboard` command shown below by running this cell.
# @markdown 3. Paste and run the command in the Cloud Shell to launch TensorBoard.
# @markdown 4. Once the command runs (You may have to click `Authorize` if prompted), click the link starting with `http://localhost`.

# @markdown Note: You may need to wait around 10 minutes after the job starts in order for the TensorBoard logs to be written to the GCS bucket.
print(f"Command to copy: tensorboard --logdir {AXOLOTL_OUTPUT_GCS_URI}/node-0/runs/")

## Deploy using SGLang

In [None]:
# @markdown 1. Wait for the training job to finish.
if train_job and train_job.end_time is None:
    print("Waiting for the training job to finish...")
    train_job.wait()
    print("The training job has finished.")

# @markdown 2. Set up SGLang docker URI and model gcs uri.


SGLANG_DOCKER_URI = "us-docker.pkg.dev/deeplearning-platform-release/vertex-model-garden/sglang-serve.cu124.0-4.ubuntu2204.py310:model-garden.sglang-0-4-release_20250810.00_p0"
SGLANG_MODEL_GCS_URI = AXOLOTL_OUTPUT_GCS_URI

if "adapter" in axolotl_config and (
    axolotl_config["adapter"] == "lora" or axolotl_config["adapter"] == "qlora"
):
    SGLANG_MODEL_GCS_URI = f"{AXOLOTL_OUTPUT_GCS_URI}/node-0/merged"

### Create model endpoint

In [None]:
# @markdown This section uploads the model to Model Registry and deploys it on the Endpoint. It takes 15 minutes to 1 hour to finish.
# @markdown 1. Set the machine type and accelerator type.
# @markdown Find Vertex AI prediction supported accelerators and regions [here](https://cloud.google.com/vertex-ai/docs/predictions/configure-compute).

if "20b" in HF_MODEL_ID:
    accelerator_type = "NVIDIA_H100_80GB"
    machine_type = "a3-highgpu-1g"
    per_node_accelerator_count = 1
else:
    raise ValueError(
        f"Recommended machine settings not found for model: {HF_MODEL_ID}."
    )

common_util.check_quota(
    project_id=PROJECT_ID,
    region=REGION,
    accelerator_type=accelerator_type,
    accelerator_count=per_node_accelerator_count,
    is_for_training=False,
)

# @markdown Set `use_dedicated_endpoint` to False if you don't want to use [dedicated endpoint](https://cloud.google.com/vertex-ai/docs/general/deployment#create-dedicated-endpoint).
use_dedicated_endpoint = True  # @param {type:"boolean"}

gpu_memory_utilization = 0.95
max_model_len = 131072
if "1b" in HF_MODEL_ID:
    max_model_len = 32768


def poll_operation(op_name: str) -> bool:  # noqa: F811
    creds, _ = auth.default()
    auth_req = auth.transport.requests.Request()
    creds.refresh(auth_req)
    headers = {
        "Authorization": f"Bearer {creds.token}",
    }
    get_resp = requests.get(
        f"https://{REGION}-aiplatform.googleapis.com/ui/{op_name}",
        headers=headers,
    )
    opjs = get_resp.json()
    if "error" in opjs:
        raise ValueError(f"Operation failed: {opjs['error']}")
    return opjs.get("done", False)


def poll_and_wait(op_name: str, total_wait: int, interval: int = 60):  # noqa: F811
    waited = 0
    while not poll_operation(op_name):
        if waited > total_wait:
            raise TimeoutError("Operation timed out")
        print(
            f"\rStill waiting for operation... Waited time in second: {waited:<6}",
            end="",
            flush=True,
        )
        waited += interval
        time.sleep(interval)


def deploy_model_sglang_multihost(
    model_name: str,
    model_id: str,
    publisher: str,
    publisher_model_id: str,
    service_account: str = "",
    base_model_id: str = "",
    machine_type: str = "g2-standard-8",
    accelerator_type: str = "NVIDIA_L4",
    accelerator_count: int = 1,
    multihost_gpu_node_count: int = 1,
    gpu_memory_utilization: float | None = None,
    context_length: int | None = None,
    dtype: str | None = None,
    quantization: str | None = None,
    enable_trust_remote_code: bool = False,
    enable_torch_compile: bool = False,
    torch_compile_max_bs: int | None = None,
    attention_backend: str = "",
    enable_flashinfer_mla: bool = False,
    disable_cuda_graph: bool = False,
    speculative_algorithm: str | None = None,
    speculative_draft_model_path: str = "",
    speculative_num_steps: int = 3,
    speculative_eagle_topk: int = 1,
    speculative_num_draft_tokens: int = 4,
    enable_jit_deepgemm: bool = False,
    enable_dp_attention: bool = False,
    dp_size: int = 1,
    enable_multimodal: bool = False,
    use_dedicated_endpoint: bool = False,
    max_num_seqs: int | None = None,
    is_spot: bool = True,
    tool_call_parser: str | None = None,
) -> Tuple[aiplatform.Model, aiplatform.Endpoint]:
    """Deploys trained models with SGLang into Vertex AI."""
    endpoint = aiplatform.Endpoint.create(
        display_name=f"{model_name}-endpoint",
        dedicated_endpoint_enabled=use_dedicated_endpoint,
    )

    if not base_model_id:
        base_model_id = model_id

    # See https://docs.sglang.ai/backend/server_arguments.html for a list of possible arguments with descriptions.
    sglang_args = [
        f"--model={model_id}",
        f"--tp={accelerator_count * multihost_gpu_node_count}",
        f"--dp={dp_size}",
    ]

    if context_length:
        sglang_args.append(f"--context-length={context_length}")

    if gpu_memory_utilization:
        sglang_args.append(f"--mem-fraction-static={gpu_memory_utilization}")

    if max_num_seqs:
        sglang_args.append(f"--max-running-requests={max_num_seqs}")

    if dtype:
        sglang_args.append(f"--dtype={dtype}")

    if quantization:
        sglang_args.append(f"--quantization={quantization}")

    if enable_trust_remote_code:
        sglang_args.append("--trust-remote-code")

    if enable_torch_compile:
        sglang_args.append("--enable-torch-compile")
        if torch_compile_max_bs:
            sglang_args.append(f"--torch-compile-max-bs={torch_compile_max_bs}")

    if attention_backend:
        sglang_args.append(f"--attention-backend={attention_backend}")

    if enable_flashinfer_mla:
        sglang_args.append("--enable-flashinfer-mla")

    if disable_cuda_graph:
        sglang_args.append("--disable-cuda-graph")

    if speculative_algorithm:
        sglang_args.append(f"--speculative-algorithm={speculative_algorithm}")
        sglang_args.append(
            f"--speculative-draft-model-path={speculative_draft_model_path}"
        )
        sglang_args.append(f"--speculative-num-steps={speculative_num_steps}")
        sglang_args.append(f"--speculative-eagle-topk={speculative_eagle_topk}")
        sglang_args.append(
            f"--speculative-num-draft-tokens={speculative_num_draft_tokens}"
        )

    if enable_dp_attention:
        sglang_args.append("--enable-dp-attention")

    if enable_multimodal:
        sglang_args.append("--enable-multimodal")

    if tool_call_parser:
        sglang_args.append(f"--tool-call-parser={tool_call_parser}")

    env_vars = {
        "MODEL_ID": base_model_id,
        "DEPLOY_SOURCE": "notebook",
    }

    if enable_jit_deepgemm:
        env_vars["SGL_ENABLE_JIT_DEEPGEMM"] = "1"

    # HF_TOKEN is not a compulsory field and may not be defined.
    try:
        if HF_TOKEN:
            env_vars["HF_TOKEN"] = HF_TOKEN
    except NameError:
        pass

    model = aiplatform.Model.upload(
        display_name=model_name,
        serving_container_image_uri=SGLANG_DOCKER_URI,
        serving_container_args=sglang_args,
        serving_container_ports=[30000],
        serving_container_predict_route="/vertex_generate",
        serving_container_health_route="/health",
        serving_container_environment_variables=env_vars,
        serving_container_shared_memory_size_mb=(16 * 1024),  # 16 GB
        serving_container_deployment_timeout=7200,
        model_garden_source_model_name=(
            f"publishers/{publisher}/models/{publisher_model_id}"
        ),
    )
    print(
        f"Deploying {model_name} on {machine_type} with {int(accelerator_count * multihost_gpu_node_count)} {accelerator_type} GPU(s)."
    )

    creds, _ = auth.default()
    auth_req = auth.transport.requests.Request()
    creds.refresh(auth_req)

    url = f"https://{REGION}-aiplatform.googleapis.com/ui/projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint.name}:deployModel"
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {creds.token}",
    }
    data = {
        "deployedModel": {
            "model": model.resource_name,
            "displayName": model_name,
            "dedicatedResources": {
                "machineSpec": {
                    "machineType": machine_type,
                    "multihostGpuNodeCount": multihost_gpu_node_count,
                    "acceleratorType": accelerator_type,
                    "acceleratorCount": accelerator_count,
                },
                "minReplicaCount": 1,
                "maxReplicaCount": 1,
            },
            "system_labels": {
                "NOTEBOOK_NAME": "model_garden_axolotl_gpt_oss_finetuning.ipynb",
                "NOTEBOOK_ENVIRONMENT": common_util.get_deploy_source(),
            },
        },
    }
    if service_account:
        data["deployedModel"]["serviceAccount"] = service_account
    if is_spot:
        data["deployedModel"]["dedicatedResources"]["spot"] = True
    response = requests.post(url, headers=headers, json=data)
    print(f"Deploy Model response: {response.json()}")
    if response.status_code != 200 or "name" not in response.json():
        raise ValueError(f"Failed to deploy model: {response.text}")
    poll_and_wait(response.json()["name"], 7200)
    print("endpoint_name:", endpoint.name)

    return model, endpoint


models["sglang_gpu"], endpoints["sglang_gpu"] = deploy_model_sglang_multihost(
    model_name=common_util.get_job_name_with_datetime(prefix="axolotl-sglang-serve"),
    model_id=SGLANG_MODEL_GCS_URI,
    publisher=publisher.lower(),
    publisher_model_id=model_id.lower(),
    service_account=SERVICE_ACCOUNT,
    machine_type=machine_type,
    accelerator_type=accelerator_type,
    accelerator_count=per_node_accelerator_count,
    use_dedicated_endpoint=use_dedicated_endpoint,
)

### Perform Prediction

In [None]:
# @title Raw predict


# @markdown Once deployment succeeds, you can send requests to the endpoint with text prompts. Sampling parameters supported by SGLang can be found [here](https://docs.sglang.ai/backend/sampling_params.html).

# @markdown Example:

# @markdown ```
# @markdown User: What is the best way to diagnose and fix a flickering light in my house?
# @markdown Assistant: Okay, so I need to figure out how to diagnose and fix a flickering light in my house. Hmm, where do I start? Let's think. First, I remember that flickering lights can be caused by various issues. Maybe the bulb is loose? That's a common problem. Let me start with the simplest things first.
# @markdown ```
# @markdown Additionally, you can moderate the generated text with Vertex AI. See [Moderate text documentation](https://cloud.google.com/natural-language/docs/moderating-text) for more details.

# Loads an existing endpoint instance using the endpoint name:
# - Using `endpoint_name = endpoint.name` allows us to get the
#   endpoint name of the endpoint `endpoint` created in the cell
#   above.
# - Alternatively, you can set `endpoint_name = "1234567890123456789"` to load
#   an existing endpoint with the ID 1234567890123456789.
# You may uncomment the code below to load an existing endpoint.

# endpoint_name = ""  # @param {type:"string"}
# aip_endpoint_name = (
#     f"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint_name}"
# )
# endpoint = aiplatform.Endpoint(aip_endpoint_name)

prompt = "What is a car?"  # @param {type: "string"}
# @markdown If you encounter an issue like `ServiceUnavailable: 503 Took too long to respond when processing`, you can reduce the maximum number of output tokens, by lowering `max_tokens`.
max_new_tokens = 1024  # @param {type:"integer"}
temperature = 0.6  # @param {type:"number"}
top_p = 0.95  # @param {type:"number"}

# Overrides parameters for inferences.
instances = [{"text": prompt}]
parameters = {
    "sampling_params": {
        "max_new_tokens": max_new_tokens,
        "temperature": temperature,
        "top_p": top_p,
    }
}

response = endpoints["sglang_gpu"].predict(
    instances=instances, use_dedicated_endpoint=use_dedicated_endpoint
)

for prediction in response.predictions:
    print(prediction)

# @markdown Click "Show Code" to see more details.

## Clean up resources

In [None]:
# @markdown Delete the training job.

if train_job:
    train_job.delete()

# @markdown  Delete the experiment models and endpoints to recycle the resources
# @markdown  and avoid unnecessary continuous charges that may incur.

# Undeploy model and delete endpoint.
for endpoint in endpoints.values():
    endpoint.delete(force=True)

# Delete models.
for model in models.values():
    model.delete()

delete_bucket = False  # @param {type:"boolean"}
if delete_bucket:
    ! gcloud storage rm --recursive $BUCKET_NAME