Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
2ae7b23
Update TODO
mkolodner-sc Aug 12, 2025
b21d419
Merge branch 'main' of github.com:Snapchat/GiGL into main
mkolodner-sc Aug 14, 2025
4d71438
Update
mkolodner-sc Aug 15, 2025
3cd90bc
Fix
mkolodner-sc Aug 15, 2025
8a9932c
Update
mkolodner-sc Aug 15, 2025
edb3155
Update
mkolodner-sc Aug 15, 2025
6bf4c61
Merge branch 'mkolodner-sc/add_util_for_extracting_labels' into mkolo…
mkolodner-sc Aug 15, 2025
f00bb4a
Revert "Merge branch 'mkolodner-sc/add_util_for_extracting_labels' in…
mkolodner-sc Aug 15, 2025
8158dea
Revert "Revert "Merge branch 'mkolodner-sc/add_util_for_extracting_la…
mkolodner-sc Aug 15, 2025
abb23b8
Update
mkolodner-sc Aug 15, 2025
f7f643f
Update
mkolodner-sc Aug 15, 2025
65ae1c0
Update
mkolodner-sc Aug 15, 2025
bd77a35
Update
mkolodner-sc Aug 15, 2025
16808ab
Fix
mkolodner-sc Aug 15, 2025
cf40dde
Update
mkolodner-sc Aug 18, 2025
6d15f42
Update
mkolodner-sc Aug 18, 2025
d447c1c
Address comments
mkolodner-sc Aug 18, 2025
b0da420
Fix
mkolodner-sc Aug 18, 2025
d5d96b9
update
mkolodner-sc Aug 18, 2025
123ac0c
fmt changelog
mkolodner-sc Aug 18, 2025
11810a9
Simplify load_and_build_partitioned_dataset
mkolodner-sc Aug 19, 2025
dcdf5e7
Update
mkolodner-sc Aug 19, 2025
a37f561
Re-arrange protected functions
mkolodner-sc Aug 19, 2025
e6aeee8
Update
mkolodner-sc Aug 19, 2025
398e061
Update
mkolodner-sc Aug 19, 2025
d7490e6
Go back to old _load_and_build_partitioned_data
mkolodner-sc Aug 19, 2025
a4b8788
Update
mkolodner-sc Aug 19, 2025
cbabc53
Update
mkolodner-sc Aug 19, 2025
5f8b667
Sort feature keys
mkolodner-sc Aug 19, 2025
3a8e797
Update
mkolodner-sc Aug 19, 2025
40bae3c
Merge branch 'main' into mkolodner-sc/update_tfrecordloader_with_labels
mkolodner-sc Aug 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

- Added support for NodeAnchorSplitter

- Add support for loading, partitioning, and separate node labels from features
Comment thread
mkolodner-sc marked this conversation as resolved.

### Changed

### Deprecated
Expand Down
43 changes: 31 additions & 12 deletions python/gigl/common/data/dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
from copy import deepcopy
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import partial
from typing import Callable, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -29,12 +29,17 @@ class SerializedTFRecordInfo:
tfrecord_uri_prefix: Uri
# Feature names to load for the current entity
feature_keys: Sequence[str]
# a dict of feature name -> FeatureSpec (eg. FixedLenFeature, VarlenFeature, SparseFeature, RaggedFeature). If entity keys are not present, we insert them during tensor loading
# A dict of feature name -> FeatureSpec (eg. FixedLenFeature, VarlenFeature, SparseFeature, RaggedFeature).
# If entity keys are not present, we insert them during tensor loading. For example, if the FeatureSpecDict
# doesn't have the "node_id" identifier, we populate the feature_spce with a FixedLenFeature with shape=[], dtype=tf.int64.
# Note that entity label keys should also be included in the feature_spec if they are present.
feature_spec: FeatureSpecDict
# Feature dimension of current entity
feature_dim: int
# Entity ID Key for current entity. If this is a Node Entity, this must be a string. If this is an edge entity, this must be a Tuple[str, str] for the source and destination ids.
entity_key: Union[str, Tuple[str, str]]
# Name of the label columns for the current entity, defaults to an empty list.
label_keys: Sequence[str] = field(default_factory=list)
# The regex pattern to match the TFRecord files at the specified prefix
tfrecord_uri_pattern: str = ".*-of-.*\.tfrecord(\.gz)?$"

Expand Down Expand Up @@ -79,23 +84,31 @@ class TFDatasetOptions:
def _concatenate_features_by_names(
feature_key_to_tf_tensor: dict[str, tf.Tensor],
feature_keys: Sequence[str],
label_keys: Sequence[str],
) -> tf.Tensor:
"""
Concatenates feature tensors in the order specified by feature names.
Also concatenates labels to the end of the feature list if they are present using the corresponding label key

It is assumed that feature_names is a subset of the keys in feature_name_to_tf_tensor.

Args:
feature_key_to_tf_tensor (dict[str, tf.Tensor]): A dictionary mapping feature names to their corresponding tf tensors.
feature_keys (list[str]): A list of feature names specifying the order in which tensors should be concatenated.
feature_keys (Sequence[str]): A list of feature names specifying the order in which tensors should be concatenated.
label_keys (Sequence[str]): Name of the label columns for the current entity.

Returns:
tf.Tensor: A concatenated tensor of the features in the specified order.
tf.Tensor: A concatenated tensor of the features in the specified order, with the labels being concatenated at the end if it exists
"""

features: list[tf.Tensor] = []

for feature_key in feature_keys:
feature_iterable = list(feature_keys)

for label_key in label_keys:
feature_iterable.append(label_key)

for feature_key in feature_iterable:
tensor = feature_key_to_tf_tensor[feature_key]

# TODO(kmonte, xgao, zfan): We will need to add support for this if we're trying to scale up.
Expand Down Expand Up @@ -298,6 +311,7 @@ def load_as_torch_tensors(
"""
entity_key = serialized_tf_record_info.entity_key
feature_keys = serialized_tf_record_info.feature_keys
label_keys = serialized_tf_record_info.label_keys

# We make a deep copy of the feature spec dict so that future modifications don't redirect to the input

Expand Down Expand Up @@ -356,11 +370,16 @@ def load_as_torch_tensors(
if entity_type == FeatureTypes.NODE
else torch.empty(2, 0)
)
empty_feature = (
torch.empty(0, serialized_tf_record_info.feature_dim)
if feature_keys
else None
)
if label_keys and feature_keys:
empty_feature = torch.empty(
0, serialized_tf_record_info.feature_dim + len(label_keys)
Comment thread
mkolodner-sc marked this conversation as resolved.
)
elif label_keys and not feature_keys:
empty_feature = torch.empty(0, len(label_keys))
elif not label_keys and feature_keys:
empty_feature = torch.empty(0, serialized_tf_record_info.feature_dim)
else:
empty_feature = None
Comment thread
mkolodner-sc marked this conversation as resolved.
return empty_entity, empty_feature

dataset = TFRecordDataLoader._build_dataset_for_uris(
Expand All @@ -375,9 +394,9 @@ def load_as_torch_tensors(
feature_tensors = []
for idx, batch in enumerate(dataset):
id_tensors.append(proccess_id_tensor(batch))
if feature_keys:
if feature_keys or label_keys:
feature_tensors.append(
_concatenate_features_by_names(batch, feature_keys)
_concatenate_features_by_names(batch, feature_keys, label_keys)
)
num_entities_processed += (
id_tensors[-1].shape[0]
Expand Down
94 changes: 89 additions & 5 deletions python/gigl/distributed/dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
DatasetFactory is responsible for building and returning a DistLinkPredictionDataset class or subclass. It does this by spawning a
process which initializes rpc + worker group, loads and builds a partitioned dataset, and shuts down the rpc + worker group.
"""
import gc
import time
from collections import abc
from collections.abc import Mapping
from distutils.util import strtobool
from typing import Literal, MutableMapping, Optional, Tuple, Type, Union

Expand All @@ -19,7 +20,7 @@
)

from gigl.common import Uri, UriFactory
from gigl.common.data.dataloaders import TFRecordDataLoader
from gigl.common.data.dataloaders import SerializedTFRecordInfo, TFRecordDataLoader
from gigl.common.data.load_torch_tensors import (
SerializedGraphMetadata,
TFDatasetOptions,
Expand All @@ -40,9 +41,9 @@
from gigl.distributed.utils.serialized_graph_metadata_translator import (
convert_pb_to_serialized_graph_metadata,
)
from gigl.src.common.types.graph_data import EdgeType
from gigl.src.common.types.graph_data import EdgeType, NodeType
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.types.graph import DEFAULT_HOMOGENEOUS_EDGE_TYPE
from gigl.types.graph import DEFAULT_HOMOGENEOUS_EDGE_TYPE, FeaturePartitionData
from gigl.utils.data_splitters import (
HashedNodeAnchorLinkSplitter,
NodeAnchorLinkSplitter,
Expand All @@ -52,6 +53,39 @@
logger = Logger()


def _get_labels_from_features(
Comment thread
mkolodner-sc marked this conversation as resolved.
feature_and_label_tensor: torch.Tensor, label_dim: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Given a combined tensor of features and labels, returns the features and labels separately.
Args:
feature_and_label_tensor (torch.Tensor): Tensor of features and labels
label_dim (int): Dimension of the labels
Returns:
feature_tensor (torch.Tensor): Tensor of features
label_tensor (torch.Tensor): Tensor of labels
"""

if len(feature_and_label_tensor.shape) != 2:
raise ValueError(
f"Expected tensor to be 2D for extracting labels, but got shape {feature_and_label_tensor.shape}"
)

_, feature_and_label_dim = feature_and_label_tensor.shape

if label_dim > feature_and_label_dim:
raise ValueError(
f"Got invalid label dim {label_dim} for extracting labels from tensor of shape {feature_and_label_tensor.shape}"
)

feature_dim = feature_and_label_dim - label_dim

return (
feature_and_label_tensor[:, :feature_dim],
feature_and_label_tensor[:, feature_dim:],
)


@tf_on_cpu
def _load_and_build_partitioned_dataset(
Comment thread
mkolodner-sc marked this conversation as resolved.
serialized_graph_metadata: SerializedGraphMetadata,
Expand Down Expand Up @@ -112,7 +146,7 @@ def _load_and_build_partitioned_dataset(
"Cannot have loaded positive and negative labels when attempting to select self-supervised positive edges from edge index."
)
positive_label_edges: Union[torch.Tensor, dict[EdgeType, torch.Tensor]]
if isinstance(loaded_graph_tensors.edge_index, abc.Mapping):
if isinstance(loaded_graph_tensors.edge_index, Mapping):
# This assert is required while `select_ssl_positive_label_edges` exists out of any splitter. Once this is in transductive splitter,
# we can remove this assert.
assert isinstance(
Expand Down Expand Up @@ -192,6 +226,56 @@ def _load_and_build_partitioned_dataset(

partition_output = partitioner.partition()

Comment thread
mkolodner-sc marked this conversation as resolved.
node_labels: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]] = None
if isinstance(partition_output.partitioned_node_features, Mapping):
node_labels = {}
Comment thread
mkolodner-sc marked this conversation as resolved.
for (
node_type,
node_feature,
) in partition_output.partitioned_node_features.items():
# serialized_graph_metadata can be heterogeneous for a heterogeneous node features
if isinstance(serialized_graph_metadata.node_entity_info, Mapping):
label_dim = len(
serialized_graph_metadata.node_entity_info[node_type].label_keys
)
# serialized_graph_metadata can be homogeneous for a heterogeneous node features,
# since we inject positive and negative label types as edges
else:
label_dim = len(serialized_graph_metadata.node_entity_info.label_keys)
if label_dim > 0:
node_features, node_labels[node_type] = _get_labels_from_features(
node_feature.feats, label_dim=label_dim
)
partition_output.partitioned_node_features[
node_type
] = FeaturePartitionData(feats=node_features, ids=node_feature.ids)
del node_feature
gc.collect()

elif isinstance(partition_output.partitioned_node_features, FeaturePartitionData):
# serialized graph metadata must be homogeneous if partitioned node features is homogeneous
if not isinstance(
serialized_graph_metadata.node_entity_info, SerializedTFRecordInfo
):
raise ValueError(
f"Expected partitioned node features to be type SerializedTFRecordInfo, got {type(partition_output.partitioned_node_features)}"
)
Comment thread
mkolodner-sc marked this conversation as resolved.
label_dim = len(serialized_graph_metadata.node_entity_info.label_keys)
if label_dim > 0:
node_features, node_labels = _get_labels_from_features(
partition_output.partitioned_node_features.feats, label_dim=label_dim
)
partition_output.partitioned_node_features = FeaturePartitionData(
feats=node_features, ids=partition_output.partitioned_node_features.ids
)
gc.collect()
else:
raise ValueError(
f"Expected to have partitioned node features if labels are present, but got node features {partition_output.partitioned_node_features}"
)

# TODO (mkolodner-sc): Add node labels to the dataset
Comment thread
mkolodner-sc marked this conversation as resolved.

logger.info(
f"Initializing DistLinkPredictionDataset instance with edge direction {edge_dir}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,14 @@ def __post_init__(self):
condensed_node_type_to_feature_dim_map[
CondensedNodeType(condensed_node_type)
] = node_metadata.feature_dim
node_feature_keys = list(node_metadata.feature_keys)
node_feature_keys = sorted(list(node_metadata.feature_keys))
label_keys = sorted(list(node_metadata.label_keys))
Comment thread
mkolodner-sc marked this conversation as resolved.
node_feature_schema = self.__build_feature_schema(
schema_uri=UriFactory.create_uri(node_metadata.schema_uri),
transform_fn_assets_uri=UriFactory.create_uri(
node_metadata.transform_fn_assets_uri
),
feature_keys=node_feature_keys,
feature_keys=node_feature_keys + label_keys,
)
condensed_node_type_to_feature_schema_map[
CondensedNodeType(condensed_node_type)
Expand Down Expand Up @@ -380,12 +381,10 @@ def condensed_node_type_to_feature_keys_map(
dict[CondensedNodeType, list[str]]: A mapping which stores the feature keys of each CondensedNodeTypes
"""
return {
condensed_node_type: list(
self.condensed_node_type_to_feature_schema_map[
condensed_node_type
].feature_spec.keys()
CondensedNodeType(condensed_node_type): sorted(
list(preprocessed_metadata.feature_keys)
)
for condensed_node_type in self.condensed_node_type_to_feature_schema_map
for condensed_node_type, preprocessed_metadata in self.preprocessed_metadata_pb.condensed_node_type_to_preprocessed_metadata.items()
}

@property
Expand Down
Loading