diff --git a/python/gigl/utils/data_splitters.py b/python/gigl/utils/data_splitters.py index 78c23cba2..3e69516db 100644 --- a/python/gigl/utils/data_splitters.py +++ b/python/gigl/utils/data_splitters.py @@ -281,14 +281,18 @@ def __call__( ) node_ids_by_node_type[anchor_node_type].append(anchor_nodes) # Second, we go through all node types and split them. - # Note the approach here (with `torch.argsort`) isn't the quickest - # we could avoid calling `torch.argsort` and do something like: + # Note: We could use `torch.argsort` here after normalizing the hash values, + # instead of generating masks directly based on comparing the hash values to the split percentages. + # e.g.: # hash_values = ... - # train_mask = hash_values < train_percentage - # train = nodes_to_select[train_mask] - # That approach is about 2x faster (30s -> 15s on 1B nodes), - # but with this `argsort` approach we can be more exact with the number of nodes per split. - # The memory usage seems the same across both approaches. + # sorted_indices = torch.argsort(hash_values) + # test_mask = sorted_indices[:int(hash_values.size(0) * self._num_test)] + # test = nodes_to_select[test_mask] + # ... + # This apporach would let us be exact in the number of nodes per split, + # and allow us to allow integer inputs as num_test and num_val. + # BUT, it is slower, it takes ~30s instead of ~15s on 1B edges. + # The memory usage is the same across both approaches. # De-dupe this way instead of using `unique` to avoid the overhead of sorting. # This approach, goes from ~60s to ~30s on 1B edges.