Skip to content

Commit

Permalink
Update embedding tuning sample to use the v1.1.3 pipeline. (#9333)
Browse files Browse the repository at this point in the history
  • Loading branch information
skarukas committed May 23, 2024
1 parent dade187 commit b770340
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 18 deletions.
39 changes: 22 additions & 17 deletions aiplatform/src/main/java/aiplatform/EmbeddingModelTuningSample.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,23 @@
import java.util.regex.Matcher;
import java.util.regex.Pattern;

// [END aiplatform_sdk_embedding_model_tuning]

public class EmbeddingModelTuningSample {
public static void main(String[] args) throws IOException {
// [START aiplatform_sdk_embedding_model_tuning]
// TODO(developer): Replace these variables before running this sample.
String apiEndpoint = "us-central1-aiplatform.googleapis.com:443";
String project = "PROJECT";
String baseModelVersionId = "BASE_MODEL_VERSION_ID";
String taskType = "DEFAULT";
String pipelineJobDisplayName = "PIPELINE_JOB_DISPLAY_NAME";
String outputDir = "OUTPUT_DIR";
String queriesPath = "QUERIES";
String corpusPath = "CORPUS";
String trainLabelPath = "TRAIN_LABEL";
String testLabelPath = "TEST_LABEL";
String queriesPath = "QUERIES_PATH";
String corpusPath = "CORPUS_PATH";
String trainLabelPath = "TRAIN_LABEL_PATH";
String testLabelPath = "TEST_LABEL_PATH";
double learningRateMultiplier = 1.0;
int outputDimensionality = 768;
int batchSize = 128;
int iterations = 1000;
int trainSteps = 1000;

createEmbeddingModelTuningPipelineJob(
apiEndpoint,
Expand All @@ -59,12 +58,12 @@ public static void main(String[] args) throws IOException {
corpusPath,
trainLabelPath,
testLabelPath,
learningRateMultiplier,
outputDimensionality,
batchSize,
iterations);
// [END aiplatform_sdk_embedding_model_tuning]
trainSteps);
}

// [START aiplatform_sdk_embedding_model_tuning]
public static PipelineJob createEmbeddingModelTuningPipelineJob(
String apiEndpoint,
String project,
Expand All @@ -76,28 +75,30 @@ public static PipelineJob createEmbeddingModelTuningPipelineJob(
String corpusPath,
String trainLabelPath,
String testLabelPath,
double learningRateMultiplier,
int outputDimensionality,
int batchSize,
int iterations)
int trainSteps)
throws IOException {
Matcher matcher = Pattern.compile("^(?<Location>\\w+-\\w+)").matcher(apiEndpoint);
String location = matcher.matches() ? matcher.group("Location") : "us-central1";
String templateUri =
"https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.2";
"https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.3";
PipelineServiceSettings settings =
PipelineServiceSettings.newBuilder().setEndpoint(apiEndpoint).build();
try (PipelineServiceClient client = PipelineServiceClient.create(settings)) {
Map<String, Value> parameterValues =
Map.of(
"project", valueOf(project),
"base_model_version_id", valueOf(baseModelVersionId),
"task_type", valueOf(taskType),
"location", valueOf(location),
"queries_path", valueOf(queriesPath),
"corpus_path", valueOf(corpusPath),
"train_label_path", valueOf(trainLabelPath),
"test_label_path", valueOf(testLabelPath),
"learning_rate_multiplier", valueOf(learningRateMultiplier),
"output_dimensionality", valueOf(outputDimensionality),
"batch_size", valueOf(batchSize),
"iterations", valueOf(iterations));
"train_steps", valueOf(trainSteps));
PipelineJob pipelineJob =
PipelineJob.newBuilder()
.setTemplateUri(templateUri)
Expand All @@ -124,5 +125,9 @@ private static Value valueOf(String s) {
private static Value valueOf(int n) {
return Value.newBuilder().setNumberValue(n).build();
}
// [END aiplatform_sdk_embedding_model_tuning]

private static Value valueOf(double n) {
return Value.newBuilder().setNumberValue(n).build();
}
}
// [END aiplatform_sdk_embedding_model_tuning]
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public class EmbeddingModelTuningSampleTest {

private static final String API_ENDPOINT = "us-central1-aiplatform.googleapis.com:443";
private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
private static final String BASE_MODEL_VERSION_ID = "textembedding-gecko@003";
private static final String BASE_MODEL_VERSION_ID = "text-embedding-004";
private static final String TASK_TYPE = "DEFAULT";
private static final String JOB_DISPLAY_NAME = "embedding-customization-pipeline-sample";
private static final String QUERIES =
Expand All @@ -64,6 +64,8 @@ public class EmbeddingModelTuningSampleTest {
private static final String TEST_LABEL = "gs://embedding-customization-pipeline/dataset/test.tsv";
private static final String OUTPUT_DIR =
"gs://ucaip-samples-us-central1/training_pipeline_output";
private static final double LEARNING_RATE_MULTIPLIER = 0.3;
private static final int OUTPUT_DIMENSIONALITY = 512;
private static final int BATCH_SIZE = 50;
private static final int ITERATIONS = 300;

Expand Down Expand Up @@ -135,6 +137,8 @@ public void createPipelineJobEmbeddingModelTuningSample() throws IOException {
CORPUS,
TRAIN_LABEL,
TEST_LABEL,
LEARNING_RATE_MULTIPLIER,
OUTPUT_DIMENSIONALITY,
BATCH_SIZE,
ITERATIONS);
assertThat(job.getState()).isNotEqualTo(PipelineState.PIPELINE_STATE_FAILED);
Expand Down

0 comments on commit b770340

Please sign in to comment.