Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions python/gigl/distributed/dist_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,8 +1462,8 @@ def partition_edge_index_and_edge_features(
],
Tuple[
dict[EdgeType, GraphPartitionData],
dict[EdgeType, FeaturePartitionData],
dict[EdgeType, PartitionBook],
Optional[dict[EdgeType, FeaturePartitionData]],
Optional[dict[EdgeType, PartitionBook]],
],
]:
"""
Expand All @@ -1474,8 +1474,10 @@ def partition_edge_index_and_edge_features(
Returns:
Union[
Tuple[GraphPartitionData, FeaturePartitionData, PartitionBook],
Tuple[dict[EdgeType, GraphPartitionData], dict[EdgeType, FeaturePartitionData], dict[EdgeType, PartitionBook]],
]: Partitioned Graph Data, Feature Data, and corresponding edge partition book, is a dictionary if heterogeneous
Tuple[dict[EdgeType, GraphPartitionData], Optional[dict[EdgeType, FeaturePartitionData]], Optional[dict[EdgeType, PartitionBook]]],
]: Partitioned Graph Data, Feature Data, and corresponding edge partition book, is a dictionary if heterogeneous.
The second and third elements of this tuple are only present if there are edge features to partition, and are None
otherwise.
"""

self._assert_and_get_rpc_setup()
Expand Down Expand Up @@ -1545,25 +1547,26 @@ def partition_edge_index_and_edge_features(
for edge_type, num_edges in self._num_edges.items()
}

# If partitioned_edge_features or edge_partition_book is empty, we return None. This is becauuse we assert
# that any registered edge feature is non-empty, so if we encounter an empty dictionary in this case, this means
# we never registered edge features and we can safely return None here.
if self._is_input_homogeneous:
logger.info(
f"Partitioned {to_homogeneous(formatted_num_edges)} edges for homogeneous dataset"
)
return (
to_homogeneous(partitioned_edge_index),
to_homogeneous(partitioned_edge_features)
if len(partitioned_edge_features) > 0
else None,
to_homogeneous(edge_partition_book)
if len(edge_partition_book) > 0
if partitioned_edge_features
else None,
to_homogeneous(edge_partition_book) if edge_partition_book else None,
)
else:
logger.info(f"Partitioned {self._num_edges} edges per edge type")
return (
partitioned_edge_index,
partitioned_edge_features,
edge_partition_book,
partitioned_edge_features if partitioned_edge_features else None,
edge_partition_book if edge_partition_book else None,
Comment thread
mkolodner-sc marked this conversation as resolved.
)

def partition_labels(
Expand Down
16 changes: 7 additions & 9 deletions python/gigl/distributed/dist_range_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,8 @@ def partition_edge_index_and_edge_features(
],
tuple[
dict[EdgeType, GraphPartitionData],
dict[EdgeType, FeaturePartitionData],
dict[EdgeType, PartitionBook],
Optional[dict[EdgeType, FeaturePartitionData]],
Optional[dict[EdgeType, PartitionBook]],
],
]:
"""
Expand All @@ -368,7 +368,7 @@ def partition_edge_index_and_edge_features(
Returns:
Union[
Tuple[GraphPartitionData, Optional[FeaturePartitionData], Optional[PartitionBook]],
Tuple[dict[EdgeType, GraphPartitionData], dict[EdgeType, FeaturePartitionData], dict[EdgeType, PartitionBook]],
Tuple[dict[EdgeType, GraphPartitionData], Optional[dict[EdgeType, FeaturePartitionData]], Optional[dict[EdgeType, PartitionBook]]],
]: Partitioned Graph Data, Feature Data, and corresponding edge partition book, is a dictionary if heterogeneous.
"""

Expand Down Expand Up @@ -436,16 +436,14 @@ def partition_edge_index_and_edge_features(
return (
to_homogeneous(partitioned_edge_index),
to_homogeneous(partitioned_edge_features)
if len(partitioned_edge_features) > 0
else None,
to_homogeneous(edge_partition_book)
if len(edge_partition_book) > 0
if partitioned_edge_features
else None,
to_homogeneous(edge_partition_book) if edge_partition_book else None,
)
else:
logger.info(f"Partitioned {formatted_num_edges} edges per edge type")
return (
partitioned_edge_index,
partitioned_edge_features,
edge_partition_book,
partitioned_edge_features if partitioned_edge_features else None,
edge_partition_book if edge_partition_book else None,
)
2 changes: 1 addition & 1 deletion python/gigl/types/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class PartitionOutput:
node_partition_book: Union[PartitionBook, dict[NodeType, PartitionBook]]

# Edge partition book
edge_partition_book: Union[PartitionBook, dict[EdgeType, PartitionBook]]
edge_partition_book: Optional[Union[PartitionBook, dict[EdgeType, PartitionBook]]]

# Partitioned edge index on current rank. This field will always be populated after partitioning. However, we may set this
# field to None during dataset.build() in order to minimize the peak memory usage, and as a result type this as Optional.
Expand Down
83 changes: 56 additions & 27 deletions python/tests/unit/distributed/distributed_partitioner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def _assert_graph_outputs(
is_heterogeneous: bool,
should_assign_edges_by_src_node: bool,
output_node_partition_book: Union[PartitionBook, dict[NodeType, PartitionBook]],
output_edge_partition_book: Union[PartitionBook, dict[EdgeType, PartitionBook]],
output_edge_partition_book: Optional[
Union[PartitionBook, dict[EdgeType, PartitionBook]]
],
output_edge_index: Union[
GraphPartitionData, dict[EdgeType, GraphPartitionData]
],
Expand Down Expand Up @@ -136,35 +138,49 @@ def _assert_graph_outputs(
)
# To unify logic between homogeneous and heterogeneous cases, we define an iterable which we'll loop over.
# Each iteration contains an EdgeType, an edge partition book, and a graph consisting of edge indices and ids.
entity_iterable: Iterable[
entity_iterable: list[
Tuple[EdgeType, Optional[PartitionBook], GraphPartitionData]
]
is_edge_partition_book_heterogeneous = isinstance(
output_edge_partition_book, abc.Mapping
)
if is_edge_partition_book_heterogeneous and isinstance(
output_edge_index, abc.Mapping
):
entity_iterable = [
(
edge_type,
output_edge_partition_book[edge_type]
if edge_type in output_edge_partition_book
else None,
output_edge_index[edge_type],
] = []
if isinstance(output_edge_index, abc.Mapping):
if isinstance(output_edge_partition_book, abc.Mapping):
for edge_type in MOCKED_HETEROGENEOUS_EDGE_TYPES:
entity_iterable.append(
(
edge_type,
output_edge_partition_book[edge_type]
if edge_type in output_edge_partition_book
else None,
output_edge_index[edge_type],
)
)
elif output_edge_partition_book is None:
for edge_type in MOCKED_HETEROGENEOUS_EDGE_TYPES:
entity_iterable.append(
(
edge_type,
None,
output_edge_index[edge_type],
)
)
else:
raise ValueError(
f"The output edge partition book of type {type(output_edge_partition_book)} is not compatible with the output edge index of type {type(output_edge_index)}."
)
for edge_type in MOCKED_HETEROGENEOUS_EDGE_TYPES
]
elif not is_edge_partition_book_heterogeneous and isinstance(
output_edge_index, GraphPartitionData
):
entity_iterable = [
(USER_TO_USER_EDGE_TYPE, output_edge_partition_book, output_edge_index)
]
else:
raise ValueError(
f"The output edge partition book of type {type(output_edge_partition_book)} and the output graph of type {type(output_edge_index)} are not compatible."
)
if isinstance(output_edge_partition_book, (PartitionBook, torch.Tensor)):
entity_iterable = [
(
USER_TO_USER_EDGE_TYPE,
output_edge_partition_book,
output_edge_index,
)
]
elif output_edge_partition_book is None:
entity_iterable = [(USER_TO_USER_EDGE_TYPE, None, output_edge_index)]
else:
raise ValueError(
f"The output edge partition book of type {type(output_edge_partition_book)} is not compatible with the output edge index of type {type(output_edge_index)}."
)

for edge_type, edge_partition_book, graph in entity_iterable:
node_partition_book: PartitionBook
Expand Down Expand Up @@ -742,12 +758,25 @@ def test_partitioning_correctness(
input_data_strategy
== InputDataStrategy.REGISTER_MINIMAL_ENTITIES_SEPARATELY
):
self.assertIsNone(partition_output.edge_partition_book)
self.assertIsNone(partition_output.partitioned_edge_features)
self.assertIsNone(partition_output.partitioned_node_features)
self.assertIsNone(partition_output.partitioned_node_labels)
self.assertIsNone(partition_output.partitioned_positive_labels)
self.assertIsNone(partition_output.partitioned_negative_labels)
else:
assert (
partition_output.edge_partition_book is not None
), f"Must create edge partition book for strategy {input_data_strategy.value}"
if isinstance(partition_output.edge_partition_book, abc.Mapping):
for (
edge_type,
partition_book,
) in partition_output.edge_partition_book.items():
assert partition_book is not None
else:
assert partition_output.edge_partition_book is not None

assert (
partition_output.partitioned_node_features is not None
), f"Must partition node features for strategy {input_data_strategy.value}"
Expand Down