diff --git a/aiplatform/src/main/java/aiplatform/EmbeddingModelTuningSample.java b/aiplatform/src/main/java/aiplatform/EmbeddingModelTuningSample.java index 86805bbdeda..2e859958e41 100644 --- a/aiplatform/src/main/java/aiplatform/EmbeddingModelTuningSample.java +++ b/aiplatform/src/main/java/aiplatform/EmbeddingModelTuningSample.java @@ -29,11 +29,8 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; -// [END aiplatform_sdk_embedding_model_tuning] - public class EmbeddingModelTuningSample { public static void main(String[] args) throws IOException { - // [START aiplatform_sdk_embedding_model_tuning] // TODO(developer): Replace these variables before running this sample. String apiEndpoint = "us-central1-aiplatform.googleapis.com:443"; String project = "PROJECT"; @@ -41,12 +38,14 @@ public static void main(String[] args) throws IOException { String taskType = "DEFAULT"; String pipelineJobDisplayName = "PIPELINE_JOB_DISPLAY_NAME"; String outputDir = "OUTPUT_DIR"; - String queriesPath = "QUERIES"; - String corpusPath = "CORPUS"; - String trainLabelPath = "TRAIN_LABEL"; - String testLabelPath = "TEST_LABEL"; + String queriesPath = "QUERIES_PATH"; + String corpusPath = "CORPUS_PATH"; + String trainLabelPath = "TRAIN_LABEL_PATH"; + String testLabelPath = "TEST_LABEL_PATH"; + double learningRateMultiplier = 1.0; + int outputDimensionality = 768; int batchSize = 128; - int iterations = 1000; + int trainSteps = 1000; createEmbeddingModelTuningPipelineJob( apiEndpoint, @@ -59,12 +58,12 @@ public static void main(String[] args) throws IOException { corpusPath, trainLabelPath, testLabelPath, + learningRateMultiplier, + outputDimensionality, batchSize, - iterations); - // [END aiplatform_sdk_embedding_model_tuning] + trainSteps); } - // [START aiplatform_sdk_embedding_model_tuning] public static PipelineJob createEmbeddingModelTuningPipelineJob( String apiEndpoint, String project, @@ -76,28 +75,30 @@ public static PipelineJob createEmbeddingModelTuningPipelineJob( String corpusPath, String trainLabelPath, String testLabelPath, + double learningRateMultiplier, + int outputDimensionality, int batchSize, - int iterations) + int trainSteps) throws IOException { Matcher matcher = Pattern.compile("^(?\\w+-\\w+)").matcher(apiEndpoint); String location = matcher.matches() ? matcher.group("Location") : "us-central1"; String templateUri = - "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.2"; + "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.3"; PipelineServiceSettings settings = PipelineServiceSettings.newBuilder().setEndpoint(apiEndpoint).build(); try (PipelineServiceClient client = PipelineServiceClient.create(settings)) { Map parameterValues = Map.of( - "project", valueOf(project), "base_model_version_id", valueOf(baseModelVersionId), "task_type", valueOf(taskType), - "location", valueOf(location), "queries_path", valueOf(queriesPath), "corpus_path", valueOf(corpusPath), "train_label_path", valueOf(trainLabelPath), "test_label_path", valueOf(testLabelPath), + "learning_rate_multiplier", valueOf(learningRateMultiplier), + "output_dimensionality", valueOf(outputDimensionality), "batch_size", valueOf(batchSize), - "iterations", valueOf(iterations)); + "train_steps", valueOf(trainSteps)); PipelineJob pipelineJob = PipelineJob.newBuilder() .setTemplateUri(templateUri) @@ -124,5 +125,9 @@ private static Value valueOf(String s) { private static Value valueOf(int n) { return Value.newBuilder().setNumberValue(n).build(); } - // [END aiplatform_sdk_embedding_model_tuning] + + private static Value valueOf(double n) { + return Value.newBuilder().setNumberValue(n).build(); + } } +// [END aiplatform_sdk_embedding_model_tuning] diff --git a/aiplatform/src/test/java/aiplatform/EmbeddingModelTuningSampleTest.java b/aiplatform/src/test/java/aiplatform/EmbeddingModelTuningSampleTest.java index de07adffed3..61e8b77926b 100644 --- a/aiplatform/src/test/java/aiplatform/EmbeddingModelTuningSampleTest.java +++ b/aiplatform/src/test/java/aiplatform/EmbeddingModelTuningSampleTest.java @@ -53,7 +53,7 @@ public class EmbeddingModelTuningSampleTest { private static final String API_ENDPOINT = "us-central1-aiplatform.googleapis.com:443"; private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); - private static final String BASE_MODEL_VERSION_ID = "textembedding-gecko@003"; + private static final String BASE_MODEL_VERSION_ID = "text-embedding-004"; private static final String TASK_TYPE = "DEFAULT"; private static final String JOB_DISPLAY_NAME = "embedding-customization-pipeline-sample"; private static final String QUERIES = @@ -64,6 +64,8 @@ public class EmbeddingModelTuningSampleTest { private static final String TEST_LABEL = "gs://embedding-customization-pipeline/dataset/test.tsv"; private static final String OUTPUT_DIR = "gs://ucaip-samples-us-central1/training_pipeline_output"; + private static final double LEARNING_RATE_MULTIPLIER = 0.3; + private static final int OUTPUT_DIMENSIONALITY = 512; private static final int BATCH_SIZE = 50; private static final int ITERATIONS = 300; @@ -135,6 +137,8 @@ public void createPipelineJobEmbeddingModelTuningSample() throws IOException { CORPUS, TRAIN_LABEL, TEST_LABEL, + LEARNING_RATE_MULTIPLIER, + OUTPUT_DIMENSIONALITY, BATCH_SIZE, ITERATIONS); assertThat(job.getState()).isNotEqualTo(PipelineState.PIPELINE_STATE_FAILED);