Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions python/gigl/distributed/dist_partitioner.py
Comment thread
swong3-sc marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,12 @@ def register_node_ids(

self._assert_and_get_rpc_setup()

# Check if node data has already been registered
if self._node_ids is not None:
raise ValueError(
"Node IDs have already been registered. Cannot re-register node data."
)

logger.info("Registering Nodes ...")
input_node_ids = self._convert_node_entity_to_heterogeneous_format(
input_node_entity=node_ids
Expand Down Expand Up @@ -431,6 +437,12 @@ def register_edge_index(

self._assert_and_get_rpc_setup()

# Check if edge data has already been registered
if self._edge_index is not None:
raise ValueError(
"Edge indices have already been registered. Cannot re-register edge data."
)

logger.info("Registering Edge Indices ...")

input_edge_index = self._convert_edge_entity_to_heterogeneous_format(
Expand Down Expand Up @@ -507,6 +519,12 @@ def register_node_features(

self._assert_and_get_rpc_setup()

# Check if node features have already been registered
if self._node_feat is not None:
raise ValueError(
"Node features have already been registered. Cannot re-register node feature data."
)

logger.info("Registering Node Features ...")

input_node_features = self._convert_node_entity_to_heterogeneous_format(
Expand Down Expand Up @@ -546,6 +564,12 @@ def register_node_labels(

self._assert_and_get_rpc_setup()

# Check if node labels have already been registered
if self._node_labels is not None:
raise ValueError(
"Node labels have already been registered. Cannot re-register node label data."
)

logger.info("Registering Node Labels ...")

input_node_labels = self._convert_node_entity_to_heterogeneous_format(
Expand Down Expand Up @@ -578,6 +602,12 @@ def register_edge_features(

self._assert_and_get_rpc_setup()

# Check if edge features have already been registered
if self._edge_feat is not None:
raise ValueError(
"Edge features have already been registered. Cannot re-register edge feature data."
)

logger.info("Registering Edge Features ...")

input_edge_features = self._convert_edge_entity_to_heterogeneous_format(
Expand Down Expand Up @@ -613,6 +643,18 @@ def register_labels(

self._assert_and_get_rpc_setup()

# Check if labels have already been registered
if is_positive:
if self._positive_label_edge_index is not None:
raise ValueError(
"Positive labels have already been registered. Cannot re-register positive label data."
)
else:
if self._negative_label_edge_index is not None:
raise ValueError(
"Negative labels have already been registered. Cannot re-register negative label data."
)

input_label_edge_index = self._convert_edge_entity_to_heterogeneous_format(
input_edge_entity=label_edge_index
)
Expand Down
226 changes: 226 additions & 0 deletions python/tests/unit/distributed/distributed_partitioner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,232 @@ def test_partitioning_invalid_node_ids(
with self.assertRaises(ValueError):
partitioner.partition_node_features_and_labels(node_pb)

def test_node_ids_re_registration(self) -> None:
"""Test that re-registering node IDs raises an error."""
master_port = glt.utils.get_free_port(self._master_ip_address)

init_worker_group(world_size=1, rank=0, group_name=get_process_group_name(0))
init_rpc(
master_addr=self._master_ip_address,
master_port=master_port,
num_rpc_threads=4,
)

partitioner = DistPartitioner(should_assign_edges_by_src_node=True)

# First registration should work
node_ids = torch.tensor([0, 1, 2])
partitioner.register_node_ids(node_ids=node_ids)

# Second registration should raise an error
with self.assertRaisesRegex(
ValueError, "Node IDs have already been registered"
):
partitioner.register_node_ids(node_ids=node_ids)

def test_edge_index_re_registration(self) -> None:
"""Test that re-registering edge indices raises an error."""
master_port = glt.utils.get_free_port(self._master_ip_address)

init_worker_group(world_size=1, rank=0, group_name=get_process_group_name(0))
init_rpc(
master_addr=self._master_ip_address,
master_port=master_port,
num_rpc_threads=4,
)

partitioner = DistPartitioner(should_assign_edges_by_src_node=True)

# In order to set the _is_input_homogeneous flag to True
partitioner.register_node_ids(torch.tensor([0, 1, 2]))

# First registration should work
edge_index = torch.tensor([[0, 1], [1, 2]])
partitioner.register_edge_index(edge_index=edge_index)

# Second registration should raise an error
with self.assertRaisesRegex(
ValueError, "Edge indices have already been registered"
):
partitioner.register_edge_index(edge_index=edge_index)

def test_node_features_re_registration(self) -> None:
"""Test that re-registering node features raises an error."""
master_port = glt.utils.get_free_port(self._master_ip_address)

init_worker_group(world_size=1, rank=0, group_name=get_process_group_name(0))
init_rpc(
master_addr=self._master_ip_address,
master_port=master_port,
num_rpc_threads=4,
)

partitioner = DistPartitioner(should_assign_edges_by_src_node=True)

# First registration should work
node_features = torch.ones(3, 5)
partitioner.register_node_features(node_features=node_features)

# Second registration should raise an error
with self.assertRaisesRegex(
ValueError, "Node features have already been registered"
):
partitioner.register_node_features(node_features=node_features)

def test_node_labels_re_registration(self) -> None:
"""Test that re-registering node labels raises an error."""
master_port = glt.utils.get_free_port(self._master_ip_address)

init_worker_group(world_size=1, rank=0, group_name=get_process_group_name(0))
init_rpc(
master_addr=self._master_ip_address,
master_port=master_port,
num_rpc_threads=4,
)

partitioner = DistPartitioner(should_assign_edges_by_src_node=True)

# First registration should work
node_labels = torch.tensor([[0, 1], [1, 0], [0, 1]])
partitioner.register_node_labels(node_labels=node_labels)

# Second registration should raise an error
with self.assertRaisesRegex(
ValueError, "Node labels have already been registered"
):
partitioner.register_node_labels(node_labels=node_labels)

def test_edge_features_re_registration(self) -> None:
"""Test that re-registering edge features raises an error."""
master_port = glt.utils.get_free_port(self._master_ip_address)

init_worker_group(world_size=1, rank=0, group_name=get_process_group_name(0))
init_rpc(
master_addr=self._master_ip_address,
master_port=master_port,
num_rpc_threads=4,
)

partitioner = DistPartitioner(should_assign_edges_by_src_node=True)

# In order to set the _is_input_homogeneous flag to True
partitioner.register_node_features(torch.ones(3, 5))

# First registration should work
edge_features = torch.ones(2, 10)
partitioner.register_edge_features(edge_features=edge_features)

# Second registration should raise an error
with self.assertRaisesRegex(
ValueError, "Edge features have already been registered"
):
partitioner.register_edge_features(edge_features=edge_features)

def test_positive_labels_re_registration(self) -> None:
"""Test that re-registering labels raises an error."""
master_port = glt.utils.get_free_port(self._master_ip_address)

init_worker_group(world_size=1, rank=0, group_name=get_process_group_name(0))
init_rpc(
master_addr=self._master_ip_address,
master_port=master_port,
num_rpc_threads=4,
)

# Positive labels test
partitioner = DistPartitioner(should_assign_edges_by_src_node=True)

# In order to set the _is_input_homogeneous flag to True
partitioner.register_node_ids(torch.tensor([0, 1, 2]))

# First registration should work
pos_labels = torch.tensor([[0, 1], [1, 2]])
partitioner.register_labels(label_edge_index=pos_labels, is_positive=True)

# # Second registration should raise an error
with self.assertRaisesRegex(
ValueError, "Positive labels have already been registered"
):
partitioner.register_labels(label_edge_index=pos_labels, is_positive=True)

def test_negative_labels_re_registration(self) -> None:
"""Test that re-registering labels raises an error."""
master_port = glt.utils.get_free_port(self._master_ip_address)

init_worker_group(world_size=1, rank=0, group_name=get_process_group_name(0))
init_rpc(
master_addr=self._master_ip_address,
master_port=master_port,
num_rpc_threads=4,
)

# Negative labels test
partitioner = DistPartitioner(should_assign_edges_by_src_node=True)

# In order to set the _is_input_homogeneous flag to True
partitioner.register_node_ids(torch.tensor([0, 1, 2]))

neg_labels = torch.tensor([[0, 1], [1, 2]])
partitioner.register_labels(label_edge_index=neg_labels, is_positive=False)

# Second registration should raise an error
with self.assertRaisesRegex(
ValueError, "Negative labels have already been registered"
):
partitioner.register_labels(label_edge_index=neg_labels, is_positive=False)

def test_heterogeneous_re_registration(self) -> None:
"""Test re-registration prevention for heterogeneous data."""
master_port = glt.utils.get_free_port(self._master_ip_address)

init_worker_group(world_size=1, rank=0, group_name=get_process_group_name(0))
init_rpc(
master_addr=self._master_ip_address,
master_port=master_port,
num_rpc_threads=4,
)

partitioner = DistPartitioner(should_assign_edges_by_src_node=True)

# Heterogeneous node IDs test
node_ids = {
USER_NODE_TYPE: torch.tensor([0, 1, 2]),
ITEM_NODE_TYPE: torch.tensor([0, 1, 2]),
}
partitioner.register_node_ids(node_ids=node_ids)

# Second registration should raise an error
with self.assertRaisesRegex(
ValueError, "Node IDs have already been registered"
):
partitioner.register_node_ids(node_ids=node_ids)

# Heterogeneous edge indices test
partitioner2 = DistPartitioner(should_assign_edges_by_src_node=True)

edge_index = {USER_TO_USER_EDGE_TYPE: torch.tensor([[0, 1], [1, 2]])}
partitioner2.register_edge_index(edge_index=edge_index)

# Second registration should raise an error
with self.assertRaisesRegex(
ValueError, "Edge indices have already been registered"
):
partitioner2.register_edge_index(edge_index=edge_index)

# Heterogeneous node labels test
partitioner3 = DistPartitioner(should_assign_edges_by_src_node=True)
node_labels = {
USER_NODE_TYPE: torch.tensor([[0, 1], [1, 0]]),
ITEM_NODE_TYPE: torch.tensor([[1, 0], [0, 1], [1, 1]]),
}
partitioner3.register_node_labels(node_labels=node_labels)

# Second registration should raise an error
with self.assertRaisesRegex(
ValueError, "Node labels have already been registered"
):
partitioner3.register_node_labels(node_labels=node_labels)


if __name__ == "__main__":
unittest.main()