diff --git a/python/gigl/common/utils/vertex_ai_context.py b/python/gigl/common/utils/vertex_ai_context.py index dfdf569c1..35d1d90ad 100644 --- a/python/gigl/common/utils/vertex_ai_context.py +++ b/python/gigl/common/utils/vertex_ai_context.py @@ -1,14 +1,20 @@ """Utility functions to be used by machines running on Vertex AI.""" +import json import os import subprocess import time +from dataclasses import dataclass +from typing import Optional + +import omegaconf +from google.cloud.aiplatform_v1.types import CustomJobSpec from gigl.common import GcsUri from gigl.common.logger import Logger from gigl.common.services.vertex_ai import LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY from gigl.common.utils.gcs import GcsUtils -from gigl.distributed import DistributedContext +from gigl.env.distributed import DistributedContext logger = Logger() @@ -156,6 +162,76 @@ def connect_worker_pool() -> DistributedContext: ) +@dataclass(frozen=True) +class TaskInfo: + """Information about the current task running on this node.""" + + type: str # The type of worker pool this task is running in (e.g., "workerpool0") + index: int # The zero-based index of the task + trial: Optional[ + str + ] = None # Hyperparameter tuning trial identifier (if applicable) + + +@dataclass(frozen=True) +class ClusterSpec: + """Represents the cluster specification for a Vertex AI custom job. + See the docs for more info: + https://cloud.google.com/vertex-ai/docs/training/distributed-training#cluster-variables + """ + + cluster: dict[str, list[str]] # Worker pool names mapped to their replica lists + environment: str # The environment string (e.g., "cloud") + task: TaskInfo # Information about the current task + # The CustomJobSpec for the current job + # See the docs for more info: + # https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec + job: Optional[CustomJobSpec] = None + + # We use a custom method for parsing, because CustomJobSpec is a protobuf message. + @classmethod + def from_json(cls, json_str: str) -> "ClusterSpec": + """Instantiates ClusterSpec from a JSON string.""" + cluster_spec_json = json.loads(json_str) + if "job" in cluster_spec_json and cluster_spec_json["job"] is not None: + job_spec = CustomJobSpec(**cluster_spec_json.pop("job")) + else: + job_spec = None + conf = omegaconf.OmegaConf.create(cluster_spec_json) + if isinstance(conf, omegaconf.ListConfig): + raise ValueError("ListConfig is not supported") + return cls( + cluster=conf.cluster, + environment=conf.environment, + task=conf.task, + job=job_spec, + ) + + +def get_cluster_spec() -> ClusterSpec: + """ + Parse the cluster specification from the CLUSTER_SPEC environment variable. + Based on the spec given at: + https://cloud.google.com/vertex-ai/docs/training/distributed-training#cluster-variables + + Returns: + ClusterSpec: Parsed cluster specification data. + + Raises: + ValueError: If not running in a Vertex AI job or CLUSTER_SPEC is not found. + json.JSONDecodeError: If CLUSTER_SPEC contains invalid JSON. + """ + if not is_currently_running_in_vertex_ai_job(): + raise ValueError("Not running in a Vertex AI job.") + + cluster_spec_str = os.environ.get("CLUSTER_SPEC") + if not cluster_spec_str: + raise ValueError("CLUSTER_SPEC not found in environment variables.") + + cluster_spec = ClusterSpec.from_json(cluster_spec_str) + return cluster_spec + + def _get_leader_worker_internal_ip_file_path() -> str: """ Get the file path to the leader worker's internal IP address. diff --git a/python/gigl/distributed/dist_context.py b/python/gigl/distributed/dist_context.py index da513ab87..1078bee5a 100644 --- a/python/gigl/distributed/dist_context.py +++ b/python/gigl/distributed/dist_context.py @@ -1,19 +1,6 @@ -from dataclasses import dataclass +from gigl.env.distributed import DistributedContext - -@dataclass(frozen=True) -class DistributedContext: - """ - GiGL Distributed Context - """ - - # TODO (mkolodner-sc): Investigate adding local rank and local world size - - # Main Worker's IP Address for RPC communication - main_worker_ip_address: str - - # Rank of machine - global_rank: int - - # Total number of machines - global_world_size: int +# TODO (mkolodner-sc): Deprecate this file. +__all__ = [ + "DistributedContext", +] diff --git a/python/gigl/env/distributed.py b/python/gigl/env/distributed.py new file mode 100644 index 000000000..84466dde4 --- /dev/null +++ b/python/gigl/env/distributed.py @@ -0,0 +1,21 @@ +"""Information about distributed environments.""" + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class DistributedContext: + """ + GiGL Distributed Context + """ + + # TODO (mkolodner-sc): Investigate adding local rank and local world size + + # Main Worker's IP Address for RPC communication + main_worker_ip_address: str + + # Rank of machine + global_rank: int + + # Total number of machines + global_world_size: int diff --git a/python/tests/unit/common/utils/vertex_ai_context_test.py b/python/tests/unit/common/utils/vertex_ai_context_test.py index d3a5132b2..db6ad7cc7 100644 --- a/python/tests/unit/common/utils/vertex_ai_context_test.py +++ b/python/tests/unit/common/utils/vertex_ai_context_test.py @@ -1,12 +1,17 @@ +import json import os import unittest from unittest.mock import call, patch +from google.cloud.aiplatform_v1.types import CustomJobSpec + from gigl.common import GcsUri from gigl.common.services.vertex_ai import LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY from gigl.common.utils.vertex_ai_context import ( - DistributedContext, + ClusterSpec, + TaskInfo, connect_worker_pool, + get_cluster_spec, get_host_name, get_leader_hostname, get_leader_port, @@ -15,7 +20,6 @@ get_world_size, is_currently_running_in_vertex_ai_job, ) -from gigl.distributed import DistributedContext class TestVertexAIContext(unittest.TestCase): @@ -76,7 +80,7 @@ def test_throws_if_not_on_vai(self): }, ) def test_connect_worker_pool_leader(self, mock_upload, mock_sleep, mock_subprocess): - distributed_context: DistributedContext = connect_worker_pool() + distributed_context = connect_worker_pool() self.assertEqual(distributed_context.main_worker_ip_address, "127.0.0.1") self.assertEqual(distributed_context.global_rank, 0) self.assertEqual(distributed_context.global_world_size, 2) @@ -102,7 +106,7 @@ def test_connect_worker_pool_worker( self, mock_upload, mock_read, mock_sleep, mock_subprocess, mock_ping_host ): mock_ping_host.side_effect = [False, True] - distributed_context: DistributedContext = connect_worker_pool() + distributed_context = connect_worker_pool() self.assertEqual(distributed_context.main_worker_ip_address, "127.0.0.1") self.assertEqual(distributed_context.global_rank, 1) self.assertEqual(distributed_context.global_world_size, 2) @@ -113,6 +117,103 @@ def test_connect_worker_pool_worker( ] ) + def test_parse_cluster_spec_success(self): + """Test successful parsing of cluster specification.""" + cluster_spec_json = json.dumps( + { + "cluster": { + "workerpool0": ["replica0-0", "replica0-1"], + "workerpool1": ["replica1-0"], + "workerpool2": ["replica2-0"], + "workerpool3": ["replica3-0", "replica3-1"], + }, + "task": {"type": "workerpool0", "index": 1, "trial": "trial-123"}, + "environment": "cloud", + "job": { + "worker_pool_specs": [ + {"machine_spec": {"machine_type": "n1-standard-4"}} + ] + }, + } + ) + + with patch.dict( + os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} + ): + cluster_spec = get_cluster_spec() + expected_cluster_spec = ClusterSpec( + cluster={ + "workerpool0": ["replica0-0", "replica0-1"], + "workerpool1": ["replica1-0"], + "workerpool2": ["replica2-0"], + "workerpool3": ["replica3-0", "replica3-1"], + }, + environment="cloud", + task=TaskInfo(type="workerpool0", index=1, trial="trial-123"), + job=CustomJobSpec( + worker_pool_specs=[ + {"machine_spec": {"machine_type": "n1-standard-4"}} + ] + ), + ) + self.assertEqual(cluster_spec, expected_cluster_spec) + + def test_parse_cluster_spec_success_without_job(self): + """Test successful parsing of cluster specification.""" + cluster_spec_json = json.dumps( + { + "cluster": { + "workerpool0": ["replica0-0", "replica0-1"], + "workerpool1": ["replica1-0"], + "workerpool2": ["replica2-0"], + "workerpool3": ["replica3-0", "replica3-1"], + }, + "task": {"type": "workerpool0", "index": 1, "trial": "trial-123"}, + "environment": "cloud", + } + ) + + with patch.dict( + os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} + ): + cluster_spec = get_cluster_spec() + expected_cluster_spec = ClusterSpec( + cluster={ + "workerpool0": ["replica0-0", "replica0-1"], + "workerpool1": ["replica1-0"], + "workerpool2": ["replica2-0"], + "workerpool3": ["replica3-0", "replica3-1"], + }, + environment="cloud", + task=TaskInfo(type="workerpool0", index=1, trial="trial-123"), + ) + + self.assertEqual(cluster_spec, expected_cluster_spec) + + def test_parse_cluster_spec_not_on_vai(self): + """Test that function raises ValueError when not running in Vertex AI.""" + with self.assertRaises(ValueError) as context: + get_cluster_spec() + self.assertIn("Not running in a Vertex AI job", str(context.exception)) + + def test_parse_cluster_spec_missing_cluster_spec(self): + """Test that function raises ValueError when CLUSTER_SPEC is missing.""" + with patch.dict(os.environ, self.VAI_JOB_ENV): + with self.assertRaises(ValueError) as context: + get_cluster_spec() + self.assertIn( + "CLUSTER_SPEC not found in environment variables", + str(context.exception), + ) + + def test_parse_cluster_spec_invalid_json(self): + """Test that function raises JSONDecodeError for invalid JSON.""" + with patch.dict( + os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": "invalid json"} + ): + with self.assertRaises(json.JSONDecodeError) as context: + get_cluster_spec() + if __name__ == "__main__": unittest.main()