diff --git a/genai/tuning/continuous_tuning_create.py b/genai/tuning/continuous_tuning_create.py new file mode 100644 index 00000000000..3e1ad9f52e3 --- /dev/null +++ b/genai/tuning/continuous_tuning_create.py @@ -0,0 +1,78 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def create_continuous_tuning_job(tuned_model_name: str, checkpoint_id: str) -> str: + # [START googlegenaisdk_continuous_tuning_create] + import time + + from google import genai + from google.genai.types import HttpOptions, TuningDataset, CreateTuningJobConfig + + # TODO(developer): Update and un-comment below line + # tuned_model_name = "projects/123456789012/locations/us-central1/models/1234567890@1" + # checkpoint_id = "1" + + client = genai.Client(http_options=HttpOptions(api_version="v1beta1")) + + training_dataset = TuningDataset( + gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini/text/sft_train_data.jsonl", + ) + validation_dataset = TuningDataset( + gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini/text/sft_validation_data.jsonl", + ) + + tuning_job = client.tunings.tune( + base_model=tuned_model_name, + training_dataset=training_dataset, + config=CreateTuningJobConfig( + tuned_model_display_name="Example tuning job", + validation_dataset=validation_dataset, + pre_tuned_model_checkpoint_id=checkpoint_id, + ), + ) + + running_states = set([ + "JOB_STATE_PENDING", + "JOB_STATE_RUNNING", + ]) + + while tuning_job.state in running_states: + print(tuning_job.state) + tuning_job = client.tunings.get(name=tuning_job.name) + time.sleep(60) + + print(tuning_job.tuned_model.model) + print(tuning_job.tuned_model.endpoint) + print(tuning_job.experiment) + # Example response: + # projects/123456789012/locations/us-central1/models/1234567890@2 + # projects/123456789012/locations/us-central1/endpoints/123456789012345 + # projects/123456789012/locations/us-central1/metadataStores/default/contexts/tuning-experiment-2025010112345678 + + if tuning_job.tuned_model.checkpoints: + for i, checkpoint in enumerate(tuning_job.tuned_model.checkpoints): + print(f"Checkpoint {i + 1}: ", checkpoint) + # Example response: + # Checkpoint 1: checkpoint_id='1' epoch=1 step=10 endpoint='projects/123456789012/locations/us-central1/endpoints/123456789000000' + # Checkpoint 2: checkpoint_id='2' epoch=2 step=20 endpoint='projects/123456789012/locations/us-central1/endpoints/123456789012345' + + # [END googlegenaisdk_continuous_tuning_create] + return tuning_job.name + + +if __name__ == "__main__": + pre_tuned_model_name = input("Pre-tuned model name: ") + pre_tuned_model_checkpoint_id = input("Pre-tuned model checkpoint id: ") + create_continuous_tuning_job(pre_tuned_model_name, pre_tuned_model_checkpoint_id) diff --git a/genai/tuning/requirements.txt b/genai/tuning/requirements.txt index d64aa5a57b1..391bcb9cc1f 100644 --- a/genai/tuning/requirements.txt +++ b/genai/tuning/requirements.txt @@ -1 +1 @@ -google-genai==1.30.0 +google-genai==1.39.1 diff --git a/genai/tuning/test_tuning_examples.py b/genai/tuning/test_tuning_examples.py index 0974c769483..2f54d55a0b0 100644 --- a/genai/tuning/test_tuning_examples.py +++ b/genai/tuning/test_tuning_examples.py @@ -20,6 +20,7 @@ from google.genai import types import pytest +import continuous_tuning_create import tuning_job_create import tuning_job_get import tuning_job_list @@ -306,3 +307,23 @@ def test_tuning_with_checkpoints_textgen_with_txt(mock_genai_client: MagicMock) call(model="test-endpoint-1", contents="Why is the sky blue?"), call(model="test-endpoint-2", contents="Why is the sky blue?"), ] + + +@patch("google.genai.Client") +def test_continuous_tuning_create(mock_genai_client: MagicMock) -> None: + # Mock the API response + mock_tuning_job = types.TuningJob( + name="test-tuning-job", + experiment="test-experiment", + tuned_model=types.TunedModel( + model="test-model-2", + endpoint="test-endpoint" + ) + ) + mock_genai_client.return_value.tunings.tune.return_value = mock_tuning_job + + response = continuous_tuning_create.create_continuous_tuning_job(tuned_model_name="test-model", checkpoint_id="1") + + mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1beta1")) + mock_genai_client.return_value.tunings.tune.assert_called_once() + assert response == "test-tuning-job"