diff --git a/python/gigl/distributed/dist_partitioner.py b/python/gigl/distributed/dist_partitioner.py index bf1e8260b..7f86f18df 100644 --- a/python/gigl/distributed/dist_partitioner.py +++ b/python/gigl/distributed/dist_partitioner.py @@ -368,6 +368,12 @@ def register_node_ids( self._assert_and_get_rpc_setup() + # Check if node data has already been registered + if self._node_ids is not None: + raise ValueError( + "Node IDs have already been registered. Cannot re-register node data." + ) + logger.info("Registering Nodes ...") input_node_ids = self._convert_node_entity_to_heterogeneous_format( input_node_entity=node_ids @@ -431,6 +437,12 @@ def register_edge_index( self._assert_and_get_rpc_setup() + # Check if edge data has already been registered + if self._edge_index is not None: + raise ValueError( + "Edge indices have already been registered. Cannot re-register edge data." + ) + logger.info("Registering Edge Indices ...") input_edge_index = self._convert_edge_entity_to_heterogeneous_format( @@ -507,6 +519,12 @@ def register_node_features( self._assert_and_get_rpc_setup() + # Check if node features have already been registered + if self._node_feat is not None: + raise ValueError( + "Node features have already been registered. Cannot re-register node feature data." + ) + logger.info("Registering Node Features ...") input_node_features = self._convert_node_entity_to_heterogeneous_format( @@ -546,6 +564,12 @@ def register_node_labels( self._assert_and_get_rpc_setup() + # Check if node labels have already been registered + if self._node_labels is not None: + raise ValueError( + "Node labels have already been registered. Cannot re-register node label data." + ) + logger.info("Registering Node Labels ...") input_node_labels = self._convert_node_entity_to_heterogeneous_format( @@ -578,6 +602,12 @@ def register_edge_features( self._assert_and_get_rpc_setup() + # Check if edge features have already been registered + if self._edge_feat is not None: + raise ValueError( + "Edge features have already been registered. Cannot re-register edge feature data." + ) + logger.info("Registering Edge Features ...") input_edge_features = self._convert_edge_entity_to_heterogeneous_format( @@ -613,6 +643,18 @@ def register_labels( self._assert_and_get_rpc_setup() + # Check if labels have already been registered + if is_positive: + if self._positive_label_edge_index is not None: + raise ValueError( + "Positive labels have already been registered. Cannot re-register positive label data." + ) + else: + if self._negative_label_edge_index is not None: + raise ValueError( + "Negative labels have already been registered. Cannot re-register negative label data." + ) + input_label_edge_index = self._convert_edge_entity_to_heterogeneous_format( input_edge_entity=label_edge_index ) diff --git a/python/tests/unit/distributed/distributed_partitioner_test.py b/python/tests/unit/distributed/distributed_partitioner_test.py index 8b9f3a908..62ce3074a 100644 --- a/python/tests/unit/distributed/distributed_partitioner_test.py +++ b/python/tests/unit/distributed/distributed_partitioner_test.py @@ -1153,6 +1153,232 @@ def test_partitioning_invalid_node_ids( with self.assertRaises(ValueError): partitioner.partition_node_features_and_labels(node_pb) + def test_node_ids_re_registration(self) -> None: + """Test that re-registering node IDs raises an error.""" + master_port = glt.utils.get_free_port(self._master_ip_address) + + init_worker_group(world_size=1, rank=0, group_name=get_process_group_name(0)) + init_rpc( + master_addr=self._master_ip_address, + master_port=master_port, + num_rpc_threads=4, + ) + + partitioner = DistPartitioner(should_assign_edges_by_src_node=True) + + # First registration should work + node_ids = torch.tensor([0, 1, 2]) + partitioner.register_node_ids(node_ids=node_ids) + + # Second registration should raise an error + with self.assertRaisesRegex( + ValueError, "Node IDs have already been registered" + ): + partitioner.register_node_ids(node_ids=node_ids) + + def test_edge_index_re_registration(self) -> None: + """Test that re-registering edge indices raises an error.""" + master_port = glt.utils.get_free_port(self._master_ip_address) + + init_worker_group(world_size=1, rank=0, group_name=get_process_group_name(0)) + init_rpc( + master_addr=self._master_ip_address, + master_port=master_port, + num_rpc_threads=4, + ) + + partitioner = DistPartitioner(should_assign_edges_by_src_node=True) + + # In order to set the _is_input_homogeneous flag to True + partitioner.register_node_ids(torch.tensor([0, 1, 2])) + + # First registration should work + edge_index = torch.tensor([[0, 1], [1, 2]]) + partitioner.register_edge_index(edge_index=edge_index) + + # Second registration should raise an error + with self.assertRaisesRegex( + ValueError, "Edge indices have already been registered" + ): + partitioner.register_edge_index(edge_index=edge_index) + + def test_node_features_re_registration(self) -> None: + """Test that re-registering node features raises an error.""" + master_port = glt.utils.get_free_port(self._master_ip_address) + + init_worker_group(world_size=1, rank=0, group_name=get_process_group_name(0)) + init_rpc( + master_addr=self._master_ip_address, + master_port=master_port, + num_rpc_threads=4, + ) + + partitioner = DistPartitioner(should_assign_edges_by_src_node=True) + + # First registration should work + node_features = torch.ones(3, 5) + partitioner.register_node_features(node_features=node_features) + + # Second registration should raise an error + with self.assertRaisesRegex( + ValueError, "Node features have already been registered" + ): + partitioner.register_node_features(node_features=node_features) + + def test_node_labels_re_registration(self) -> None: + """Test that re-registering node labels raises an error.""" + master_port = glt.utils.get_free_port(self._master_ip_address) + + init_worker_group(world_size=1, rank=0, group_name=get_process_group_name(0)) + init_rpc( + master_addr=self._master_ip_address, + master_port=master_port, + num_rpc_threads=4, + ) + + partitioner = DistPartitioner(should_assign_edges_by_src_node=True) + + # First registration should work + node_labels = torch.tensor([[0, 1], [1, 0], [0, 1]]) + partitioner.register_node_labels(node_labels=node_labels) + + # Second registration should raise an error + with self.assertRaisesRegex( + ValueError, "Node labels have already been registered" + ): + partitioner.register_node_labels(node_labels=node_labels) + + def test_edge_features_re_registration(self) -> None: + """Test that re-registering edge features raises an error.""" + master_port = glt.utils.get_free_port(self._master_ip_address) + + init_worker_group(world_size=1, rank=0, group_name=get_process_group_name(0)) + init_rpc( + master_addr=self._master_ip_address, + master_port=master_port, + num_rpc_threads=4, + ) + + partitioner = DistPartitioner(should_assign_edges_by_src_node=True) + + # In order to set the _is_input_homogeneous flag to True + partitioner.register_node_features(torch.ones(3, 5)) + + # First registration should work + edge_features = torch.ones(2, 10) + partitioner.register_edge_features(edge_features=edge_features) + + # Second registration should raise an error + with self.assertRaisesRegex( + ValueError, "Edge features have already been registered" + ): + partitioner.register_edge_features(edge_features=edge_features) + + def test_positive_labels_re_registration(self) -> None: + """Test that re-registering labels raises an error.""" + master_port = glt.utils.get_free_port(self._master_ip_address) + + init_worker_group(world_size=1, rank=0, group_name=get_process_group_name(0)) + init_rpc( + master_addr=self._master_ip_address, + master_port=master_port, + num_rpc_threads=4, + ) + + # Positive labels test + partitioner = DistPartitioner(should_assign_edges_by_src_node=True) + + # In order to set the _is_input_homogeneous flag to True + partitioner.register_node_ids(torch.tensor([0, 1, 2])) + + # First registration should work + pos_labels = torch.tensor([[0, 1], [1, 2]]) + partitioner.register_labels(label_edge_index=pos_labels, is_positive=True) + + # # Second registration should raise an error + with self.assertRaisesRegex( + ValueError, "Positive labels have already been registered" + ): + partitioner.register_labels(label_edge_index=pos_labels, is_positive=True) + + def test_negative_labels_re_registration(self) -> None: + """Test that re-registering labels raises an error.""" + master_port = glt.utils.get_free_port(self._master_ip_address) + + init_worker_group(world_size=1, rank=0, group_name=get_process_group_name(0)) + init_rpc( + master_addr=self._master_ip_address, + master_port=master_port, + num_rpc_threads=4, + ) + + # Negative labels test + partitioner = DistPartitioner(should_assign_edges_by_src_node=True) + + # In order to set the _is_input_homogeneous flag to True + partitioner.register_node_ids(torch.tensor([0, 1, 2])) + + neg_labels = torch.tensor([[0, 1], [1, 2]]) + partitioner.register_labels(label_edge_index=neg_labels, is_positive=False) + + # Second registration should raise an error + with self.assertRaisesRegex( + ValueError, "Negative labels have already been registered" + ): + partitioner.register_labels(label_edge_index=neg_labels, is_positive=False) + + def test_heterogeneous_re_registration(self) -> None: + """Test re-registration prevention for heterogeneous data.""" + master_port = glt.utils.get_free_port(self._master_ip_address) + + init_worker_group(world_size=1, rank=0, group_name=get_process_group_name(0)) + init_rpc( + master_addr=self._master_ip_address, + master_port=master_port, + num_rpc_threads=4, + ) + + partitioner = DistPartitioner(should_assign_edges_by_src_node=True) + + # Heterogeneous node IDs test + node_ids = { + USER_NODE_TYPE: torch.tensor([0, 1, 2]), + ITEM_NODE_TYPE: torch.tensor([0, 1, 2]), + } + partitioner.register_node_ids(node_ids=node_ids) + + # Second registration should raise an error + with self.assertRaisesRegex( + ValueError, "Node IDs have already been registered" + ): + partitioner.register_node_ids(node_ids=node_ids) + + # Heterogeneous edge indices test + partitioner2 = DistPartitioner(should_assign_edges_by_src_node=True) + + edge_index = {USER_TO_USER_EDGE_TYPE: torch.tensor([[0, 1], [1, 2]])} + partitioner2.register_edge_index(edge_index=edge_index) + + # Second registration should raise an error + with self.assertRaisesRegex( + ValueError, "Edge indices have already been registered" + ): + partitioner2.register_edge_index(edge_index=edge_index) + + # Heterogeneous node labels test + partitioner3 = DistPartitioner(should_assign_edges_by_src_node=True) + node_labels = { + USER_NODE_TYPE: torch.tensor([[0, 1], [1, 0]]), + ITEM_NODE_TYPE: torch.tensor([[1, 0], [0, 1], [1, 1]]), + } + partitioner3.register_node_labels(node_labels=node_labels) + + # Second registration should raise an error + with self.assertRaisesRegex( + ValueError, "Node labels have already been registered" + ): + partitioner3.register_node_labels(node_labels=node_labels) + if __name__ == "__main__": unittest.main()