From 4233f92c13c273d6f25fca7b8e9b03d00b60d4ac Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Sat, 10 Jan 2026 00:07:40 +0000 Subject: [PATCH 1/5] Setup DistNeighborloader for graph store sampling --- .../distributed/distributed_neighborloader.py | 299 +++++++++++++++--- .../gigl/distributed/graph_store/compute.py | 32 +- .../distributed/graph_store/storage_main.py | 47 ++- python/gigl/distributed/utils/networking.py | 33 +- python/gigl/env/distributed.py | 7 + .../graph_store_integration_test.py | 46 ++- python/tests/unit/env/distributed_test.py | 2 + 7 files changed, 382 insertions(+), 84 deletions(-) diff --git a/python/gigl/distributed/distributed_neighborloader.py b/python/gigl/distributed/distributed_neighborloader.py index 3be9662f0..fd094d9ad 100644 --- a/python/gigl/distributed/distributed_neighborloader.py +++ b/python/gigl/distributed/distributed_neighborloader.py @@ -1,9 +1,14 @@ from collections import Counter, abc -from typing import Optional, Tuple, Union +from dataclasses import dataclass +from typing import Literal, Optional, Tuple, Union import torch from graphlearn_torch.channel import SampleMessage -from graphlearn_torch.distributed import DistLoader, MpDistSamplingWorkerOptions +from graphlearn_torch.distributed import ( + DistLoader, + MpDistSamplingWorkerOptions, + RemoteDistSamplingWorkerOptions, +) from graphlearn_torch.sampler import NodeSamplerInput, SamplingConfig, SamplingType from torch_geometric.data import Data, HeteroData from torch_geometric.typing import EdgeType @@ -13,6 +18,7 @@ from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.utils.neighborloader import ( labeled_to_homogeneous, patch_fanout_for_sampling, @@ -26,6 +32,7 @@ from gigl.types.graph import ( DEFAULT_HOMOGENEOUS_EDGE_TYPE, DEFAULT_HOMOGENEOUS_NODE_TYPE, + FeatureInfo, ) logger = Logger() @@ -34,13 +41,27 @@ DEFAULT_NUM_CPU_THREADS = 2 +# Shared metadata between the local and remote datasets. +@dataclass(frozen=True) +class _DatasetMetadata: + is_labeled_heterogeneous: bool + node_feature_info: Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]] + edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]] + edge_dir: Union[str, Literal["in", "out"]] + + class DistNeighborLoader(DistLoader): def __init__( self, - dataset: DistDataset, + dataset: Union[DistDataset, RemoteDistDataset], num_neighbors: Union[list[int], dict[EdgeType, list[int]]], input_nodes: Optional[ - Union[torch.Tensor, Tuple[NodeType, torch.Tensor]] + Union[ + torch.Tensor, + Tuple[NodeType, torch.Tensor], + list[torch.Tensor], + Tuple[NodeType, list[torch.Tensor]], + ] ] = None, num_workers: int = 1, batch_size: int = 1, @@ -62,7 +83,7 @@ def __init__( https://pytorch-geometric.readthedocs.io/en/2.5.2/_modules/torch_geometric/distributed/dist_neighbor_loader.html#DistNeighborLoader Args: - dataset (DistDataset): The dataset to sample from. + dataset (DistDataset | RemoteDistDataset): The dataset to sample from. Must be a "RemoteDistDataset" if using Graph Store mode. num_neighbors (list[int] or dict[Tuple[str, str, str], list[int]]): The number of neighbors to sample for each node in each iteration. If an entry is set to `-1`, all neighbors will be included. @@ -71,12 +92,15 @@ def __init__( context (deprecated - will be removed soon) (DistributedContext): Distributed context information of the current process. local_process_rank (deprecated - will be removed soon) (int): Required if context provided. The local rank of the current process within a node. local_process_world_size (deprecated - will be removed soon)(int): Required if context provided. The total number of processes within a node. - input_nodes (torch.Tensor or Tuple[str, torch.Tensor]): The - indices of seed nodes to start sampling from. + input_nodes (Tenor | Tuple[NodeType, Tenor] | list[Tenor] | Tuple[NodeType, list[Tenor]]): + The nodes to start sampling from. It is of type `torch.LongTensor` for homogeneous graphs. If set to `None` for homogeneous settings, all nodes will be considered. In heterogeneous graphs, this flag must be passed in as a tuple that holds the node type and node indices. (default: `None`) + For Graph Store mode, this must be a tuple of (NodeType, list[Tenor]) or list[Tenor]. + Where each Tensor in the list is the node ids to sample from, for each server. + e.g. [[10, 20], [30, 40]] means sample from nodes 10 and 20 on server 0, and nodes 30 and 40 on server 1. num_workers (int): How many workers to use (subprocesses to spwan) for distributed neighbor sampling of the current process. (default: ``1``). batch_size (int, optional): how many samples per batch to load @@ -188,10 +212,219 @@ def __init__( local_process_rank=local_rank ) ) + + # Determines if the node ids passed in are heterogeneous or homogeneous. + self._is_labeled_heterogeneous = False + if isinstance(dataset, DistDataset): + input_data, worker_options, dataset_metadata = self._setup_for_colocated( + input_nodes, + dataset, + local_rank, + local_world_size, + device, + master_ip_address, + node_rank, + node_world_size, + process_start_gap_seconds, + num_workers, + worker_concurrency, + channel_size, + num_cpu_threads, + ) + else: # RemoteDistDataset + input_data, worker_options, dataset_metadata = self._setup_for_graph_store( + input_nodes, + dataset, + num_workers, + ) + + self._is_labeled_heterogeneous = dataset_metadata.is_labeled_heterogeneous + self._node_feature_info = dataset_metadata.node_feature_info + self._edge_feature_info = dataset_metadata.edge_feature_info + + num_neighbors = patch_fanout_for_sampling( + list(dataset_metadata.edge_feature_info.keys()) + if isinstance(dataset_metadata.edge_feature_info, dict) + else None, + num_neighbors, + ) + + sampling_config = SamplingConfig( + sampling_type=SamplingType.NODE, + num_neighbors=num_neighbors, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + with_edge=True, + collect_features=True, + with_neg=False, + with_weight=False, + edge_dir=dataset_metadata.edge_dir, + seed=None, # it's actually optional - None means random. + ) + + if should_cleanup_distributed_context and torch.distributed.is_initialized(): + logger.info( + f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__." + ) + torch.distributed.destroy_process_group() + + if isinstance(dataset, DistDataset): + super().__init__( + dataset if isinstance(dataset, DistDataset) else None, + input_data, + sampling_config, + device, + worker_options, + ) + else: + # For Graph Store mode, we need to start the communcation between compute and storage nodes sequentially, by compute node. + # E.g. intialize connections between compute node 0 and storage nodes 0, 1, 2, 3, then compute node 1 and storage nodes 0, 1, 2, 3, etc. + # Note that each compute node may have multiple connections to each storage node, once per compute process. + # E.g. if there are 4 gpus per compute node, then there will be 4 connections to each storage node. + # We need to this because if we don't, then there is a race condition when initalizing the samplers on the storage nodes [1] + # Where since the lock is per *server* (e.g. per storage node), if we try to start one connection from compute node 0, and compute node 1 + # Then we deadlock and fail. + # [1]: https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/distributed/dist_server.py#L129-L167 + node_rank = dataset.cluster_info.compute_node_rank + for target_node_rank in range(dataset.cluster_info.num_compute_nodes): + if node_rank == target_node_rank: + super().__init__( + dataset if isinstance(dataset, DistDataset) else None, + input_data, + sampling_config, + device, + worker_options, + ) + print(f"node_rank {node_rank} initialized the dist loader") + torch.distributed.barrier() + torch.distributed.barrier() + + def _setup_for_graph_store( + self, + input_nodes: Optional[ + Union[ + torch.Tensor, + Tuple[NodeType, torch.Tensor], + list[torch.Tensor], + Tuple[NodeType, list[torch.Tensor]], + ] + ], + dataset: RemoteDistDataset, + num_workers: int, + ) -> tuple[NodeSamplerInput, RemoteDistSamplingWorkerOptions, _DatasetMetadata]: + if input_nodes is None: + raise ValueError( + f"When using Graph Store mode, input nodes must be provided, received {input_nodes}" + ) + elif isinstance(input_nodes, torch.Tensor): + raise ValueError( + f"When using Graph Store mode, input nodes must be of type (list[Tensor] | (NodeType, list[torch.Tensor]), received {type(input_nodes)}" + ) + elif isinstance(input_nodes, tuple) and isinstance( + input_nodes[1], torch.Tensor + ): + raise ValueError( + f"When using Graph Store mode, input nodes must be of type (list[torch.Tensor] | (NodeType, list[torch.Tensor])), received {type(input_nodes)} ({type(input_nodes[0])}, {type(input_nodes[1])})" + ) + + is_labeled_heterogeneous = False + node_feature_info = dataset.get_node_feature_info() + edge_feature_info = dataset.get_edge_feature_info() + node_rank = dataset.cluster_info.compute_node_rank + + # Get sampling ports for compute-storage connections. + sampling_ports = dataset.get_free_ports_on_storage_cluster( + num_ports=dataset.cluster_info.num_processes_per_compute + ) + sampling_port = sampling_ports[node_rank] + + worker_options = RemoteDistSamplingWorkerOptions( + server_rank=list(range(dataset.cluster_info.num_storage_nodes)), + num_workers=num_workers, + worker_devices=[torch.device("cpu") for i in range(num_workers)], + master_addr=dataset.cluster_info.storage_cluster_master_ip, + master_port=sampling_port, + worker_key=f"compute_rank_{node_rank}", + ) logger.info( - f"Dataset Building started on {node_rank} of {node_world_size} nodes, using following node as main: {master_ip_address}" + f"Rank {torch.distributed.get_rank()}! init for sampling rpc: {f'tcp://{dataset.cluster_info.storage_cluster_master_ip}:{sampling_port}'}" + ) + + # Setup input data for the dataloader. + + # Determine nodes list and fallback input_type based on input_nodes structure + if isinstance(input_nodes, list): + nodes = input_nodes + fallback_input_type = None + require_edge_feature_info = False + elif isinstance(input_nodes, tuple) and isinstance(input_nodes[1], list): + nodes = input_nodes[1] + fallback_input_type = input_nodes[0] + require_edge_feature_info = True + else: + raise ValueError( + f"When using Graph Store mode, input nodes must be of type (list[torch.Tensor] | (NodeType, list[torch.Tensor])), received {type(input_nodes)}" + ) + + # Determine input_type based on edge_feature_info + if isinstance(edge_feature_info, dict): + if len(edge_feature_info) == 0: + raise ValueError( + "When using Graph Store mode, edge feature info must be provided for heterogeneous graphs." + ) + elif ( + len(edge_feature_info) == 1 + and DEFAULT_HOMOGENEOUS_EDGE_TYPE in edge_feature_info + ): + input_type: NodeType | None = DEFAULT_HOMOGENEOUS_NODE_TYPE + else: + input_type = fallback_input_type + elif require_edge_feature_info: + raise ValueError( + "When using Graph Store mode, edge feature info must be provided for heterogeneous graphs." + ) + else: + input_type = None + + input_data = [ + NodeSamplerInput(node=node, input_type=input_type) for node in nodes + ] + + return ( + input_data, + worker_options, + _DatasetMetadata( + is_labeled_heterogeneous=is_labeled_heterogeneous, + node_feature_info=node_feature_info, + edge_feature_info=edge_feature_info, + edge_dir=dataset.get_edge_dir(), + ), ) + def _setup_for_colocated( + self, + input_nodes: Optional[ + Union[ + torch.Tensor, + Tuple[NodeType, torch.Tensor], + list[torch.Tensor], + Tuple[NodeType, list[torch.Tensor]], + ] + ], + dataset: DistDataset, + local_rank: int, + local_world_size: int, + device: torch.device, + master_ip_address: str, + node_rank: int, + node_world_size: int, + process_start_gap_seconds: float, + num_workers: int, + worker_concurrency: int, + channel_size: str, + num_cpu_threads: Optional[int], + ) -> tuple[NodeSamplerInput, MpDistSamplingWorkerOptions, _DatasetMetadata]: if input_nodes is None: if dataset.node_ids is None: raise ValueError( @@ -202,9 +435,15 @@ def __init__( f"input_nodes must be provided for heterogeneous datasets, received node_ids of type: {dataset.node_ids.keys()}" ) input_nodes = dataset.node_ids - - # Determines if the node ids passed in are heterogeneous or homogeneous. - self._is_labeled_heterogeneous = False + if isinstance(input_nodes, list): + raise ValueError( + f"When using Colocated mode, input nodes must be of type (torch.Tensor | (NodeType, torch.Tensor)), received {type(input_nodes)}" + ) + elif isinstance(input_nodes, tuple) and isinstance(input_nodes[1], list): + raise ValueError( + f"When using Colocated mode, input nodes must be of type (torch.Tensor | (NodeType, torch.Tensor)), received {type(input_nodes)} ({type(input_nodes[0])}, {type(input_nodes[1])})" + ) + is_labeled_heterogeneous = False if isinstance(input_nodes, torch.Tensor): node_ids = input_nodes @@ -216,7 +455,7 @@ def __init__( and DEFAULT_HOMOGENEOUS_NODE_TYPE in dataset.node_ids ): node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE - self._is_labeled_heterogeneous = True + is_labeled_heterogeneous = True else: raise ValueError( f"For heterogeneous datasets, input_nodes must be a tuple of (node_type, node_ids) OR if it is a labeled homogeneous dataset, input_nodes may be a torch.Tensor. Received node types: {dataset.node_ids.keys()}" @@ -229,19 +468,12 @@ def __init__( dataset.node_ids, abc.Mapping ), "Dataset must be heterogeneous if provided input nodes are a tuple." - num_neighbors = patch_fanout_for_sampling( - dataset.get_edge_types(), num_neighbors - ) - curr_process_nodes = shard_nodes_by_process( input_nodes=node_ids, local_process_rank=local_rank, local_process_world_size=local_world_size, ) - self._node_feature_info = dataset.node_feature_info - self._edge_feature_info = dataset.edge_feature_info - input_data = NodeSamplerInput(node=curr_process_nodes, input_type=node_type) # Sets up processes and torch device for initializing the GLT DistNeighborLoader, setting up RPC and worker groups to minimize @@ -305,28 +537,17 @@ def __init__( pin_memory=device.type == "cuda", ) - sampling_config = SamplingConfig( - sampling_type=SamplingType.NODE, - num_neighbors=num_neighbors, - batch_size=batch_size, - shuffle=shuffle, - drop_last=drop_last, - with_edge=True, - collect_features=True, - with_neg=False, - with_weight=False, - edge_dir=dataset.edge_dir, - seed=None, # it's actually optional - None means random. + return ( + input_data, + worker_options, + _DatasetMetadata( + is_labeled_heterogeneous=is_labeled_heterogeneous, + node_feature_info=dataset.node_feature_info, + edge_feature_info=dataset.edge_feature_info, + edge_dir=dataset.edge_dir, + ), ) - if should_cleanup_distributed_context and torch.distributed.is_initialized(): - logger.info( - f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__." - ) - torch.distributed.destroy_process_group() - - super().__init__(dataset, input_data, sampling_config, device, worker_options) - def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: data = super()._collate_fn(msg) data = set_missing_features( diff --git a/python/gigl/distributed/graph_store/compute.py b/python/gigl/distributed/graph_store/compute.py index 36a3b66dd..6039eaef9 100644 --- a/python/gigl/distributed/graph_store/compute.py +++ b/python/gigl/distributed/graph_store/compute.py @@ -1,8 +1,8 @@ import os from typing import Optional -import graphlearn_torch as glt import torch +from graphlearn_torch.distributed.dist_client import init_client, shutdown_client from gigl.common.logger import Logger from gigl.env.distributed import GraphStoreInfo @@ -36,6 +36,21 @@ def init_compute_process( cluster_info.compute_node_rank * cluster_info.num_processes_per_compute + local_rank ) + cluster_master_ip = cluster_info.storage_cluster_master_ip + logger.info( + f"Initializing RPC client for compute node {compute_cluster_rank} / {cluster_info.compute_cluster_world_size} on {cluster_master_ip}:{cluster_info.rpc_master_port}." + f" OS rank: {os.environ['RANK']}, local compute rank: {local_rank}" + f" num_servers: {cluster_info.num_storage_nodes}, num_clients: {cluster_info.compute_cluster_world_size}" + ) + init_client( + num_servers=cluster_info.num_storage_nodes, + num_clients=cluster_info.compute_cluster_world_size, + client_rank=compute_cluster_rank, + master_addr=cluster_master_ip, + master_port=cluster_info.rpc_master_port, + client_group_name="gigl_client_rpc", + ) + logger.info( f"Initializing compute process group {compute_cluster_rank} / {cluster_info.compute_cluster_world_size}. on {cluster_info.compute_cluster_master_ip}:{cluster_info.compute_cluster_master_port} with backend {compute_world_backend}." f" OS rank: {os.environ['RANK']}, local client rank: {local_rank}" @@ -46,19 +61,6 @@ def init_compute_process( rank=compute_cluster_rank, init_method=f"tcp://{cluster_info.compute_cluster_master_ip}:{cluster_info.compute_cluster_master_port}", ) - logger.info( - f"Initializing RPC client for compute node {compute_cluster_rank} / {cluster_info.compute_cluster_world_size} on {cluster_info.cluster_master_ip}:{cluster_info.cluster_master_port}." - f" OS rank: {os.environ['RANK']}, local compute rank: {local_rank}" - f" num_servers: {cluster_info.num_storage_nodes}, num_clients: {cluster_info.compute_cluster_world_size}" - ) - glt.distributed.init_client( - num_servers=cluster_info.num_storage_nodes, - num_clients=cluster_info.compute_cluster_world_size, - client_rank=compute_cluster_rank, - master_addr=cluster_info.cluster_master_ip, - master_port=cluster_info.cluster_master_port, - client_group_name="gigl_client_rpc", - ) def shutdown_compute_proccess() -> None: @@ -70,5 +72,5 @@ def shutdown_compute_proccess() -> None: Args: None """ - glt.distributed.shutdown_client() + shutdown_client() torch.distributed.destroy_process_group() diff --git a/python/gigl/distributed/graph_store/storage_main.py b/python/gigl/distributed/graph_store/storage_main.py index 0cfdef957..222613af8 100644 --- a/python/gigl/distributed/graph_store/storage_main.py +++ b/python/gigl/distributed/graph_store/storage_main.py @@ -7,8 +7,11 @@ import os from typing import Optional -import graphlearn_torch as glt import torch +from graphlearn_torch.distributed.dist_server import ( + init_server, + wait_and_shutdown_server, +) from gigl.common import Uri, UriFactory from gigl.common.logger import Logger @@ -30,28 +33,34 @@ def _run_storage_process( storage_world_backend: Optional[str], ) -> None: register_dataset(dataset) + cluster_master_ip = cluster_info.storage_cluster_master_ip logger.info( - f"Initializing storage node {storage_rank} / {cluster_info.num_storage_nodes} with backend {storage_world_backend} on {cluster_info.cluster_master_ip}:{torch_process_port}" + f"Initializing GLT server for storage node process group {storage_rank} / {cluster_info.num_storage_nodes} on {cluster_master_ip}:{cluster_info.rpc_master_port}" ) - torch.distributed.init_process_group( - backend=storage_world_backend, - world_size=cluster_info.num_storage_nodes, - rank=storage_rank, - init_method=f"tcp://{cluster_info.cluster_master_ip}:{torch_process_port}", - ) - glt.distributed.init_server( + init_server( num_servers=cluster_info.num_storage_nodes, server_rank=storage_rank, dataset=dataset, - master_addr=cluster_info.cluster_master_ip, - master_port=cluster_info.cluster_master_port, + master_addr=cluster_master_ip, + master_port=cluster_info.rpc_master_port, num_clients=cluster_info.compute_cluster_world_size, ) + init_method = f"tcp://{cluster_info.storage_cluster_master_ip}:{torch_process_port}" + logger.info( + f"Initializing storage node process group {storage_rank} / {cluster_info.num_storage_nodes} with backend {storage_world_backend} on {init_method}" + ) + torch.distributed.init_process_group( + backend=storage_world_backend, + world_size=cluster_info.num_storage_nodes, + rank=storage_rank, + init_method=init_method, + ) + logger.info( f"Waiting for storage node {storage_rank} / {cluster_info.num_storage_nodes} to exit" ) - glt.distributed.wait_and_shutdown_server() + wait_and_shutdown_server() logger.info(f"Storage node {storage_rank} exited") @@ -59,7 +68,7 @@ def storage_node_process( storage_rank: int, cluster_info: GraphStoreInfo, task_config_uri: Uri, - is_inference: bool, + is_inference: bool = True, tf_record_uri_pattern: str = ".*-of-.*\.tfrecord(\.gz)?$", storage_world_backend: Optional[str] = None, ) -> None: @@ -71,7 +80,7 @@ def storage_node_process( storage_rank (int): The rank of the storage node. cluster_info (GraphStoreInfo): The cluster information. task_config_uri (Uri): The task config URI. - is_inference (bool): Whether the process is an inference process. + is_inference (bool): Whether the process is an inference process. Defaults to True. tf_record_uri_pattern (str): The TF Record URI pattern. storage_world_backend (Optional[str]): The backend for the storage Torch Distributed process group. """ @@ -95,6 +104,7 @@ def storage_node_process( _tfrecord_uri_pattern=tf_record_uri_pattern, ) torch_process_port = get_free_ports_from_master_node(num_ports=1)[0] + torch.distributed.destroy_process_group() server_processes = [] mp_context = torch.multiprocessing.get_context("spawn") # TODO(kmonte): Enable more than one server process per machine @@ -120,18 +130,21 @@ def storage_node_process( parser = argparse.ArgumentParser() parser.add_argument("--task_config_uri", type=str, required=True) parser.add_argument("--resource_config_uri", type=str, required=True) - parser.add_argument("--is_inference", action="store_true") + parser.add_argument("--job_name", type=str, required=True) args = parser.parse_args() logger.info(f"Running storage node with arguments: {args}") is_inference = args.is_inference - torch.distributed.init_process_group() + torch.distributed.init_process_group(backend="gloo") cluster_info = get_graph_store_info() + logger.info(f"Cluster info: {cluster_info}") + logger.info( + f"World size: {torch.distributed.get_world_size()}, rank: {torch.distributed.get_rank()}, OS world size: {os.environ['WORLD_SIZE']}, OS rank: {os.environ['RANK']}" + ) # Tear down the """"global""" process group so we can have a server-specific process group. torch.distributed.destroy_process_group() storage_node_process( storage_rank=cluster_info.storage_node_rank, cluster_info=cluster_info, task_config_uri=UriFactory.create_uri(args.task_config_uri), - is_inference=is_inference, ) diff --git a/python/gigl/distributed/utils/networking.py b/python/gigl/distributed/utils/networking.py index 7d2ba46b9..e9d33921c 100644 --- a/python/gigl/distributed/utils/networking.py +++ b/python/gigl/distributed/utils/networking.py @@ -115,6 +115,13 @@ def get_internal_ip_from_master_node( ) -> str: """ Get the internal IP address of the master node in a distributed setup. + + Args: + _global_rank_override (Optional[int]): Override for the global rank, + useful for testing or if global rank is not accurately available. + + Returns: + str: The internal IP address of the master node. """ return get_internal_ip_from_node( node_rank=0, _global_rank_override=_global_rank_override @@ -131,6 +138,12 @@ def get_internal_ip_from_node( i.e. when using :py:obj:`gigl.distributed.dataset_factory` + Args: + node_rank (int): Rank of the node, to fetch the internal IP address of. + device (Optional[torch.device]): Device to use for communication. Defaults to None, which will use the default device. + _global_rank_override (Optional[int]): Override for the global rank, + useful for testing or if global rank is not accurately available. + Returns: str: The internal IP address of the node. """ @@ -155,7 +168,8 @@ def get_internal_ip_from_node( # Other nodes will receive the master's IP via broadcast ip_list = [None] - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") + torch.distributed.broadcast_object_list(ip_list, src=node_rank, device=device) node_ip = ip_list[0] logger.info(f"Rank {rank} received master node's internal IP: {node_ip}") @@ -230,12 +244,15 @@ def get_graph_store_info() -> GraphStoreInfo: compute_cluster_master_ip = cluster_master_ip storage_cluster_master_ip = get_internal_ip_from_node(node_rank=num_compute_nodes) - cluster_master_port, compute_cluster_master_port = get_free_ports_from_node( - num_ports=2, node_rank=0 - ) - storage_cluster_master_port = get_free_ports_from_node( - num_ports=1, node_rank=num_compute_nodes - )[0] + ( + cluster_master_port, + compute_cluster_master_port, + ) = get_free_ports_from_node(num_ports=2, node_rank=0) + ( + storage_cluster_master_port, + storage_rpc_port, + storage_rpc_wait_port, + ) = get_free_ports_from_node(num_ports=3, node_rank=num_compute_nodes) num_processes_per_compute = int( os.environ.get(COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY, "1") @@ -251,6 +268,8 @@ def get_graph_store_info() -> GraphStoreInfo: cluster_master_port=cluster_master_port, storage_cluster_master_port=storage_cluster_master_port, compute_cluster_master_port=compute_cluster_master_port, + rpc_master_port=storage_rpc_port, + rpc_wait_port=storage_rpc_wait_port, ) diff --git a/python/gigl/env/distributed.py b/python/gigl/env/distributed.py index 3c3bc6465..990569059 100644 --- a/python/gigl/env/distributed.py +++ b/python/gigl/env/distributed.py @@ -55,6 +55,13 @@ class GraphStoreInfo: # https://snapchat.github.io/GiGL/docs/api/snapchat/research/gbml/gigl_resource_config_pb2/index.html#snapchat.research.gbml.gigl_resource_config_pb2.VertexAiGraphStoreConfig num_processes_per_compute: int + # Port of the master node for the RPC communication. + # NOTE: This should be on the *storage* master node, not the compute master node. + rpc_master_port: int + # Port of the master node for the RPC wait communication. + # NOTE: This should be on the *storage* master node, not the compute master node. + rpc_wait_port: int + @property def num_cluster_nodes(self) -> int: return self.num_storage_nodes + self.num_compute_nodes diff --git a/python/tests/integration/distributed/graph_store/graph_store_integration_test.py b/python/tests/integration/distributed/graph_store/graph_store_integration_test.py index c267b7fed..5a1cc8369 100644 --- a/python/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/python/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -1,13 +1,16 @@ import collections import os +import socket import unittest from unittest import mock import torch import torch.multiprocessing as mp +from torch_geometric.data import Data from gigl.common import Uri from gigl.common.logger import Logger +from gigl.distributed.distributed_neighborloader import DistNeighborLoader from gigl.distributed.graph_store.compute import ( init_compute_process, shutdown_compute_proccess, @@ -103,7 +106,32 @@ def _run_client_process( mp_sharing_dict=None, ).get_node_ids() _assert_sampler_input(cluster_info, simple_sampler_input, expected_sampler_input) + torch.distributed.barrier() + # Test the DistNeighborLoader + loader = DistNeighborLoader( + dataset=remote_dist_dataset, + num_neighbors=[2, 2], + pin_memory_device=torch.device("cpu"), + input_nodes=sampler_input, + num_workers=2, + worker_concurrency=2, + ) + count = 0 + for datum in loader: + assert isinstance(datum, Data) + count += 1 + torch.distributed.barrier() + logger.info(f"Rank {torch.distributed.get_rank()} loaded {count} batches") + # Verify that we sampled all nodes. + count_tensor = torch.tensor(count, dtype=torch.int64) + all_node_count = 0 + for rank_expected_sampler_input in expected_sampler_input.values(): + all_node_count += sum(len(nodes) for nodes in rank_expected_sampler_input) + torch.distributed.all_reduce(count_tensor, op=torch.distributed.ReduceOp.SUM) + assert ( + count_tensor.item() == all_node_count + ), f"Expected {all_node_count} total nodes, got {count_tensor.item()}" shutdown_compute_proccess() @@ -176,6 +204,7 @@ def _get_expected_input_nodes_by_rank( ], } + Args: num_nodes (int): The number of nodes in the graph. cluster_info (GraphStoreInfo): The cluster information. @@ -212,17 +241,22 @@ def test_graph_store_locally(self): storage_cluster_master_port, compute_cluster_master_port, master_port, - ) = get_free_ports(num_ports=4) + rpc_master_port, + rpc_wait_port, + ) = get_free_ports(num_ports=6) + host_ip = socket.gethostbyname(socket.gethostname()) cluster_info = GraphStoreInfo( num_storage_nodes=2, num_compute_nodes=2, num_processes_per_compute=2, - cluster_master_ip="localhost", - storage_cluster_master_ip="localhost", - compute_cluster_master_ip="localhost", + cluster_master_ip=host_ip, + storage_cluster_master_ip=host_ip, + compute_cluster_master_ip=host_ip, cluster_master_port=cluster_master_port, storage_cluster_master_port=storage_cluster_master_port, compute_cluster_master_port=compute_cluster_master_port, + rpc_master_port=rpc_master_port, + rpc_wait_port=rpc_wait_port, ) num_cora_nodes = 2708 @@ -236,7 +270,7 @@ def test_graph_store_locally(self): with mock.patch.dict( os.environ, { - "MASTER_ADDR": "localhost", + "MASTER_ADDR": host_ip, "MASTER_PORT": str(master_port), "RANK": str(i), "WORLD_SIZE": str(cluster_info.compute_cluster_world_size), @@ -262,7 +296,7 @@ def test_graph_store_locally(self): with mock.patch.dict( os.environ, { - "MASTER_ADDR": "localhost", + "MASTER_ADDR": host_ip, "MASTER_PORT": str(master_port), "RANK": str(i + cluster_info.num_compute_nodes), "WORLD_SIZE": str(cluster_info.compute_cluster_world_size), diff --git a/python/tests/unit/env/distributed_test.py b/python/tests/unit/env/distributed_test.py index 793ffdc42..d1ba90391 100644 --- a/python/tests/unit/env/distributed_test.py +++ b/python/tests/unit/env/distributed_test.py @@ -27,6 +27,8 @@ def setUp(self) -> None: cluster_master_port=1234, storage_cluster_master_port=1235, compute_cluster_master_port=1236, + rpc_master_port=1237, + rpc_wait_port=1238, ) def test_num_cluster_nodes(self): From fa0bd9adb269abb4ca483533882f33207d0d55ff Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Sat, 10 Jan 2026 00:19:32 +0000 Subject: [PATCH 2/5] cleanup --- .../distributed/distributed_neighborloader.py | 22 +++++-------------- .../gigl/distributed/graph_store/compute.py | 8 ++++--- .../distributed/graph_store/storage_main.py | 11 +++++----- .../gigl/distributed/utils/neighborloader.py | 19 +++++++++++++++- 4 files changed, 34 insertions(+), 26 deletions(-) diff --git a/python/gigl/distributed/distributed_neighborloader.py b/python/gigl/distributed/distributed_neighborloader.py index fd094d9ad..bcf4c2d5e 100644 --- a/python/gigl/distributed/distributed_neighborloader.py +++ b/python/gigl/distributed/distributed_neighborloader.py @@ -1,6 +1,5 @@ from collections import Counter, abc -from dataclasses import dataclass -from typing import Literal, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch from graphlearn_torch.channel import SampleMessage @@ -20,6 +19,7 @@ from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.utils.neighborloader import ( + DatasetMetadata, labeled_to_homogeneous, patch_fanout_for_sampling, set_missing_features, @@ -32,7 +32,6 @@ from gigl.types.graph import ( DEFAULT_HOMOGENEOUS_EDGE_TYPE, DEFAULT_HOMOGENEOUS_NODE_TYPE, - FeatureInfo, ) logger = Logger() @@ -41,15 +40,6 @@ DEFAULT_NUM_CPU_THREADS = 2 -# Shared metadata between the local and remote datasets. -@dataclass(frozen=True) -class _DatasetMetadata: - is_labeled_heterogeneous: bool - node_feature_info: Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]] - edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]] - edge_dir: Union[str, Literal["in", "out"]] - - class DistNeighborLoader(DistLoader): def __init__( self, @@ -312,7 +302,7 @@ def _setup_for_graph_store( ], dataset: RemoteDistDataset, num_workers: int, - ) -> tuple[NodeSamplerInput, RemoteDistSamplingWorkerOptions, _DatasetMetadata]: + ) -> tuple[NodeSamplerInput, RemoteDistSamplingWorkerOptions, DatasetMetadata]: if input_nodes is None: raise ValueError( f"When using Graph Store mode, input nodes must be provided, received {input_nodes}" @@ -394,7 +384,7 @@ def _setup_for_graph_store( return ( input_data, worker_options, - _DatasetMetadata( + DatasetMetadata( is_labeled_heterogeneous=is_labeled_heterogeneous, node_feature_info=node_feature_info, edge_feature_info=edge_feature_info, @@ -424,7 +414,7 @@ def _setup_for_colocated( worker_concurrency: int, channel_size: str, num_cpu_threads: Optional[int], - ) -> tuple[NodeSamplerInput, MpDistSamplingWorkerOptions, _DatasetMetadata]: + ) -> tuple[NodeSamplerInput, MpDistSamplingWorkerOptions, DatasetMetadata]: if input_nodes is None: if dataset.node_ids is None: raise ValueError( @@ -540,7 +530,7 @@ def _setup_for_colocated( return ( input_data, worker_options, - _DatasetMetadata( + DatasetMetadata( is_labeled_heterogeneous=is_labeled_heterogeneous, node_feature_info=dataset.node_feature_info, edge_feature_info=dataset.edge_feature_info, diff --git a/python/gigl/distributed/graph_store/compute.py b/python/gigl/distributed/graph_store/compute.py index 6039eaef9..cd556d941 100644 --- a/python/gigl/distributed/graph_store/compute.py +++ b/python/gigl/distributed/graph_store/compute.py @@ -1,8 +1,8 @@ import os from typing import Optional +import graphlearn_torch as glt import torch -from graphlearn_torch.distributed.dist_client import init_client, shutdown_client from gigl.common.logger import Logger from gigl.env.distributed import GraphStoreInfo @@ -42,7 +42,9 @@ def init_compute_process( f" OS rank: {os.environ['RANK']}, local compute rank: {local_rank}" f" num_servers: {cluster_info.num_storage_nodes}, num_clients: {cluster_info.compute_cluster_world_size}" ) - init_client( + # Initialize the GLT client before starting the Torch Distributed process group. + # Otherwise, we saw intermittent hangs when initializing the client. + glt.distributed.init_client( num_servers=cluster_info.num_storage_nodes, num_clients=cluster_info.compute_cluster_world_size, client_rank=compute_cluster_rank, @@ -72,5 +74,5 @@ def shutdown_compute_proccess() -> None: Args: None """ - shutdown_client() + glt.distributed.shutdown_client() torch.distributed.destroy_process_group() diff --git a/python/gigl/distributed/graph_store/storage_main.py b/python/gigl/distributed/graph_store/storage_main.py index 222613af8..3dc8040a5 100644 --- a/python/gigl/distributed/graph_store/storage_main.py +++ b/python/gigl/distributed/graph_store/storage_main.py @@ -7,11 +7,8 @@ import os from typing import Optional +import graphlearn_torch as glt import torch -from graphlearn_torch.distributed.dist_server import ( - init_server, - wait_and_shutdown_server, -) from gigl.common import Uri, UriFactory from gigl.common.logger import Logger @@ -37,7 +34,9 @@ def _run_storage_process( logger.info( f"Initializing GLT server for storage node process group {storage_rank} / {cluster_info.num_storage_nodes} on {cluster_master_ip}:{cluster_info.rpc_master_port}" ) - init_server( + # Initialize the GLT server before starting the Torch Distributed process group. + # Otherwise, we saw intermittent hangs when initializing the server. + glt.distributed.init_server( num_servers=cluster_info.num_storage_nodes, server_rank=storage_rank, dataset=dataset, @@ -60,7 +59,7 @@ def _run_storage_process( logger.info( f"Waiting for storage node {storage_rank} / {cluster_info.num_storage_nodes} to exit" ) - wait_and_shutdown_server() + glt.distributed.wait_and_shutdown_server() logger.info(f"Storage node {storage_rank} exited") diff --git a/python/gigl/distributed/utils/neighborloader.py b/python/gigl/distributed/utils/neighborloader.py index 29a09d71a..9aa587cfd 100644 --- a/python/gigl/distributed/utils/neighborloader.py +++ b/python/gigl/distributed/utils/neighborloader.py @@ -1,7 +1,8 @@ """Utils for Neighbor loaders.""" from collections import abc from copy import deepcopy -from typing import Optional, TypeVar, Union +from dataclasses import dataclass +from typing import Literal, Optional, TypeVar, Union import torch from torch_geometric.data import Data, HeteroData @@ -15,6 +16,22 @@ _GraphType = TypeVar("_GraphType", Data, HeteroData) +@dataclass(frozen=True) +class DatasetMetadata: + """ + Shared metadata between the local and remote datasets. + """ + + # If the dataset is labeled heterogeneous. E.g. one node type, one edge type, and "label" edges. + is_labeled_heterogeneous: bool + # Node feature info. + node_feature_info: Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]] + # Edge feature info. + edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]] + # Edge direction. + edge_dir: Union[str, Literal["in", "out"]] + + def patch_fanout_for_sampling( edge_types: Optional[list[EdgeType]], num_neighbors: Union[list[int], dict[EdgeType, list[int]]], From 27a0e375d008ccb4ecbcb4672f7437aa61bb9ef8 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 13 Jan 2026 23:40:54 +0000 Subject: [PATCH 3/5] address comments --- .../distributed/distributed_neighborloader.py | 90 +++++++++++++++---- .../gigl/distributed/utils/neighborloader.py | 12 ++- python/gigl/distributed/utils/networking.py | 17 ++-- 3 files changed, 97 insertions(+), 22 deletions(-) diff --git a/python/gigl/distributed/distributed_neighborloader.py b/python/gigl/distributed/distributed_neighborloader.py index bcf4c2d5e..f633e65f5 100644 --- a/python/gigl/distributed/distributed_neighborloader.py +++ b/python/gigl/distributed/distributed_neighborloader.py @@ -19,7 +19,8 @@ from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.utils.neighborloader import ( - DatasetMetadata, + DatasetSchema, + SamplingClusterSetup, labeled_to_homogeneous, patch_fanout_for_sampling, set_missing_features, @@ -73,7 +74,8 @@ def __init__( https://pytorch-geometric.readthedocs.io/en/2.5.2/_modules/torch_geometric/distributed/dist_neighbor_loader.html#DistNeighborLoader Args: - dataset (DistDataset | RemoteDistDataset): The dataset to sample from. Must be a "RemoteDistDataset" if using Graph Store mode. + dataset (DistDataset | RemoteDistDataset): The dataset to sample from. + If this is a `RemoteDistDataset`, then we assumed to be in "Graph Store" mode. num_neighbors (list[int] or dict[Tuple[str, str, str], list[int]]): The number of neighbors to sample for each node in each iteration. If an entry is set to `-1`, all neighbors will be included. @@ -82,7 +84,7 @@ def __init__( context (deprecated - will be removed soon) (DistributedContext): Distributed context information of the current process. local_process_rank (deprecated - will be removed soon) (int): Required if context provided. The local rank of the current process within a node. local_process_world_size (deprecated - will be removed soon)(int): Required if context provided. The total number of processes within a node. - input_nodes (Tenor | Tuple[NodeType, Tenor] | list[Tenor] | Tuple[NodeType, list[Tenor]]): + input_nodes (Tensor | Tuple[NodeType, Tensor] | list[Tensor] | Tuple[NodeType, list[Tensor]]): The nodes to start sampling from. It is of type `torch.LongTensor` for homogeneous graphs. If set to `None` for homogeneous settings, all nodes will be considered. @@ -91,6 +93,8 @@ def __init__( For Graph Store mode, this must be a tuple of (NodeType, list[Tenor]) or list[Tenor]. Where each Tensor in the list is the node ids to sample from, for each server. e.g. [[10, 20], [30, 40]] means sample from nodes 10 and 20 on server 0, and nodes 30 and 40 on server 1. + If a Graph Store input (e.g. list[Tensor]) is provided to colocated mode, or colocated input (e.g. Tensor) is provided to Graph Store mode, + then an error will be raised. num_workers (int): How many workers to use (subprocesses to spwan) for distributed neighbor sampling of the current process. (default: ``1``). batch_size (int, optional): how many samples per batch to load @@ -194,7 +198,11 @@ def __init__( local_process_rank, local_process_world_size, ) # delete deprecated vars so we don't accidentally use them. - + if isinstance(dataset, RemoteDistDataset): + sampling_cluster_setup = SamplingClusterSetup.GRAPH_STORE + else: + sampling_cluster_setup = SamplingClusterSetup.COLOCATED + logger.info(f"Sampling cluster setup: {sampling_cluster_setup.value}") device = ( pin_memory_device if pin_memory_device @@ -204,8 +212,10 @@ def __init__( ) # Determines if the node ids passed in are heterogeneous or homogeneous. - self._is_labeled_heterogeneous = False - if isinstance(dataset, DistDataset): + if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: + assert isinstance( + dataset, DistDataset + ), "When using colocated mode, dataset must be a DistDataset." input_data, worker_options, dataset_metadata = self._setup_for_colocated( input_nodes, dataset, @@ -221,7 +231,10 @@ def __init__( channel_size, num_cpu_threads, ) - else: # RemoteDistDataset + else: # Graph Store mode + assert isinstance( + dataset, RemoteDistDataset + ), "When using Graph Store mode, dataset must be a RemoteDistDataset." input_data, worker_options, dataset_metadata = self._setup_for_graph_store( input_nodes, dataset, @@ -233,10 +246,10 @@ def __init__( self._edge_feature_info = dataset_metadata.edge_feature_info num_neighbors = patch_fanout_for_sampling( - list(dataset_metadata.edge_feature_info.keys()) + edge_types=list(dataset_metadata.edge_feature_info.keys()) if isinstance(dataset_metadata.edge_feature_info, dict) else None, - num_neighbors, + num_neighbors=num_neighbors, ) sampling_config = SamplingConfig( @@ -259,7 +272,7 @@ def __init__( ) torch.distributed.destroy_process_group() - if isinstance(dataset, DistDataset): + if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: super().__init__( dataset if isinstance(dataset, DistDataset) else None, input_data, @@ -271,11 +284,47 @@ def __init__( # For Graph Store mode, we need to start the communcation between compute and storage nodes sequentially, by compute node. # E.g. intialize connections between compute node 0 and storage nodes 0, 1, 2, 3, then compute node 1 and storage nodes 0, 1, 2, 3, etc. # Note that each compute node may have multiple connections to each storage node, once per compute process. - # E.g. if there are 4 gpus per compute node, then there will be 4 connections to each storage node. + # It's important to distinguish "compute node" (e.g. physical compute machine) from "compute process" (e.g. process running on the compute node). + # Since in practice we have multiple compute processes per compute node, and each compute process needs to initialize the connection to the storage nodes. + # E.g. if there are 4 gpus per compute node, then there will be 4 connections from each compute node to each storage node. # We need to this because if we don't, then there is a race condition when initalizing the samplers on the storage nodes [1] # Where since the lock is per *server* (e.g. per storage node), if we try to start one connection from compute node 0, and compute node 1 # Then we deadlock and fail. + # Specifically, the race condition happens in `DistLoader.__init__` when it initializes the sampling producers on the storage nodes. [2] # [1]: https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/distributed/dist_server.py#L129-L167 + # [2]: https://github.com/alibaba/graphlearn-for-pytorch/blob/88ff111ac0d9e45c6c9d2d18cfc5883dca07e9f9/graphlearn_torch/python/distributed/dist_loader.py#L187-L193 + + # See below for a connection setup. + # ╔═══════════════════════════════════════════════════════════════════════════════════════╗ + # ║ COMPUTE TO STORAGE NODE CONNECTIONS ║ + # ╚═══════════════════════════════════════════════════════════════════════════════════════╝ + + # COMPUTE NODES STORAGE NODES + # ═════════════ ═════════════ + + # ┌──────────────────────┐ (1) ┌───────────────┐ + # │ COMPUTE NODE 0 │ │ │ + # │ ┌────┬────┬────┬────┤ ══════════════════════════════════│ STORAGE 0 │ + # │ │GPU │GPU │GPU │GPU │ ╱ │ │ + # │ │ 0 │ 1 │ 2 │ 3 │ ════════════════════╲ ╱ └───────────────┘ + # │ └────┴────┴────┴────┤ (2) ╲ ╱ + # └──────────────────────┘ ╲ ╱ + # ╳ + # (3) ╱ ╲ (4) + # ┌──────────────────────┐ ╱ ╲ ┌───────────────┐ + # │ COMPUTE NODE 1 │ ╱ ╲ │ │ + # │ ┌────┬────┬────┬────┤ ═════════════════╱ ═│ STORAGE 1 │ + # │ │GPU │GPU │GPU │GPU │ │ │ + # │ │ 0 │ 1 │ 2 │ 3 │ ══════════════════════════════════│ │ + # │ └────┴────┴────┴────┤ └───────────────┘ + # └──────────────────────┘ + + # ┌─────────────────────────────────────────────────────────────────────────────┐ + # │ (1) Compute Node 0 → Storage 0 (4 connections, one per GPU) │ + # │ (2) Compute Node 0 → Storage 1 (4 connections, one per GPU) │ + # │ (3) Compute Node 1 → Storage 0 (4 connections, one per GPU) │ + # │ (4) Compute Node 1 → Storage 1 (4 connections, one per GPU) │ + # └─────────────────────────────────────────────────────────────────────────────┘ node_rank = dataset.cluster_info.compute_node_rank for target_node_rank in range(dataset.cluster_info.num_compute_nodes): if node_rank == target_node_rank: @@ -302,7 +351,7 @@ def _setup_for_graph_store( ], dataset: RemoteDistDataset, num_workers: int, - ) -> tuple[NodeSamplerInput, RemoteDistSamplingWorkerOptions, DatasetMetadata]: + ) -> tuple[NodeSamplerInput, RemoteDistSamplingWorkerOptions, DatasetSchema]: if input_nodes is None: raise ValueError( f"When using Graph Store mode, input nodes must be provided, received {input_nodes}" @@ -367,7 +416,7 @@ def _setup_for_graph_store( len(edge_feature_info) == 1 and DEFAULT_HOMOGENEOUS_EDGE_TYPE in edge_feature_info ): - input_type: NodeType | None = DEFAULT_HOMOGENEOUS_NODE_TYPE + input_type: Optional[NodeType] = DEFAULT_HOMOGENEOUS_NODE_TYPE else: input_type = fallback_input_type elif require_edge_feature_info: @@ -377,6 +426,15 @@ def _setup_for_graph_store( else: input_type = None + if ( + input_type is not None + and isinstance(node_feature_info, dict) + and input_type not in node_feature_info.keys() + ): + raise ValueError( + f"Input type {input_type} is not in node node types: {node_feature_info.keys()}" + ) + input_data = [ NodeSamplerInput(node=node, input_type=input_type) for node in nodes ] @@ -384,7 +442,7 @@ def _setup_for_graph_store( return ( input_data, worker_options, - DatasetMetadata( + DatasetSchema( is_labeled_heterogeneous=is_labeled_heterogeneous, node_feature_info=node_feature_info, edge_feature_info=edge_feature_info, @@ -414,7 +472,7 @@ def _setup_for_colocated( worker_concurrency: int, channel_size: str, num_cpu_threads: Optional[int], - ) -> tuple[NodeSamplerInput, MpDistSamplingWorkerOptions, DatasetMetadata]: + ) -> tuple[NodeSamplerInput, MpDistSamplingWorkerOptions, DatasetSchema]: if input_nodes is None: if dataset.node_ids is None: raise ValueError( @@ -530,7 +588,7 @@ def _setup_for_colocated( return ( input_data, worker_options, - DatasetMetadata( + DatasetSchema( is_labeled_heterogeneous=is_labeled_heterogeneous, node_feature_info=dataset.node_feature_info, edge_feature_info=dataset.edge_feature_info, diff --git a/python/gigl/distributed/utils/neighborloader.py b/python/gigl/distributed/utils/neighborloader.py index 9aa587cfd..f7aae1176 100644 --- a/python/gigl/distributed/utils/neighborloader.py +++ b/python/gigl/distributed/utils/neighborloader.py @@ -2,6 +2,7 @@ from collections import abc from copy import deepcopy from dataclasses import dataclass +from enum import Enum from typing import Literal, Optional, TypeVar, Union import torch @@ -16,8 +17,17 @@ _GraphType = TypeVar("_GraphType", Data, HeteroData) +class SamplingClusterSetup(Enum): + """ + The setup of the sampling cluster. + """ + + COLOCATED = "colocated" + GRAPH_STORE = "graph_store" + + @dataclass(frozen=True) -class DatasetMetadata: +class DatasetSchema: """ Shared metadata between the local and remote datasets. """ diff --git a/python/gigl/distributed/utils/networking.py b/python/gigl/distributed/utils/networking.py index e9d33921c..e2e1f320e 100644 --- a/python/gigl/distributed/utils/networking.py +++ b/python/gigl/distributed/utils/networking.py @@ -168,11 +168,12 @@ def get_internal_ip_from_node( # Other nodes will receive the master's IP via broadcast ip_list = [None] - device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") - + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.distributed.broadcast_object_list(ip_list, src=node_rank, device=device) node_ip = ip_list[0] - logger.info(f"Rank {rank} received master node's internal IP: {node_ip}") + logger.info( + f"Rank {rank} received master node's internal IP: {node_ip} on device {device}" + ) assert node_ip is not None, "Could not retrieve master node's internal IP" return node_ip @@ -244,15 +245,21 @@ def get_graph_store_info() -> GraphStoreInfo: compute_cluster_master_ip = cluster_master_ip storage_cluster_master_ip = get_internal_ip_from_node(node_rank=num_compute_nodes) + # Cluster master is by convention rank 0. + cluster_master_rank = 0 ( cluster_master_port, compute_cluster_master_port, - ) = get_free_ports_from_node(num_ports=2, node_rank=0) + ) = get_free_ports_from_node(num_ports=2, node_rank=cluster_master_rank) + + # Since we structure the cluster as [compute0, ..., computeN, storage0, ..., storageN], the storage master is the first storage node. + # And it's rank is the number of compute nodes. + storage_master_rank = num_compute_nodes ( storage_cluster_master_port, storage_rpc_port, storage_rpc_wait_port, - ) = get_free_ports_from_node(num_ports=3, node_rank=num_compute_nodes) + ) = get_free_ports_from_node(num_ports=3, node_rank=storage_master_rank) num_processes_per_compute = int( os.environ.get(COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY, "1") From 922cfd2facb458288cb54689a4d4df74e7082a18 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 14 Jan 2026 17:53:33 +0000 Subject: [PATCH 4/5] actually get edge types --- .../distributed/distributed_neighborloader.py | 22 +++++++++++++------ .../graph_store/remote_dist_dataset.py | 12 ++++++++++ .../distributed/graph_store/storage_utils.py | 14 ++++++++++++ .../gigl/distributed/utils/neighborloader.py | 2 ++ 4 files changed, 43 insertions(+), 7 deletions(-) diff --git a/python/gigl/distributed/distributed_neighborloader.py b/python/gigl/distributed/distributed_neighborloader.py index f633e65f5..1706628ae 100644 --- a/python/gigl/distributed/distributed_neighborloader.py +++ b/python/gigl/distributed/distributed_neighborloader.py @@ -199,10 +199,10 @@ def __init__( local_process_world_size, ) # delete deprecated vars so we don't accidentally use them. if isinstance(dataset, RemoteDistDataset): - sampling_cluster_setup = SamplingClusterSetup.GRAPH_STORE + self._sampling_cluster_setup = SamplingClusterSetup.GRAPH_STORE else: - sampling_cluster_setup = SamplingClusterSetup.COLOCATED - logger.info(f"Sampling cluster setup: {sampling_cluster_setup.value}") + self._sampling_cluster_setup = SamplingClusterSetup.COLOCATED + logger.info(f"Sampling cluster setup: {self._sampling_cluster_setup.value}") device = ( pin_memory_device if pin_memory_device @@ -245,13 +245,14 @@ def __init__( self._node_feature_info = dataset_metadata.node_feature_info self._edge_feature_info = dataset_metadata.edge_feature_info + logger.info(f"num_neighbors before patch: {num_neighbors}") num_neighbors = patch_fanout_for_sampling( - edge_types=list(dataset_metadata.edge_feature_info.keys()) - if isinstance(dataset_metadata.edge_feature_info, dict) - else None, + edge_types=dataset_metadata.edge_types, num_neighbors=num_neighbors, ) - + logger.info( + f"num_neighbors: {num_neighbors}, edge_types: {dataset_metadata.edge_types}" + ) sampling_config = SamplingConfig( sampling_type=SamplingType.NODE, num_neighbors=num_neighbors, @@ -444,6 +445,7 @@ def _setup_for_graph_store( worker_options, DatasetSchema( is_labeled_heterogeneous=is_labeled_heterogeneous, + edge_types=dataset.get_edge_types(), node_feature_info=node_feature_info, edge_feature_info=edge_feature_info, edge_dir=dataset.get_edge_dir(), @@ -585,11 +587,17 @@ def _setup_for_colocated( pin_memory=device.type == "cuda", ) + if isinstance(dataset.graph, dict): + edge_types = list(dataset.graph.keys()) + else: + edge_types = None + return ( input_data, worker_options, DatasetSchema( is_labeled_heterogeneous=is_labeled_heterogeneous, + edge_types=edge_types, node_feature_info=dataset.node_feature_info, edge_feature_info=dataset.edge_feature_info, edge_dir=dataset.edge_dir, diff --git a/python/gigl/distributed/graph_store/remote_dist_dataset.py b/python/gigl/distributed/graph_store/remote_dist_dataset.py index 18f4177bc..4a534a864 100644 --- a/python/gigl/distributed/graph_store/remote_dist_dataset.py +++ b/python/gigl/distributed/graph_store/remote_dist_dataset.py @@ -9,6 +9,7 @@ from gigl.distributed.graph_store.storage_utils import ( get_edge_dir, get_edge_feature_info, + get_edge_types, get_node_feature_info, get_node_ids_for_rank, ) @@ -213,3 +214,14 @@ def get_free_ports_on_storage_cluster(self, num_ports: int) -> list[int]: torch.distributed.broadcast_object_list(ports, src=0) logger.info(f"Compute rank {compute_cluster_rank} received free ports: {ports}") return ports + + def get_edge_types(self) -> Optional[list[EdgeType]]: + """Get the edge types from the registered dataset. + + Returns: + The edge types. + """ + return request_server( + 0, + get_edge_types, + ) diff --git a/python/gigl/distributed/graph_store/storage_utils.py b/python/gigl/distributed/graph_store/storage_utils.py index ef0b055ec..331eebd7c 100644 --- a/python/gigl/distributed/graph_store/storage_utils.py +++ b/python/gigl/distributed/graph_store/storage_utils.py @@ -141,3 +141,17 @@ def get_node_ids_for_rank( f"Node ids must be a torch.Tensor or a dict[NodeType, torch.Tensor], got {type(_dataset.node_ids)}" ) return shard_nodes_by_process(nodes, rank, world_size) + + +def get_edge_types() -> Optional[list[EdgeType]]: + """Get the edge types from the registered dataset. + + Returns: + The edge types. + """ + if _dataset is None: + raise _NO_DATASET_ERROR + if isinstance(_dataset.graph, dict): + return list(_dataset.graph.keys()) + else: + return None diff --git a/python/gigl/distributed/utils/neighborloader.py b/python/gigl/distributed/utils/neighborloader.py index f7aae1176..54b687c98 100644 --- a/python/gigl/distributed/utils/neighborloader.py +++ b/python/gigl/distributed/utils/neighborloader.py @@ -34,6 +34,8 @@ class DatasetSchema: # If the dataset is labeled heterogeneous. E.g. one node type, one edge type, and "label" edges. is_labeled_heterogeneous: bool + # List of all edge types in the graph. + edge_types: Optional[list[EdgeType]] # Node feature info. node_feature_info: Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]] # Edge feature info. From 94b521f242bc7b3353365c8f4f1568f7f38c9f54 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 15 Jan 2026 21:10:04 +0000 Subject: [PATCH 5/5] address comments --- .../distributed/distributed_neighborloader.py | 29 +++++-------------- 1 file changed, 7 insertions(+), 22 deletions(-) diff --git a/python/gigl/distributed/distributed_neighborloader.py b/python/gigl/distributed/distributed_neighborloader.py index 1706628ae..1125f1f9f 100644 --- a/python/gigl/distributed/distributed_neighborloader.py +++ b/python/gigl/distributed/distributed_neighborloader.py @@ -275,7 +275,7 @@ def __init__( if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: super().__init__( - dataset if isinstance(dataset, DistDataset) else None, + dataset, # Pass in the dataset for colocated mode. input_data, sampling_config, device, @@ -330,7 +330,7 @@ def __init__( for target_node_rank in range(dataset.cluster_info.num_compute_nodes): if node_rank == target_node_rank: super().__init__( - dataset if isinstance(dataset, DistDataset) else None, + None, # Pass in None for Graph Store mode. input_data, sampling_config, device, @@ -371,6 +371,7 @@ def _setup_for_graph_store( is_labeled_heterogeneous = False node_feature_info = dataset.get_node_feature_info() edge_feature_info = dataset.get_edge_feature_info() + edge_types = dataset.get_edge_types() node_rank = dataset.cluster_info.compute_node_rank # Get sampling ports for compute-storage connections. @@ -408,34 +409,18 @@ def _setup_for_graph_store( ) # Determine input_type based on edge_feature_info - if isinstance(edge_feature_info, dict): - if len(edge_feature_info) == 0: - raise ValueError( - "When using Graph Store mode, edge feature info must be provided for heterogeneous graphs." - ) - elif ( - len(edge_feature_info) == 1 - and DEFAULT_HOMOGENEOUS_EDGE_TYPE in edge_feature_info - ): + if isinstance(edge_types, list): + if edge_types == [DEFAULT_HOMOGENEOUS_EDGE_TYPE]: input_type: Optional[NodeType] = DEFAULT_HOMOGENEOUS_NODE_TYPE else: input_type = fallback_input_type elif require_edge_feature_info: raise ValueError( - "When using Graph Store mode, edge feature info must be provided for heterogeneous graphs." + "When using Graph Store mode, edge types must be provided for heterogeneous graphs." ) else: input_type = None - if ( - input_type is not None - and isinstance(node_feature_info, dict) - and input_type not in node_feature_info.keys() - ): - raise ValueError( - f"Input type {input_type} is not in node node types: {node_feature_info.keys()}" - ) - input_data = [ NodeSamplerInput(node=node, input_type=input_type) for node in nodes ] @@ -445,7 +430,7 @@ def _setup_for_graph_store( worker_options, DatasetSchema( is_labeled_heterogeneous=is_labeled_heterogeneous, - edge_types=dataset.get_edge_types(), + edge_types=edge_types, node_feature_info=node_feature_info, edge_feature_info=edge_feature_info, edge_dir=dataset.get_edge_dir(),