Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/on-pr-comment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4

Comment thread
mkolodner-sc marked this conversation as resolved.
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: '3.x'

- name: Install PyYAML
run: pip install PyYAML

- name: Generate help text
id: parse_commands
run: python .github/scripts/get_help_text.py

- name: Post help comment
uses: snapchat/gigl/.github/actions/comment-on-pr@main
with:
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Added

- Added support for NodeAnchorSplitter
Comment thread
mkolodner-sc marked this conversation as resolved.
Comment thread
mkolodner-sc marked this conversation as resolved.

### Changed

### Deprecated
Expand Down
277 changes: 230 additions & 47 deletions python/gigl/utils/data_splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,38 @@ def should_convert_labels_to_edges(self):
...


class NodeSplitter(Protocol):
"""Protocol that should be satisfied for anything that is used to split on nodes directly.

Args:
node_ids: The node IDs to split on. 1D tensor for homogeneous or mapping for heterogeneous. 1 x N
Returns:
The train (1 x X), val (1 x Y), test (1 x Z) nodes. X + Y + Z = N
"""

@overload
def __call__(
self,
node_ids: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
...

@overload
def __call__(
self,
node_ids: Mapping[NodeType, torch.Tensor],
) -> Mapping[NodeType, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
...

def __call__(
self, *args, **kwargs
) -> Union[
Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
Mapping[NodeType, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
]:
...


def _fast_hash(x: torch.Tensor) -> torch.Tensor:
"""Fast hash function.

Expand Down Expand Up @@ -128,6 +160,7 @@ class HashedNodeAnchorLinkSplitter:
NOTE: This splitter must be called when a Torch distributed process group is initialized.
e.g. `torch.distributed.init_process_group` must be called before using this splitter.

We need this communication between the processes for determining the maximum and minimum hashed node id across all machines.

In node-based splitting, a node may only ever live in one split. E.g. if one
node has two label edges, *both* of those edges will be placed into the same split.
Expand Down Expand Up @@ -166,8 +199,8 @@ def __init__(
`gigl.distributed.build_dataset` convert all labels into edges, and will infer positive and negative edge types based on
`supervision_edge_types`.
"""
_check_sampling_direction(sampling_direction)
_check_val_test_percentage(num_val, num_test)
_assert_sampling_direction(sampling_direction)
_assert_valid_split_ratios(num_val, num_test)

self._sampling_direction = sampling_direction
self._num_val = num_val
Expand Down Expand Up @@ -238,6 +271,12 @@ def __call__(
Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
Mapping[NodeType, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
]:
# Validate distributed process group
if not torch.distributed.is_initialized():
raise RuntimeError(
f"Splitter requires a Torch distributed process group, but none was found. "
"Please initialize a process group (`torch.distributed.init_process_group`) before using this splitter."
)
Comment thread
mkolodner-sc marked this conversation as resolved.
if isinstance(edge_index, torch.Tensor):
if self._labeled_edge_types != [DEFAULT_HOMOGENEOUS_EDGE_TYPE]:
logger.warning(
Expand Down Expand Up @@ -314,59 +353,22 @@ def __call__(
# To a tensor of the non-zero counts, e.g. `[[1], [3]]`
# and the `squeeze` converts that to a 1d tensor (`[1, 3]`).
nodes_to_select = torch.nonzero(node_id_count).squeeze()
if nodes_to_select.dim() == 0:
nodes_to_select = nodes_to_select.unsqueeze(0)
Comment thread
mkolodner-sc marked this conversation as resolved.
# node_id_count no longer needed, so we can clean up it's memory.
del node_id_count
gc.collect()

hash_values = self._hash_function(nodes_to_select) # 1 x M
# Now, we want to normalize the hash values to [0, 1) range so we can select them easily into splits.
# We want to do this *globally* e.g. across all processes,
# so that we can ensure that the same nodes are selected for the same split across all processes.
# If we don't do this, then if we have `[0, 1, 2, 3, 4]` on one process and `[4, 5, 6, 7]` on another,
# with the identity hash `4` may end up in Test in one rank and Train in another.
min_hash_value, max_hash_value = map(
torch.Tensor.item, hash_values.aminmax()
# Create train, val, test splits using distributed coordination
train, val, test = _create_distributed_splits_from_hash(
nodes_to_select, hash_values, self._num_val, self._num_test
)
if torch.distributed.is_initialized():
all_max_and_mins = [
torch.zeros(2, dtype=torch.int64)
for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(
all_max_and_mins,
torch.tensor([max_hash_value, min_hash_value], dtype=torch.int64),
)
global_max_hash_value = max_hash_value
global_min_hash_value = min_hash_value
for max_and_min in all_max_and_mins:
global_max_hash_value = max(
global_max_hash_value, max_and_min[0].item()
)
global_min_hash_value = min(
global_min_hash_value, max_and_min[1].item()
)
else:
raise RuntimeError(
f"{type(self).__name__} requires a Torch distributed process group, but none was found. Please initialize a process group (`torch.distributed.init_process_group`) before using this splitter."
)
hash_values = (
hash_values - global_min_hash_value
) / global_max_hash_value # Normalize the hash values to [0, 1)

# Now that we've normalized the hash values, we can select the train, val, and test nodes.
test_inds = hash_values >= 1 - self._num_test # 1 x M
val_inds = (
hash_values >= 1 - self._num_test - self._num_val
) & ~test_inds # 1 x M
del hash_values
gc.collect()
train_inds = ~test_inds & ~val_inds # 1 x M
train = nodes_to_select[train_inds] # 1 x num_train_nodes
val = nodes_to_select[val_inds] # 1 x num_val_nodes
test = nodes_to_select[test_inds] # 1 x num_test_nodes
splits[anchor_node_type] = (train, val, test)
# We no longer need the nodes to select, so we can clean up their memory.
del nodes_to_select, train_inds, val_inds, test_inds
del nodes_to_select
gc.collect()
if len(splits) == 0:
raise ValueError(
Expand All @@ -383,6 +385,175 @@ def should_convert_labels_to_edges(self):
return self._should_convert_labels_to_edges


class HashedNodeSplitter:
"""Selects train, val, and test nodes based on provided node IDs directly.

NOTE: This splitter must be called when a Torch distributed process group is initialized.
e.g. `torch.distributed.init_process_group` must be called before using this splitter.

We need this communication between the processes for determining the maximum and minimum hashed node id across all machines.

In node-based splitting, each node will be placed into exactly one split based on its hash value.
This is simpler than edge-based splitting as it doesn't require extracting anchor nodes from edges.

Additionally, the HashedNodeSplitter does not de-dup repeated node ids. This means that if there are repeated node ids
which are passed in, the same number of repeated node ids are included in the output, all of which are put into the same split.
This differs from the HashedNodeAnchorLinkSplitter, which does de-dup the repeated source or destination nodes that appear from the
labeled edges.


Args:
node_ids: The node IDs to split. Either a 1D tensor for homogeneous graphs,
or a mapping from node types to 1D tensors for heterogeneous graphs.
Returns:
The train, val, test node splits as tensors or mappings depending on input format.
"""

def __init__(
self,
num_val: float = 0.1,
num_test: float = 0.1,
hash_function: Callable[[torch.Tensor], torch.Tensor] = _fast_hash,
):
"""Initializes the HashedNodeSplitter.

Args:
num_val (float): The percentage of nodes to use for validation. Defaults to 0.1 (10%).
num_test (float): The percentage of nodes to use for testing. Defaults to 0.1 (10%).
hash_function (Callable[[torch.Tensor], torch.Tensor]): The hash function to use. Defaults to `_fast_hash`.
"""
_assert_valid_split_ratios(num_val, num_test)

self._num_val = num_val
self._num_test = num_test
self._hash_function = hash_function

@overload
def __call__(
self,
node_ids: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
...

@overload
def __call__(
self,
node_ids: Mapping[NodeType, torch.Tensor],
) -> Mapping[NodeType, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
...
Comment thread
mkolodner-sc marked this conversation as resolved.

def __call__(
self,
node_ids: Union[torch.Tensor, Mapping[NodeType, torch.Tensor]],
) -> Union[
Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
Mapping[NodeType, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
]:
# Validate distributed process group
if not torch.distributed.is_initialized():
raise RuntimeError(
f"Splitter requires a Torch distributed process group, but none was found. "
"Please initialize a process group (`torch.distributed.init_process_group`) before using this splitter."
)
if isinstance(node_ids, Mapping):
is_heterogeneous = True
node_ids_dict = node_ids
else:
is_heterogeneous = False
node_ids_dict = {DEFAULT_HOMOGENEOUS_NODE_TYPE: node_ids}

splits: dict[NodeType, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {}

for node_type, nodes_to_split in node_ids_dict.items():
_check_node_ids(nodes_to_split)

hash_values = self._hash_function(nodes_to_split) # 1 x M

# Create train, val, test splits using distributed coordination
train, val, test = _create_distributed_splits_from_hash(
nodes_to_split, hash_values, self._num_val, self._num_test
)

splits[node_type] = (train, val, test)

# Clean up memory
del hash_values
gc.collect()

if len(splits) == 0:
raise ValueError("No node IDs provided for splitting")

if is_heterogeneous:
return splits
else:
return splits[DEFAULT_HOMOGENEOUS_NODE_TYPE]


def _create_distributed_splits_from_hash(
nodes_to_select: torch.Tensor,
hash_values: torch.Tensor,
num_val: float,
num_test: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Creates train, val, test splits from hash values using distributed coordination.

This function performs the complete splitting workflow:
1. Normalizes hash values globally across all processes
2. Creates train/val/test splits from the normalized hash values

Note that 0 < num_test + num_val < 1

Args:
nodes_to_select (torch.Tensor): The nodes to split. 1 x M
hash_values (torch.Tensor): Raw hash values for the nodes 1 x M
num_val (float): Percentage of nodes for validation.
num_test (float): Percentage of nodes for testing.

Returns:
Tuple of (train_nodes, val_nodes, test_nodes).

Raises:
RuntimeError: If no distributed process group is found.
"""

# Ensure hash values and nodes_to_select are on the same device
hash_values = hash_values.to(nodes_to_select.device)

# Normalize hash values globally across all processes
min_hash_value, max_hash_value = map(torch.Tensor.item, hash_values.aminmax())

all_max_and_mins = [
torch.zeros(2, dtype=torch.int64)
for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(
all_max_and_mins,
torch.tensor([max_hash_value, min_hash_value], dtype=torch.int64),
)
global_max_hash_value = max_hash_value
global_min_hash_value = min_hash_value
for max_and_min in all_max_and_mins:
global_max_hash_value = max(global_max_hash_value, max_and_min[0].item())
global_min_hash_value = min(global_min_hash_value, max_and_min[1].item())

normalized_hash_values = (
hash_values - global_min_hash_value
) / global_max_hash_value

# Create splits from normalized hash values
test_inds = normalized_hash_values >= 1 - num_test # 1 x M
val_inds = (normalized_hash_values >= 1 - num_test - num_val) & ~test_inds # 1 x M
train_inds = ~test_inds & ~val_inds # 1 x M

# Apply masks to select nodes
train = nodes_to_select[train_inds] # 1 x num_train_nodes
val = nodes_to_select[val_inds] # 1 x num_val_nodes
test = nodes_to_select[test_inds] # 1 x num_test_nodes

return train, val, test


def get_labels_for_anchor_nodes(
dataset: Dataset,
node_ids: torch.Tensor,
Expand Down Expand Up @@ -505,14 +676,14 @@ def _get_padded_labels(
return labels


def _check_sampling_direction(sampling_direction: str):
def _assert_sampling_direction(sampling_direction: str):
if sampling_direction not in ["in", "out"]:
raise ValueError(
f"Invalid sampling direction {sampling_direction}. Expected 'in' or 'out'."
)


def _check_val_test_percentage(
def _assert_valid_split_ratios(
val_percentage: Union[float, int], test_percentage: Union[float, int]
):
"""Checks that the val and test percentages make sense, e.g. we can still have train nodes, and they are non-negative."""
Expand Down Expand Up @@ -550,6 +721,18 @@ def _check_edge_index(edge_index: torch.Tensor):
raise ValueError("Expected a dense tensor. Received a sparse tensor.")


def _check_node_ids(node_ids: torch.Tensor):
Comment thread
mkolodner-sc marked this conversation as resolved.
"""Asserts node_ids tensor is the appropriate shape and is not sparse."""
if len(node_ids.shape) != 1:
raise ValueError(
f"Expected node_ids to be a 1D tensor. Received a tensor of shape: {node_ids.shape}."
)
if node_ids.is_sparse:
raise ValueError("Expected a dense tensor. Received a sparse tensor.")
if node_ids.numel() == 0:
raise ValueError("Expected non-empty node_ids tensor.")


def select_ssl_positive_label_edges(
edge_index: torch.Tensor, positive_label_percentage: float
) -> torch.Tensor:
Expand Down
Loading