In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI Supervised Tuning token count and cost estimation.

<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/vertexai_supervised_tuning_token_count_and_cost_estimation.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Google Colaboratory logo"><br> Open in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fgemini%2Ftuning%2Fvertexai_supervised_tuning_token_count_and_cost_estimation.ipynb">
      <img width="32px" src="https://cloud.google.com/ml-engine/images/colab-enterprise-logo-32px.png" alt="Google Cloud Colab Enterprise logo"><br> Open in Colab Enterprise
    </a>
  </td>    
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/gemini/tuning/vertexai_supervised_tuning_token_count_and_cost_estimation.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo"><br> Open in Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/vertexai_supervised_tuning_token_count_and_cost_estimation.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
</table>

| | |
|-|-|
| Author(s) | [Lehui Liu](https://github.com/liulehui), [Erwin Huizenga](https://github.com/Huize501) |

## Overview

This notebook serves as a tool to preprocess and estimate token counts for tuning costs for tuning [`gemini-1.0-pro-002`](https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning).


For how to prepare dataset for tuning gemini, please refer to this [tutorial](https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-about).

## Get started

### Install Vertex AI SDK and other required packages


In [None]:
%pip install --upgrade --user --quiet google-cloud-aiplatform[tokenization] numpy==1.26.4 tensorflow

### Restart runtime

To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel.

The restart might take a minute or longer. After it's restarted, continue to the next step.

In [2]:
import IPython

app = IPython.Application.instance()
app.kernel.do_shutdown(True)

{'status': 'ok', 'restart': True}

<div class="alert alert-block alert-warning">
<b>⚠️ The kernel is going to restart. Wait until it's finished before continuing to the next step. ⚠️</b>
</div>


### Authenticate your notebook environment (Colab only)

If you're running this notebook on Google Colab, run the cell below to authenticate your environment.

In [1]:
import sys

if "google.colab" in sys.modules:
    from google.colab import auth

    auth.authenticate_user()

### Set Google Cloud project information and initialize Vertex AI SDK

To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).

Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment).

In [2]:
PROJECT_ID = "[your-project-id]"  # @param {type:"string"}
LOCATION = "us-central1"  # @param {type:"string"}


import vertexai

vertexai.init(project=PROJECT_ID, location=LOCATION)

## Tuning token count and cost estimation.

### Import libraries

In [3]:
import tensorflow as tf
import json
import numpy as np
import pandas as pd
from collections import defaultdict
import dataclasses
import math

from vertexai.generative_models import Content, Part
from vertexai.preview.tokenization import get_tokenizer_for_model
from vertexai.generative_models import GenerativeModel
from google.cloud import aiplatform

### Load the dataset

Define the Google Cloud Storage URIs pointing to your training and validation datasets or continue using the URIs provided.

In [5]:
BASE_MODEL = "gemini-1.0-pro-002"  # @param ['gemini-1.0-pro-002']{type:"string"}
training_dataset_uri = "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl"  # @param {type:"string"}
validation_dataset_uri = "gs://cloud-samples-data/ai-platform/generative_ai/sft_validation_data.jsonl"  # @param {type:"string"}

tokenizer = get_tokenizer_for_model(BASE_MODEL)

We'll now load the dataset and conduct some basic statistical analysis to understand its structure and content.


In [None]:
with tf.io.gfile.GFile(training_dataset_uri) as dataset_jsonl_file:
    example_training_dataset = [
        json.loads(dataset_line) for dataset_line in dataset_jsonl_file
    ]

if validation_dataset_uri:
    with tf.io.gfile.GFile(validation_dataset_uri) as dataset_jsonl_file:
        example_validation_dataset = [
            json.loads(dataset_line) for dataset_line in dataset_jsonl_file
        ]

# Initial dataset stats
print("Num training examples:", len(example_training_dataset))
print("First example:")
for message in example_training_dataset[0]["messages"]:
    print(message)
    print(tokenizer.count_tokens(message.get("content")))

if example_validation_dataset:
    print("Num validation examples:", len(example_validation_dataset))

### Validate the format of the data

You can perform various error checks to validate that each tuning example in the dataset adheres to the format expected by the tuning API. Errors are categorized based on their nature for easier debugging.  
  
For how to prepare dataset for tuning gemini, please refer to this [tutorial](https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-about).

1. **Presence of Message List**: Checks if a `messages` list is present in each entry. Error type: `missing_messages_list`:
2. **Message Keys Check**: Validates that each message in the messages list contains the keys `role` and `content`. Error type: `message_missing_key`.
3. **Role Validation**: Ensures the role is one of `system`, `user`, or `model`. Error type: `unrecognized_role`. Note: only the first message can have `system` as role.
5. **Content Validation**: Verifies that content has textual data and is a string. Error type: `missing_content`.
6. **Consecutive Turns**. For the chat history, it is enforced that the message must can repeat in an alternating manner. Error type: `consecutive_turns`.


In [9]:
def validate_dataset_format(dataset):
    """Validates the dataset.

    Args:
      dataset_uri: The dataset uri to be validated.
    """
    format_errors = defaultdict(list)
    if not dataset or len(dataset) == 0:
        print("Input dataset file is empty or inaccessible.")
        return

    for row_idx, example in enumerate(dataset):
        # Verify presence of messages list
        if not isinstance(example, dict):
            format_errors["missing_messages_list"].append(row_idx)
            continue
        messages = example.get("messages", None)
        try:
            validate_messages(messages, format_errors, row_idx)
        except (TypeError, AttributeError, KeyError) as e:
            print("Invalid input during validation: %s", e)
            format_errors["invalid_input"].append(row_idx)

    if format_errors:
        print("Found errors for this dataset:")
        for k, v in format_errors.items():
            print(f"{k}: {v}")
    else:
        print("No errors found for this dataset.")


def validate_messages(messages, format_errors, row_index):
    """Validates messages list format."""
    if not messages:
        format_errors["missing_messages_list"].append(row_index)
        return

    # Check if the first role is for system instruction
    if messages[0].get("role", "").lower() == "system":
        messages = messages[1:]
    else:
        messages = messages[:]

    prev_role = None

    for message in messages:
        if "role" not in message or "content" not in message:
            format_errors["message_missing_key"].append(row_index)
            return

        if message.get("role", "").lower() not in ("user", "model"):
            format_errors["unrecognized_role"].append(row_index)
            return

        content = message.get("content", None)
        if not content:
            format_errors["missing_content"].append(row_index)
            return

            role = message.get("role", "").lower()
            # messages to have alternate turns.
            if role == prev_role:
                format_errors["consecutive_turns"].append(row_index)
                return

            prev_role = role

Now you can check the data for any issues.

In [18]:
validate_dataset_format(example_training_dataset)
if example_validation_dataset:
    validate_dataset_format(example_validation_dataset)

No errors found for this dataset.
No errors found for this dataset.


### Utils for dataset analysis and token counting

This section focuses on analyzing the structure and token counts of your datasets. You will also define some utility functions to streamline subsequent steps in the notebook.

* Load and inspect sample data from the training and validation datasets.
* Calculate token counts for messages to understand the dataset's characteristics.
* Define utility functions for calculating token distributions and dataset statistics. These will help assess the suitability of your data for supervised tuning and estimate potential costs.


In [12]:
@dataclasses.dataclass
class DatasetDistribution:
    """Dataset disbribution for given a population of values.

    It optionally contains a histogram consists of bucketized data representing
    the distribution of those values. The summary statistics are the sum, min,
    max, mean, median, p5, p95.

    Attributes:
      sum: Sum of the values in the population.
      max: Max of the values in the population.
      min: Min of the values in the population.
      mean: The arithmetic mean of the values in the population.
      median: The median of the values in the population.
      p5: P5 quantile of the values in the population.
      p95: P95 quantile of the values in the population.
    """

    sum: int | None = None
    max: float | None = None
    min: float | None = None
    mean: float | None = None
    median: float | None = None
    p5: float | None = None
    p95: float | None = None


@dataclasses.dataclass
class DatasetStatistics:
    """Dataset statistics used for dataset profiling.

    Attributes:
      total_number_of_dataset_examples: Number of tuning examples in the dataset.
      total_number_of_records_for_training: Number of tuning records after
        formatting. Each model turn in the chat message will be considered as a record for tuning.
      total_number_of_billable_tokens: Number of total billable tokens in the
        dataset.
      user_input_token_length_stats: Stats for input token length.
      user_output_token_length_stats: Stats for output token length.
    """

    total_number_of_dataset_examples: int | None = None
    total_number_of_records_for_training: int | None = None
    total_number_of_billable_tokens: int | None = None
    user_input_token_length_stats: DatasetDistribution | None = None
    user_output_token_length_stats: DatasetDistribution | None = None


MAX_TOKENS_PER_EXAMPLE = 32 * 1024
ESTIMATE_PADDING_TOKEN_PER_EXAMPLE = 8

In [15]:
def calculate_distribution_for_population(population) -> DatasetDistribution:
    """Calculates the distribution from the population of values.

    Args:
      population: The population of values to calculate distribution for.

    Returns:
      DatasetDistribution of the given population of values.
    """
    if not population:
        raise ValueError("population is empty")

    return DatasetDistribution(
        sum=np.sum(population),
        max=np.max(population),
        min=np.min(population),
        mean=np.mean(population),
        median=np.median(population),
        p5=np.percentile(population, 5, method="nearest"),
        p95=np.percentile(population, 95, method="nearest"),
    )


def get_token_distribution_for_one_tuning_dataset_example(example):
    model_turn_token_list = []
    input_token_list = []
    input = []
    n_too_long = 0
    number_of_records_for_training = 0  # each model turn in the chat message will be considered as a record for tuning
    for message in example["messages"]:
        role = message.get("role").lower()
        text = message.get("content")

        if role.lower() == "model":
            result = tokenizer.count_tokens(input)
            input_token_list.append(result.total_tokens)
            model_turn_token_list.append(tokenizer.count_tokens(text).total_tokens)
            number_of_records_for_training += 1
            if (
                result.total_tokens + tokenizer.count_tokens(text).total_tokens
                > MAX_TOKENS_PER_EXAMPLE
            ):
                n_too_long += 1
                break

        input.append(Content(role=role, parts=[Part.from_text(text)]))

    return (
        input_token_list,
        model_turn_token_list,
        number_of_records_for_training,
        np.sum(model_turn_token_list) + np.sum(input_token_list),
        n_too_long,
    )


def get_dataset_stats_for_dataset(dataset):
    results = map(get_token_distribution_for_one_tuning_dataset_example, dataset)
    user_input_token_list = []
    model_turn_token_list = []
    number_of_records_for_training = 0
    total_number_of_billable_tokens = 0
    n_too_long_for_dataset = 0
    for (
        input_token_list_per_example,
        model_turn_token_list_per_example,
        number_of_records_for_training_per_example,
        number_of_billable_token_per_example,
        n_too_long,
    ) in results:
        user_input_token_list.extend(input_token_list_per_example)
        model_turn_token_list.extend(model_turn_token_list_per_example)
        number_of_records_for_training += number_of_records_for_training_per_example
        total_number_of_billable_tokens += number_of_billable_token_per_example
        n_too_long_for_dataset += n_too_long

    print(
        f"\n{n_too_long_for_dataset} examples may be over the {MAX_TOKENS_PER_EXAMPLE} token limit, they will be truncated during tuning."
    )

    return DatasetStatistics(
        total_number_of_dataset_examples=len(dataset),
        total_number_of_records_for_training=number_of_records_for_training,
        total_number_of_billable_tokens=total_number_of_billable_tokens
        + number_of_records_for_training * ESTIMATE_PADDING_TOKEN_PER_EXAMPLE,
        user_input_token_length_stats=calculate_distribution_for_population(
            user_input_token_list
        ),
        user_output_token_length_stats=calculate_distribution_for_population(
            model_turn_token_list
        ),
    )


def print_dataset_stats(dataset):
    dataset_stats = get_dataset_stats_for_dataset(dataset)
    print("Below you can find the dataset statistics:")
    print(
        f"Total number of examples in the dataset: {dataset_stats.total_number_of_dataset_examples}"
    )
    print(
        f"Total number of records for training: {dataset_stats.total_number_of_records_for_training}"
    )
    print(
        f"Total number of billable tokens in the dataset: {dataset_stats.total_number_of_billable_tokens}"
    )
    print(
        f"User input token length distribution: {dataset_stats.user_input_token_length_stats}"
    )
    print(
        f"User output token length distribution: {dataset_stats.user_output_token_length_stats}"
    )
    return dataset_stats

Next you can analyze the structure and token counts of your datasets.

In [17]:
training_dataset_stats = print_dataset_stats(example_training_dataset)

if example_validation_dataset:
    validation_dataset_stats = print_dataset_stats(example_validation_dataset)


0 examples may be over the 32768 token limit, they will be truncated during tuning.
Below you can find the dataset statistics:
Total number of examples in the dataset: 500
Total number of records for training: 500
Total number of billable tokens in the dataset: 130300
User input token length distribution: DatasetDistribution(sum=109172, max=712, min=70, mean=218.344, median=198.5, p5=89, p95=403)
User output token length distribution: DatasetDistribution(sum=17128, max=124, min=12, mean=34.256, median=31.0, p5=17, p95=63)

0 examples may be over the 32768 token limit, they will be truncated during tuning.
Below you can find the dataset statistics:
Total number of examples in the dataset: 100
Total number of records for training: 100
Total number of billable tokens in the dataset: 28414
User input token length distribution: DatasetDistribution(sum=23922, max=829, min=70, mean=239.22, median=225.5, p5=92, p95=430)
User output token length distribution: DatasetDistribution(sum=3692, max=

### Cost Estimation for Supervised Fine-tuning
In this final section, you will estimate the total cost for supervised fine-tuning based on the number of tokens processed. The number of tokens used will be charged to you. Please refer to the [pricing page for the rate](https://cloud.google.com/vertex-ai/generative-ai/pricing#gemini-models).

**Important Note:** The final cost may vary slightly from this estimate due to dataset formatting and truncation logic during training.

The code calculates the total number of billable tokens by summing up the tokens from the training dataset and (if provided) the validation dataset. Then, it estimates the total cost by multiplying the total billable tokens with the number of training epochs (default is 4).

### Cost estimation

In this final section, you will estimate the total number of tokens used for supervised tuning. The number of tokens will be charged to you.

There might be a slight difference between the estimation and actual cost due to dataset formatting and truncation logic.

In [19]:
epoch_count = 4  # @param {type:"integer"}
if epoch_count is None:
    epoch_count = 4


total_number_of_billable_tokens = training_dataset_stats.total_number_of_billable_tokens


if validation_dataset_stats:
    total_number_of_billable_tokens += (
        validation_dataset_stats.total_number_of_billable_tokens
    )

print(f"Dataset has ~{total_number_of_billable_tokens} tokens that will be charged")
print(f"By default, you'll train for {epoch_count} epochs on this dataset.")
print(
    f"By default, you'll be charged for ~{epoch_count * total_number_of_billable_tokens} tokens."
)

Dataset has ~158714 tokens that will be charged
By default, you'll train for 4 epochs on this dataset.
By default, you'll be charged for ~634856 tokens.
