From 74d8df17b7e8f8301a45cbd3957e45bc06dece95 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 7 Oct 2025 15:30:13 +0000 Subject: [PATCH 1/8] Add utils to parse VAI CLUSTER_SPEC --- python/gigl/common/utils/vertex_ai_context.py | 106 +++++++++++- python/gigl/env/distributed.py | 21 +++ .../common/utils/vertex_ai_context_test.py | 159 +++++++++++++++++- 3 files changed, 281 insertions(+), 5 deletions(-) create mode 100644 python/gigl/env/distributed.py diff --git a/python/gigl/common/utils/vertex_ai_context.py b/python/gigl/common/utils/vertex_ai_context.py index dfdf569c1..d155e29ac 100644 --- a/python/gigl/common/utils/vertex_ai_context.py +++ b/python/gigl/common/utils/vertex_ai_context.py @@ -1,14 +1,19 @@ """Utility functions to be used by machines running on Vertex AI.""" +import json import os import subprocess import time +from dataclasses import dataclass +from typing import Optional + +from google.cloud.aiplatform_v1.types import CustomJobSpec from gigl.common import GcsUri from gigl.common.logger import Logger from gigl.common.services.vertex_ai import LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY from gigl.common.utils.gcs import GcsUtils -from gigl.distributed import DistributedContext +from gigl.env.distributed import DistributedContext logger = Logger() @@ -156,6 +161,105 @@ def connect_worker_pool() -> DistributedContext: ) +def get_num_storage_and_compute_nodes() -> tuple[int, int]: + """ + Returns the number of storage and compute nodes for a Vertex AI job. + + Raises: + ValueError: If not running in a Vertex AI job. + """ + if not is_currently_running_in_vertex_ai_job(): + raise ValueError("Not running in a Vertex AI job.") + + cluster_spec = _parse_cluster_spec() + if len(cluster_spec.cluster) != 4: + raise ValueError( + f"Cluster specification must have 4 worker pools to fetch the number of storage and compute nodes. Found {len(cluster_spec.cluster)} worker pools." + ) + num_storage_nodes = len(cluster_spec.cluster["workerpool0"]) + len( + cluster_spec.cluster["workerpool1"] + ) + num_compute_nodes = len(cluster_spec.cluster["workerpool3"]) + + return num_storage_nodes, num_compute_nodes + + +@dataclass +class TaskInfo: + """Information about the current task running on this node.""" + + type: str # The type of worker pool this task is running in (e.g., "workerpool0") + index: int # The zero-based index of the task + trial: Optional[ + str + ] = None # Hyperparameter tuning trial identifier (if applicable) + + +@dataclass +class ClusterSpec: + """Represents the cluster specification for a Vertex AI custom job.""" + + cluster: dict[str, list[str]] # Worker pool names mapped to their replica lists + environment: str # The environment string (e.g., "cloud") + task: TaskInfo # Information about the current task + job: Optional[CustomJobSpec] = None # The CustomJobSpec for the current job + + +def _parse_cluster_spec() -> ClusterSpec: + """ + Parse the cluster specification from the CLUSTER_SPEC environment variable. + Based on the spec given at: + https://cloud.google.com/vertex-ai/docs/training/distributed-training#cluster-variables + + Returns: + ClusterSpec: Parsed cluster specification data. + + Raises: + ValueError: If not running in a Vertex AI job or CLUSTER_SPEC is not found. + json.JSONDecodeError: If CLUSTER_SPEC contains invalid JSON. + """ + if not is_currently_running_in_vertex_ai_job(): + raise ValueError("Not running in a Vertex AI job.") + + cluster_spec_json = os.environ.get("CLUSTER_SPEC") + if not cluster_spec_json: + raise ValueError("CLUSTER_SPEC not found in environment variables.") + + try: + cluster_spec_data = json.loads(cluster_spec_json) + except json.JSONDecodeError as e: + raise json.JSONDecodeError( + f"Failed to parse CLUSTER_SPEC JSON: {e.msg}", e.doc, e.pos + ) + + # Parse the task information + task_data = cluster_spec_data.get("task", {}) + task_info = TaskInfo( + type=task_data.get("type", ""), + index=task_data.get("index", 0), + trial=task_data.get("trial"), + ) + + # Parse the cluster specification + cluster_data = cluster_spec_data.get("cluster", {}) + + # Parse the environment + environment = cluster_spec_data.get("environment", "cloud") + + # Parse the job specification (optional) + job_data = cluster_spec_data.get("job") + job_spec = None + if job_data: + # Convert the dictionary to CustomJobSpec + # Note: This assumes the job_data is already in the correct format + # You may need to adjust this based on the actual structure + job_spec = CustomJobSpec(**job_data) + + return ClusterSpec( + cluster=cluster_data, environment=environment, task=task_info, job=job_spec + ) + + def _get_leader_worker_internal_ip_file_path() -> str: """ Get the file path to the leader worker's internal IP address. diff --git a/python/gigl/env/distributed.py b/python/gigl/env/distributed.py new file mode 100644 index 000000000..84466dde4 --- /dev/null +++ b/python/gigl/env/distributed.py @@ -0,0 +1,21 @@ +"""Information about distributed environments.""" + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class DistributedContext: + """ + GiGL Distributed Context + """ + + # TODO (mkolodner-sc): Investigate adding local rank and local world size + + # Main Worker's IP Address for RPC communication + main_worker_ip_address: str + + # Rank of machine + global_rank: int + + # Total number of machines + global_world_size: int diff --git a/python/tests/unit/common/utils/vertex_ai_context_test.py b/python/tests/unit/common/utils/vertex_ai_context_test.py index d3a5132b2..7c4649e14 100644 --- a/python/tests/unit/common/utils/vertex_ai_context_test.py +++ b/python/tests/unit/common/utils/vertex_ai_context_test.py @@ -1,3 +1,4 @@ +import json import os import unittest from unittest.mock import call, patch @@ -5,17 +6,17 @@ from gigl.common import GcsUri from gigl.common.services.vertex_ai import LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY from gigl.common.utils.vertex_ai_context import ( - DistributedContext, + _parse_cluster_spec, connect_worker_pool, get_host_name, get_leader_hostname, get_leader_port, + get_num_storage_and_compute_nodes, get_rank, get_vertex_ai_job_id, get_world_size, is_currently_running_in_vertex_ai_job, ) -from gigl.distributed import DistributedContext class TestVertexAIContext(unittest.TestCase): @@ -76,7 +77,7 @@ def test_throws_if_not_on_vai(self): }, ) def test_connect_worker_pool_leader(self, mock_upload, mock_sleep, mock_subprocess): - distributed_context: DistributedContext = connect_worker_pool() + distributed_context = connect_worker_pool() self.assertEqual(distributed_context.main_worker_ip_address, "127.0.0.1") self.assertEqual(distributed_context.global_rank, 0) self.assertEqual(distributed_context.global_world_size, 2) @@ -102,7 +103,7 @@ def test_connect_worker_pool_worker( self, mock_upload, mock_read, mock_sleep, mock_subprocess, mock_ping_host ): mock_ping_host.side_effect = [False, True] - distributed_context: DistributedContext = connect_worker_pool() + distributed_context = connect_worker_pool() self.assertEqual(distributed_context.main_worker_ip_address, "127.0.0.1") self.assertEqual(distributed_context.global_rank, 1) self.assertEqual(distributed_context.global_world_size, 2) @@ -113,6 +114,156 @@ def test_connect_worker_pool_worker( ] ) + def test_get_num_storage_and_compute_nodes_success(self): + """Test successful retrieval of storage and compute node counts.""" + cluster_spec_json = json.dumps( + { + "cluster": { + "workerpool0": ["replica-0", "replica-1"], + "workerpool1": ["replica-0"], + "workerpool2": ["replica-0"], + "workerpool3": ["replica-0", "replica-1", "replica-2"], + }, + "task": {"type": "workerpool0", "index": 0}, + "environment": "cloud", + } + ) + + with patch.dict( + os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} + ): + num_storage, num_compute = get_num_storage_and_compute_nodes() + self.assertEqual(num_storage, 3) # workerpool0 (2) + workerpool1 (1) + self.assertEqual(num_compute, 3) # workerpool3 (3) + + def test_get_num_storage_and_compute_nodes_not_on_vai(self): + """Test that function raises ValueError when not running in Vertex AI.""" + with self.assertRaises(ValueError) as context: + get_num_storage_and_compute_nodes() + self.assertIn("Not running in a Vertex AI job", str(context.exception)) + + def test_get_num_storage_and_compute_nodes_invalid_worker_pools(self): + """Test that function raises ValueError when cluster doesn't have 4 worker pools.""" + cluster_spec_json = json.dumps( + { + "cluster": {"workerpool0": ["replica-0"], "workerpool1": ["replica-0"]}, + "task": {"type": "workerpool0", "index": 0}, + "environment": "cloud", + } + ) + + with patch.dict( + os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} + ): + with self.assertRaises(ValueError) as context: + get_num_storage_and_compute_nodes() + self.assertIn( + "Cluster specification must have 4 worker pools", str(context.exception) + ) + self.assertIn("Found 2 worker pools", str(context.exception)) + + def test_parse_cluster_spec_success(self): + """Test successful parsing of cluster specification.""" + cluster_spec_json = json.dumps( + { + "cluster": { + "workerpool0": ["replica-0", "replica-1"], + "workerpool1": ["replica-0"], + "workerpool2": ["replica-0"], + "workerpool3": ["replica-0", "replica-1"], + }, + "task": {"type": "workerpool0", "index": 1, "trial": "trial-123"}, + "environment": "cloud", + "job": { + "worker_pool_specs": [ + {"machine_spec": {"machine_type": "n1-standard-4"}} + ] + }, + } + ) + + with patch.dict( + os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} + ): + cluster_spec = _parse_cluster_spec() + + # Test cluster data + self.assertEqual(len(cluster_spec.cluster), 4) + self.assertEqual( + cluster_spec.cluster["workerpool0"], ["replica-0", "replica-1"] + ) + self.assertEqual(cluster_spec.cluster["workerpool1"], ["replica-0"]) + + # Test task info + self.assertEqual(cluster_spec.task.type, "workerpool0") + self.assertEqual(cluster_spec.task.index, 1) + self.assertEqual(cluster_spec.task.trial, "trial-123") + + # Test environment + self.assertEqual(cluster_spec.environment, "cloud") + + # Test job spec + self.assertIsNotNone(cluster_spec.job) + + def test_parse_cluster_spec_minimal(self): + """Test parsing of minimal cluster specification without optional fields.""" + cluster_spec_json = json.dumps( + { + "cluster": { + "workerpool0": ["replica-0"], + "workerpool1": ["replica-0"], + "workerpool2": ["replica-0"], + "workerpool3": ["replica-0"], + }, + "task": {"type": "workerpool0", "index": 0}, + "environment": "cloud", + } + ) + + with patch.dict( + os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} + ): + cluster_spec = _parse_cluster_spec() + + # Test cluster data + self.assertEqual(len(cluster_spec.cluster), 4) + + # Test task info with defaults + self.assertEqual(cluster_spec.task.type, "workerpool0") + self.assertEqual(cluster_spec.task.index, 0) + self.assertIsNone(cluster_spec.task.trial) + + # Test environment + self.assertEqual(cluster_spec.environment, "cloud") + + # Test job spec (should be None when not provided) + self.assertIsNone(cluster_spec.job) + + def test_parse_cluster_spec_not_on_vai(self): + """Test that function raises ValueError when not running in Vertex AI.""" + with self.assertRaises(ValueError) as context: + _parse_cluster_spec() + self.assertIn("Not running in a Vertex AI job", str(context.exception)) + + def test_parse_cluster_spec_missing_cluster_spec(self): + """Test that function raises ValueError when CLUSTER_SPEC is missing.""" + with patch.dict(os.environ, self.VAI_JOB_ENV): + with self.assertRaises(ValueError) as context: + _parse_cluster_spec() + self.assertIn( + "CLUSTER_SPEC not found in environment variables", + str(context.exception), + ) + + def test_parse_cluster_spec_invalid_json(self): + """Test that function raises JSONDecodeError for invalid JSON.""" + with patch.dict( + os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": "invalid json"} + ): + with self.assertRaises(json.JSONDecodeError) as context: + _parse_cluster_spec() + self.assertIn("Failed to parse CLUSTER_SPEC JSON", str(context.exception)) + if __name__ == "__main__": unittest.main() From de1de6ac470a05634cb65ec5af439c936566b69d Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 7 Oct 2025 15:33:35 +0000 Subject: [PATCH 2/8] comments --- python/gigl/common/utils/vertex_ai_context.py | 38 +++--------- .../common/utils/vertex_ai_context_test.py | 61 ++----------------- 2 files changed, 15 insertions(+), 84 deletions(-) diff --git a/python/gigl/common/utils/vertex_ai_context.py b/python/gigl/common/utils/vertex_ai_context.py index d155e29ac..f8572906f 100644 --- a/python/gigl/common/utils/vertex_ai_context.py +++ b/python/gigl/common/utils/vertex_ai_context.py @@ -161,29 +161,6 @@ def connect_worker_pool() -> DistributedContext: ) -def get_num_storage_and_compute_nodes() -> tuple[int, int]: - """ - Returns the number of storage and compute nodes for a Vertex AI job. - - Raises: - ValueError: If not running in a Vertex AI job. - """ - if not is_currently_running_in_vertex_ai_job(): - raise ValueError("Not running in a Vertex AI job.") - - cluster_spec = _parse_cluster_spec() - if len(cluster_spec.cluster) != 4: - raise ValueError( - f"Cluster specification must have 4 worker pools to fetch the number of storage and compute nodes. Found {len(cluster_spec.cluster)} worker pools." - ) - num_storage_nodes = len(cluster_spec.cluster["workerpool0"]) + len( - cluster_spec.cluster["workerpool1"] - ) - num_compute_nodes = len(cluster_spec.cluster["workerpool3"]) - - return num_storage_nodes, num_compute_nodes - - @dataclass class TaskInfo: """Information about the current task running on this node.""" @@ -197,15 +174,21 @@ class TaskInfo: @dataclass class ClusterSpec: - """Represents the cluster specification for a Vertex AI custom job.""" + """Represents the cluster specification for a Vertex AI custom job. + See the docs for more info: + https://cloud.google.com/vertex-ai/docs/training/distributed-training#cluster-variables + """ cluster: dict[str, list[str]] # Worker pool names mapped to their replica lists environment: str # The environment string (e.g., "cloud") task: TaskInfo # Information about the current task - job: Optional[CustomJobSpec] = None # The CustomJobSpec for the current job + # The CustomJobSpec for the current job + # See the docs for more info: + # https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec + job: Optional[CustomJobSpec] = None -def _parse_cluster_spec() -> ClusterSpec: +def parse_cluster_spec() -> ClusterSpec: """ Parse the cluster specification from the CLUSTER_SPEC environment variable. Based on the spec given at: @@ -250,9 +233,6 @@ def _parse_cluster_spec() -> ClusterSpec: job_data = cluster_spec_data.get("job") job_spec = None if job_data: - # Convert the dictionary to CustomJobSpec - # Note: This assumes the job_data is already in the correct format - # You may need to adjust this based on the actual structure job_spec = CustomJobSpec(**job_data) return ClusterSpec( diff --git a/python/tests/unit/common/utils/vertex_ai_context_test.py b/python/tests/unit/common/utils/vertex_ai_context_test.py index 7c4649e14..bcd3a4ca4 100644 --- a/python/tests/unit/common/utils/vertex_ai_context_test.py +++ b/python/tests/unit/common/utils/vertex_ai_context_test.py @@ -6,16 +6,15 @@ from gigl.common import GcsUri from gigl.common.services.vertex_ai import LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY from gigl.common.utils.vertex_ai_context import ( - _parse_cluster_spec, connect_worker_pool, get_host_name, get_leader_hostname, get_leader_port, - get_num_storage_and_compute_nodes, get_rank, get_vertex_ai_job_id, get_world_size, is_currently_running_in_vertex_ai_job, + parse_cluster_spec, ) @@ -114,54 +113,6 @@ def test_connect_worker_pool_worker( ] ) - def test_get_num_storage_and_compute_nodes_success(self): - """Test successful retrieval of storage and compute node counts.""" - cluster_spec_json = json.dumps( - { - "cluster": { - "workerpool0": ["replica-0", "replica-1"], - "workerpool1": ["replica-0"], - "workerpool2": ["replica-0"], - "workerpool3": ["replica-0", "replica-1", "replica-2"], - }, - "task": {"type": "workerpool0", "index": 0}, - "environment": "cloud", - } - ) - - with patch.dict( - os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} - ): - num_storage, num_compute = get_num_storage_and_compute_nodes() - self.assertEqual(num_storage, 3) # workerpool0 (2) + workerpool1 (1) - self.assertEqual(num_compute, 3) # workerpool3 (3) - - def test_get_num_storage_and_compute_nodes_not_on_vai(self): - """Test that function raises ValueError when not running in Vertex AI.""" - with self.assertRaises(ValueError) as context: - get_num_storage_and_compute_nodes() - self.assertIn("Not running in a Vertex AI job", str(context.exception)) - - def test_get_num_storage_and_compute_nodes_invalid_worker_pools(self): - """Test that function raises ValueError when cluster doesn't have 4 worker pools.""" - cluster_spec_json = json.dumps( - { - "cluster": {"workerpool0": ["replica-0"], "workerpool1": ["replica-0"]}, - "task": {"type": "workerpool0", "index": 0}, - "environment": "cloud", - } - ) - - with patch.dict( - os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} - ): - with self.assertRaises(ValueError) as context: - get_num_storage_and_compute_nodes() - self.assertIn( - "Cluster specification must have 4 worker pools", str(context.exception) - ) - self.assertIn("Found 2 worker pools", str(context.exception)) - def test_parse_cluster_spec_success(self): """Test successful parsing of cluster specification.""" cluster_spec_json = json.dumps( @@ -185,7 +136,7 @@ def test_parse_cluster_spec_success(self): with patch.dict( os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} ): - cluster_spec = _parse_cluster_spec() + cluster_spec = parse_cluster_spec() # Test cluster data self.assertEqual(len(cluster_spec.cluster), 4) @@ -223,7 +174,7 @@ def test_parse_cluster_spec_minimal(self): with patch.dict( os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} ): - cluster_spec = _parse_cluster_spec() + cluster_spec = parse_cluster_spec() # Test cluster data self.assertEqual(len(cluster_spec.cluster), 4) @@ -242,14 +193,14 @@ def test_parse_cluster_spec_minimal(self): def test_parse_cluster_spec_not_on_vai(self): """Test that function raises ValueError when not running in Vertex AI.""" with self.assertRaises(ValueError) as context: - _parse_cluster_spec() + parse_cluster_spec() self.assertIn("Not running in a Vertex AI job", str(context.exception)) def test_parse_cluster_spec_missing_cluster_spec(self): """Test that function raises ValueError when CLUSTER_SPEC is missing.""" with patch.dict(os.environ, self.VAI_JOB_ENV): with self.assertRaises(ValueError) as context: - _parse_cluster_spec() + parse_cluster_spec() self.assertIn( "CLUSTER_SPEC not found in environment variables", str(context.exception), @@ -261,7 +212,7 @@ def test_parse_cluster_spec_invalid_json(self): os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": "invalid json"} ): with self.assertRaises(json.JSONDecodeError) as context: - _parse_cluster_spec() + parse_cluster_spec() self.assertIn("Failed to parse CLUSTER_SPEC JSON", str(context.exception)) From fc0dca428596d4c8843aa893ca7c38a34d475c52 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 7 Oct 2025 15:34:23 +0000 Subject: [PATCH 3/8] rename --- python/gigl/common/utils/vertex_ai_context.py | 2 +- .../unit/common/utils/vertex_ai_context_test.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/gigl/common/utils/vertex_ai_context.py b/python/gigl/common/utils/vertex_ai_context.py index f8572906f..0f3f032bc 100644 --- a/python/gigl/common/utils/vertex_ai_context.py +++ b/python/gigl/common/utils/vertex_ai_context.py @@ -188,7 +188,7 @@ class ClusterSpec: job: Optional[CustomJobSpec] = None -def parse_cluster_spec() -> ClusterSpec: +def get_cluster_spec() -> ClusterSpec: """ Parse the cluster specification from the CLUSTER_SPEC environment variable. Based on the spec given at: diff --git a/python/tests/unit/common/utils/vertex_ai_context_test.py b/python/tests/unit/common/utils/vertex_ai_context_test.py index bcd3a4ca4..8234725f6 100644 --- a/python/tests/unit/common/utils/vertex_ai_context_test.py +++ b/python/tests/unit/common/utils/vertex_ai_context_test.py @@ -14,7 +14,7 @@ get_vertex_ai_job_id, get_world_size, is_currently_running_in_vertex_ai_job, - parse_cluster_spec, + get_cluster_spec, ) @@ -136,7 +136,7 @@ def test_parse_cluster_spec_success(self): with patch.dict( os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} ): - cluster_spec = parse_cluster_spec() + cluster_spec = get_cluster_spec() # Test cluster data self.assertEqual(len(cluster_spec.cluster), 4) @@ -174,7 +174,7 @@ def test_parse_cluster_spec_minimal(self): with patch.dict( os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} ): - cluster_spec = parse_cluster_spec() + cluster_spec = get_cluster_spec() # Test cluster data self.assertEqual(len(cluster_spec.cluster), 4) @@ -193,14 +193,14 @@ def test_parse_cluster_spec_minimal(self): def test_parse_cluster_spec_not_on_vai(self): """Test that function raises ValueError when not running in Vertex AI.""" with self.assertRaises(ValueError) as context: - parse_cluster_spec() + get_cluster_spec() self.assertIn("Not running in a Vertex AI job", str(context.exception)) def test_parse_cluster_spec_missing_cluster_spec(self): """Test that function raises ValueError when CLUSTER_SPEC is missing.""" with patch.dict(os.environ, self.VAI_JOB_ENV): with self.assertRaises(ValueError) as context: - parse_cluster_spec() + get_cluster_spec() self.assertIn( "CLUSTER_SPEC not found in environment variables", str(context.exception), @@ -212,7 +212,7 @@ def test_parse_cluster_spec_invalid_json(self): os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": "invalid json"} ): with self.assertRaises(json.JSONDecodeError) as context: - parse_cluster_spec() + get_cluster_spec() self.assertIn("Failed to parse CLUSTER_SPEC JSON", str(context.exception)) From d3319d61c7148ab467f19c94063bbf5bbec2e513 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 7 Oct 2025 15:40:46 +0000 Subject: [PATCH 4/8] fixes --- python/gigl/common/utils/vertex_ai_context.py | 4 ++-- python/gigl/distributed/dist_context.py | 22 +++++-------------- .../common/utils/vertex_ai_context_test.py | 2 +- 3 files changed, 9 insertions(+), 19 deletions(-) diff --git a/python/gigl/common/utils/vertex_ai_context.py b/python/gigl/common/utils/vertex_ai_context.py index 0f3f032bc..c21571091 100644 --- a/python/gigl/common/utils/vertex_ai_context.py +++ b/python/gigl/common/utils/vertex_ai_context.py @@ -161,7 +161,7 @@ def connect_worker_pool() -> DistributedContext: ) -@dataclass +@dataclass(frozen=True) class TaskInfo: """Information about the current task running on this node.""" @@ -172,7 +172,7 @@ class TaskInfo: ] = None # Hyperparameter tuning trial identifier (if applicable) -@dataclass +@dataclass(frozen=True) class ClusterSpec: """Represents the cluster specification for a Vertex AI custom job. See the docs for more info: diff --git a/python/gigl/distributed/dist_context.py b/python/gigl/distributed/dist_context.py index da513ab87..0f222b956 100644 --- a/python/gigl/distributed/dist_context.py +++ b/python/gigl/distributed/dist_context.py @@ -1,19 +1,9 @@ -from dataclasses import dataclass +from gigl.env.distributed import DistributedContext +# TODO (mkolodner-sc): Deprecate this file. +__all__ = [ + "DeprecatedDistributedContext", +] -@dataclass(frozen=True) -class DistributedContext: - """ - GiGL Distributed Context - """ - # TODO (mkolodner-sc): Investigate adding local rank and local world size - - # Main Worker's IP Address for RPC communication - main_worker_ip_address: str - - # Rank of machine - global_rank: int - - # Total number of machines - global_world_size: int +DeprecatedDistributedContext = DistributedContext diff --git a/python/tests/unit/common/utils/vertex_ai_context_test.py b/python/tests/unit/common/utils/vertex_ai_context_test.py index 8234725f6..ee2f39ecd 100644 --- a/python/tests/unit/common/utils/vertex_ai_context_test.py +++ b/python/tests/unit/common/utils/vertex_ai_context_test.py @@ -7,6 +7,7 @@ from gigl.common.services.vertex_ai import LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY from gigl.common.utils.vertex_ai_context import ( connect_worker_pool, + get_cluster_spec, get_host_name, get_leader_hostname, get_leader_port, @@ -14,7 +15,6 @@ get_vertex_ai_job_id, get_world_size, is_currently_running_in_vertex_ai_job, - get_cluster_spec, ) From 090566427621d031021444c28fb61300560a3167 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 7 Oct 2025 15:42:28 +0000 Subject: [PATCH 5/8] fixes --- python/gigl/distributed/dist_context.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/gigl/distributed/dist_context.py b/python/gigl/distributed/dist_context.py index 0f222b956..1078bee5a 100644 --- a/python/gigl/distributed/dist_context.py +++ b/python/gigl/distributed/dist_context.py @@ -2,8 +2,5 @@ # TODO (mkolodner-sc): Deprecate this file. __all__ = [ - "DeprecatedDistributedContext", + "DistributedContext", ] - - -DeprecatedDistributedContext = DistributedContext From 112d0ad605ba62b10dc541b70705e3c0d716f9a0 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 7 Oct 2025 15:44:27 +0000 Subject: [PATCH 6/8] fix --- .../common/utils/vertex_ai_context_test.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/python/tests/unit/common/utils/vertex_ai_context_test.py b/python/tests/unit/common/utils/vertex_ai_context_test.py index ee2f39ecd..28e98b4e6 100644 --- a/python/tests/unit/common/utils/vertex_ai_context_test.py +++ b/python/tests/unit/common/utils/vertex_ai_context_test.py @@ -118,10 +118,10 @@ def test_parse_cluster_spec_success(self): cluster_spec_json = json.dumps( { "cluster": { - "workerpool0": ["replica-0", "replica-1"], - "workerpool1": ["replica-0"], - "workerpool2": ["replica-0"], - "workerpool3": ["replica-0", "replica-1"], + "workerpool0": ["replica0-0", "replica0-1"], + "workerpool1": ["replica1-0"], + "workerpool2": ["replica2-0"], + "workerpool3": ["replica3-0", "replica3-1"], }, "task": {"type": "workerpool0", "index": 1, "trial": "trial-123"}, "environment": "cloud", @@ -141,9 +141,13 @@ def test_parse_cluster_spec_success(self): # Test cluster data self.assertEqual(len(cluster_spec.cluster), 4) self.assertEqual( - cluster_spec.cluster["workerpool0"], ["replica-0", "replica-1"] + cluster_spec.cluster["workerpool0"], ["replica0-0", "replica0-1"] + ) + self.assertEqual(cluster_spec.cluster["workerpool1"], ["replica1-0"]) + self.assertEqual(cluster_spec.cluster["workerpool2"], ["replica2-0"]) + self.assertEqual( + cluster_spec.cluster["workerpool3"], ["replica3-0", "replica3-1"] ) - self.assertEqual(cluster_spec.cluster["workerpool1"], ["replica-0"]) # Test task info self.assertEqual(cluster_spec.task.type, "workerpool0") @@ -161,10 +165,10 @@ def test_parse_cluster_spec_minimal(self): cluster_spec_json = json.dumps( { "cluster": { - "workerpool0": ["replica-0"], - "workerpool1": ["replica-0"], - "workerpool2": ["replica-0"], - "workerpool3": ["replica-0"], + "workerpool0": ["replica0-0"], + "workerpool1": ["replica1-0"], + "workerpool2": ["replica2-0"], + "workerpool3": ["replica3-0"], }, "task": {"type": "workerpool0", "index": 0}, "environment": "cloud", From f621bc7f120fca86d4760856d7c340d8b726d43e Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 8 Oct 2025 16:13:58 +0000 Subject: [PATCH 7/8] address comments --- python/gigl/common/utils/vertex_ai_context.py | 56 ++++++-------- .../common/utils/vertex_ai_context_test.py | 75 +++++++++---------- 2 files changed, 59 insertions(+), 72 deletions(-) diff --git a/python/gigl/common/utils/vertex_ai_context.py b/python/gigl/common/utils/vertex_ai_context.py index c21571091..667044435 100644 --- a/python/gigl/common/utils/vertex_ai_context.py +++ b/python/gigl/common/utils/vertex_ai_context.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from typing import Optional +import omegaconf from google.cloud.aiplatform_v1.types import CustomJobSpec from gigl.common import GcsUri @@ -187,6 +188,25 @@ class ClusterSpec: # https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec job: Optional[CustomJobSpec] = None + # We use a custom method for parsing, because we need to handle the DictConfig -> Proto conversion + @classmethod + def from_json(cls, json_str: str) -> "ClusterSpec": + """Instantiates ClusterSpec from an OmegaConf DictConfig.""" + cluster_spec_json = json.loads(json_str) + if "job" in cluster_spec_json and cluster_spec_json["job"] is not None: + job_spec = CustomJobSpec(**cluster_spec_json.pop("job")) + else: + job_spec = None + conf = omegaconf.OmegaConf.create(cluster_spec_json) + if isinstance(conf, omegaconf.ListConfig): + raise ValueError("ListConfig is not supported") + return cls( + cluster=conf.cluster, + environment=conf.environment, + task=conf.task, + job=job_spec, + ) + def get_cluster_spec() -> ClusterSpec: """ @@ -204,40 +224,12 @@ def get_cluster_spec() -> ClusterSpec: if not is_currently_running_in_vertex_ai_job(): raise ValueError("Not running in a Vertex AI job.") - cluster_spec_json = os.environ.get("CLUSTER_SPEC") - if not cluster_spec_json: + cluster_spec_str = os.environ.get("CLUSTER_SPEC") + if not cluster_spec_str: raise ValueError("CLUSTER_SPEC not found in environment variables.") - try: - cluster_spec_data = json.loads(cluster_spec_json) - except json.JSONDecodeError as e: - raise json.JSONDecodeError( - f"Failed to parse CLUSTER_SPEC JSON: {e.msg}", e.doc, e.pos - ) - - # Parse the task information - task_data = cluster_spec_data.get("task", {}) - task_info = TaskInfo( - type=task_data.get("type", ""), - index=task_data.get("index", 0), - trial=task_data.get("trial"), - ) - - # Parse the cluster specification - cluster_data = cluster_spec_data.get("cluster", {}) - - # Parse the environment - environment = cluster_spec_data.get("environment", "cloud") - - # Parse the job specification (optional) - job_data = cluster_spec_data.get("job") - job_spec = None - if job_data: - job_spec = CustomJobSpec(**job_data) - - return ClusterSpec( - cluster=cluster_data, environment=environment, task=task_info, job=job_spec - ) + cluster_spec = ClusterSpec.from_json(cluster_spec_str) + return cluster_spec def _get_leader_worker_internal_ip_file_path() -> str: diff --git a/python/tests/unit/common/utils/vertex_ai_context_test.py b/python/tests/unit/common/utils/vertex_ai_context_test.py index 28e98b4e6..db6ad7cc7 100644 --- a/python/tests/unit/common/utils/vertex_ai_context_test.py +++ b/python/tests/unit/common/utils/vertex_ai_context_test.py @@ -3,9 +3,13 @@ import unittest from unittest.mock import call, patch +from google.cloud.aiplatform_v1.types import CustomJobSpec + from gigl.common import GcsUri from gigl.common.services.vertex_ai import LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY from gigl.common.utils.vertex_ai_context import ( + ClusterSpec, + TaskInfo, connect_worker_pool, get_cluster_spec, get_host_name, @@ -137,40 +141,34 @@ def test_parse_cluster_spec_success(self): os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} ): cluster_spec = get_cluster_spec() - - # Test cluster data - self.assertEqual(len(cluster_spec.cluster), 4) - self.assertEqual( - cluster_spec.cluster["workerpool0"], ["replica0-0", "replica0-1"] - ) - self.assertEqual(cluster_spec.cluster["workerpool1"], ["replica1-0"]) - self.assertEqual(cluster_spec.cluster["workerpool2"], ["replica2-0"]) - self.assertEqual( - cluster_spec.cluster["workerpool3"], ["replica3-0", "replica3-1"] + expected_cluster_spec = ClusterSpec( + cluster={ + "workerpool0": ["replica0-0", "replica0-1"], + "workerpool1": ["replica1-0"], + "workerpool2": ["replica2-0"], + "workerpool3": ["replica3-0", "replica3-1"], + }, + environment="cloud", + task=TaskInfo(type="workerpool0", index=1, trial="trial-123"), + job=CustomJobSpec( + worker_pool_specs=[ + {"machine_spec": {"machine_type": "n1-standard-4"}} + ] + ), ) + self.assertEqual(cluster_spec, expected_cluster_spec) - # Test task info - self.assertEqual(cluster_spec.task.type, "workerpool0") - self.assertEqual(cluster_spec.task.index, 1) - self.assertEqual(cluster_spec.task.trial, "trial-123") - - # Test environment - self.assertEqual(cluster_spec.environment, "cloud") - - # Test job spec - self.assertIsNotNone(cluster_spec.job) - - def test_parse_cluster_spec_minimal(self): - """Test parsing of minimal cluster specification without optional fields.""" + def test_parse_cluster_spec_success_without_job(self): + """Test successful parsing of cluster specification.""" cluster_spec_json = json.dumps( { "cluster": { - "workerpool0": ["replica0-0"], + "workerpool0": ["replica0-0", "replica0-1"], "workerpool1": ["replica1-0"], "workerpool2": ["replica2-0"], - "workerpool3": ["replica3-0"], + "workerpool3": ["replica3-0", "replica3-1"], }, - "task": {"type": "workerpool0", "index": 0}, + "task": {"type": "workerpool0", "index": 1, "trial": "trial-123"}, "environment": "cloud", } ) @@ -179,20 +177,18 @@ def test_parse_cluster_spec_minimal(self): os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} ): cluster_spec = get_cluster_spec() + expected_cluster_spec = ClusterSpec( + cluster={ + "workerpool0": ["replica0-0", "replica0-1"], + "workerpool1": ["replica1-0"], + "workerpool2": ["replica2-0"], + "workerpool3": ["replica3-0", "replica3-1"], + }, + environment="cloud", + task=TaskInfo(type="workerpool0", index=1, trial="trial-123"), + ) - # Test cluster data - self.assertEqual(len(cluster_spec.cluster), 4) - - # Test task info with defaults - self.assertEqual(cluster_spec.task.type, "workerpool0") - self.assertEqual(cluster_spec.task.index, 0) - self.assertIsNone(cluster_spec.task.trial) - - # Test environment - self.assertEqual(cluster_spec.environment, "cloud") - - # Test job spec (should be None when not provided) - self.assertIsNone(cluster_spec.job) + self.assertEqual(cluster_spec, expected_cluster_spec) def test_parse_cluster_spec_not_on_vai(self): """Test that function raises ValueError when not running in Vertex AI.""" @@ -217,7 +213,6 @@ def test_parse_cluster_spec_invalid_json(self): ): with self.assertRaises(json.JSONDecodeError) as context: get_cluster_spec() - self.assertIn("Failed to parse CLUSTER_SPEC JSON", str(context.exception)) if __name__ == "__main__": From 9b9970638b34097e97459ca8f9a614d8012c575f Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 8 Oct 2025 16:17:18 +0000 Subject: [PATCH 8/8] reword --- python/gigl/common/utils/vertex_ai_context.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/gigl/common/utils/vertex_ai_context.py b/python/gigl/common/utils/vertex_ai_context.py index 667044435..35d1d90ad 100644 --- a/python/gigl/common/utils/vertex_ai_context.py +++ b/python/gigl/common/utils/vertex_ai_context.py @@ -188,10 +188,10 @@ class ClusterSpec: # https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec job: Optional[CustomJobSpec] = None - # We use a custom method for parsing, because we need to handle the DictConfig -> Proto conversion + # We use a custom method for parsing, because CustomJobSpec is a protobuf message. @classmethod def from_json(cls, json_str: str) -> "ClusterSpec": - """Instantiates ClusterSpec from an OmegaConf DictConfig.""" + """Instantiates ClusterSpec from a JSON string.""" cluster_spec_json = json.loads(json_str) if "job" in cluster_spec_json and cluster_spec_json["job"] is not None: job_spec = CustomJobSpec(**cluster_spec_json.pop("job"))