Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 77 additions & 1 deletion python/gigl/common/utils/vertex_ai_context.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
"""Utility functions to be used by machines running on Vertex AI."""

import json
import os
import subprocess
import time
from dataclasses import dataclass
from typing import Optional

import omegaconf
from google.cloud.aiplatform_v1.types import CustomJobSpec

from gigl.common import GcsUri
from gigl.common.logger import Logger
from gigl.common.services.vertex_ai import LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY
from gigl.common.utils.gcs import GcsUtils
from gigl.distributed import DistributedContext
from gigl.env.distributed import DistributedContext

logger = Logger()

Expand Down Expand Up @@ -156,6 +162,76 @@ def connect_worker_pool() -> DistributedContext:
)


@dataclass(frozen=True)
class TaskInfo:
"""Information about the current task running on this node."""

type: str # The type of worker pool this task is running in (e.g., "workerpool0")
index: int # The zero-based index of the task
trial: Optional[
Comment thread
kmontemayor2-sc marked this conversation as resolved.
str
] = None # Hyperparameter tuning trial identifier (if applicable)


@dataclass(frozen=True)
class ClusterSpec:
"""Represents the cluster specification for a Vertex AI custom job.
See the docs for more info:
https://cloud.google.com/vertex-ai/docs/training/distributed-training#cluster-variables
"""

cluster: dict[str, list[str]] # Worker pool names mapped to their replica lists
environment: str # The environment string (e.g., "cloud")
task: TaskInfo # Information about the current task
# The CustomJobSpec for the current job
# See the docs for more info:
# https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec
job: Optional[CustomJobSpec] = None

# We use a custom method for parsing, because CustomJobSpec is a protobuf message.
@classmethod
def from_json(cls, json_str: str) -> "ClusterSpec":
"""Instantiates ClusterSpec from a JSON string."""
cluster_spec_json = json.loads(json_str)
if "job" in cluster_spec_json and cluster_spec_json["job"] is not None:
job_spec = CustomJobSpec(**cluster_spec_json.pop("job"))
else:
job_spec = None
conf = omegaconf.OmegaConf.create(cluster_spec_json)
if isinstance(conf, omegaconf.ListConfig):
raise ValueError("ListConfig is not supported")
return cls(
cluster=conf.cluster,
environment=conf.environment,
task=conf.task,
job=job_spec,
)


def get_cluster_spec() -> ClusterSpec:
"""
Parse the cluster specification from the CLUSTER_SPEC environment variable.
Based on the spec given at:
https://cloud.google.com/vertex-ai/docs/training/distributed-training#cluster-variables

Returns:
ClusterSpec: Parsed cluster specification data.

Raises:
ValueError: If not running in a Vertex AI job or CLUSTER_SPEC is not found.
json.JSONDecodeError: If CLUSTER_SPEC contains invalid JSON.
"""
if not is_currently_running_in_vertex_ai_job():
raise ValueError("Not running in a Vertex AI job.")

cluster_spec_str = os.environ.get("CLUSTER_SPEC")
if not cluster_spec_str:
raise ValueError("CLUSTER_SPEC not found in environment variables.")

cluster_spec = ClusterSpec.from_json(cluster_spec_str)
return cluster_spec


def _get_leader_worker_internal_ip_file_path() -> str:
"""
Get the file path to the leader worker's internal IP address.
Expand Down
23 changes: 5 additions & 18 deletions python/gigl/distributed/dist_context.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,6 @@
from dataclasses import dataclass
from gigl.env.distributed import DistributedContext


@dataclass(frozen=True)
class DistributedContext:
"""
GiGL Distributed Context
"""

# TODO (mkolodner-sc): Investigate adding local rank and local world size

# Main Worker's IP Address for RPC communication
main_worker_ip_address: str

# Rank of machine
global_rank: int

# Total number of machines
global_world_size: int
# TODO (mkolodner-sc): Deprecate this file.
__all__ = [
"DistributedContext",
]
21 changes: 21 additions & 0 deletions python/gigl/env/distributed.py
Comment thread
kmontemayor2-sc marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Information about distributed environments."""

from dataclasses import dataclass


@dataclass(frozen=True)
class DistributedContext:
"""
GiGL Distributed Context
"""

# TODO (mkolodner-sc): Investigate adding local rank and local world size

# Main Worker's IP Address for RPC communication
main_worker_ip_address: str

# Rank of machine
global_rank: int

# Total number of machines
global_world_size: int
109 changes: 105 additions & 4 deletions python/tests/unit/common/utils/vertex_ai_context_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import json
import os
import unittest
from unittest.mock import call, patch

from google.cloud.aiplatform_v1.types import CustomJobSpec

from gigl.common import GcsUri
from gigl.common.services.vertex_ai import LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY
from gigl.common.utils.vertex_ai_context import (
DistributedContext,
ClusterSpec,
TaskInfo,
connect_worker_pool,
get_cluster_spec,
get_host_name,
get_leader_hostname,
get_leader_port,
Expand All @@ -15,7 +20,6 @@
get_world_size,
is_currently_running_in_vertex_ai_job,
)
from gigl.distributed import DistributedContext


class TestVertexAIContext(unittest.TestCase):
Expand Down Expand Up @@ -76,7 +80,7 @@ def test_throws_if_not_on_vai(self):
},
)
def test_connect_worker_pool_leader(self, mock_upload, mock_sleep, mock_subprocess):
distributed_context: DistributedContext = connect_worker_pool()
distributed_context = connect_worker_pool()
self.assertEqual(distributed_context.main_worker_ip_address, "127.0.0.1")
self.assertEqual(distributed_context.global_rank, 0)
self.assertEqual(distributed_context.global_world_size, 2)
Expand All @@ -102,7 +106,7 @@ def test_connect_worker_pool_worker(
self, mock_upload, mock_read, mock_sleep, mock_subprocess, mock_ping_host
):
mock_ping_host.side_effect = [False, True]
distributed_context: DistributedContext = connect_worker_pool()
distributed_context = connect_worker_pool()
self.assertEqual(distributed_context.main_worker_ip_address, "127.0.0.1")
self.assertEqual(distributed_context.global_rank, 1)
self.assertEqual(distributed_context.global_world_size, 2)
Expand All @@ -113,6 +117,103 @@ def test_connect_worker_pool_worker(
]
)

def test_parse_cluster_spec_success(self):
"""Test successful parsing of cluster specification."""
cluster_spec_json = json.dumps(
{
"cluster": {
"workerpool0": ["replica0-0", "replica0-1"],
"workerpool1": ["replica1-0"],
"workerpool2": ["replica2-0"],
"workerpool3": ["replica3-0", "replica3-1"],
},
"task": {"type": "workerpool0", "index": 1, "trial": "trial-123"},
"environment": "cloud",
"job": {
"worker_pool_specs": [
{"machine_spec": {"machine_type": "n1-standard-4"}}
]
},
}
)

with patch.dict(
os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json}
):
cluster_spec = get_cluster_spec()
expected_cluster_spec = ClusterSpec(
cluster={
"workerpool0": ["replica0-0", "replica0-1"],
"workerpool1": ["replica1-0"],
"workerpool2": ["replica2-0"],
"workerpool3": ["replica3-0", "replica3-1"],
},
environment="cloud",
task=TaskInfo(type="workerpool0", index=1, trial="trial-123"),
job=CustomJobSpec(
worker_pool_specs=[
{"machine_spec": {"machine_type": "n1-standard-4"}}
]
),
)
self.assertEqual(cluster_spec, expected_cluster_spec)

def test_parse_cluster_spec_success_without_job(self):
"""Test successful parsing of cluster specification."""
cluster_spec_json = json.dumps(
{
"cluster": {
"workerpool0": ["replica0-0", "replica0-1"],
"workerpool1": ["replica1-0"],
"workerpool2": ["replica2-0"],
"workerpool3": ["replica3-0", "replica3-1"],
},
"task": {"type": "workerpool0", "index": 1, "trial": "trial-123"},
"environment": "cloud",
}
)

with patch.dict(
os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json}
):
cluster_spec = get_cluster_spec()
expected_cluster_spec = ClusterSpec(
cluster={
"workerpool0": ["replica0-0", "replica0-1"],
"workerpool1": ["replica1-0"],
"workerpool2": ["replica2-0"],
"workerpool3": ["replica3-0", "replica3-1"],
},
environment="cloud",
task=TaskInfo(type="workerpool0", index=1, trial="trial-123"),
)

self.assertEqual(cluster_spec, expected_cluster_spec)

def test_parse_cluster_spec_not_on_vai(self):
"""Test that function raises ValueError when not running in Vertex AI."""
with self.assertRaises(ValueError) as context:
get_cluster_spec()
self.assertIn("Not running in a Vertex AI job", str(context.exception))

def test_parse_cluster_spec_missing_cluster_spec(self):
"""Test that function raises ValueError when CLUSTER_SPEC is missing."""
with patch.dict(os.environ, self.VAI_JOB_ENV):
with self.assertRaises(ValueError) as context:
get_cluster_spec()
self.assertIn(
"CLUSTER_SPEC not found in environment variables",
str(context.exception),
)

def test_parse_cluster_spec_invalid_json(self):
"""Test that function raises JSONDecodeError for invalid JSON."""
with patch.dict(
os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": "invalid json"}
):
with self.assertRaises(json.JSONDecodeError) as context:
get_cluster_spec()


if __name__ == "__main__":
unittest.main()