Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
340 changes: 301 additions & 39 deletions python/gigl/distributed/distributed_neighborloader.py

Large diffs are not rendered by default.

30 changes: 17 additions & 13 deletions python/gigl/distributed/graph_store/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,30 +36,34 @@ def init_compute_process(
cluster_info.compute_node_rank * cluster_info.num_processes_per_compute
+ local_rank
)
cluster_master_ip = cluster_info.storage_cluster_master_ip
logger.info(
f"Initializing compute process group {compute_cluster_rank} / {cluster_info.compute_cluster_world_size}. on {cluster_info.compute_cluster_master_ip}:{cluster_info.compute_cluster_master_port} with backend {compute_world_backend}."
f" OS rank: {os.environ['RANK']}, local client rank: {local_rank}"
)
torch.distributed.init_process_group(
backend=compute_world_backend,
world_size=cluster_info.compute_cluster_world_size,
rank=compute_cluster_rank,
init_method=f"tcp://{cluster_info.compute_cluster_master_ip}:{cluster_info.compute_cluster_master_port}",
)
logger.info(
f"Initializing RPC client for compute node {compute_cluster_rank} / {cluster_info.compute_cluster_world_size} on {cluster_info.cluster_master_ip}:{cluster_info.cluster_master_port}."
f"Initializing RPC client for compute node {compute_cluster_rank} / {cluster_info.compute_cluster_world_size} on {cluster_master_ip}:{cluster_info.rpc_master_port}."
f" OS rank: {os.environ['RANK']}, local compute rank: {local_rank}"
f" num_servers: {cluster_info.num_storage_nodes}, num_clients: {cluster_info.compute_cluster_world_size}"
)
# Initialize the GLT client before starting the Torch Distributed process group.
# Otherwise, we saw intermittent hangs when initializing the client.
glt.distributed.init_client(
num_servers=cluster_info.num_storage_nodes,
num_clients=cluster_info.compute_cluster_world_size,
client_rank=compute_cluster_rank,
master_addr=cluster_info.cluster_master_ip,
master_port=cluster_info.cluster_master_port,
master_addr=cluster_master_ip,
master_port=cluster_info.rpc_master_port,
client_group_name="gigl_client_rpc",
)

logger.info(
f"Initializing compute process group {compute_cluster_rank} / {cluster_info.compute_cluster_world_size}. on {cluster_info.compute_cluster_master_ip}:{cluster_info.compute_cluster_master_port} with backend {compute_world_backend}."
f" OS rank: {os.environ['RANK']}, local client rank: {local_rank}"
)
torch.distributed.init_process_group(
backend=compute_world_backend,
world_size=cluster_info.compute_cluster_world_size,
rank=compute_cluster_rank,
init_method=f"tcp://{cluster_info.compute_cluster_master_ip}:{cluster_info.compute_cluster_master_port}",
)


def shutdown_compute_proccess() -> None:
"""
Expand Down
40 changes: 26 additions & 14 deletions python/gigl/distributed/graph_store/storage_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,32 @@ def _run_storage_process(
storage_world_backend: Optional[str],
) -> None:
register_dataset(dataset)
cluster_master_ip = cluster_info.storage_cluster_master_ip
logger.info(
f"Initializing storage node {storage_rank} / {cluster_info.num_storage_nodes} with backend {storage_world_backend} on {cluster_info.cluster_master_ip}:{torch_process_port}"
)
torch.distributed.init_process_group(
backend=storage_world_backend,
world_size=cluster_info.num_storage_nodes,
rank=storage_rank,
init_method=f"tcp://{cluster_info.cluster_master_ip}:{torch_process_port}",
f"Initializing GLT server for storage node process group {storage_rank} / {cluster_info.num_storage_nodes} on {cluster_master_ip}:{cluster_info.rpc_master_port}"
)
# Initialize the GLT server before starting the Torch Distributed process group.
# Otherwise, we saw intermittent hangs when initializing the server.
glt.distributed.init_server(
num_servers=cluster_info.num_storage_nodes,
server_rank=storage_rank,
dataset=dataset,
master_addr=cluster_info.cluster_master_ip,
master_port=cluster_info.cluster_master_port,
master_addr=cluster_master_ip,
master_port=cluster_info.rpc_master_port,
num_clients=cluster_info.compute_cluster_world_size,
)

init_method = f"tcp://{cluster_info.storage_cluster_master_ip}:{torch_process_port}"
logger.info(
f"Initializing storage node process group {storage_rank} / {cluster_info.num_storage_nodes} with backend {storage_world_backend} on {init_method}"
)
torch.distributed.init_process_group(
backend=storage_world_backend,
world_size=cluster_info.num_storage_nodes,
rank=storage_rank,
init_method=init_method,
)

logger.info(
f"Waiting for storage node {storage_rank} / {cluster_info.num_storage_nodes} to exit"
)
Expand All @@ -59,7 +67,7 @@ def storage_node_process(
storage_rank: int,
cluster_info: GraphStoreInfo,
task_config_uri: Uri,
is_inference: bool,
is_inference: bool = True,
tf_record_uri_pattern: str = ".*-of-.*\.tfrecord(\.gz)?$",
storage_world_backend: Optional[str] = None,
) -> None:
Expand All @@ -71,7 +79,7 @@ def storage_node_process(
storage_rank (int): The rank of the storage node.
cluster_info (GraphStoreInfo): The cluster information.
task_config_uri (Uri): The task config URI.
is_inference (bool): Whether the process is an inference process.
is_inference (bool): Whether the process is an inference process. Defaults to True.
tf_record_uri_pattern (str): The TF Record URI pattern.
storage_world_backend (Optional[str]): The backend for the storage Torch Distributed process group.
"""
Expand All @@ -95,6 +103,7 @@ def storage_node_process(
_tfrecord_uri_pattern=tf_record_uri_pattern,
)
torch_process_port = get_free_ports_from_master_node(num_ports=1)[0]
torch.distributed.destroy_process_group()
server_processes = []
mp_context = torch.multiprocessing.get_context("spawn")
# TODO(kmonte): Enable more than one server process per machine
Expand All @@ -120,18 +129,21 @@ def storage_node_process(
parser = argparse.ArgumentParser()
parser.add_argument("--task_config_uri", type=str, required=True)
parser.add_argument("--resource_config_uri", type=str, required=True)
parser.add_argument("--is_inference", action="store_true")
parser.add_argument("--job_name", type=str, required=True)
args = parser.parse_args()
logger.info(f"Running storage node with arguments: {args}")

is_inference = args.is_inference
torch.distributed.init_process_group()
torch.distributed.init_process_group(backend="gloo")
cluster_info = get_graph_store_info()
logger.info(f"Cluster info: {cluster_info}")
logger.info(
f"World size: {torch.distributed.get_world_size()}, rank: {torch.distributed.get_rank()}, OS world size: {os.environ['WORLD_SIZE']}, OS rank: {os.environ['RANK']}"
)
# Tear down the """"global""" process group so we can have a server-specific process group.
torch.distributed.destroy_process_group()
storage_node_process(
storage_rank=cluster_info.storage_node_rank,
cluster_info=cluster_info,
task_config_uri=UriFactory.create_uri(args.task_config_uri),
is_inference=is_inference,
)
31 changes: 30 additions & 1 deletion python/gigl/distributed/utils/neighborloader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Utils for Neighbor loaders."""
from collections import abc
from copy import deepcopy
from typing import Optional, TypeVar, Union
from dataclasses import dataclass
from enum import Enum
from typing import Literal, Optional, TypeVar, Union

import torch
from torch_geometric.data import Data, HeteroData
Expand All @@ -15,6 +17,33 @@
_GraphType = TypeVar("_GraphType", Data, HeteroData)


class SamplingClusterSetup(Enum):
"""
The setup of the sampling cluster.
"""

COLOCATED = "colocated"
GRAPH_STORE = "graph_store"


@dataclass(frozen=True)
class DatasetSchema:
"""
Shared metadata between the local and remote datasets.
"""

# If the dataset is labeled heterogeneous. E.g. one node type, one edge type, and "label" edges.
is_labeled_heterogeneous: bool
# List of all edge types in the graph.
edge_types: Optional[list[EdgeType]]
# Node feature info.
node_feature_info: Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]]
# Edge feature info.
edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]
# Edge direction.
edge_dir: Union[str, Literal["in", "out"]]


def patch_fanout_for_sampling(
edge_types: Optional[list[EdgeType]],
num_neighbors: Union[list[int], dict[EdgeType, list[int]]],
Expand Down
40 changes: 33 additions & 7 deletions python/gigl/distributed/utils/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ def get_internal_ip_from_master_node(
) -> str:
"""
Get the internal IP address of the master node in a distributed setup.

Args:
_global_rank_override (Optional[int]): Override for the global rank,
useful for testing or if global rank is not accurately available.

Returns:
str: The internal IP address of the master node.
"""
return get_internal_ip_from_node(
node_rank=0, _global_rank_override=_global_rank_override
Expand All @@ -131,6 +138,12 @@ def get_internal_ip_from_node(

i.e. when using :py:obj:`gigl.distributed.dataset_factory`

Args:
node_rank (int): Rank of the node, to fetch the internal IP address of.
device (Optional[torch.device]): Device to use for communication. Defaults to None, which will use the default device.
_global_rank_override (Optional[int]): Override for the global rank,
useful for testing or if global rank is not accurately available.

Returns:
str: The internal IP address of the node.
"""
Expand Down Expand Up @@ -158,7 +171,9 @@ def get_internal_ip_from_node(
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.distributed.broadcast_object_list(ip_list, src=node_rank, device=device)
node_ip = ip_list[0]
logger.info(f"Rank {rank} received master node's internal IP: {node_ip}")
logger.info(
f"Rank {rank} received master node's internal IP: {node_ip} on device {device}"
)
assert node_ip is not None, "Could not retrieve master node's internal IP"
return node_ip

Expand Down Expand Up @@ -230,12 +245,21 @@ def get_graph_store_info() -> GraphStoreInfo:
compute_cluster_master_ip = cluster_master_ip
storage_cluster_master_ip = get_internal_ip_from_node(node_rank=num_compute_nodes)

cluster_master_port, compute_cluster_master_port = get_free_ports_from_node(
num_ports=2, node_rank=0
)
storage_cluster_master_port = get_free_ports_from_node(
num_ports=1, node_rank=num_compute_nodes
)[0]
# Cluster master is by convention rank 0.
cluster_master_rank = 0
(
cluster_master_port,
compute_cluster_master_port,
) = get_free_ports_from_node(num_ports=2, node_rank=cluster_master_rank)

# Since we structure the cluster as [compute0, ..., computeN, storage0, ..., storageN], the storage master is the first storage node.
# And it's rank is the number of compute nodes.
storage_master_rank = num_compute_nodes
(
storage_cluster_master_port,
storage_rpc_port,
storage_rpc_wait_port,
) = get_free_ports_from_node(num_ports=3, node_rank=storage_master_rank)

num_processes_per_compute = int(
os.environ.get(COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY, "1")
Expand All @@ -251,6 +275,8 @@ def get_graph_store_info() -> GraphStoreInfo:
cluster_master_port=cluster_master_port,
storage_cluster_master_port=storage_cluster_master_port,
compute_cluster_master_port=compute_cluster_master_port,
rpc_master_port=storage_rpc_port,
rpc_wait_port=storage_rpc_wait_port,
)


Expand Down
7 changes: 7 additions & 0 deletions python/gigl/env/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ class GraphStoreInfo:
# https://snapchat.github.io/GiGL/docs/api/snapchat/research/gbml/gigl_resource_config_pb2/index.html#snapchat.research.gbml.gigl_resource_config_pb2.VertexAiGraphStoreConfig
num_processes_per_compute: int

# Port of the master node for the RPC communication.
# NOTE: This should be on the *storage* master node, not the compute master node.
rpc_master_port: int
# Port of the master node for the RPC wait communication.
# NOTE: This should be on the *storage* master node, not the compute master node.
rpc_wait_port: int

@property
def num_cluster_nodes(self) -> int:
return self.num_storage_nodes + self.num_compute_nodes
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import collections
import os
import socket
import unittest
from typing import Optional
from unittest import mock

import torch
import torch.multiprocessing as mp
from torch_geometric.data import Data

from gigl.common import Uri
from gigl.common.logger import Logger
from gigl.distributed.distributed_neighborloader import DistNeighborLoader
from gigl.distributed.graph_store.compute import (
init_compute_process,
shutdown_compute_proccess,
Expand Down Expand Up @@ -108,10 +111,37 @@ def _run_client_process(
).get_node_ids()
_assert_sampler_input(cluster_info, simple_sampler_input, expected_sampler_input)

# Check that the edge types are correct
assert (
remote_dist_dataset.get_edge_types() == expected_edge_types
), f"Expected edge types {expected_edge_types}, got {remote_dist_dataset.get_edge_types()}"

torch.distributed.barrier()

# Test the DistNeighborLoader
loader = DistNeighborLoader(
dataset=remote_dist_dataset,
num_neighbors=[2, 2],
pin_memory_device=torch.device("cpu"),
input_nodes=sampler_input,
num_workers=2,
worker_concurrency=2,
)
count = 0
for datum in loader:
assert isinstance(datum, Data)
count += 1
torch.distributed.barrier()
logger.info(f"Rank {torch.distributed.get_rank()} loaded {count} batches")
# Verify that we sampled all nodes.
count_tensor = torch.tensor(count, dtype=torch.int64)
all_node_count = 0
for rank_expected_sampler_input in expected_sampler_input.values():
all_node_count += sum(len(nodes) for nodes in rank_expected_sampler_input)
torch.distributed.all_reduce(count_tensor, op=torch.distributed.ReduceOp.SUM)
assert (
count_tensor.item() == all_node_count
), f"Expected {all_node_count} total nodes, got {count_tensor.item()}"
shutdown_compute_proccess()


Expand Down Expand Up @@ -186,6 +216,7 @@ def _get_expected_input_nodes_by_rank(
],
}


Args:
num_nodes (int): The number of nodes in the graph.
cluster_info (GraphStoreInfo): The cluster information.
Expand Down Expand Up @@ -220,17 +251,22 @@ def test_graph_store_locally(self):
storage_cluster_master_port,
compute_cluster_master_port,
master_port,
) = get_free_ports(num_ports=4)
rpc_master_port,
rpc_wait_port,
) = get_free_ports(num_ports=6)
host_ip = socket.gethostbyname(socket.gethostname())
cluster_info = GraphStoreInfo(
num_storage_nodes=2,
num_compute_nodes=2,
num_processes_per_compute=2,
cluster_master_ip="localhost",
storage_cluster_master_ip="localhost",
compute_cluster_master_ip="localhost",
cluster_master_ip=host_ip,
storage_cluster_master_ip=host_ip,
compute_cluster_master_ip=host_ip,
cluster_master_port=cluster_master_port,
storage_cluster_master_port=storage_cluster_master_port,
compute_cluster_master_port=compute_cluster_master_port,
rpc_master_port=rpc_master_port,
rpc_wait_port=rpc_wait_port,
)

num_cora_nodes = 2708
Expand All @@ -244,7 +280,7 @@ def test_graph_store_locally(self):
with mock.patch.dict(
os.environ,
{
"MASTER_ADDR": "localhost",
"MASTER_ADDR": host_ip,
"MASTER_PORT": str(master_port),
"RANK": str(i),
"WORLD_SIZE": str(cluster_info.compute_cluster_world_size),
Expand All @@ -271,7 +307,7 @@ def test_graph_store_locally(self):
with mock.patch.dict(
os.environ,
{
"MASTER_ADDR": "localhost",
"MASTER_ADDR": host_ip,
"MASTER_PORT": str(master_port),
"RANK": str(i + cluster_info.num_compute_nodes),
"WORLD_SIZE": str(cluster_info.compute_cluster_world_size),
Expand Down
2 changes: 2 additions & 0 deletions python/tests/unit/env/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def setUp(self) -> None:
cluster_master_port=1234,
storage_cluster_master_port=1235,
compute_cluster_master_port=1236,
rpc_master_port=1237,
rpc_wait_port=1238,
)

def test_num_cluster_nodes(self):
Expand Down