From 24d12d0bec7c3d198aa5be2ecbed41a63a02e622 Mon Sep 17 00:00:00 2001 From: Kyle Montemayor Date: Wed, 27 May 2026 23:06:25 +0000 Subject: [PATCH 1/2] Expose GiGL env vars in Vertex AI launcher --- gigl/src/common/vertex_ai_launcher.py | 69 ++++++++- gigl/src/inference/v2/glt_inferencer.py | 2 + gigl/src/training/v2/glt_trainer.py | 2 + .../src/common/vertex_ai_launcher_test.py | 138 +++++++++++++++++- 4 files changed, 209 insertions(+), 2 deletions(-) diff --git a/gigl/src/common/vertex_ai_launcher.py b/gigl/src/common/vertex_ai_launcher.py index d5dae15d9..9828c5a25 100644 --- a/gigl/src/common/vertex_ai_launcher.py +++ b/gigl/src/common/vertex_ai_launcher.py @@ -17,6 +17,14 @@ ) from gigl.common.logger import Logger from gigl.common.services.vertex_ai import VertexAiJobConfig, VertexAIService +from gigl.env.constants import ( + GIGL_APPLIED_TASK_IDENTIFIER_ENV_KEY, + GIGL_COMPONENT_ENV_KEY, + GIGL_CPU_DOCKER_URI_ENV_KEY, + GIGL_CUDA_DOCKER_URI_ENV_KEY, + GIGL_RESOURCE_CONFIG_URI_ENV_KEY, + GIGL_TASK_CONFIG_URI_ENV_KEY, +) from gigl.env.distributed import COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY from gigl.src.common.constants.components import GiGLComponents from gigl.src.common.types.pb_wrappers.gigl_resource_config import ( @@ -43,6 +51,7 @@ def launch_single_pool_job( vertex_ai_resource_config: VertexAiResourceConfig, job_name: str, + applied_task_identifier: str, task_config_uri: Uri, resource_config_uri: Uri, process_command: str, @@ -58,6 +67,7 @@ def launch_single_pool_job( Args: vertex_ai_resource_config: The Vertex AI resource configuration job_name: Full name for the Vertex AI job + applied_task_identifier: The raw GiGL task identifier task_config_uri: URI to the task configuration resource_config_uri: URI to the resource configuration process_command: Command to run in the container @@ -88,7 +98,17 @@ def launch_single_pool_job( use_cuda=not is_cpu_execution, container_uri=container_uri, vertex_ai_resource_config=vertex_ai_resource_config, - env_vars=[env_var.EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="3")], + env_vars=[ + env_var.EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="3"), + *_build_common_gigl_env_vars( + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + component=component, + ), + ], labels=resource_config_wrapper.get_resource_labels(component=component), ) logger.info(f"Launching {component.value} job with config: {job_config}") @@ -105,6 +125,7 @@ def launch_single_pool_job( def launch_graph_store_enabled_job( vertex_ai_graph_store_config: VertexAiGraphStoreConfig, job_name: str, + applied_task_identifier: str, task_config_uri: Uri, resource_config_uri: Uri, compute_commmand: str, @@ -121,6 +142,7 @@ def launch_graph_store_enabled_job( Args: vertex_ai_graph_store_config: The Vertex AI graph store configuration job_name: Full name for the Vertex AI job + applied_task_identifier: The raw GiGL task identifier task_config_uri: URI to the task configuration resource_config_uri: URI to the resource configuration compute_commmand: Command to run in the compute container @@ -167,6 +189,14 @@ def launch_graph_store_enabled_job( name=COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY, value=str(num_compute_processes), ), + *_build_common_gigl_env_vars( + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + component=component, + ), ] labels = resource_config_wrapper.get_resource_labels(component=component) @@ -297,6 +327,43 @@ def _build_job_config( return job_config +def _build_common_gigl_env_vars( + applied_task_identifier: str, + task_config_uri: Uri, + resource_config_uri: Uri, + cpu_docker_uri: str, + cuda_docker_uri: str, + component: GiGLComponents, +) -> list[env_var.EnvVar]: + """Build common GiGL runtime context env vars for Vertex AI containers. + + Args: + applied_task_identifier: The raw GiGL task identifier. + task_config_uri: URI to the task configuration. + resource_config_uri: URI to the resource configuration. + cpu_docker_uri: Resolved CPU Docker image URI. + cuda_docker_uri: Resolved CUDA Docker image URI. + component: The GiGL component being launched. + + Returns: + Environment variables carrying shared GiGL launcher context. + """ + return [ + env_var.EnvVar( + name=GIGL_APPLIED_TASK_IDENTIFIER_ENV_KEY, + value=str(applied_task_identifier), + ), + env_var.EnvVar(name=GIGL_TASK_CONFIG_URI_ENV_KEY, value=str(task_config_uri)), + env_var.EnvVar( + name=GIGL_RESOURCE_CONFIG_URI_ENV_KEY, + value=str(resource_config_uri), + ), + env_var.EnvVar(name=GIGL_COMPONENT_ENV_KEY, value=component.name), + env_var.EnvVar(name=GIGL_CPU_DOCKER_URI_ENV_KEY, value=cpu_docker_uri), + env_var.EnvVar(name=GIGL_CUDA_DOCKER_URI_ENV_KEY, value=cuda_docker_uri), + ] + + def _build_reservation_affinity( affinity: VertexAiReservationAffinity, ) -> Optional[ReservationAffinity]: diff --git a/gigl/src/inference/v2/glt_inferencer.py b/gigl/src/inference/v2/glt_inferencer.py index 1587828da..c4ff82d7f 100644 --- a/gigl/src/inference/v2/glt_inferencer.py +++ b/gigl/src/inference/v2/glt_inferencer.py @@ -63,6 +63,7 @@ def __execute_VAI_inference( launch_single_pool_job( vertex_ai_resource_config=resource_config_wrapper.inferencer_config, job_name=job_name, + applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, process_command=inference_process_command, @@ -79,6 +80,7 @@ def __execute_VAI_inference( launch_graph_store_enabled_job( vertex_ai_graph_store_config=resource_config_wrapper.inferencer_config, job_name=job_name, + applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, compute_commmand=inference_process_command, diff --git a/gigl/src/training/v2/glt_trainer.py b/gigl/src/training/v2/glt_trainer.py index 2f8ecbbbe..d4d5a3d5c 100644 --- a/gigl/src/training/v2/glt_trainer.py +++ b/gigl/src/training/v2/glt_trainer.py @@ -61,6 +61,7 @@ def __execute_VAI_training( launch_single_pool_job( vertex_ai_resource_config=resource_config.trainer_config, job_name=job_name, + applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, process_command=training_process_command, @@ -75,6 +76,7 @@ def __execute_VAI_training( launch_graph_store_enabled_job( vertex_ai_graph_store_config=resource_config.trainer_config, job_name=job_name, + applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, compute_commmand=training_process_command, diff --git a/tests/unit/src/common/vertex_ai_launcher_test.py b/tests/unit/src/common/vertex_ai_launcher_test.py index 9170f6041..91abd5df4 100644 --- a/tests/unit/src/common/vertex_ai_launcher_test.py +++ b/tests/unit/src/common/vertex_ai_launcher_test.py @@ -3,8 +3,22 @@ from unittest.mock import Mock, patch from absl.testing import absltest +from google.cloud.aiplatform_v1.types import env_var from gigl.common import Uri +from gigl.common.constants import ( + DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU, + DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA, +) +from gigl.env.constants import ( + GIGL_APPLIED_TASK_IDENTIFIER_ENV_KEY, + GIGL_COMPONENT_ENV_KEY, + GIGL_CPU_DOCKER_URI_ENV_KEY, + GIGL_CUDA_DOCKER_URI_ENV_KEY, + GIGL_RESOURCE_CONFIG_URI_ENV_KEY, + GIGL_TASK_CONFIG_URI_ENV_KEY, +) +from gigl.env.distributed import COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY from gigl.src.common.constants.components import GiGLComponents from gigl.src.common.types.pb_wrappers.gigl_resource_config import ( GiglResourceConfigWrapper, @@ -105,6 +119,42 @@ def _create_gigl_resource_config_with_single_pool_inference( ) +def _env_vars_to_dict( + environment_variables: list[env_var.EnvVar] | None, +) -> dict[str, str]: + """Convert Vertex AI EnvVar objects to a name/value dictionary.""" + if environment_variables is None: + return {} + return { + environment_variable.name: environment_variable.value + for environment_variable in environment_variables + } + + +def _assert_common_gigl_env_vars( + test_case: TestCase, + environment_variables: list[env_var.EnvVar] | None, + applied_task_identifier: str, + task_config_uri: Uri, + resource_config_uri: Uri, + cpu_docker_uri: str, + cuda_docker_uri: str, + component: GiGLComponents, +) -> None: + """Assert Vertex AI job env vars include the common GiGL launcher context.""" + env_vars = _env_vars_to_dict(environment_variables) + test_case.assertEqual( + env_vars[GIGL_APPLIED_TASK_IDENTIFIER_ENV_KEY], applied_task_identifier + ) + test_case.assertEqual(env_vars[GIGL_TASK_CONFIG_URI_ENV_KEY], str(task_config_uri)) + test_case.assertEqual( + env_vars[GIGL_RESOURCE_CONFIG_URI_ENV_KEY], str(resource_config_uri) + ) + test_case.assertEqual(env_vars[GIGL_CPU_DOCKER_URI_ENV_KEY], cpu_docker_uri) + test_case.assertEqual(env_vars[GIGL_CUDA_DOCKER_URI_ENV_KEY], cuda_docker_uri) + test_case.assertEqual(env_vars[GIGL_COMPONENT_ENV_KEY], component.name) + + class TestVertexAILauncher(TestCase): """Test suite for vertex_ai_launcher module.""" @@ -113,6 +163,7 @@ def test_launch_training_graph_store_cuda(self, mock_vertex_ai_service_class): """Test launching a training job with graph store enabled and GPU/CUDA configuration.""" # Define test inputs job_name = "test-training-job" + applied_task_identifier = "test-training-task" task_config_uri = Uri("gs://bucket/task_config.yaml") resource_config_uri = Uri("gs://bucket/resource_config.yaml") process_command = "python -m gigl.src.training.v2.glt_trainer" @@ -142,6 +193,7 @@ def test_launch_training_graph_store_cuda(self, mock_vertex_ai_service_class): launch_graph_store_enabled_job( vertex_ai_graph_store_config=graph_store_config, job_name=job_name, + applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, compute_commmand=process_command, @@ -212,9 +264,29 @@ def test_launch_training_graph_store_cuda(self, mock_vertex_ai_service_class): ev.name: ev.value for ev in compute_job_config.environment_variables } self.assertEqual( - compute_env_vars["COMPUTE_CLUSTER_LOCAL_WORLD_SIZE"], + compute_env_vars[COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY], str(graph_store_config.compute_cluster_local_world_size), ) + _assert_common_gigl_env_vars( + test_case=self, + environment_variables=compute_job_config.environment_variables, + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + component=component, + ) + _assert_common_gigl_env_vars( + test_case=self, + environment_variables=storage_job_config.environment_variables, + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + component=component, + ) # Verify resource labels expected_labels = { @@ -229,6 +301,7 @@ def test_launch_inference_single_pool_cpu(self, mock_vertex_ai_service_class): """Test launching an inference job with single pool and CPU-only configuration.""" # Define test inputs job_name = "test-inference-job" + applied_task_identifier = "test-inference-task" task_config_uri = Uri("gs://bucket/inference_config.yaml") resource_config_uri = Uri("gs://bucket/resource_config.yaml") process_command = "python -m gigl.src.inference.v2.glt_inferencer" @@ -262,6 +335,7 @@ def test_launch_inference_single_pool_cpu(self, mock_vertex_ai_service_class): launch_single_pool_job( vertex_ai_resource_config=vertex_ai_config, job_name=job_name, + applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, process_command=process_command, @@ -312,6 +386,16 @@ def test_launch_inference_single_pool_cpu(self, mock_vertex_ai_service_class): f"--output_path={process_runtime_args['output_path']}", job_config.args ) self.assertNotIn("--use_cuda", job_config.args) + _assert_common_gigl_env_vars( + test_case=self, + environment_variables=job_config.environment_variables, + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + component=component, + ) # Verify resource labels expected_labels = { @@ -321,6 +405,58 @@ def test_launch_inference_single_pool_cpu(self, mock_vertex_ai_service_class): } self.assertEqual(job_config.labels, expected_labels) + @patch("gigl.src.common.vertex_ai_launcher.VertexAIService") + def test_launch_single_pool_defaults_docker_uri_env_vars( + self, mock_vertex_ai_service_class + ): + """Test Vertex AI env vars expose default Docker URIs when inputs are unset.""" + job_name = "test-default-image-job" + applied_task_identifier = "test-default-image-task" + task_config_uri = Uri("gs://bucket/task_config.yaml") + resource_config_uri = Uri("gs://bucket/resource_config.yaml") + component = GiGLComponents.Trainer + + gigl_resource_config_proto = ( + _create_gigl_resource_config_with_single_pool_inference() + ) + resource_config_wrapper = GiglResourceConfigWrapper( + resource_config=gigl_resource_config_proto + ) + vertex_ai_config = gigl_resource_config_pb2.VertexAiResourceConfig( + machine_type="n1-standard-8", + num_replicas=1, + ) + + mock_service_instance = Mock() + mock_vertex_ai_service_class.return_value = mock_service_instance + + launch_single_pool_job( + vertex_ai_resource_config=vertex_ai_config, + job_name=job_name, + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + process_command="python -m gigl.src.training.v2.glt_trainer", + process_runtime_args={}, + resource_config_wrapper=resource_config_wrapper, + cpu_docker_uri=None, + cuda_docker_uri=None, + component=component, + vertex_ai_region="us-central1", + ) + + job_config = mock_service_instance.launch_job.call_args.kwargs["job_config"] + _assert_common_gigl_env_vars( + test_case=self, + environment_variables=job_config.environment_variables, + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + cpu_docker_uri=DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU, + cuda_docker_uri=DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA, + component=component, + ) + if __name__ == "__main__": absltest.main() From 2a77df3a7576a2eed841f666d6ed093bb4528621 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 28 May 2026 20:18:32 +0000 Subject: [PATCH 2/2] Fix Vertex AI launcher job name handling --- gigl/env/constants.py | 7 +- gigl/src/common/vertex_ai_launcher.py | 66 +++++-- gigl/src/inference/v2/glt_inferencer.py | 8 +- gigl/src/training/v2/glt_trainer.py | 8 +- .../src/common/vertex_ai_launcher_test.py | 27 ++- tests/unit/src/inference/v2/__init__.py | 0 .../src/inference/v2/glt_inferencer_test.py | 180 ++++++++++++++++++ tests/unit/src/training/v2/__init__.py | 0 .../unit/src/training/v2/glt_trainer_test.py | 180 ++++++++++++++++++ 9 files changed, 428 insertions(+), 48 deletions(-) create mode 100644 tests/unit/src/inference/v2/__init__.py create mode 100644 tests/unit/src/inference/v2/glt_inferencer_test.py create mode 100644 tests/unit/src/training/v2/__init__.py create mode 100644 tests/unit/src/training/v2/glt_trainer_test.py diff --git a/gigl/env/constants.py b/gigl/env/constants.py index ce3b9da95..2a97dd374 100644 --- a/gigl/env/constants.py +++ b/gigl/env/constants.py @@ -1,8 +1,9 @@ """Environment-variable keys used across GiGL. -Most of these keys are set on subprocess env (never on the parent -``os.environ``) by ``gigl.src.common.custom_launcher.launch_custom`` so that -receiving CLIs can ``os.environ.get(...)`` their runtime context. +Most of these keys are set on launched process envs by +``gigl.src.common.custom_launcher.launch_custom`` and +``gigl.src.common.vertex_ai_launcher`` so that receiving CLIs can +``os.environ.get(...)`` their runtime context. ``GIGL_RESOURCE_CONFIG_URI`` is also written to the parent ``os.environ`` by ``gigl.env.pipelines_config.get_resource_config`` so that downstream readers diff --git a/gigl/src/common/vertex_ai_launcher.py b/gigl/src/common/vertex_ai_launcher.py index 9828c5a25..c237b57e8 100644 --- a/gigl/src/common/vertex_ai_launcher.py +++ b/gigl/src/common/vertex_ai_launcher.py @@ -51,7 +51,6 @@ def launch_single_pool_job( vertex_ai_resource_config: VertexAiResourceConfig, job_name: str, - applied_task_identifier: str, task_config_uri: Uri, resource_config_uri: Uri, process_command: str, @@ -66,8 +65,7 @@ def launch_single_pool_job( Args: vertex_ai_resource_config: The Vertex AI resource configuration - job_name: Full name for the Vertex AI job - applied_task_identifier: The raw GiGL task identifier + job_name: Raw GiGL applied task identifier task_config_uri: URI to the task configuration resource_config_uri: URI to the resource configuration process_command: Command to run in the container @@ -82,6 +80,10 @@ def launch_single_pool_job( raise ValueError( f"Invalid component: {component}. Expected one of: {_LAUNCHABLE_COMPONENTS}" ) + vertex_ai_job_name = _build_vertex_ai_job_name( + job_name=job_name, + component=component, + ) is_cpu_execution = _determine_if_cpu_execution( vertex_ai_resource_config=vertex_ai_resource_config ) @@ -90,7 +92,8 @@ def launch_single_pool_job( container_uri = cpu_docker_uri if is_cpu_execution else cuda_docker_uri job_config = _build_job_config( - job_name=job_name, + vertex_ai_job_name=vertex_ai_job_name, + applied_task_identifier=job_name, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, command_str=process_command, @@ -101,7 +104,7 @@ def launch_single_pool_job( env_vars=[ env_var.EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="3"), *_build_common_gigl_env_vars( - applied_task_identifier=applied_task_identifier, + applied_task_identifier=job_name, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, cpu_docker_uri=cpu_docker_uri, @@ -125,7 +128,6 @@ def launch_single_pool_job( def launch_graph_store_enabled_job( vertex_ai_graph_store_config: VertexAiGraphStoreConfig, job_name: str, - applied_task_identifier: str, task_config_uri: Uri, resource_config_uri: Uri, compute_commmand: str, @@ -141,8 +143,7 @@ def launch_graph_store_enabled_job( Args: vertex_ai_graph_store_config: The Vertex AI graph store configuration - job_name: Full name for the Vertex AI job - applied_task_identifier: The raw GiGL task identifier + job_name: Raw GiGL applied task identifier task_config_uri: URI to the task configuration resource_config_uri: URI to the resource configuration compute_commmand: Command to run in the compute container @@ -158,6 +159,10 @@ def launch_graph_store_enabled_job( raise ValueError( f"Invalid component: {component}. Expected one of: {_LAUNCHABLE_COMPONENTS}" ) + vertex_ai_job_name = _build_vertex_ai_job_name( + job_name=job_name, + component=component, + ) storage_pool_config = vertex_ai_graph_store_config.graph_store_pool compute_pool_config = vertex_ai_graph_store_config.compute_pool @@ -190,7 +195,7 @@ def launch_graph_store_enabled_job( value=str(num_compute_processes), ), *_build_common_gigl_env_vars( - applied_task_identifier=applied_task_identifier, + applied_task_identifier=job_name, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, cpu_docker_uri=cpu_docker_uri, @@ -203,7 +208,8 @@ def launch_graph_store_enabled_job( # Create compute pool job config compute_job_config = _build_job_config( - job_name=job_name, + vertex_ai_job_name=vertex_ai_job_name, + applied_task_identifier=job_name, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, command_str=compute_commmand, @@ -217,7 +223,8 @@ def launch_graph_store_enabled_job( # Create storage pool job config storage_job_config = _build_job_config( - job_name=job_name, + vertex_ai_job_name=vertex_ai_job_name, + applied_task_identifier=job_name, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, command_str=storage_command, @@ -249,7 +256,8 @@ def launch_graph_store_enabled_job( def _build_job_config( - job_name: str, + vertex_ai_job_name: str, + applied_task_identifier: str, task_config_uri: Uri, resource_config_uri: Uri, command_str: str, @@ -263,12 +271,12 @@ def _build_job_config( """Build a VertexAiJobConfig for training or inference jobs. This function constructs a configuration object for running GiGL training or inference - jobs on Vertex AI. It assembles job arguments, sets appropriate job naming conventions, - and configures resource specifications based on the provided parameters. + jobs on Vertex AI. It assembles job arguments and configures resource specifications + based on the provided parameters. Args: - job_name (str): The base name for the job. Will be prefixed with "gigl_train_" or "gigl_infer_". - is_inference (bool): Whether this is an inference job (True) or training job (False). + vertex_ai_job_name (str): The Vertex AI CustomJob display name. + applied_task_identifier (str): Raw GiGL applied task identifier passed to the process. task_config_uri (Uri): URI to the task configuration file. resource_config_uri (Uri): URI to the resource configuration file. command_str (str): The command to run in the container (will be split on spaces). @@ -285,7 +293,7 @@ def _build_job_config( """ job_args = ( [ - f"--job_name={job_name}", + f"--job_name={applied_task_identifier}", f"--task_config_uri={task_config_uri}", f"--resource_config_uri={resource_config_uri}", ] @@ -296,7 +304,7 @@ def _build_job_config( command = command_str.strip().split(" ") job_config = VertexAiJobConfig( - job_name=job_name, + job_name=vertex_ai_job_name, container_uri=container_uri, command=command, args=job_args, @@ -327,6 +335,28 @@ def _build_job_config( return job_config +def _build_vertex_ai_job_name(job_name: str, component: GiGLComponents) -> str: + """Build the Vertex AI CustomJob display name from a raw GiGL job name. + + Args: + job_name: Raw GiGL applied task identifier. + component: The GiGL component being launched. + + Returns: + The component-prefixed Vertex AI CustomJob display name. + + Raises: + ValueError: If ``component`` is not a launchable Vertex AI component. + """ + if component == GiGLComponents.Trainer: + return f"gigl_train_{job_name}" + if component == GiGLComponents.Inferencer: + return f"gigl_infer_{job_name}" + raise ValueError( + f"Invalid component: {component}. Expected one of: {_LAUNCHABLE_COMPONENTS}" + ) + + def _build_common_gigl_env_vars( applied_task_identifier: str, task_config_uri: Uri, diff --git a/gigl/src/inference/v2/glt_inferencer.py b/gigl/src/inference/v2/glt_inferencer.py index c4ff82d7f..d1e5e4f5d 100644 --- a/gigl/src/inference/v2/glt_inferencer.py +++ b/gigl/src/inference/v2/glt_inferencer.py @@ -55,15 +55,12 @@ def __execute_VAI_inference( gbml_config_pb_wrapper.inferencer_config.inferencer_args ) - job_name = f"gigl_infer_{applied_task_identifier}" - if isinstance( resource_config_wrapper.inferencer_config, VertexAiResourceConfig ): launch_single_pool_job( vertex_ai_resource_config=resource_config_wrapper.inferencer_config, - job_name=job_name, - applied_task_identifier=applied_task_identifier, + job_name=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, process_command=inference_process_command, @@ -79,8 +76,7 @@ def __execute_VAI_inference( ): launch_graph_store_enabled_job( vertex_ai_graph_store_config=resource_config_wrapper.inferencer_config, - job_name=job_name, - applied_task_identifier=applied_task_identifier, + job_name=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, compute_commmand=inference_process_command, diff --git a/gigl/src/training/v2/glt_trainer.py b/gigl/src/training/v2/glt_trainer.py index d4d5a3d5c..ff2acc5f9 100644 --- a/gigl/src/training/v2/glt_trainer.py +++ b/gigl/src/training/v2/glt_trainer.py @@ -55,13 +55,10 @@ def __execute_VAI_training( gbml_config_pb_wrapper.trainer_config.trainer_args ) - job_name = f"gigl_train_{applied_task_identifier}" - if isinstance(resource_config.trainer_config, VertexAiResourceConfig): launch_single_pool_job( vertex_ai_resource_config=resource_config.trainer_config, - job_name=job_name, - applied_task_identifier=applied_task_identifier, + job_name=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, process_command=training_process_command, @@ -75,8 +72,7 @@ def __execute_VAI_training( elif isinstance(resource_config.trainer_config, VertexAiGraphStoreConfig): launch_graph_store_enabled_job( vertex_ai_graph_store_config=resource_config.trainer_config, - job_name=job_name, - applied_task_identifier=applied_task_identifier, + job_name=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, compute_commmand=training_process_command, diff --git a/tests/unit/src/common/vertex_ai_launcher_test.py b/tests/unit/src/common/vertex_ai_launcher_test.py index 91abd5df4..69a0d5967 100644 --- a/tests/unit/src/common/vertex_ai_launcher_test.py +++ b/tests/unit/src/common/vertex_ai_launcher_test.py @@ -162,8 +162,8 @@ class TestVertexAILauncher(TestCase): def test_launch_training_graph_store_cuda(self, mock_vertex_ai_service_class): """Test launching a training job with graph store enabled and GPU/CUDA configuration.""" # Define test inputs - job_name = "test-training-job" - applied_task_identifier = "test-training-task" + job_name = "test-training-task" + expected_vertex_ai_job_name = f"gigl_train_{job_name}" task_config_uri = Uri("gs://bucket/task_config.yaml") resource_config_uri = Uri("gs://bucket/resource_config.yaml") process_command = "python -m gigl.src.training.v2.glt_trainer" @@ -193,7 +193,6 @@ def test_launch_training_graph_store_cuda(self, mock_vertex_ai_service_class): launch_graph_store_enabled_job( vertex_ai_graph_store_config=graph_store_config, job_name=job_name, - applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, compute_commmand=process_command, @@ -227,7 +226,7 @@ def test_launch_training_graph_store_cuda(self, mock_vertex_ai_service_class): self.assertEqual(compute_job_config.machine_type, compute_pool.machine_type) self.assertEqual(compute_job_config.accelerator_type, compute_pool.gpu_type) self.assertEqual(compute_job_config.accelerator_count, compute_pool.gpu_limit) - self.assertEqual(compute_job_config.job_name, job_name) + self.assertEqual(compute_job_config.job_name, expected_vertex_ai_job_name) # Verify compute pool command and args self.assertEqual( @@ -251,6 +250,7 @@ def test_launch_training_graph_store_cuda(self, mock_vertex_ai_service_class): # Verify storage pool config self.assertEqual(storage_job_config.machine_type, storage_pool.machine_type) self.assertEqual(storage_job_config.container_uri, cpu_docker_uri) + self.assertEqual(storage_job_config.job_name, expected_vertex_ai_job_name) self.assertIn( "gigl.distributed.graph_store.storage_main", " ".join(storage_job_config.command), @@ -270,7 +270,7 @@ def test_launch_training_graph_store_cuda(self, mock_vertex_ai_service_class): _assert_common_gigl_env_vars( test_case=self, environment_variables=compute_job_config.environment_variables, - applied_task_identifier=applied_task_identifier, + applied_task_identifier=job_name, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, cpu_docker_uri=cpu_docker_uri, @@ -280,7 +280,7 @@ def test_launch_training_graph_store_cuda(self, mock_vertex_ai_service_class): _assert_common_gigl_env_vars( test_case=self, environment_variables=storage_job_config.environment_variables, - applied_task_identifier=applied_task_identifier, + applied_task_identifier=job_name, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, cpu_docker_uri=cpu_docker_uri, @@ -300,8 +300,8 @@ def test_launch_training_graph_store_cuda(self, mock_vertex_ai_service_class): def test_launch_inference_single_pool_cpu(self, mock_vertex_ai_service_class): """Test launching an inference job with single pool and CPU-only configuration.""" # Define test inputs - job_name = "test-inference-job" - applied_task_identifier = "test-inference-task" + job_name = "test-inference-task" + expected_vertex_ai_job_name = f"gigl_infer_{job_name}" task_config_uri = Uri("gs://bucket/inference_config.yaml") resource_config_uri = Uri("gs://bucket/resource_config.yaml") process_command = "python -m gigl.src.inference.v2.glt_inferencer" @@ -335,7 +335,6 @@ def test_launch_inference_single_pool_cpu(self, mock_vertex_ai_service_class): launch_single_pool_job( vertex_ai_resource_config=vertex_ai_config, job_name=job_name, - applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, process_command=process_command, @@ -366,7 +365,7 @@ def test_launch_inference_single_pool_cpu(self, mock_vertex_ai_service_class): # Verify CPU execution settings self.assertEqual(job_config.container_uri, cpu_docker_uri) self.assertEqual(job_config.machine_type, vertex_ai_config.machine_type) - self.assertEqual(job_config.job_name, job_name) + self.assertEqual(job_config.job_name, expected_vertex_ai_job_name) self.assertEqual(job_config.accelerator_count, 0) self.assertEqual(job_config.accelerator_type, "") self.assertEqual(job_config.timeout_s, vertex_ai_config.timeout) @@ -389,7 +388,7 @@ def test_launch_inference_single_pool_cpu(self, mock_vertex_ai_service_class): _assert_common_gigl_env_vars( test_case=self, environment_variables=job_config.environment_variables, - applied_task_identifier=applied_task_identifier, + applied_task_identifier=job_name, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, cpu_docker_uri=cpu_docker_uri, @@ -410,8 +409,7 @@ def test_launch_single_pool_defaults_docker_uri_env_vars( self, mock_vertex_ai_service_class ): """Test Vertex AI env vars expose default Docker URIs when inputs are unset.""" - job_name = "test-default-image-job" - applied_task_identifier = "test-default-image-task" + job_name = "test-default-image-task" task_config_uri = Uri("gs://bucket/task_config.yaml") resource_config_uri = Uri("gs://bucket/resource_config.yaml") component = GiGLComponents.Trainer @@ -433,7 +431,6 @@ def test_launch_single_pool_defaults_docker_uri_env_vars( launch_single_pool_job( vertex_ai_resource_config=vertex_ai_config, job_name=job_name, - applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, process_command="python -m gigl.src.training.v2.glt_trainer", @@ -449,7 +446,7 @@ def test_launch_single_pool_defaults_docker_uri_env_vars( _assert_common_gigl_env_vars( test_case=self, environment_variables=job_config.environment_variables, - applied_task_identifier=applied_task_identifier, + applied_task_identifier=job_name, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, cpu_docker_uri=DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU, diff --git a/tests/unit/src/inference/v2/__init__.py b/tests/unit/src/inference/v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/src/inference/v2/glt_inferencer_test.py b/tests/unit/src/inference/v2/glt_inferencer_test.py new file mode 100644 index 000000000..2e6b700f2 --- /dev/null +++ b/tests/unit/src/inference/v2/glt_inferencer_test.py @@ -0,0 +1,180 @@ +"""Tests for ``gigl.src.inference.v2.glt_inferencer`` Vertex AI dispatch.""" + +from typing import Final +from unittest.mock import patch + +from absl.testing import absltest + +from gigl.common import Uri +from gigl.src.common.constants.components import GiGLComponents +from gigl.src.common.types import AppliedTaskIdentifier +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.common.types.pb_wrappers.gigl_resource_config import ( + GiglResourceConfigWrapper, +) +from gigl.src.inference.v2.glt_inferencer import GLTInferencer +from snapchat.research.gbml import gbml_config_pb2, gigl_resource_config_pb2 +from tests.test_assets.test_case import TestCase + +_PROCESS_COMMAND: Final[str] = "python -m gigl.src.inference.v2.glt_inferencer" +_STORAGE_COMMAND: Final[str] = "python -m gigl.distributed.graph_store.storage_main" + + +def _build_shared_resource_config() -> gigl_resource_config_pb2.SharedResourceConfig: + return gigl_resource_config_pb2.SharedResourceConfig( + resource_labels={ + "env": "test", + "cost_resource_group_tag": "unittest_COMPONENT", + "cost_resource_group": "gigl_test", + }, + common_compute_config=( + gigl_resource_config_pb2.SharedResourceConfig.CommonComputeConfig( + project="test-project", + region="us-central1", + temp_assets_bucket="gs://test-temp-bucket", + temp_regional_assets_bucket="gs://test-temp-regional-bucket", + perm_assets_bucket="gs://test-perm-bucket", + temp_assets_bq_dataset_name="test_temp_dataset", + embedding_bq_dataset_name="test_embeddings_dataset", + gcp_service_account_email="test-sa@test-project.iam.gserviceaccount.com", + dataflow_runner="DataflowRunner", + ) + ), + ) + + +def _build_resource_config_with_vertex_ai_inferencer() -> ( + gigl_resource_config_pb2.GiglResourceConfig +): + return gigl_resource_config_pb2.GiglResourceConfig( + shared_resource_config=_build_shared_resource_config(), + inferencer_resource_config=gigl_resource_config_pb2.InferencerResourceConfig( + vertex_ai_inferencer_config=gigl_resource_config_pb2.VertexAiResourceConfig( + machine_type="n1-standard-8", + num_replicas=1, + ), + ), + ) + + +def _build_resource_config_with_graph_store_inferencer() -> ( + gigl_resource_config_pb2.GiglResourceConfig +): + return gigl_resource_config_pb2.GiglResourceConfig( + shared_resource_config=_build_shared_resource_config(), + inferencer_resource_config=gigl_resource_config_pb2.InferencerResourceConfig( + vertex_ai_graph_store_inferencer_config=( + gigl_resource_config_pb2.VertexAiGraphStoreConfig( + compute_pool=gigl_resource_config_pb2.VertexAiResourceConfig( + machine_type="n1-standard-8", + num_replicas=1, + ), + graph_store_pool=gigl_resource_config_pb2.VertexAiResourceConfig( + machine_type="n1-highmem-16", + num_replicas=1, + ), + ) + ), + ), + ) + + +def _build_gbml_config_with_inferencer_command() -> gbml_config_pb2.GbmlConfig: + return gbml_config_pb2.GbmlConfig( + inferencer_config=gbml_config_pb2.GbmlConfig.InferencerConfig( + command=_PROCESS_COMMAND, + inferencer_args={"batch_size": "64"}, + graph_store_storage_config=( + gbml_config_pb2.GbmlConfig.GraphStoreStorageConfig( + command=_STORAGE_COMMAND, + storage_args={"dataset_uri": "gs://bucket/dataset"}, + ) + ), + ), + ) + + +class TestGLTInferencerVertexAiDispatch(TestCase): + """Asserts GLT inferencer forwards raw job names to launcher helpers.""" + + @patch("gigl.src.inference.v2.glt_inferencer.launch_single_pool_job") + @patch( + "gigl.src.inference.v2.glt_inferencer.GbmlConfigPbWrapper" + ".get_gbml_config_pb_wrapper_from_uri" + ) + @patch("gigl.src.inference.v2.glt_inferencer.get_resource_config") + def test_single_pool_dispatch_passes_raw_job_name( + self, + mock_get_resource_config, + mock_get_gbml, + mock_launch_single_pool_job, + ) -> None: + mock_get_resource_config.return_value = GiglResourceConfigWrapper( + resource_config=_build_resource_config_with_vertex_ai_inferencer() + ) + mock_get_gbml.return_value = GbmlConfigPbWrapper( + gbml_config_pb=_build_gbml_config_with_inferencer_command() + ) + + task_uri = Uri("gs://bucket/task.yaml") + resource_uri = Uri("gs://bucket/resource.yaml") + GLTInferencer().run( + applied_task_identifier=AppliedTaskIdentifier("job_99"), + task_config_uri=task_uri, + resource_config_uri=resource_uri, + cpu_docker_uri="gcr.io/p/cpu:tag", + cuda_docker_uri="gcr.io/p/cuda:tag", + ) + + mock_launch_single_pool_job.assert_called_once() + call_kwargs = mock_launch_single_pool_job.call_args.kwargs + self.assertEqual(call_kwargs["job_name"], "job_99") + self.assertNotIn("applied_task_identifier", call_kwargs) + self.assertEqual(call_kwargs["component"], GiGLComponents.Inferencer) + self.assertEqual(call_kwargs["task_config_uri"], task_uri) + self.assertEqual(call_kwargs["resource_config_uri"], resource_uri) + self.assertEqual(call_kwargs["process_command"], _PROCESS_COMMAND) + self.assertEqual( + dict(call_kwargs["process_runtime_args"]), + {"batch_size": "64"}, + ) + + @patch("gigl.src.inference.v2.glt_inferencer.launch_graph_store_enabled_job") + @patch( + "gigl.src.inference.v2.glt_inferencer.GbmlConfigPbWrapper" + ".get_gbml_config_pb_wrapper_from_uri" + ) + @patch("gigl.src.inference.v2.glt_inferencer.get_resource_config") + def test_graph_store_dispatch_passes_raw_job_name( + self, + mock_get_resource_config, + mock_get_gbml, + mock_launch_graph_store_enabled_job, + ) -> None: + mock_get_resource_config.return_value = GiglResourceConfigWrapper( + resource_config=_build_resource_config_with_graph_store_inferencer() + ) + mock_get_gbml.return_value = GbmlConfigPbWrapper( + gbml_config_pb=_build_gbml_config_with_inferencer_command() + ) + + GLTInferencer().run( + applied_task_identifier=AppliedTaskIdentifier("job_100"), + task_config_uri=Uri("gs://bucket/task.yaml"), + resource_config_uri=Uri("gs://bucket/resource.yaml"), + ) + + mock_launch_graph_store_enabled_job.assert_called_once() + call_kwargs = mock_launch_graph_store_enabled_job.call_args.kwargs + self.assertEqual(call_kwargs["job_name"], "job_100") + self.assertNotIn("applied_task_identifier", call_kwargs) + self.assertEqual(call_kwargs["component"], GiGLComponents.Inferencer) + self.assertEqual(call_kwargs["storage_command"], _STORAGE_COMMAND) + self.assertEqual( + dict(call_kwargs["storage_args"]), + {"dataset_uri": "gs://bucket/dataset"}, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/unit/src/training/v2/__init__.py b/tests/unit/src/training/v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/src/training/v2/glt_trainer_test.py b/tests/unit/src/training/v2/glt_trainer_test.py new file mode 100644 index 000000000..55cb6c513 --- /dev/null +++ b/tests/unit/src/training/v2/glt_trainer_test.py @@ -0,0 +1,180 @@ +"""Tests for ``gigl.src.training.v2.glt_trainer`` Vertex AI dispatch.""" + +from typing import Final +from unittest.mock import patch + +from absl.testing import absltest + +from gigl.common import Uri +from gigl.src.common.constants.components import GiGLComponents +from gigl.src.common.types import AppliedTaskIdentifier +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.common.types.pb_wrappers.gigl_resource_config import ( + GiglResourceConfigWrapper, +) +from gigl.src.training.v2.glt_trainer import GLTTrainer +from snapchat.research.gbml import gbml_config_pb2, gigl_resource_config_pb2 +from tests.test_assets.test_case import TestCase + +_PROCESS_COMMAND: Final[str] = "python -m gigl.src.training.v2.glt_trainer" +_STORAGE_COMMAND: Final[str] = "python -m gigl.distributed.graph_store.storage_main" + + +def _build_shared_resource_config() -> gigl_resource_config_pb2.SharedResourceConfig: + return gigl_resource_config_pb2.SharedResourceConfig( + resource_labels={ + "env": "test", + "cost_resource_group_tag": "unittest_COMPONENT", + "cost_resource_group": "gigl_test", + }, + common_compute_config=( + gigl_resource_config_pb2.SharedResourceConfig.CommonComputeConfig( + project="test-project", + region="us-central1", + temp_assets_bucket="gs://test-temp-bucket", + temp_regional_assets_bucket="gs://test-temp-regional-bucket", + perm_assets_bucket="gs://test-perm-bucket", + temp_assets_bq_dataset_name="test_temp_dataset", + embedding_bq_dataset_name="test_embeddings_dataset", + gcp_service_account_email="test-sa@test-project.iam.gserviceaccount.com", + dataflow_runner="DataflowRunner", + ) + ), + ) + + +def _build_resource_config_with_vertex_ai_trainer() -> ( + gigl_resource_config_pb2.GiglResourceConfig +): + return gigl_resource_config_pb2.GiglResourceConfig( + shared_resource_config=_build_shared_resource_config(), + trainer_resource_config=gigl_resource_config_pb2.TrainerResourceConfig( + vertex_ai_trainer_config=gigl_resource_config_pb2.VertexAiResourceConfig( + machine_type="n1-standard-8", + num_replicas=1, + ), + ), + ) + + +def _build_resource_config_with_graph_store_trainer() -> ( + gigl_resource_config_pb2.GiglResourceConfig +): + return gigl_resource_config_pb2.GiglResourceConfig( + shared_resource_config=_build_shared_resource_config(), + trainer_resource_config=gigl_resource_config_pb2.TrainerResourceConfig( + vertex_ai_graph_store_trainer_config=( + gigl_resource_config_pb2.VertexAiGraphStoreConfig( + compute_pool=gigl_resource_config_pb2.VertexAiResourceConfig( + machine_type="n1-standard-8", + num_replicas=1, + ), + graph_store_pool=gigl_resource_config_pb2.VertexAiResourceConfig( + machine_type="n1-highmem-16", + num_replicas=1, + ), + ) + ), + ), + ) + + +def _build_gbml_config_with_trainer_command() -> gbml_config_pb2.GbmlConfig: + return gbml_config_pb2.GbmlConfig( + trainer_config=gbml_config_pb2.GbmlConfig.TrainerConfig( + command=_PROCESS_COMMAND, + trainer_args={"lr": "0.01", "epochs": "5"}, + graph_store_storage_config=( + gbml_config_pb2.GbmlConfig.GraphStoreStorageConfig( + command=_STORAGE_COMMAND, + storage_args={"dataset_uri": "gs://bucket/dataset"}, + ) + ), + ), + ) + + +class TestGLTTrainerVertexAiDispatch(TestCase): + """Asserts GLT trainer forwards raw job names to launcher helpers.""" + + @patch("gigl.src.training.v2.glt_trainer.launch_single_pool_job") + @patch( + "gigl.src.training.v2.glt_trainer.GbmlConfigPbWrapper" + ".get_gbml_config_pb_wrapper_from_uri" + ) + @patch("gigl.src.training.v2.glt_trainer.get_resource_config") + def test_single_pool_dispatch_passes_raw_job_name( + self, + mock_get_resource_config, + mock_get_gbml, + mock_launch_single_pool_job, + ) -> None: + mock_get_resource_config.return_value = GiglResourceConfigWrapper( + resource_config=_build_resource_config_with_vertex_ai_trainer() + ) + mock_get_gbml.return_value = GbmlConfigPbWrapper( + gbml_config_pb=_build_gbml_config_with_trainer_command() + ) + + task_uri = Uri("gs://bucket/task.yaml") + resource_uri = Uri("gs://bucket/resource.yaml") + GLTTrainer().run( + applied_task_identifier=AppliedTaskIdentifier("job_77"), + task_config_uri=task_uri, + resource_config_uri=resource_uri, + cpu_docker_uri="gcr.io/p/cpu:tag", + cuda_docker_uri="gcr.io/p/cuda:tag", + ) + + mock_launch_single_pool_job.assert_called_once() + call_kwargs = mock_launch_single_pool_job.call_args.kwargs + self.assertEqual(call_kwargs["job_name"], "job_77") + self.assertNotIn("applied_task_identifier", call_kwargs) + self.assertEqual(call_kwargs["component"], GiGLComponents.Trainer) + self.assertEqual(call_kwargs["task_config_uri"], task_uri) + self.assertEqual(call_kwargs["resource_config_uri"], resource_uri) + self.assertEqual(call_kwargs["process_command"], _PROCESS_COMMAND) + self.assertEqual( + dict(call_kwargs["process_runtime_args"]), + {"lr": "0.01", "epochs": "5"}, + ) + + @patch("gigl.src.training.v2.glt_trainer.launch_graph_store_enabled_job") + @patch( + "gigl.src.training.v2.glt_trainer.GbmlConfigPbWrapper" + ".get_gbml_config_pb_wrapper_from_uri" + ) + @patch("gigl.src.training.v2.glt_trainer.get_resource_config") + def test_graph_store_dispatch_passes_raw_job_name( + self, + mock_get_resource_config, + mock_get_gbml, + mock_launch_graph_store_enabled_job, + ) -> None: + mock_get_resource_config.return_value = GiglResourceConfigWrapper( + resource_config=_build_resource_config_with_graph_store_trainer() + ) + mock_get_gbml.return_value = GbmlConfigPbWrapper( + gbml_config_pb=_build_gbml_config_with_trainer_command() + ) + + GLTTrainer().run( + applied_task_identifier=AppliedTaskIdentifier("job_88"), + task_config_uri=Uri("gs://bucket/task.yaml"), + resource_config_uri=Uri("gs://bucket/resource.yaml"), + ) + + mock_launch_graph_store_enabled_job.assert_called_once() + call_kwargs = mock_launch_graph_store_enabled_job.call_args.kwargs + self.assertEqual(call_kwargs["job_name"], "job_88") + self.assertNotIn("applied_task_identifier", call_kwargs) + self.assertEqual(call_kwargs["component"], GiGLComponents.Trainer) + self.assertEqual(call_kwargs["storage_command"], _STORAGE_COMMAND) + self.assertEqual( + dict(call_kwargs["storage_args"]), + {"dataset_uri": "gs://bucket/dataset"}, + ) + + +if __name__ == "__main__": + absltest.main()