Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/gigl/distributed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
__all__ = [
"get_available_device",
"get_free_ports_from_master_node",
"get_free_ports_from_node",
"get_free_port",
"get_internal_ip_from_all_ranks",
"get_internal_ip_from_master_node",
"get_internal_ip_from_node",
"get_process_group_name",
"init_neighbor_loader_worker",
]
Expand All @@ -20,6 +22,8 @@
from .networking import (
get_free_port,
get_free_ports_from_master_node,
get_free_ports_from_node,
get_internal_ip_from_all_ranks,
get_internal_ip_from_master_node,
get_internal_ip_from_node,
)
65 changes: 47 additions & 18 deletions python/gigl/distributed/utils/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,39 @@ def get_free_ports(num_ports: int) -> list[int]:


def get_free_ports_from_master_node(
Comment thread
kmontemayor2-sc marked this conversation as resolved.
num_ports=1, _global_rank_override: Optional[int] = None
num_ports: int, _global_rank_override: Optional[int] = None
) -> list[int]:
"""
Get free ports from master node, that can be used for communication between workers.
Args:
num_ports (int): Number of free ports to find.
_global_rank_override (Optional[int]): Override for the global rank,
useful for testing or if global rank is not accurately available.
"""
return get_free_ports_from_node(
num_ports, node_rank=0, _global_rank_override=_global_rank_override
)


def get_free_ports_from_node(
num_ports: int,
node_rank: int,
_global_rank_override: Optional[int] = None,
) -> list[int]:
"""
Get free ports from a node, that can be used for communication between workers.
Args:
num_ports (int): Number of free ports to find.
node_rank (int): Rank of the node, default is 0.
_global_rank_override (Optional[int]): Override for the global rank,
useful for testing or if global rank is not accurately available.
Returns:
list[int]: A list of free port numbers on the master node.
list[int]: A list of free port numbers on the node.
"""
# Ensure that the distributed environment is initialized
assert (
torch.distributed.is_initialized()
), "Distributed environment must be initialized to communicate free ports on master"
), "Distributed environment must be initialized to communicate free ports on a node"
assert num_ports >= 1, "num_ports must be >= 1"

rank = (
Expand All @@ -67,17 +85,16 @@ def get_free_ports_from_master_node(
else _global_rank_override
)
logger.info(
f"Rank {rank} is requesting {num_ports} free ports from rank 0 (master)"
f"Rank {rank} is requesting {num_ports} free ports from rank {node_rank} (node)"
)
ports: list[int]
if rank == 0:
if rank == node_rank:
ports = get_free_ports(num_ports)
logger.info(f"Rank {rank} found free ports: {ports}")
else:
ports = [0] * num_ports

# Broadcast from master from rank 0 to all other ranks
torch.distributed.broadcast_object_list(ports, src=0)
torch.distributed.broadcast_object_list(ports, src=node_rank)
logger.info(f"Rank {rank} received ports: {ports}")
return ports

Expand All @@ -87,12 +104,24 @@ def get_internal_ip_from_master_node(
) -> str:
"""
Get the internal IP address of the master node in a distributed setup.
"""
return get_internal_ip_from_node(
node_rank=0, _global_rank_override=_global_rank_override
)


def get_internal_ip_from_node(
node_rank: int,
_global_rank_override: Optional[int] = None,
) -> str:
"""
Get the internal IP address of the node in a distributed setup.
This is useful for setting up RPC communication between workers where the default torch.distributed env:// setup is not enough.

i.e. when using :py:obj:`gigl.distributed.dataset_factory`

Returns:
str: The internal IP address of the master node.
str: The internal IP address of the node.
"""
assert (
torch.distributed.is_initialized()
Expand All @@ -104,23 +133,23 @@ def get_internal_ip_from_master_node(
else _global_rank_override
)
logger.info(
f"Rank {rank} is requesting internal ip address of master node from rank 0 (master)"
f"Rank {rank} is requesting internal ip address of node from rank {node_rank}"
)

master_ip_list: list[Optional[str]] = []
if rank == 0:
ip_list: list[Optional[str]] = []
if rank == node_rank:
# Master node, return its own internal IP
master_ip_list = [socket.gethostbyname(socket.gethostname())]
ip_list = [socket.gethostbyname(socket.gethostname())]
else:
# Other nodes will receive the master's IP via broadcast
master_ip_list = [None]
ip_list = [None]

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.distributed.broadcast_object_list(master_ip_list, src=0, device=device)
master_ip = master_ip_list[0]
logger.info(f"Rank {rank} received master internal IP: {master_ip}")
assert master_ip is not None, "Could not retrieve master node's internal IP"
return master_ip
torch.distributed.broadcast_object_list(ip_list, src=node_rank, device=device)
node_ip = ip_list[0]
logger.info(f"Rank {rank} received master internal IP: {node_ip}")
assert node_ip is not None, "Could not retrieve master node's internal IP"
return node_ip


def get_internal_ip_from_all_ranks() -> list[str]:
Expand Down
163 changes: 158 additions & 5 deletions python/tests/unit/distributed/utils/networking_test.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import subprocess
import unittest
from unittest.mock import patch

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from parameterized import param, parameterized

from gigl.distributed.utils import (
get_free_port,
get_free_ports_from_master_node,
get_free_ports_from_node,
get_internal_ip_from_master_node,
get_internal_ip_from_node,
)
from tests.test_assets.distributed.utils import get_process_group_init_method


def _test_fetching_free_ports_in_dist_context(
Expand Down Expand Up @@ -50,6 +53,61 @@ def _test_fetching_free_ports_in_dist_context(
dist.destroy_process_group()


def _test_fetching_free_ports_from_node(
rank: int,
world_size: int,
init_process_group_init_method: str,
num_ports: int,
master_node_rank: int,
ports: list[int],
):
# Initialize distributed process group
dist.init_process_group(
backend="gloo",
init_method=init_process_group_init_method,
world_size=world_size,
rank=rank,
)
try:
if rank == master_node_rank:
with patch(
"gigl.distributed.utils.networking.get_free_ports", return_value=ports
):
free_ports: list[int] = get_free_ports_from_node(
num_ports=num_ports, node_rank=master_node_rank
)
else:
with patch(
"gigl.distributed.utils.networking.get_free_ports",
side_effect=Exception("Should not be called on non-master node"),
):
free_ports = get_free_ports_from_node(
num_ports=num_ports, node_rank=master_node_rank
)
assert len(free_ports) == num_ports
# Check that all ranks see the same ports broadcasted from master (rank 0)
gathered_ports_across_ranks = [
torch.zeros(num_ports, dtype=torch.int32) for _ in range(world_size)
]
dist.all_gather_object(gathered_ports_across_ranks, free_ports)
assert (
len(gathered_ports_across_ranks) == world_size
), f"Expected {world_size} ports, but got {len(gathered_ports_across_ranks)}"
ports_gathered_at_rank_0 = gathered_ports_across_ranks[0]
assert (
len(ports_gathered_at_rank_0) == num_ports
), "returned number of ports to match requested number of ports"
assert all(
port >= 0 for port in ports_gathered_at_rank_0
), "All ports should be non-negative integers"
assert all(
ports_gathered_at_rank_k == ports_gathered_at_rank_0
for ports_gathered_at_rank_k in gathered_ports_across_ranks
), "All ranks should receive the same ports from master (rank 0)"
finally:
dist.destroy_process_group()


def _test_get_internal_ip_from_master_node_in_dist_context(
rank: int, world_size: int, init_process_group_init_method: str, expected_ip: str
):
Expand All @@ -72,6 +130,39 @@ def _test_get_internal_ip_from_master_node_in_dist_context(
dist.destroy_process_group()


def _test_get_internal_ip_from_node(
rank: int,
world_size: int,
init_process_group_init_method: str,
expected_ip: str,
master_node_rank: int,
):
# Initialize distributed process group
dist.init_process_group(
backend="gloo",
init_method=init_process_group_init_method,
world_size=world_size,
rank=rank,
)
print(
f"Rank {rank} initialized process group with init method: {init_process_group_init_method}"
)
try:
if rank == master_node_rank:
master_ip = get_internal_ip_from_node(node_rank=master_node_rank)
else:
with patch(
"gigl.distributed.utils.networking.socket.gethostbyname",
side_effect=Exception("Should not be called on non-master node"),
):
master_ip = get_internal_ip_from_node(node_rank=master_node_rank)
assert (
master_ip == expected_ip
), f"Expected master IP to be {expected_ip}, but got {master_ip}"
finally:
dist.destroy_process_group()


class TestDistributedNetworkingUtils(unittest.TestCase):
def tearDown(self):
if dist.is_initialized():
Expand Down Expand Up @@ -102,14 +193,47 @@ def tearDown(self):
def test_get_free_ports_from_master_node_two_ranks(
self, _name, num_ports, world_size
):
port = get_free_port()
init_process_group_init_method = f"tcp://127.0.0.1:{port}"
init_process_group_init_method = get_process_group_init_method()
mp.spawn(
fn=_test_fetching_free_ports_in_dist_context,
args=(world_size, init_process_group_init_method, num_ports),
nprocs=world_size,
)

@parameterized.expand(
[
param(
"Test fetching 2 ports for world_size = 2 with master_node_rank = 0",
num_ports=2,
world_size=2,
master_node_rank=0,
ports=[1, 2],
),
param(
"Test fetching 2 ports for world_size = 2 with master_node_rank = 1",
num_ports=2,
world_size=2,
master_node_rank=1,
ports=[3, 4],
),
]
)
def test_get_free_ports_from_master_node_two_ranks_custom_master_node_rank(
self, _name, num_ports, world_size, master_node_rank, ports
):
init_process_group_init_method = get_process_group_init_method()
mp.spawn(
fn=_test_fetching_free_ports_from_node,
args=(
world_size,
init_process_group_init_method,
num_ports,
master_node_rank,
ports,
),
nprocs=world_size,
)

def test_get_free_ports_from_master_fails_if_process_group_not_initialized(self):
with self.assertRaises(
AssertionError,
Expand All @@ -118,8 +242,7 @@ def test_get_free_ports_from_master_fails_if_process_group_not_initialized(self)
get_free_ports_from_master_node(num_ports=1)

def test_get_internal_ip_from_master_node(self):
port = get_free_port()
init_process_group_init_method = f"tcp://127.0.0.1:{port}"
init_process_group_init_method = get_process_group_init_method()
expected_host_ip = subprocess.check_output(["hostname", "-i"]).decode().strip()
world_size = 2
mp.spawn(
Expand All @@ -128,6 +251,36 @@ def test_get_internal_ip_from_master_node(self):
nprocs=world_size,
)

@parameterized.expand(
[
param(
"Getting internal IP from master node with master_node_rank = 0",
world_size=2,
master_node_rank=0,
),
param(
"Getting internal IP from master node with master_node_rank = 1",
world_size=2,
master_node_rank=1,
),
]
)
def test_get_internal_ip_from_master_node_with_master_node_rank(
self, _, world_size, master_node_rank
):
init_process_group_init_method = get_process_group_init_method()
expected_host_ip = subprocess.check_output(["hostname", "-i"]).decode().strip()
mp.spawn(
fn=_test_get_internal_ip_from_node,
args=(
world_size,
init_process_group_init_method,
expected_host_ip,
master_node_rank,
),
nprocs=world_size,
)

def test_get_internal_ip_from_master_node_fails_if_process_group_not_initialized(
self,
):
Expand Down