In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI Supervised tuning token count and cost estimation.

<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/vertexai_supervised_tuning_token_count_and_cost_estimation.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Google Colaboratory logo"><br> Open in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fgemini%2Ftuning%2Fvertexai_supervised_tuning_token_count_and_cost_estimation.ipynb">
      <img width="32px" src="https://cloud.google.com/ml-engine/images/colab-enterprise-logo-32px.png" alt="Google Cloud Colab Enterprise logo"><br> Open in Colab Enterprise
    </a>
  </td>    
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/gemini/tuning/vertexai_supervised_tuning_token_count_and_cost_estimation.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo"><br> Open in Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/vertexai_supervised_tuning_token_count_and_cost_estimation.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
</table>

<div style="clear: both;"></div>

<b>Share to:</b>

<a href="https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/vertexai_supervised_tuning_token_count_and_cost_estimation.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg" alt="LinkedIn logo">
</a>

<a href="https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/vertexai_supervised_tuning_token_count_and_cost_estimation.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg" alt="Bluesky logo">
</a>

<a href="https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/vertexai_supervised_tuning_token_count_and_cost_estimation.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/53/X_logo_2023_original.svg" alt="X logo">
</a>

<a href="https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/vertexai_supervised_tuning_token_count_and_cost_estimation.ipynb" target="_blank">
  <img width="20px" src="https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png" alt="Reddit logo">
</a>

<a href="https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/vertexai_supervised_tuning_token_count_and_cost_estimation.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg" alt="Facebook logo">
</a>            

| | |
|-|-|
| Author(s) | [Lehui Liu](https://github.com/liulehui), [Erwin Huizenga](https://github.com/Huize501) |

## Overview

This notebook serves as a tool to preprocess and estimate token counts for tuning costs for tuning [`gemini-1.5-pro-002`](https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning).

At the end you will also find the code to preprocess and estimate token counts for tuning costs for tuning `gemini-1.0-pro-002`. If you get started please start with `gemini-1.5-pro-002`.

For how to prepare dataset for tuning gemini, please refer to this [tutorial](https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-about).

## Get started

### Install Vertex AI SDK and other required packages


In [24]:
%pip install --upgrade --user --quiet google-cloud-aiplatform[tokenization] numpy==1.26.4 tensorflow

### Restart runtime

To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel.

The restart might take a minute or longer. After it's restarted, continue to the next step.

In [25]:
import IPython

app = IPython.Application.instance()
app.kernel.do_shutdown(True)

{'status': 'ok', 'restart': True}

<div class="alert alert-block alert-warning">
<b>⚠️ The kernel is going to restart. Wait until it's finished before continuing to the next step. ⚠️</b>
</div>


### Authenticate your notebook environment (Colab only)

If you're running this notebook on Google Colab, run the cell below to authenticate your environment.

In [14]:
import sys

if "google.colab" in sys.modules:
    from google.colab import auth

    auth.authenticate_user()

### Set Google Cloud project information and initialize Vertex AI SDK

To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).

Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment).

In [17]:
PROJECT_ID = "[your-project-id]"  # @param {type:"string"}
LOCATION = "us-central1"  # @param {type:"string"}


import vertexai

vertexai.init(project=PROJECT_ID, location=LOCATION)

## Tuning token count and cost estimation: `Gemini 1.5 Pro` and `Gemini 1.5 Flash`

### Import libraries

In [115]:
from collections import defaultdict
import dataclasses
import json

from google.cloud import storage
import numpy as np
import tensorflow as tf
from vertexai.generative_models import Content, Part
from vertexai.preview.tokenization import get_tokenizer_for_model

### Load the dataset

This example is for text only. Define the Google Cloud Storage URIs pointing to your training and validation datasets or continue using the URIs provided.

In [135]:
BASE_MODEL = "gemini-1.5-pro-002"  # @param ['gemini-1.5-pro-002']{type:"string"}
training_dataset_uri = "gs://github-repo/generative-ai/gemini/tuning/train_sft_train_samples.jsonl"  # @param {type:"string"}
validation_dataset_uri = "gs://github-repo/generative-ai/gemini/tuning/val_sft_val_samples.jsonl"  # @param {type:"string"}

tokenizer = get_tokenizer_for_model("gemini-1.5-pro-001")

We'll now load the dataset and conduct some basic statistical analysis to understand its structure and content.


In [136]:
example_training_dataset = []
example_validation_dataset = []

try:
    with tf.io.gfile.GFile(training_dataset_uri) as dataset_jsonl_file:
        example_training_dataset = [
            json.loads(dataset_line) for dataset_line in dataset_jsonl_file
        ]
except KeyError as e:
    print(
        f"KeyError: Please check if your file '{training_dataset_uri}' is a JSONL file with correct JSON format. Error: {e}"
    )
    # Exit the script if there's an error in the training data
    import sys

    sys.exit(1)

print()

if validation_dataset_uri:
    try:
        with tf.io.gfile.GFile(validation_dataset_uri) as dataset_jsonl_file:
            example_validation_dataset = [
                json.loads(dataset_line) for dataset_line in dataset_jsonl_file
            ]
    except KeyError as e:
        print(
            f"KeyError: Please check if your file '{validation_dataset_uri}' is a JSONL file with correct JSON format. Error: {e}"
        )
        # Exit the script if there's an error in the validation data
        import sys

        sys.exit(1)

# Initial dataset stats
print("Num training examples:", len(example_training_dataset))
if example_training_dataset:  # Check if the list is not empty
    print("First example:")
    for item in example_training_dataset[0]["contents"]:
        print(item)
        text_content = item.get("parts", [{}])[0].get("text", "")
        print(tokenizer.count_tokens(text_content))  # Make sure 'tokenizer' is defined

if example_validation_dataset:
    print("Num validation examples:", len(example_validation_dataset))


Num training examples: 500
First example:
{'role': 'user', 'parts': [{'text': 'Honesty is usually the best policy. It is disrespectful to lie to someone. If you don\'t want to date someone, you should say so.  Sometimes it is easy to be honest. For example, you might be able to truthfully say, "No, thank you, I already have a date for that party." Other times, you might need to find a kinder way to be nice. Maybe you are not attracted to the person. Instead of bluntly saying that, try saying, "No, thank you, I just don\'t think we would be a good fit." Avoid making up a phony excuse. For instance, don\'t tell someone you will be out of town this weekend if you won\'t be. There\'s a chance that you might then run into them at the movies, which would definitely cause hurt feelings. A compliment sandwich is a really effective way to provide feedback. Essentially, you "sandwich" your negative comment between two positive things. Try using this method when you need to reject someone.  An e

You can perform various error checks to validate that each tuning example in the dataset adheres to the format expected by the tuning API. Errors are categorized based on their nature for easier debugging.  
  
For how to prepare dataset for tuning gemini, please refer to this [tutorial](https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-about).

1. **Presence of System Instruction**: Checks if there is a system instruction and if its there for all rows. System instruction is optional. Warning type: `systemInstruction is missing in some rows`.
2. **Presence of Contents List:** Checks if a `contents` list is present in each entry. Error type: `missing_contents_list`.
3. **Content Item Format:** Validates that each item in the `contents` list is a dictionary. Error type: `invalid_content_item`.
4. **Content Item Format:** Validates that each item in the `contents` list is a dictionary. Error type: `invalid_content_item`.
5. **Role Validation:** Checks if the role is one of `user`, or `model` for `contents` list and system for `systemInstruction` list. Error type: `unrecognized_role`.
6. **Parts List Validation:** Verifies that the `parts` key contains a list. Error type: `missing_or_invalid_parts`.
7. **Part Format:** Checks if each part in the `parts` list is a dictionary and contains the key `text`. Error type: `invalid_part`.
8. **Text Validation:** Ensures that the `text` key has textual data and is a string. Error type: `missing_text`.
9. **Consecutive Turns:** For the chat history, it is enforced that the message roles alternate (user, then model, then user, etc.). Error type: `consecutive_turns`. This check is not applicable for systemInstruction.


In [137]:
from collections import defaultdict


def validate_dataset_format(dataset):
    """Validates the dataset.

    Args:
      dataset_uri: The dataset uri to be validated.
    """
    format_errors = defaultdict(list)
    system_instruction_missing = False  # Flag to track missing systemInstruction

    if not dataset or len(dataset) == 0:
        print("Input dataset file is empty or inaccessible.")
        return

    for row_idx, example in enumerate(dataset):
        # Verify presence of contents list
        if not isinstance(example, dict):
            format_errors["invalid_input"].append(row_idx)
            continue

        # Check for systemInstruction and validate if present
        system_instruction = example.get("systemInstruction", None)
        if system_instruction:
            try:
                # Validate the list within "parts"
                validate_contents(
                    system_instruction.get("parts", []),
                    format_errors,
                    row_idx,
                    is_system_instruction=True,
                )
            except (TypeError, AttributeError, KeyError) as e:
                print("Invalid input during system instruction validation: %s", e)
                format_errors["invalid_system_instruction"].append(row_idx)
        else:
            system_instruction_missing = True  # Set the flag if missing

        contents = example.get("contents", None)
        if not contents:
            format_errors["missing_contents_list"].append(row_idx)
            continue
        try:
            validate_contents(contents, format_errors, row_idx)
        except (TypeError, AttributeError, KeyError) as e:
            print("Invalid input during contents validation: %s", e)
            format_errors["invalid_input"].append(row_idx)

    if format_errors:
        print("Found errors for this dataset:")
        for k, v in format_errors.items():
            print(f"{k}: {v}")
    else:
        print("No errors found for this dataset.")

    # Print warning only once after processing all rows
    if system_instruction_missing:
        print("Warning: systemInstruction is missing in some rows.")


def validate_contents(contents, format_errors, row_index, is_system_instruction=False):
    """Validates contents list format."""

    if not isinstance(contents, list):
        format_errors["invalid_contents_list"].append(row_index)
        return

    prev_role = None
    for content_item in contents:  # Iterate over content items in the "contents" list
        if not isinstance(content_item, dict):
            format_errors["invalid_content_item"].append(row_index)
            return

        # Skip key checks for system instructions
        if not is_system_instruction and (
            "role" not in content_item or "parts" not in content_item
        ):
            format_errors["content_item_missing_key"].append(row_index)
            return

        # ... (rest of the validation logic remains the same)

In [138]:
validate_dataset_format(example_training_dataset)
if example_validation_dataset:
    validate_dataset_format(example_validation_dataset)

No errors found for this dataset.
No errors found for this dataset.


### Utils for dataset analysis and token counting

This section focuses on analyzing the structure and token counts of your datasets. You will also define some utility functions to streamline subsequent steps in the notebook.

* Load and inspect sample data from the training and validation datasets.
* Calculate token counts for messages to understand the dataset's characteristics.
* Define utility functions for calculating token distributions and dataset statistics. These will help assess the suitability of your data for supervised tuning and estimate potential costs.

In [140]:
@dataclasses.dataclass
class DatasetDistribution:
    """Dataset disbribution for given a population of values.

    It optionally contains a histogram consists of bucketized data representing
    the distribution of those values. The summary statistics are the sum, min,
    max, mean, median, p5, p95.

    Attributes:
      sum: Sum of the values in the population.
      max: Max of the values in the population.
      min: Min of the values in the population.
      mean: The arithmetic mean of the values in the population.
      median: The median of the values in the population.
      p5: P5 quantile of the values in the population.
      p95: P95 quantile of the values in the population.
    """

    sum: int | None = None
    max: float | None = None
    min: float | None = None
    mean: float | None = None
    median: float | None = None
    p5: float | None = None
    p95: float | None = None


@dataclasses.dataclass
class DatasetStatistics:
    """Dataset statistics used for dataset profiling.

    Attributes:
      total_number_of_dataset_examples: Number of tuning examples in the dataset.
      total_number_of_records_for_training: Number of tuning records after
        formatting. Each model turn in the chat message will be considered as a record for tuning.
      total_number_of_billable_tokens: Number of total billable tokens in the
        dataset.
      user_input_token_length_stats: Stats for input token length.
      user_output_token_length_stats: Stats for output token length.
    """

    total_number_of_dataset_examples: int | None = None
    total_number_of_records_for_training: int | None = None
    total_number_of_billable_tokens: int | None = None
    user_input_token_length_stats: DatasetDistribution | None = None
    user_output_token_length_stats: DatasetDistribution | None = None


MAX_TOKENS_PER_EXAMPLE = 32 * 1024
ESTIMATE_PADDING_TOKEN_PER_EXAMPLE = 8

In [141]:
def calculate_distribution_for_population(population) -> DatasetDistribution:
    """Calculates the distribution from the population of values.

    Args:
      population: The population of values to calculate distribution for.

    Returns:
      DatasetDistribution of the given population of values.
    """
    if not population:
        raise ValueError("population is empty")

    return DatasetDistribution(
        sum=np.sum(population),
        max=np.max(population),
        min=np.min(population),
        mean=np.mean(population),
        median=np.median(population),
        p5=np.percentile(population, 5, method="nearest"),
        p95=np.percentile(population, 95, method="nearest"),
    )


def get_token_distribution_for_one_tuning_dataset_example(example):
    model_turn_token_list = []
    input_token_list = []
    input = []
    n_too_long = 0
    number_of_records_for_training = 0  # each model turn in the chat message will be considered as a record for tuning

    # Handle optional systemInstruction
    system_instruction = example.get("systemInstruction")
    if system_instruction:
        text = system_instruction.get("parts")[0].get(
            "text"
        )  # Assuming single part in system instruction
        input.append(Content(role="system", parts=[Part.from_text(text)]))

    for content_item in example["contents"]:
        role = content_item.get("role").lower()
        text = content_item.get("parts")[0].get(
            "text"
        )  # Assuming single part in content item

        if role.lower() == "model":
            result = tokenizer.count_tokens(input)
            input_token_list.append(result.total_tokens)
            model_turn_token_list.append(tokenizer.count_tokens(text).total_tokens)
            number_of_records_for_training += 1
            if (
                result.total_tokens + tokenizer.count_tokens(text).total_tokens
                > MAX_TOKENS_PER_EXAMPLE
            ):
                n_too_long += 1
                break

        input.append(Content(role=role, parts=[Part.from_text(text)]))

    return (
        input_token_list,
        model_turn_token_list,
        number_of_records_for_training,
        np.sum(model_turn_token_list) + np.sum(input_token_list),
        n_too_long,
    )


def get_dataset_stats_for_dataset(dataset):
    results = map(get_token_distribution_for_one_tuning_dataset_example, dataset)
    user_input_token_list = []
    model_turn_token_list = []
    number_of_records_for_training = 0
    total_number_of_billable_tokens = 0
    n_too_long_for_dataset = 0
    for (
        input_token_list_per_example,
        model_turn_token_list_per_example,
        number_of_records_for_training_per_example,
        number_of_billable_token_per_example,
        n_too_long,
    ) in results:
        user_input_token_list.extend(input_token_list_per_example)
        model_turn_token_list.extend(model_turn_token_list_per_example)
        number_of_records_for_training += number_of_records_for_training_per_example
        total_number_of_billable_tokens += number_of_billable_token_per_example
        n_too_long_for_dataset += n_too_long

    print(
        f"\n{n_too_long_for_dataset} examples may be over the {MAX_TOKENS_PER_EXAMPLE} token limit, they will be truncated during tuning."
    )

    return DatasetStatistics(
        total_number_of_dataset_examples=len(dataset),
        total_number_of_records_for_training=number_of_records_for_training,
        total_number_of_billable_tokens=total_number_of_billable_tokens
        + number_of_records_for_training * ESTIMATE_PADDING_TOKEN_PER_EXAMPLE,
        user_input_token_length_stats=calculate_distribution_for_population(
            user_input_token_list
        ),
        user_output_token_length_stats=calculate_distribution_for_population(
            model_turn_token_list
        ),
    )


def print_dataset_stats(dataset):
    dataset_stats = get_dataset_stats_for_dataset(dataset)
    print("Below you can find the dataset statistics:")
    print(
        f"Total number of examples in the dataset: {dataset_stats.total_number_of_dataset_examples}"
    )
    print(
        f"Total number of records for training: {dataset_stats.total_number_of_records_for_training}"
    )
    print(
        f"Total number of billable tokens in the dataset: {dataset_stats.total_number_of_billable_tokens}"
    )
    print(
        f"User input token length distribution: {dataset_stats.user_input_token_length_stats}"
    )
    print(
        f"User output token length distribution: {dataset_stats.user_output_token_length_stats}"
    )
    return dataset_stats

Next you can analyze the structure and token counts of your datasets.

In [142]:
training_dataset_stats = print_dataset_stats(example_training_dataset)

if example_validation_dataset:
    validation_dataset_stats = print_dataset_stats(example_validation_dataset)


0 examples may be over the 32768 token limit, they will be truncated during tuning.
Below you can find the dataset statistics:
Total number of examples in the dataset: 500
Total number of records for training: 500
Total number of billable tokens in the dataset: 259243
User input token length distribution: DatasetDistribution(sum=233592, max=2932, min=25, mean=467.184, median=414.5, p5=101, p95=1002)
User output token length distribution: DatasetDistribution(sum=21651, max=237, min=3, mean=43.302, median=37.0, p5=15, p95=89)

0 examples may be over the 32768 token limit, they will be truncated during tuning.
Below you can find the dataset statistics:
Total number of examples in the dataset: 100
Total number of records for training: 100
Total number of billable tokens in the dataset: 50154
User input token length distribution: DatasetDistribution(sum=45535, max=1418, min=29, mean=455.35, median=413.5, p5=145, p95=846)
User output token length distribution: DatasetDistribution(sum=3819, 

### Cost Estimation for Supervised Fine-tuning
In this final section, you will estimate the total cost for supervised fine-tuning based on the number of tokens processed. The number of tokens used will be charged to you. Please refer to the [pricing page for the rate](https://cloud.google.com/vertex-ai/generative-ai/pricing#gemini-models).

**Important Note:** The final cost may vary slightly from this estimate due to dataset formatting and truncation logic during training.

The code calculates the total number of billable tokens by summing up the tokens from the training dataset and (if provided) the validation dataset. Then, it estimates the total cost by multiplying the total billable tokens with the number of training epochs (default is 4).

In [143]:
epoch_count = 4  # @param {type:"integer"}
if epoch_count is None:
    epoch_count = 4


total_number_of_billable_tokens = training_dataset_stats.total_number_of_billable_tokens


if validation_dataset_stats:
    total_number_of_billable_tokens += (
        validation_dataset_stats.total_number_of_billable_tokens
    )

print(f"Dataset has ~{total_number_of_billable_tokens} tokens that will be charged")
print(f"By default, you'll train for {epoch_count} epochs on this dataset.")
print(
    f"By default, you'll be charged for ~{epoch_count * total_number_of_billable_tokens} tokens."
)

Dataset has ~309397 tokens that will be charged
By default, you'll train for 4 epochs on this dataset.
By default, you'll be charged for ~1237588 tokens.


## Convert `Gemini 1.0 Pro` fine-tuning dataset to `Gemini 1.5 Pro` dataset.

In [144]:
source_uri = (
    "gs://next-23-tuning-demo/example-fine-tuning.json"  # @param {type:"string"}
)
destination_uri = (
    "gs://next-23-tuning-demo/new-data-format.jsonl"  # @param {type:"string"}
)
system_instruction = "You are a helpful and friendly AI assistant"  # Optional

In [117]:
def convert_jsonl_format(
    source_uri: str,
    destination_uri: str,
    system_instruction: str = None,
):
    """Converts a JSONL file from the old format to the new format.

    Args:
        source_uri: Google Cloud Storage URI of the source JSONL file.
        destination_uri: Google Cloud Storage URI for the new JSONL file.
        system_instruction: Optional system instruction text.
                            If provided, it will be added as "systemInstruction" in the new format.
    """
    storage_client = storage.Client()

    # Extract bucket and file name from source URI
    source_bucket_name, source_blob_name = extract_bucket_and_blob_name(source_uri)
    source_bucket = storage_client.bucket(source_bucket_name)
    source_blob = source_bucket.blob(source_blob_name)

    # Extract bucket and file name from destination URI
    dest_bucket_name, dest_blob_name = extract_bucket_and_blob_name(destination_uri)
    dest_bucket = storage_client.bucket(dest_bucket_name)
    dest_blob = dest_bucket.blob(dest_blob_name)

    # Download the source JSONL file
    source_data = source_blob.download_as_string().decode("utf-8")

    new_data = []
    for line in source_data.splitlines():
        try:
            json_data = json.loads(line)
            new_json_data = convert_json_object(json_data, system_instruction)
            new_data.append(new_json_data)
        except json.JSONDecodeError as e:
            print(f"Skipping invalid JSON line: {line} - Error: {e}")

    # Upload the new JSONL file
    new_data_str = "\n".join([json.dumps(data) for data in new_data])
    dest_blob.upload_from_string(new_data_str)

    print(f"Successfully converted and uploaded to {destination_uri}")


def convert_json_object(json_data: dict, system_instruction: str = None) -> dict:
    """Converts a single JSON object from the old format to the new format.

    Args:
        json_data: The JSON object to convert.
        system_instruction: Optional system instruction text.

    Returns:
        The converted JSON object.
    """
    new_json_data = {}  # Create an empty dict instead of initializing with "contents"

    if system_instruction:
        new_json_data["systemInstruction"] = {
            "role": "system",
            "parts": [{"text": system_instruction}],
        }

    new_json_data["contents"] = []  # Initialize "contents" after "systemInstruction"

    for message in json_data.get("messages", []):
        new_message = {"role": message["role"], "parts": [{"text": message["content"]}]}
        new_json_data["contents"].append(new_message)

    return new_json_data


def extract_bucket_and_blob_name(gcs_uri: str) -> tuple:
    """Extracts the bucket name and blob name from a Google Cloud Storage URI.

    Args:
        gcs_uri: The Google Cloud Storage URI (e.g., "gs://my-bucket/my-file.jsonl")

    Returns:
        A tuple containing the bucket name and blob name.
    """
    if not gcs_uri.startswith("gs://"):
        raise ValueError("Invalid Google Cloud Storage URI")
    parts = gcs_uri[5:].split("/", 1)
    return parts[0], parts[1]

In [118]:
convert_jsonl_format(source_uri, destination_uri, system_instruction)

Successfully converted and uploaded to gs://next-23-tuning-demo/new-data-format.jsonl


## Tuning token count and cost estimation for `Gemini 1.0 Pro` legacy users.

Only use this part if you still use `Gemini 1.0 Pro`. Its best to upgrade to using [`gemini-1.5-pro-002`](https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning).

### Load the dataset

Define the Google Cloud Storage URIs pointing to your training and validation datasets or continue using the URIs provided.

In [None]:
BASE_MODEL = "gemini-1.0-pro-002"  # @param ['gemini-1.0-pro-002']{type:"string"}
training_dataset_uri = "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl"  # @param {type:"string"}
validation_dataset_uri = "gs://cloud-samples-data/ai-platform/generative_ai/sft_validation_data.jsonl"  # @param {type:"string"}

tokenizer = get_tokenizer_for_model(BASE_MODEL)

We'll now load the dataset and conduct some basic statistical analysis to understand its structure and content.


In [None]:
with tf.io.gfile.GFile(training_dataset_uri) as dataset_jsonl_file:
    example_training_dataset = [
        json.loads(dataset_line) for dataset_line in dataset_jsonl_file
    ]

if validation_dataset_uri:
    with tf.io.gfile.GFile(validation_dataset_uri) as dataset_jsonl_file:
        example_validation_dataset = [
            json.loads(dataset_line) for dataset_line in dataset_jsonl_file
        ]

# Initial dataset stats
print("Num training examples:", len(example_training_dataset))
print("First example:")
for message in example_training_dataset[0]["messages"]:
    print(message)
    print(tokenizer.count_tokens(message.get("content")))

if example_validation_dataset:
    print("Num validation examples:", len(example_validation_dataset))

Num training examples: 500
First example:
{'role': 'user', 'content': "#Person1#: Hi, Mr. Smith. I'm Doctor Hawkins. Why are you here today?\n#Person2#: I found it would be a good idea to get a check-up.\n#Person1#: Yes, well, you haven't had one for 5 years. You should have one every year.\n#Person2#: I know. I figure as long as there is nothing wrong, why go see the doctor?\n#Person1#: Well, the best way to avoid serious illnesses is to find out about them early. So try to come at least once a year for your own good.\n#Person2#: Ok.\n#Person1#: Let me see here. Your eyes and ears look fine. Take a deep breath, please. Do you smoke, Mr. Smith?\n#Person2#: Yes.\n#Person1#: Smoking is the leading cause of lung cancer and heart disease, you know. You really should quit.\n#Person2#: I've tried hundreds of times, but I just can't seem to kick the habit.\n#Person1#: Well, we have classes and some medications that might help. I'll give you more information before you leave.\n#Person2#: Ok, t

### Validate the format of the data

You can perform various error checks to validate that each tuning example in the dataset adheres to the format expected by the tuning API. Errors are categorized based on their nature for easier debugging.  
  
For how to prepare dataset for tuning gemini, please refer to this [tutorial](https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-about).

1. **Presence of Message List**: Checks if a `messages` list is present in each entry. Error type: `missing_messages_list`:
2. **Message Keys Check**: Validates that each message in the messages list contains the keys `role` and `content`. Error type: `message_missing_key`.
3. **Role Validation**: Ensures the role is one of `system`, `user`, or `model`. Error type: `unrecognized_role`. Note: only the first message can have `system` as role.
5. **Content Validation**: Verifies that content has textual data and is a string. Error type: `missing_content`.
6. **Consecutive Turns**. For the chat history, it is enforced that the message must can repeat in an alternating manner. Error type: `consecutive_turns`.


In [None]:
def validate_dataset_format(dataset):
    """Validates the dataset.

    Args:
      dataset_uri: The dataset uri to be validated.
    """
    format_errors = defaultdict(list)
    if not dataset or len(dataset) == 0:
        print("Input dataset file is empty or inaccessible.")
        return

    for row_idx, example in enumerate(dataset):
        # Verify presence of messages list
        if not isinstance(example, dict):
            format_errors["missing_messages_list"].append(row_idx)
            continue
        messages = example.get("messages", None)
        try:
            validate_messages(messages, format_errors, row_idx)
        except (TypeError, AttributeError, KeyError) as e:
            print("Invalid input during validation: %s", e)
            format_errors["invalid_input"].append(row_idx)

    if format_errors:
        print("Found errors for this dataset:")
        for k, v in format_errors.items():
            print(f"{k}: {v}")
    else:
        print("No errors found for this dataset.")


def validate_messages(messages, format_errors, row_index):
    """Validates messages list format."""
    if not messages:
        format_errors["missing_messages_list"].append(row_index)
        return

    # Check if the first role is for system instruction
    if messages[0].get("role", "").lower() == "system":
        messages = messages[1:]
    else:
        messages = messages[:]

    prev_role = None

    for message in messages:
        if "role" not in message or "content" not in message:
            format_errors["message_missing_key"].append(row_index)
            return

        if message.get("role", "").lower() not in ("user", "model"):
            format_errors["unrecognized_role"].append(row_index)
            return

        content = message.get("content", None)
        if not content:
            format_errors["missing_content"].append(row_index)
            return

            role = message.get("role", "").lower()
            # messages to have alternate turns.
            if role == prev_role:
                format_errors["consecutive_turns"].append(row_index)
                return

            prev_role = role

Now you can check the data for any issues.

In [None]:
validate_dataset_format(example_training_dataset)
if example_validation_dataset:
    validate_dataset_format(example_validation_dataset)

No errors found for this dataset.
No errors found for this dataset.


### Utils for dataset analysis and token counting

This section focuses on analyzing the structure and token counts of your datasets. You will also define some utility functions to streamline subsequent steps in the notebook.

* Load and inspect sample data from the training and validation datasets.
* Calculate token counts for messages to understand the dataset's characteristics.
* Define utility functions for calculating token distributions and dataset statistics. These will help assess the suitability of your data for supervised tuning and estimate potential costs.


In [None]:
@dataclasses.dataclass
class DatasetDistribution:
    """Dataset disbribution for given a population of values.

    It optionally contains a histogram consists of bucketized data representing
    the distribution of those values. The summary statistics are the sum, min,
    max, mean, median, p5, p95.

    Attributes:
      sum: Sum of the values in the population.
      max: Max of the values in the population.
      min: Min of the values in the population.
      mean: The arithmetic mean of the values in the population.
      median: The median of the values in the population.
      p5: P5 quantile of the values in the population.
      p95: P95 quantile of the values in the population.
    """

    sum: int | None = None
    max: float | None = None
    min: float | None = None
    mean: float | None = None
    median: float | None = None
    p5: float | None = None
    p95: float | None = None


@dataclasses.dataclass
class DatasetStatistics:
    """Dataset statistics used for dataset profiling.

    Attributes:
      total_number_of_dataset_examples: Number of tuning examples in the dataset.
      total_number_of_records_for_training: Number of tuning records after
        formatting. Each model turn in the chat message will be considered as a record for tuning.
      total_number_of_billable_tokens: Number of total billable tokens in the
        dataset.
      user_input_token_length_stats: Stats for input token length.
      user_output_token_length_stats: Stats for output token length.
    """

    total_number_of_dataset_examples: int | None = None
    total_number_of_records_for_training: int | None = None
    total_number_of_billable_tokens: int | None = None
    user_input_token_length_stats: DatasetDistribution | None = None
    user_output_token_length_stats: DatasetDistribution | None = None


MAX_TOKENS_PER_EXAMPLE = 32 * 1024
ESTIMATE_PADDING_TOKEN_PER_EXAMPLE = 8

In [None]:
def calculate_distribution_for_population(population) -> DatasetDistribution:
    """Calculates the distribution from the population of values.

    Args:
      population: The population of values to calculate distribution for.

    Returns:
      DatasetDistribution of the given population of values.
    """
    if not population:
        raise ValueError("population is empty")

    return DatasetDistribution(
        sum=np.sum(population),
        max=np.max(population),
        min=np.min(population),
        mean=np.mean(population),
        median=np.median(population),
        p5=np.percentile(population, 5, method="nearest"),
        p95=np.percentile(population, 95, method="nearest"),
    )


def get_token_distribution_for_one_tuning_dataset_example(example):
    model_turn_token_list = []
    input_token_list = []
    input = []
    n_too_long = 0
    number_of_records_for_training = 0  # each model turn in the chat message will be considered as a record for tuning
    for message in example["messages"]:
        role = message.get("role").lower()
        text = message.get("content")

        if role.lower() == "model":
            result = tokenizer.count_tokens(input)
            input_token_list.append(result.total_tokens)
            model_turn_token_list.append(tokenizer.count_tokens(text).total_tokens)
            number_of_records_for_training += 1
            if (
                result.total_tokens + tokenizer.count_tokens(text).total_tokens
                > MAX_TOKENS_PER_EXAMPLE
            ):
                n_too_long += 1
                break

        input.append(Content(role=role, parts=[Part.from_text(text)]))

    return (
        input_token_list,
        model_turn_token_list,
        number_of_records_for_training,
        np.sum(model_turn_token_list) + np.sum(input_token_list),
        n_too_long,
    )


def get_dataset_stats_for_dataset(dataset):
    results = map(get_token_distribution_for_one_tuning_dataset_example, dataset)
    user_input_token_list = []
    model_turn_token_list = []
    number_of_records_for_training = 0
    total_number_of_billable_tokens = 0
    n_too_long_for_dataset = 0
    for (
        input_token_list_per_example,
        model_turn_token_list_per_example,
        number_of_records_for_training_per_example,
        number_of_billable_token_per_example,
        n_too_long,
    ) in results:
        user_input_token_list.extend(input_token_list_per_example)
        model_turn_token_list.extend(model_turn_token_list_per_example)
        number_of_records_for_training += number_of_records_for_training_per_example
        total_number_of_billable_tokens += number_of_billable_token_per_example
        n_too_long_for_dataset += n_too_long

    print(
        f"\n{n_too_long_for_dataset} examples may be over the {MAX_TOKENS_PER_EXAMPLE} token limit, they will be truncated during tuning."
    )

    return DatasetStatistics(
        total_number_of_dataset_examples=len(dataset),
        total_number_of_records_for_training=number_of_records_for_training,
        total_number_of_billable_tokens=total_number_of_billable_tokens
        + number_of_records_for_training * ESTIMATE_PADDING_TOKEN_PER_EXAMPLE,
        user_input_token_length_stats=calculate_distribution_for_population(
            user_input_token_list
        ),
        user_output_token_length_stats=calculate_distribution_for_population(
            model_turn_token_list
        ),
    )


def print_dataset_stats(dataset):
    dataset_stats = get_dataset_stats_for_dataset(dataset)
    print("Below you can find the dataset statistics:")
    print(
        f"Total number of examples in the dataset: {dataset_stats.total_number_of_dataset_examples}"
    )
    print(
        f"Total number of records for training: {dataset_stats.total_number_of_records_for_training}"
    )
    print(
        f"Total number of billable tokens in the dataset: {dataset_stats.total_number_of_billable_tokens}"
    )
    print(
        f"User input token length distribution: {dataset_stats.user_input_token_length_stats}"
    )
    print(
        f"User output token length distribution: {dataset_stats.user_output_token_length_stats}"
    )
    return dataset_stats

Next you can analyze the structure and token counts of your datasets.

In [None]:
training_dataset_stats = print_dataset_stats(example_training_dataset)

if example_validation_dataset:
    validation_dataset_stats = print_dataset_stats(example_validation_dataset)


0 examples may be over the 32768 token limit, they will be truncated during tuning.
Below you can find the dataset statistics:
Total number of examples in the dataset: 500
Total number of records for training: 500
Total number of billable tokens in the dataset: 130300
User input token length distribution: DatasetDistribution(sum=109172, max=712, min=70, mean=218.344, median=198.5, p5=89, p95=403)
User output token length distribution: DatasetDistribution(sum=17128, max=124, min=12, mean=34.256, median=31.0, p5=17, p95=63)

0 examples may be over the 32768 token limit, they will be truncated during tuning.
Below you can find the dataset statistics:
Total number of examples in the dataset: 100
Total number of records for training: 100
Total number of billable tokens in the dataset: 28414
User input token length distribution: DatasetDistribution(sum=23922, max=829, min=70, mean=239.22, median=225.5, p5=92, p95=430)
User output token length distribution: DatasetDistribution(sum=3692, max=

### Cost Estimation for Supervised Fine-tuning
In this final section, you will estimate the total cost for supervised fine-tuning based on the number of tokens processed. The number of tokens used will be charged to you. Please refer to the [pricing page for the rate](https://cloud.google.com/vertex-ai/generative-ai/pricing#gemini-models).

**Important Note:** The final cost may vary slightly from this estimate due to dataset formatting and truncation logic during training.

The code calculates the total number of billable tokens by summing up the tokens from the training dataset and (if provided) the validation dataset. Then, it estimates the total cost by multiplying the total billable tokens with the number of training epochs (default is 4).

### Cost estimation

In this final section, you will estimate the total number of tokens used for supervised tuning. The number of tokens will be charged to you.

There might be a slight difference between the estimation and actual cost due to dataset formatting and truncation logic.

In [None]:
epoch_count = 4  # @param {type:"integer"}
if epoch_count is None:
    epoch_count = 4


total_number_of_billable_tokens = training_dataset_stats.total_number_of_billable_tokens


if validation_dataset_stats:
    total_number_of_billable_tokens += (
        validation_dataset_stats.total_number_of_billable_tokens
    )

print(f"Dataset has ~{total_number_of_billable_tokens} tokens that will be charged")
print(f"By default, you'll train for {epoch_count} epochs on this dataset.")
print(
    f"By default, you'll be charged for ~{epoch_count * total_number_of_billable_tokens} tokens."
)

Dataset has ~158714 tokens that will be charged
By default, you'll train for 4 epochs on this dataset.
By default, you'll be charged for ~634856 tokens.
