Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions python/gigl/utils/data_splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,14 +281,18 @@ def __call__(
)
node_ids_by_node_type[anchor_node_type].append(anchor_nodes)
# Second, we go through all node types and split them.
# Note the approach here (with `torch.argsort`) isn't the quickest
# we could avoid calling `torch.argsort` and do something like:
# Note: We could use `torch.argsort` here after normalizing the hash values,
# instead of generating masks directly based on comparing the hash values to the split percentages.
# e.g.:
# hash_values = ...
# train_mask = hash_values < train_percentage
# train = nodes_to_select[train_mask]
# That approach is about 2x faster (30s -> 15s on 1B nodes),
# but with this `argsort` approach we can be more exact with the number of nodes per split.
# The memory usage seems the same across both approaches.
# sorted_indices = torch.argsort(hash_values)
# test_mask = sorted_indices[:int(hash_values.size(0) * self._num_test)]
# test = nodes_to_select[test_mask]
# ...
# This apporach would let us be exact in the number of nodes per split,
# and allow us to allow integer inputs as num_test and num_val.
# BUT, it is slower, it takes ~30s instead of ~15s on 1B edges.
# The memory usage is the same across both approaches.

# De-dupe this way instead of using `unique` to avoid the overhead of sorting.
# This approach, goes from ~60s to ~30s on 1B edges.
Expand Down