From 872ecbeb8ad61b3fb7eb3177e6698744978c56f2 Mon Sep 17 00:00:00 2001 From: Juan Dominguez Date: Thu, 2 Oct 2025 12:56:44 -0300 Subject: [PATCH 1/5] feat(genai): add new batch prediction samples and update SDK --- genai/snippets/pom.xml | 2 +- .../BatchPredictionEmbeddingsWithGcs.java | 112 ++++++++++++++ .../BatchPredictionWithGcs.java | 114 ++++++++++++++ .../batchprediction/BatchPredictionIT.java | 140 ++++++++++++++++++ 4 files changed, 367 insertions(+), 1 deletion(-) create mode 100644 genai/snippets/src/main/java/genai/batchprediction/BatchPredictionEmbeddingsWithGcs.java create mode 100644 genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithGcs.java create mode 100644 genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java diff --git a/genai/snippets/pom.xml b/genai/snippets/pom.xml index 32015ec0bf4..44374a359a3 100644 --- a/genai/snippets/pom.xml +++ b/genai/snippets/pom.xml @@ -51,7 +51,7 @@ com.google.genai google-genai - 1.15.0 + 1.20.0 junit diff --git a/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionEmbeddingsWithGcs.java b/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionEmbeddingsWithGcs.java new file mode 100644 index 00000000000..f6544b0a1ef --- /dev/null +++ b/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionEmbeddingsWithGcs.java @@ -0,0 +1,112 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package genai.batchprediction; + +// [START googlegenaisdk_batchpredict_embeddings_with_gcs] + +import static com.google.genai.types.JobState.Known.JOB_STATE_CANCELLED; +import static com.google.genai.types.JobState.Known.JOB_STATE_FAILED; +import static com.google.genai.types.JobState.Known.JOB_STATE_PAUSED; +import static com.google.genai.types.JobState.Known.JOB_STATE_SUCCEEDED; + +import com.google.genai.Client; +import com.google.genai.types.BatchJob; +import com.google.genai.types.BatchJobDestination; +import com.google.genai.types.BatchJobSource; +import com.google.genai.types.CreateBatchJobConfig; +import com.google.genai.types.HttpOptions; +import com.google.genai.types.JobState; +import java.util.HashSet; +import java.util.Set; + +public class BatchPredictionEmbeddingsWithGcs { + + public static void main(String[] args) throws InterruptedException { + // TODO(developer): Replace these variables before running the sample. + String modelId = "text-embedding-005"; + String outputGcsUri = "gs://your-bucket/your-prefix"; + createBatchJob(modelId, outputGcsUri); + } + + // Creates a batch prediction job with embedding model and Google Cloud Storage + public static JobState createBatchJob(String modelId, String outputGcsUri) + throws InterruptedException { + // Client Initialization. Once created, it can be reused for multiple requests. + try (Client client = + Client.builder() + .location("us-central1") + .vertexAI(true) + .httpOptions(HttpOptions.builder().apiVersion("v1").build()) + .build()) { + + // See the documentation: + // https://googleapis.github.io/java-genai/javadoc/com/google/genai/Batches.html + BatchJobSource batchJobSource = + BatchJobSource.builder() + // Source link: + // https://storage.cloud.google.com/cloud-samples-data/generative-ai/embeddings/embeddings_input.jsonl + .gcsUri("gs://cloud-samples-data/generative-ai/embeddings/embeddings_input.jsonl") + .format("jsonl") + .build(); + + CreateBatchJobConfig batchJobConfig = + CreateBatchJobConfig.builder() + .displayName("your-display-name") + .dest(BatchJobDestination.builder().gcsUri(outputGcsUri).format("jsonl").build()) + .build(); + + BatchJob batchJob = client.batches.create(modelId, batchJobSource, batchJobConfig); + + String jobName = + batchJob.name().orElseThrow(() -> new IllegalStateException("Failed to get job name.")); + JobState jobState = + batchJob.state().orElseThrow(() -> new IllegalStateException("Failed to get job state.")); + + System.out.println("Job name: " + jobName); + System.out.println("Job state: " + jobState); + // Job name: + // projects/{PROJECT_ID}/locations/us-central1/batchPredictionJobs/6205497615459549184 + // Job state: JOB_STATE_PENDING + + // See the documentation: + // https://googleapis.github.io/java-genai/javadoc/com/google/genai/types/BatchJob.html + Set completedStates = new HashSet<>(); + completedStates.add(JOB_STATE_SUCCEEDED); + completedStates.add(JOB_STATE_FAILED); + completedStates.add(JOB_STATE_CANCELLED); + completedStates.add(JOB_STATE_PAUSED); + + while (!completedStates.contains(jobState.knownEnum())) { + Thread.sleep(30000); + batchJob = client.batches.get(jobName, null); + jobState = + batchJob + .state() + .orElseThrow(() -> new IllegalStateException("Failed to get job state.")); + System.out.println("Job state: " + jobState); + } + // Example response: + // Job state: JOB_STATE_QUEUED + // Job state: JOB_STATE_RUNNING + // Job state: JOB_STATE_RUNNING + // ... + // Job state: JOB_STATE_SUCCEEDED + return jobState; + } + } +} +// [END googlegenaisdk_batchpredict_embeddings_with_gcs] diff --git a/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithGcs.java b/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithGcs.java new file mode 100644 index 00000000000..e0827f8738b --- /dev/null +++ b/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithGcs.java @@ -0,0 +1,114 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package genai.batchprediction; + +// [START googlegenaisdk_batchpredict_with_gcs] + +import static com.google.genai.types.JobState.Known.JOB_STATE_CANCELLED; +import static com.google.genai.types.JobState.Known.JOB_STATE_FAILED; +import static com.google.genai.types.JobState.Known.JOB_STATE_PAUSED; +import static com.google.genai.types.JobState.Known.JOB_STATE_SUCCEEDED; + +import com.google.genai.Client; +import com.google.genai.types.BatchJob; +import com.google.genai.types.BatchJobDestination; +import com.google.genai.types.BatchJobSource; +import com.google.genai.types.CreateBatchJobConfig; +import com.google.genai.types.HttpOptions; +import com.google.genai.types.JobState; +import java.util.HashSet; +import java.util.Set; + +public class BatchPredictionWithGcs { + + public static void main(String[] args) throws InterruptedException { + // TODO(developer): Replace these variables before running the sample. + // To use a tuned model, set the model param to your tuned model using the following format: + // modelId = "projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID} + String modelId = "gemini-2.5-flash"; + String outputGcsUri = "gs://your-bucket/your-prefix"; + createBatchJob(modelId, outputGcsUri); + } + + // Creates a batch prediction job with Google Cloud Storage + public static JobState createBatchJob(String modelId, String outputGcsUri) + throws InterruptedException { + // Client Initialization. Once created, it can be reused for multiple requests. + try (Client client = + Client.builder() + .location("us-central1") + .vertexAI(true) + .httpOptions(HttpOptions.builder().apiVersion("v1").build()) + .build()) { + // See the documentation: + // https://googleapis.github.io/java-genai/javadoc/com/google/genai/Batches.html + BatchJobSource batchJobSource = + BatchJobSource.builder() + // Source link: + // https://storage.cloud.google.com/cloud-samples-data/batch/prompt_for_batch_gemini_predict.jsonl + .gcsUri("gs://cloud-samples-data/batch/prompt_for_batch_gemini_predict.jsonl") + .format("jsonl") + .build(); + + CreateBatchJobConfig batchJobConfig = + CreateBatchJobConfig.builder() + .displayName("your-display-name") + .dest(BatchJobDestination.builder().gcsUri(outputGcsUri).format("jsonl").build()) + .build(); + + BatchJob batchJob = client.batches.create(modelId, batchJobSource, batchJobConfig); + + String jobName = + batchJob.name().orElseThrow(() -> new IllegalStateException("Failed to get job name.")); + JobState jobState = + batchJob.state().orElseThrow(() -> new IllegalStateException("Failed to get job state.")); + + System.out.println("Job name: " + jobName); + System.out.println("Job state: " + jobState); + // Example response: + // Job name: + // projects/{PROJECT_ID}/locations/us-central1/batchPredictionJobs/6205497615459549184 + // Job state: JOB_STATE_PENDING + + // See the documentation: + // https://googleapis.github.io/java-genai/javadoc/com/google/genai/types/BatchJob.html + Set completedStates = new HashSet<>(); + completedStates.add(JOB_STATE_SUCCEEDED); + completedStates.add(JOB_STATE_FAILED); + completedStates.add(JOB_STATE_CANCELLED); + completedStates.add(JOB_STATE_PAUSED); + + while (!completedStates.contains(jobState.knownEnum())) { + Thread.sleep(30000); + batchJob = client.batches.get(jobName, null); + jobState = + batchJob + .state() + .orElseThrow(() -> new IllegalStateException("Failed to get job state.")); + System.out.println("Job state: " + jobState); + } + // Example response: + // Job state: JOB_STATE_QUEUED + // Job state: JOB_STATE_RUNNING + // Job state: JOB_STATE_RUNNING + // ... + // Job state: JOB_STATE_SUCCEEDED + return jobState; + } + } +} +// [END googlegenaisdk_batchpredict_with_gcs] diff --git a/genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java b/genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java new file mode 100644 index 00000000000..95fb0ff43ed --- /dev/null +++ b/genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java @@ -0,0 +1,140 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package genai.batchprediction; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; +import static com.google.genai.types.JobState.Known.JOB_STATE_SUCCEEDED; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.RETURNS_SELF; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.genai.Batches; +import com.google.genai.Client; +import com.google.genai.types.BatchJob; +import com.google.genai.types.BatchJobSource; +import com.google.genai.types.CreateBatchJobConfig; +import com.google.genai.types.JobState; +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; +import java.lang.reflect.Field; +import java.util.Optional; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.MockedStatic; + +@RunWith(JUnit4.class) +public class BatchPredictionIT { + + private static final String GEMINI_FLASH = "gemini-2.5-flash"; + private static final String EMBEDDING_MODEL = "text-embedding-005"; + private static String jobName; + private static String outputGcsUri; + private ByteArrayOutputStream bout; + private PrintStream out; + private Client mockedClient; + private MockedStatic mockedStatic; + + // Check if the required environment variables are set. + public static void requireEnvVar(String envVarName) { + assertWithMessage(String.format("Missing environment variable '%s' ", envVarName)) + .that(System.getenv(envVarName)) + .isNotEmpty(); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_CLOUD_PROJECT"); + jobName = "projects/project_id/locations/us-central1/batchPredictionJobs/job_id"; + outputGcsUri = "gs://your-bucket/your-prefix"; + } + + @Before + public void setUp() throws NoSuchFieldException, IllegalAccessException { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + System.setOut(out); + + // Mock builder, client, batches and response + Client.Builder mockedBuilder = mock(Client.Builder.class, RETURNS_SELF); + mockedClient = mock(Client.class); + Batches mockedBatches = mock(Batches.class); + BatchJob mockedBatchJobResponse = mock(BatchJob.class); + // Static mock of Client.builder() + mockedStatic = mockStatic(Client.class); + mockedStatic.when(Client::builder).thenReturn(mockedBuilder); + when(mockedBuilder.build()).thenReturn(mockedClient); + + // Inject mockBatches into mockClient by using reflection because + // 'batches' is a final field and cannot be mockable directly + Field field = Client.class.getDeclaredField("batches"); + field.setAccessible(true); + field.set(mockedClient, mockedBatches); + + when(mockedClient.batches.create( + anyString(), any(BatchJobSource.class), any(CreateBatchJobConfig.class))) + .thenReturn(mockedBatchJobResponse); + + when(mockedBatchJobResponse.name()).thenReturn(Optional.of(jobName)); + when(mockedBatchJobResponse.state()).thenReturn(Optional.of(new JobState(JOB_STATE_SUCCEEDED))); + } + + @After + public void tearDown() { + System.setOut(null); + bout.reset(); + mockedStatic.close(); + } + + @Test + public void testBatchPredictionWithGcs() throws InterruptedException { + + JobState response = BatchPredictionWithGcs.createBatchJob(GEMINI_FLASH, outputGcsUri); + + verify(mockedClient.batches, times(1)) + .create(anyString(), any(BatchJobSource.class), any(CreateBatchJobConfig.class)); + + assertThat(response.toString()).isNotEmpty(); + assertThat(response.toString()).isEqualTo("JOB_STATE_SUCCEEDED"); + assertThat(bout.toString()).contains("Job name: " + jobName); + assertThat(bout.toString()).contains("Job state: JOB_STATE_SUCCEEDED"); + } + + @Test + public void testBatchPredictionEmbeddingsWithGcs() throws InterruptedException { + + JobState response = + BatchPredictionEmbeddingsWithGcs.createBatchJob(EMBEDDING_MODEL, outputGcsUri); + + verify(mockedClient.batches, times(1)) + .create(anyString(), any(BatchJobSource.class), any(CreateBatchJobConfig.class)); + + assertThat(response.toString()).isNotEmpty(); + assertThat(response.toString()).isEqualTo("JOB_STATE_SUCCEEDED"); + assertThat(bout.toString()).contains("Job name: " + jobName); + assertThat(bout.toString()).contains("Job state: JOB_STATE_SUCCEEDED"); + } +} From dc504d7c91e7f3ac57a550085124d327f1884bb2 Mon Sep 17 00:00:00 2001 From: Juan Dominguez Date: Thu, 2 Oct 2025 16:54:20 -0300 Subject: [PATCH 2/5] refactor: change hashset for enumSet and now use TimeUnit instead if thread.sleep --- .../BatchPredictionEmbeddingsWithGcs.java | 12 +++++------- .../batchprediction/BatchPredictionWithGcs.java | 12 +++++------- .../genai/batchprediction/BatchPredictionIT.java | 16 ++++++++++------ 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionEmbeddingsWithGcs.java b/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionEmbeddingsWithGcs.java index f6544b0a1ef..77287bc44e1 100644 --- a/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionEmbeddingsWithGcs.java +++ b/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionEmbeddingsWithGcs.java @@ -30,8 +30,9 @@ import com.google.genai.types.CreateBatchJobConfig; import com.google.genai.types.HttpOptions; import com.google.genai.types.JobState; -import java.util.HashSet; +import java.util.EnumSet; import java.util.Set; +import java.util.concurrent.TimeUnit; public class BatchPredictionEmbeddingsWithGcs { @@ -84,14 +85,11 @@ public static JobState createBatchJob(String modelId, String outputGcsUri) // See the documentation: // https://googleapis.github.io/java-genai/javadoc/com/google/genai/types/BatchJob.html - Set completedStates = new HashSet<>(); - completedStates.add(JOB_STATE_SUCCEEDED); - completedStates.add(JOB_STATE_FAILED); - completedStates.add(JOB_STATE_CANCELLED); - completedStates.add(JOB_STATE_PAUSED); + Set completedStates = + EnumSet.of(JOB_STATE_SUCCEEDED, JOB_STATE_FAILED, JOB_STATE_CANCELLED, JOB_STATE_PAUSED); while (!completedStates.contains(jobState.knownEnum())) { - Thread.sleep(30000); + TimeUnit.SECONDS.sleep(30); batchJob = client.batches.get(jobName, null); jobState = batchJob diff --git a/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithGcs.java b/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithGcs.java index e0827f8738b..47a640c2837 100644 --- a/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithGcs.java +++ b/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithGcs.java @@ -30,8 +30,9 @@ import com.google.genai.types.CreateBatchJobConfig; import com.google.genai.types.HttpOptions; import com.google.genai.types.JobState; -import java.util.HashSet; +import java.util.EnumSet; import java.util.Set; +import java.util.concurrent.TimeUnit; public class BatchPredictionWithGcs { @@ -86,14 +87,11 @@ public static JobState createBatchJob(String modelId, String outputGcsUri) // See the documentation: // https://googleapis.github.io/java-genai/javadoc/com/google/genai/types/BatchJob.html - Set completedStates = new HashSet<>(); - completedStates.add(JOB_STATE_SUCCEEDED); - completedStates.add(JOB_STATE_FAILED); - completedStates.add(JOB_STATE_CANCELLED); - completedStates.add(JOB_STATE_PAUSED); + Set completedStates = + EnumSet.of(JOB_STATE_SUCCEEDED, JOB_STATE_FAILED, JOB_STATE_CANCELLED, JOB_STATE_PAUSED); while (!completedStates.contains(jobState.knownEnum())) { - Thread.sleep(30000); + TimeUnit.SECONDS.sleep(30); batchJob = client.batches.get(jobName, null); jobState = batchJob diff --git a/genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java b/genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java index 95fb0ff43ed..027be9e4e3a 100644 --- a/genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java +++ b/genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java @@ -55,7 +55,10 @@ public class BatchPredictionIT { private static String outputGcsUri; private ByteArrayOutputStream bout; private PrintStream out; + private Client.Builder mockedBuilder; private Client mockedClient; + private Batches mockedBatches; + private BatchJob mockedJobResponse; private MockedStatic mockedStatic; // Check if the required environment variables are set. @@ -79,10 +82,11 @@ public void setUp() throws NoSuchFieldException, IllegalAccessException { System.setOut(out); // Mock builder, client, batches and response - Client.Builder mockedBuilder = mock(Client.Builder.class, RETURNS_SELF); + mockedBuilder = mock(Client.Builder.class, RETURNS_SELF); mockedClient = mock(Client.class); - Batches mockedBatches = mock(Batches.class); - BatchJob mockedBatchJobResponse = mock(BatchJob.class); + mockedBatches = mock(Batches.class); + mockedJobResponse = mock(BatchJob.class); + // Static mock of Client.builder() mockedStatic = mockStatic(Client.class); mockedStatic.when(Client::builder).thenReturn(mockedBuilder); @@ -96,10 +100,10 @@ public void setUp() throws NoSuchFieldException, IllegalAccessException { when(mockedClient.batches.create( anyString(), any(BatchJobSource.class), any(CreateBatchJobConfig.class))) - .thenReturn(mockedBatchJobResponse); + .thenReturn(mockedJobResponse); + when(mockedJobResponse.name()).thenReturn(Optional.of(jobName)); + when(mockedJobResponse.state()).thenReturn(Optional.of(new JobState(JOB_STATE_SUCCEEDED))); - when(mockedBatchJobResponse.name()).thenReturn(Optional.of(jobName)); - when(mockedBatchJobResponse.state()).thenReturn(Optional.of(new JobState(JOB_STATE_SUCCEEDED))); } @After From 6e48a6352083dcca25521014e474046a7d317892 Mon Sep 17 00:00:00 2001 From: Juan Dominguez Date: Thu, 2 Oct 2025 18:29:49 -0300 Subject: [PATCH 3/5] refactor: change tests to ensure that the API is called and not mocked --- genai/snippets/pom.xml | 4 + .../batchprediction/BatchPredictionIT.java | 90 +++++-------------- 2 files changed, 28 insertions(+), 66 deletions(-) diff --git a/genai/snippets/pom.xml b/genai/snippets/pom.xml index 44374a359a3..585d72061cc 100644 --- a/genai/snippets/pom.xml +++ b/genai/snippets/pom.xml @@ -53,6 +53,10 @@ google-genai 1.20.0 + + com.google.cloud + google-cloud-storage + junit junit diff --git a/genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java b/genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java index 027be9e4e3a..77bb654454d 100644 --- a/genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java +++ b/genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java @@ -18,48 +18,33 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; -import static com.google.genai.types.JobState.Known.JOB_STATE_SUCCEEDED; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.RETURNS_SELF; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.mockStatic; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import com.google.genai.Batches; -import com.google.genai.Client; -import com.google.genai.types.BatchJob; -import com.google.genai.types.BatchJobSource; -import com.google.genai.types.CreateBatchJobConfig; +import com.google.api.gax.paging.Page; +import com.google.cloud.storage.Blob; +import com.google.cloud.storage.Storage; +import com.google.cloud.storage.StorageOptions; import com.google.genai.types.JobState; import java.io.ByteArrayOutputStream; import java.io.PrintStream; -import java.lang.reflect.Field; -import java.util.Optional; +import java.util.UUID; import org.junit.After; +import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.mockito.MockedStatic; @RunWith(JUnit4.class) public class BatchPredictionIT { private static final String GEMINI_FLASH = "gemini-2.5-flash"; private static final String EMBEDDING_MODEL = "text-embedding-005"; - private static String jobName; - private static String outputGcsUri; + private static final String BUCKET_NAME = "java-docs-samples-testing"; + private static final String PREFIX = "genai-batch-prediction" + UUID.randomUUID(); + private static final String OUTPUT_GCS_URI = String.format("gs://%s/%s", BUCKET_NAME, PREFIX); private ByteArrayOutputStream bout; private PrintStream out; - private Client.Builder mockedBuilder; - private Client mockedClient; - private Batches mockedBatches; - private BatchJob mockedJobResponse; - private MockedStatic mockedStatic; // Check if the required environment variables are set. public static void requireEnvVar(String envVarName) { @@ -71,74 +56,47 @@ public static void requireEnvVar(String envVarName) { @BeforeClass public static void checkRequirements() { requireEnvVar("GOOGLE_CLOUD_PROJECT"); - jobName = "projects/project_id/locations/us-central1/batchPredictionJobs/job_id"; - outputGcsUri = "gs://your-bucket/your-prefix"; + } + + @AfterClass + public static void cleanup() { + Storage storage = StorageOptions.getDefaultInstance().getService(); + Page blobs = storage.list(BUCKET_NAME, Storage.BlobListOption.prefix(PREFIX)); + + for (Blob blob : blobs.iterateAll()) { + storage.delete(blob.getBlobId()); + } } @Before - public void setUp() throws NoSuchFieldException, IllegalAccessException { + public void setUp() { bout = new ByteArrayOutputStream(); out = new PrintStream(bout); System.setOut(out); - - // Mock builder, client, batches and response - mockedBuilder = mock(Client.Builder.class, RETURNS_SELF); - mockedClient = mock(Client.class); - mockedBatches = mock(Batches.class); - mockedJobResponse = mock(BatchJob.class); - - // Static mock of Client.builder() - mockedStatic = mockStatic(Client.class); - mockedStatic.when(Client::builder).thenReturn(mockedBuilder); - when(mockedBuilder.build()).thenReturn(mockedClient); - - // Inject mockBatches into mockClient by using reflection because - // 'batches' is a final field and cannot be mockable directly - Field field = Client.class.getDeclaredField("batches"); - field.setAccessible(true); - field.set(mockedClient, mockedBatches); - - when(mockedClient.batches.create( - anyString(), any(BatchJobSource.class), any(CreateBatchJobConfig.class))) - .thenReturn(mockedJobResponse); - when(mockedJobResponse.name()).thenReturn(Optional.of(jobName)); - when(mockedJobResponse.state()).thenReturn(Optional.of(new JobState(JOB_STATE_SUCCEEDED))); - } @After public void tearDown() { System.setOut(null); bout.reset(); - mockedStatic.close(); } @Test public void testBatchPredictionWithGcs() throws InterruptedException { - - JobState response = BatchPredictionWithGcs.createBatchJob(GEMINI_FLASH, outputGcsUri); - - verify(mockedClient.batches, times(1)) - .create(anyString(), any(BatchJobSource.class), any(CreateBatchJobConfig.class)); - + JobState response = BatchPredictionWithGcs.createBatchJob(GEMINI_FLASH, OUTPUT_GCS_URI); assertThat(response.toString()).isNotEmpty(); assertThat(response.toString()).isEqualTo("JOB_STATE_SUCCEEDED"); - assertThat(bout.toString()).contains("Job name: " + jobName); + assertThat(bout.toString()).contains("Job name: "); assertThat(bout.toString()).contains("Job state: JOB_STATE_SUCCEEDED"); } @Test public void testBatchPredictionEmbeddingsWithGcs() throws InterruptedException { - JobState response = - BatchPredictionEmbeddingsWithGcs.createBatchJob(EMBEDDING_MODEL, outputGcsUri); - - verify(mockedClient.batches, times(1)) - .create(anyString(), any(BatchJobSource.class), any(CreateBatchJobConfig.class)); - + BatchPredictionEmbeddingsWithGcs.createBatchJob(EMBEDDING_MODEL, OUTPUT_GCS_URI); assertThat(response.toString()).isNotEmpty(); assertThat(response.toString()).isEqualTo("JOB_STATE_SUCCEEDED"); - assertThat(bout.toString()).contains("Job name: " + jobName); + assertThat(bout.toString()).contains("Job name: "); assertThat(bout.toString()).contains("Job state: JOB_STATE_SUCCEEDED"); } } From 508af3304639025f6da0a70f81026a5428af5b0f Mon Sep 17 00:00:00 2001 From: Juan Dominguez Date: Wed, 8 Oct 2025 13:13:17 -0300 Subject: [PATCH 4/5] refactor: change batch polling logic and return type --- .../BatchPredictionEmbeddingsWithGcs.java | 27 +++++++---------- .../BatchPredictionWithGcs.java | 29 ++++++++----------- .../batchprediction/BatchPredictionIT.java | 16 ++++++---- 3 files changed, 33 insertions(+), 39 deletions(-) diff --git a/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionEmbeddingsWithGcs.java b/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionEmbeddingsWithGcs.java index 77287bc44e1..22519cc0bb1 100644 --- a/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionEmbeddingsWithGcs.java +++ b/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionEmbeddingsWithGcs.java @@ -28,9 +28,11 @@ import com.google.genai.types.BatchJobDestination; import com.google.genai.types.BatchJobSource; import com.google.genai.types.CreateBatchJobConfig; +import com.google.genai.types.GetBatchJobConfig; import com.google.genai.types.HttpOptions; import com.google.genai.types.JobState; import java.util.EnumSet; +import java.util.Optional; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -44,7 +46,7 @@ public static void main(String[] args) throws InterruptedException { } // Creates a batch prediction job with embedding model and Google Cloud Storage - public static JobState createBatchJob(String modelId, String outputGcsUri) + public static Optional createBatchJob(String modelId, String outputGcsUri) throws InterruptedException { // Client Initialization. Once created, it can be reused for multiple requests. try (Client client = @@ -73,14 +75,11 @@ public static JobState createBatchJob(String modelId, String outputGcsUri) BatchJob batchJob = client.batches.create(modelId, batchJobSource, batchJobConfig); String jobName = - batchJob.name().orElseThrow(() -> new IllegalStateException("Failed to get job name.")); - JobState jobState = - batchJob.state().orElseThrow(() -> new IllegalStateException("Failed to get job state.")); - + batchJob.name().orElseThrow(() -> new IllegalStateException("Missing job name")); System.out.println("Job name: " + jobName); - System.out.println("Job state: " + jobState); - // Job name: - // projects/{PROJECT_ID}/locations/us-central1/batchPredictionJobs/6205497615459549184 + Optional jobState = batchJob.state(); + jobState.ifPresent(state -> System.out.println("Job state: " + state)); + // Job name: projects/project_id/locations/us-central1/batchPredictionJobs/6205497615459549184 // Job state: JOB_STATE_PENDING // See the documentation: @@ -88,19 +87,15 @@ public static JobState createBatchJob(String modelId, String outputGcsUri) Set completedStates = EnumSet.of(JOB_STATE_SUCCEEDED, JOB_STATE_FAILED, JOB_STATE_CANCELLED, JOB_STATE_PAUSED); - while (!completedStates.contains(jobState.knownEnum())) { + while (jobState.isPresent() && !completedStates.contains(jobState.get().knownEnum())) { TimeUnit.SECONDS.sleep(30); - batchJob = client.batches.get(jobName, null); - jobState = - batchJob - .state() - .orElseThrow(() -> new IllegalStateException("Failed to get job state.")); - System.out.println("Job state: " + jobState); + batchJob = client.batches.get(jobName, GetBatchJobConfig.builder().build()); + jobState = batchJob.state(); + batchJob.state().ifPresent(state -> System.out.println("Job state: " + state)); } // Example response: // Job state: JOB_STATE_QUEUED // Job state: JOB_STATE_RUNNING - // Job state: JOB_STATE_RUNNING // ... // Job state: JOB_STATE_SUCCEEDED return jobState; diff --git a/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithGcs.java b/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithGcs.java index 47a640c2837..2020ebd4e59 100644 --- a/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithGcs.java +++ b/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithGcs.java @@ -28,9 +28,11 @@ import com.google.genai.types.BatchJobDestination; import com.google.genai.types.BatchJobSource; import com.google.genai.types.CreateBatchJobConfig; +import com.google.genai.types.GetBatchJobConfig; import com.google.genai.types.HttpOptions; import com.google.genai.types.JobState; import java.util.EnumSet; +import java.util.Optional; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -46,7 +48,7 @@ public static void main(String[] args) throws InterruptedException { } // Creates a batch prediction job with Google Cloud Storage - public static JobState createBatchJob(String modelId, String outputGcsUri) + public static Optional createBatchJob(String modelId, String outputGcsUri) throws InterruptedException { // Client Initialization. Once created, it can be reused for multiple requests. try (Client client = @@ -74,30 +76,23 @@ public static JobState createBatchJob(String modelId, String outputGcsUri) BatchJob batchJob = client.batches.create(modelId, batchJobSource, batchJobConfig); String jobName = - batchJob.name().orElseThrow(() -> new IllegalStateException("Failed to get job name.")); - JobState jobState = - batchJob.state().orElseThrow(() -> new IllegalStateException("Failed to get job state.")); - + batchJob.name().orElseThrow(() -> new IllegalStateException("Missing job name")); System.out.println("Job name: " + jobName); - System.out.println("Job state: " + jobState); - // Example response: - // Job name: - // projects/{PROJECT_ID}/locations/us-central1/batchPredictionJobs/6205497615459549184 + Optional jobState = batchJob.state(); + jobState.ifPresent(state -> System.out.println("Job state: " + state)); + // Job name: projects/project_id/locations/us-central1/batchPredictionJobs/6205497615459549184 // Job state: JOB_STATE_PENDING // See the documentation: // https://googleapis.github.io/java-genai/javadoc/com/google/genai/types/BatchJob.html Set completedStates = - EnumSet.of(JOB_STATE_SUCCEEDED, JOB_STATE_FAILED, JOB_STATE_CANCELLED, JOB_STATE_PAUSED); + EnumSet.of(JOB_STATE_SUCCEEDED, JOB_STATE_FAILED, JOB_STATE_CANCELLED, JOB_STATE_PAUSED); - while (!completedStates.contains(jobState.knownEnum())) { + while (jobState.isPresent() && !completedStates.contains(jobState.get().knownEnum())) { TimeUnit.SECONDS.sleep(30); - batchJob = client.batches.get(jobName, null); - jobState = - batchJob - .state() - .orElseThrow(() -> new IllegalStateException("Failed to get job state.")); - System.out.println("Job state: " + jobState); + batchJob = client.batches.get(jobName, GetBatchJobConfig.builder().build()); + jobState = batchJob.state(); + batchJob.state().ifPresent(state -> System.out.println("Job state: " + state)); } // Example response: // Job state: JOB_STATE_QUEUED diff --git a/genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java b/genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java index 77bb654454d..b0e2c46f6b0 100644 --- a/genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java +++ b/genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java @@ -26,6 +26,7 @@ import com.google.genai.types.JobState; import java.io.ByteArrayOutputStream; import java.io.PrintStream; +import java.util.Optional; import java.util.UUID; import org.junit.After; import org.junit.AfterClass; @@ -83,19 +84,22 @@ public void tearDown() { @Test public void testBatchPredictionWithGcs() throws InterruptedException { - JobState response = BatchPredictionWithGcs.createBatchJob(GEMINI_FLASH, OUTPUT_GCS_URI); - assertThat(response.toString()).isNotEmpty(); - assertThat(response.toString()).isEqualTo("JOB_STATE_SUCCEEDED"); + Optional response = + BatchPredictionWithGcs.createBatchJob(GEMINI_FLASH, OUTPUT_GCS_URI); + assertThat(response).isPresent(); + assertThat(response.get().toString()).isNotEmpty(); + assertThat(response.get().toString()).isEqualTo("JOB_STATE_SUCCEEDED"); assertThat(bout.toString()).contains("Job name: "); assertThat(bout.toString()).contains("Job state: JOB_STATE_SUCCEEDED"); } @Test public void testBatchPredictionEmbeddingsWithGcs() throws InterruptedException { - JobState response = + Optional response = BatchPredictionEmbeddingsWithGcs.createBatchJob(EMBEDDING_MODEL, OUTPUT_GCS_URI); - assertThat(response.toString()).isNotEmpty(); - assertThat(response.toString()).isEqualTo("JOB_STATE_SUCCEEDED"); + assertThat(response).isPresent(); + assertThat(response.get().toString()).isNotEmpty(); + assertThat(response.get().toString()).isEqualTo("JOB_STATE_SUCCEEDED"); assertThat(bout.toString()).contains("Job name: "); assertThat(bout.toString()).contains("Job state: JOB_STATE_SUCCEEDED"); } From a6d777d7d9927d4b8f36af84aab96f102784f68c Mon Sep 17 00:00:00 2001 From: Juan Dominguez Date: Thu, 9 Oct 2025 12:41:50 -0300 Subject: [PATCH 5/5] fix lint --- .../java/genai/batchprediction/BatchPredictionWithGcs.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithGcs.java b/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithGcs.java index 2020ebd4e59..675038a6de7 100644 --- a/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithGcs.java +++ b/genai/snippets/src/main/java/genai/batchprediction/BatchPredictionWithGcs.java @@ -76,7 +76,7 @@ public static Optional createBatchJob(String modelId, String outputGcs BatchJob batchJob = client.batches.create(modelId, batchJobSource, batchJobConfig); String jobName = - batchJob.name().orElseThrow(() -> new IllegalStateException("Missing job name")); + batchJob.name().orElseThrow(() -> new IllegalStateException("Missing job name")); System.out.println("Job name: " + jobName); Optional jobState = batchJob.state(); jobState.ifPresent(state -> System.out.println("Job state: " + state)); @@ -86,7 +86,7 @@ public static Optional createBatchJob(String modelId, String outputGcs // See the documentation: // https://googleapis.github.io/java-genai/javadoc/com/google/genai/types/BatchJob.html Set completedStates = - EnumSet.of(JOB_STATE_SUCCEEDED, JOB_STATE_FAILED, JOB_STATE_CANCELLED, JOB_STATE_PAUSED); + EnumSet.of(JOB_STATE_SUCCEEDED, JOB_STATE_FAILED, JOB_STATE_CANCELLED, JOB_STATE_PAUSED); while (jobState.isPresent() && !completedStates.contains(jobState.get().knownEnum())) { TimeUnit.SECONDS.sleep(30);