diff --git a/CHANGELOG.md b/CHANGELOG.md index 306cc320b..073e0f37e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - Added support for NodeAnchorSplitter +- Add support for loading, partitioning, and separate node labels from features + ### Changed ### Deprecated diff --git a/python/gigl/common/data/dataloaders.py b/python/gigl/common/data/dataloaders.py index 26d86ecc6..581e831b0 100644 --- a/python/gigl/common/data/dataloaders.py +++ b/python/gigl/common/data/dataloaders.py @@ -1,6 +1,6 @@ import time from copy import deepcopy -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import partial from typing import Callable, Optional, Sequence, Tuple, Union @@ -29,12 +29,17 @@ class SerializedTFRecordInfo: tfrecord_uri_prefix: Uri # Feature names to load for the current entity feature_keys: Sequence[str] - # a dict of feature name -> FeatureSpec (eg. FixedLenFeature, VarlenFeature, SparseFeature, RaggedFeature). If entity keys are not present, we insert them during tensor loading + # A dict of feature name -> FeatureSpec (eg. FixedLenFeature, VarlenFeature, SparseFeature, RaggedFeature). + # If entity keys are not present, we insert them during tensor loading. For example, if the FeatureSpecDict + # doesn't have the "node_id" identifier, we populate the feature_spce with a FixedLenFeature with shape=[], dtype=tf.int64. + # Note that entity label keys should also be included in the feature_spec if they are present. feature_spec: FeatureSpecDict # Feature dimension of current entity feature_dim: int # Entity ID Key for current entity. If this is a Node Entity, this must be a string. If this is an edge entity, this must be a Tuple[str, str] for the source and destination ids. entity_key: Union[str, Tuple[str, str]] + # Name of the label columns for the current entity, defaults to an empty list. + label_keys: Sequence[str] = field(default_factory=list) # The regex pattern to match the TFRecord files at the specified prefix tfrecord_uri_pattern: str = ".*-of-.*\.tfrecord(\.gz)?$" @@ -79,23 +84,31 @@ class TFDatasetOptions: def _concatenate_features_by_names( feature_key_to_tf_tensor: dict[str, tf.Tensor], feature_keys: Sequence[str], + label_keys: Sequence[str], ) -> tf.Tensor: """ Concatenates feature tensors in the order specified by feature names. + Also concatenates labels to the end of the feature list if they are present using the corresponding label key It is assumed that feature_names is a subset of the keys in feature_name_to_tf_tensor. Args: feature_key_to_tf_tensor (dict[str, tf.Tensor]): A dictionary mapping feature names to their corresponding tf tensors. - feature_keys (list[str]): A list of feature names specifying the order in which tensors should be concatenated. + feature_keys (Sequence[str]): A list of feature names specifying the order in which tensors should be concatenated. + label_keys (Sequence[str]): Name of the label columns for the current entity. Returns: - tf.Tensor: A concatenated tensor of the features in the specified order. + tf.Tensor: A concatenated tensor of the features in the specified order, with the labels being concatenated at the end if it exists """ features: list[tf.Tensor] = [] - for feature_key in feature_keys: + feature_iterable = list(feature_keys) + + for label_key in label_keys: + feature_iterable.append(label_key) + + for feature_key in feature_iterable: tensor = feature_key_to_tf_tensor[feature_key] # TODO(kmonte, xgao, zfan): We will need to add support for this if we're trying to scale up. @@ -298,6 +311,7 @@ def load_as_torch_tensors( """ entity_key = serialized_tf_record_info.entity_key feature_keys = serialized_tf_record_info.feature_keys + label_keys = serialized_tf_record_info.label_keys # We make a deep copy of the feature spec dict so that future modifications don't redirect to the input @@ -356,11 +370,16 @@ def load_as_torch_tensors( if entity_type == FeatureTypes.NODE else torch.empty(2, 0) ) - empty_feature = ( - torch.empty(0, serialized_tf_record_info.feature_dim) - if feature_keys - else None - ) + if label_keys and feature_keys: + empty_feature = torch.empty( + 0, serialized_tf_record_info.feature_dim + len(label_keys) + ) + elif label_keys and not feature_keys: + empty_feature = torch.empty(0, len(label_keys)) + elif not label_keys and feature_keys: + empty_feature = torch.empty(0, serialized_tf_record_info.feature_dim) + else: + empty_feature = None return empty_entity, empty_feature dataset = TFRecordDataLoader._build_dataset_for_uris( @@ -375,9 +394,9 @@ def load_as_torch_tensors( feature_tensors = [] for idx, batch in enumerate(dataset): id_tensors.append(proccess_id_tensor(batch)) - if feature_keys: + if feature_keys or label_keys: feature_tensors.append( - _concatenate_features_by_names(batch, feature_keys) + _concatenate_features_by_names(batch, feature_keys, label_keys) ) num_entities_processed += ( id_tensors[-1].shape[0] diff --git a/python/gigl/distributed/dataset_factory.py b/python/gigl/distributed/dataset_factory.py index e6e0358eb..0dbfe2e4d 100644 --- a/python/gigl/distributed/dataset_factory.py +++ b/python/gigl/distributed/dataset_factory.py @@ -2,8 +2,9 @@ DatasetFactory is responsible for building and returning a DistLinkPredictionDataset class or subclass. It does this by spawning a process which initializes rpc + worker group, loads and builds a partitioned dataset, and shuts down the rpc + worker group. """ +import gc import time -from collections import abc +from collections.abc import Mapping from distutils.util import strtobool from typing import Literal, MutableMapping, Optional, Tuple, Type, Union @@ -19,7 +20,7 @@ ) from gigl.common import Uri, UriFactory -from gigl.common.data.dataloaders import TFRecordDataLoader +from gigl.common.data.dataloaders import SerializedTFRecordInfo, TFRecordDataLoader from gigl.common.data.load_torch_tensors import ( SerializedGraphMetadata, TFDatasetOptions, @@ -40,9 +41,9 @@ from gigl.distributed.utils.serialized_graph_metadata_translator import ( convert_pb_to_serialized_graph_metadata, ) -from gigl.src.common.types.graph_data import EdgeType +from gigl.src.common.types.graph_data import EdgeType, NodeType from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper -from gigl.types.graph import DEFAULT_HOMOGENEOUS_EDGE_TYPE +from gigl.types.graph import DEFAULT_HOMOGENEOUS_EDGE_TYPE, FeaturePartitionData from gigl.utils.data_splitters import ( HashedNodeAnchorLinkSplitter, NodeAnchorLinkSplitter, @@ -52,6 +53,39 @@ logger = Logger() +def _get_labels_from_features( + feature_and_label_tensor: torch.Tensor, label_dim: int +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Given a combined tensor of features and labels, returns the features and labels separately. + Args: + feature_and_label_tensor (torch.Tensor): Tensor of features and labels + label_dim (int): Dimension of the labels + Returns: + feature_tensor (torch.Tensor): Tensor of features + label_tensor (torch.Tensor): Tensor of labels + """ + + if len(feature_and_label_tensor.shape) != 2: + raise ValueError( + f"Expected tensor to be 2D for extracting labels, but got shape {feature_and_label_tensor.shape}" + ) + + _, feature_and_label_dim = feature_and_label_tensor.shape + + if label_dim > feature_and_label_dim: + raise ValueError( + f"Got invalid label dim {label_dim} for extracting labels from tensor of shape {feature_and_label_tensor.shape}" + ) + + feature_dim = feature_and_label_dim - label_dim + + return ( + feature_and_label_tensor[:, :feature_dim], + feature_and_label_tensor[:, feature_dim:], + ) + + @tf_on_cpu def _load_and_build_partitioned_dataset( serialized_graph_metadata: SerializedGraphMetadata, @@ -112,7 +146,7 @@ def _load_and_build_partitioned_dataset( "Cannot have loaded positive and negative labels when attempting to select self-supervised positive edges from edge index." ) positive_label_edges: Union[torch.Tensor, dict[EdgeType, torch.Tensor]] - if isinstance(loaded_graph_tensors.edge_index, abc.Mapping): + if isinstance(loaded_graph_tensors.edge_index, Mapping): # This assert is required while `select_ssl_positive_label_edges` exists out of any splitter. Once this is in transductive splitter, # we can remove this assert. assert isinstance( @@ -192,6 +226,56 @@ def _load_and_build_partitioned_dataset( partition_output = partitioner.partition() + node_labels: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]] = None + if isinstance(partition_output.partitioned_node_features, Mapping): + node_labels = {} + for ( + node_type, + node_feature, + ) in partition_output.partitioned_node_features.items(): + # serialized_graph_metadata can be heterogeneous for a heterogeneous node features + if isinstance(serialized_graph_metadata.node_entity_info, Mapping): + label_dim = len( + serialized_graph_metadata.node_entity_info[node_type].label_keys + ) + # serialized_graph_metadata can be homogeneous for a heterogeneous node features, + # since we inject positive and negative label types as edges + else: + label_dim = len(serialized_graph_metadata.node_entity_info.label_keys) + if label_dim > 0: + node_features, node_labels[node_type] = _get_labels_from_features( + node_feature.feats, label_dim=label_dim + ) + partition_output.partitioned_node_features[ + node_type + ] = FeaturePartitionData(feats=node_features, ids=node_feature.ids) + del node_feature + gc.collect() + + elif isinstance(partition_output.partitioned_node_features, FeaturePartitionData): + # serialized graph metadata must be homogeneous if partitioned node features is homogeneous + if not isinstance( + serialized_graph_metadata.node_entity_info, SerializedTFRecordInfo + ): + raise ValueError( + f"Expected partitioned node features to be type SerializedTFRecordInfo, got {type(partition_output.partitioned_node_features)}" + ) + label_dim = len(serialized_graph_metadata.node_entity_info.label_keys) + if label_dim > 0: + node_features, node_labels = _get_labels_from_features( + partition_output.partitioned_node_features.feats, label_dim=label_dim + ) + partition_output.partitioned_node_features = FeaturePartitionData( + feats=node_features, ids=partition_output.partitioned_node_features.ids + ) + gc.collect() + else: + raise ValueError( + f"Expected to have partitioned node features if labels are present, but got node features {partition_output.partitioned_node_features}" + ) + + # TODO (mkolodner-sc): Add node labels to the dataset + logger.info( f"Initializing DistLinkPredictionDataset instance with edge direction {edge_dir}" ) diff --git a/python/gigl/src/common/types/pb_wrappers/preprocessed_metadata.py b/python/gigl/src/common/types/pb_wrappers/preprocessed_metadata.py index 5cfe84412..0d751bde5 100644 --- a/python/gigl/src/common/types/pb_wrappers/preprocessed_metadata.py +++ b/python/gigl/src/common/types/pb_wrappers/preprocessed_metadata.py @@ -68,13 +68,14 @@ def __post_init__(self): condensed_node_type_to_feature_dim_map[ CondensedNodeType(condensed_node_type) ] = node_metadata.feature_dim - node_feature_keys = list(node_metadata.feature_keys) + node_feature_keys = sorted(list(node_metadata.feature_keys)) + label_keys = sorted(list(node_metadata.label_keys)) node_feature_schema = self.__build_feature_schema( schema_uri=UriFactory.create_uri(node_metadata.schema_uri), transform_fn_assets_uri=UriFactory.create_uri( node_metadata.transform_fn_assets_uri ), - feature_keys=node_feature_keys, + feature_keys=node_feature_keys + label_keys, ) condensed_node_type_to_feature_schema_map[ CondensedNodeType(condensed_node_type) @@ -380,12 +381,10 @@ def condensed_node_type_to_feature_keys_map( dict[CondensedNodeType, list[str]]: A mapping which stores the feature keys of each CondensedNodeTypes """ return { - condensed_node_type: list( - self.condensed_node_type_to_feature_schema_map[ - condensed_node_type - ].feature_spec.keys() + CondensedNodeType(condensed_node_type): sorted( + list(preprocessed_metadata.feature_keys) ) - for condensed_node_type in self.condensed_node_type_to_feature_schema_map + for condensed_node_type, preprocessed_metadata in self.preprocessed_metadata_pb.condensed_node_type_to_preprocessed_metadata.items() } @property diff --git a/python/tests/unit/common/data/dataloaders_test.py b/python/tests/unit/common/data/dataloaders_test.py index f4d34a075..58e66d6b3 100644 --- a/python/tests/unit/common/data/dataloaders_test.py +++ b/python/tests/unit/common/data/dataloaders_test.py @@ -14,17 +14,29 @@ TFDatasetOptions, TFRecordDataLoader, ) +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.data_preprocessor.lib.types import FeatureSpecDict +from gigl.src.mocking.lib.versioning import ( + MockedDatasetArtifactMetadata, + get_mocked_dataset_artifact_metadata, +) +from gigl.src.mocking.mocking_assets.mocked_datasets_for_pipeline_tests import ( + CORA_NODE_CLASSIFICATION_MOCKED_DATASET_INFO, +) _FEATURE_SPEC_WITH_ENTITY_KEY: FeatureSpecDict = { "node_id": tf.io.FixedLenFeature([], tf.int64), "feature_0": tf.io.FixedLenFeature([], tf.float32), "feature_1": tf.io.FixedLenFeature([], tf.float32), + "label_0": tf.io.FixedLenFeature([], tf.int64), + "label_1": tf.io.FixedLenFeature([], tf.int64), } _FEATURE_SPEC_WITHOUT_ENTITY_KEY: FeatureSpecDict = { "feature_0": tf.io.FixedLenFeature([], tf.float32), "feature_1": tf.io.FixedLenFeature([], tf.float32), + "label_0": tf.io.FixedLenFeature([], tf.int64), + "label_1": tf.io.FixedLenFeature([], tf.int64), } @@ -49,6 +61,12 @@ def _get_mock_node_examples() -> list[tf.train.Example]: "feature_1": tf.train.Feature( float_list=tf.train.FloatList(value=[i * 0.1]) ), + "label_0": tf.train.Feature( + int64_list=tf.train.Int64List(value=[i % 2]) + ), + "label_1": tf.train.Feature( + int64_list=tf.train.Int64List(value=[i % 3]) + ), } ) ) @@ -62,6 +80,7 @@ def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() self.data_dir = Path(self.temp_dir.name) + # Create standard examples without labels examples = _get_mock_node_examples() with tf.io.TFRecordWriter(str(self.data_dir / "100.tfrecord")) as writer: for example in examples: @@ -74,18 +93,20 @@ def tearDown(self): @parameterized.expand( [ param( - "No features", + "No features, no labels", feature_spec=_FEATURE_SPEC_WITH_ENTITY_KEY, feature_keys=[], feature_dim=0, + label_keys=[], expected_id_tensor=torch.tensor(range(100)), expected_feature_tensor=None, ), param( - "One feature", + "One feature, no labels", feature_spec=_FEATURE_SPEC_WITH_ENTITY_KEY, feature_keys=["feature_0"], - feature_dim=0, + feature_dim=1, + label_keys=[], expected_id_tensor=torch.tensor(range(100)), expected_feature_tensor=torch.tensor( range(100), dtype=torch.float32 @@ -93,10 +114,11 @@ def tearDown(self): * 10, ), param( - "Two features", + "Two features, no labels", feature_spec=_FEATURE_SPEC_WITH_ENTITY_KEY, feature_keys=["feature_0", "feature_1"], - feature_dim=0, + feature_dim=2, + label_keys=[], expected_id_tensor=torch.tensor(range(100)), expected_feature_tensor=torch.concat( ( @@ -109,10 +131,11 @@ def tearDown(self): ), ), param( - "Two features, no entity key in feature schema", + "Two features, no entity key in feature schema, no labels", feature_spec=_FEATURE_SPEC_WITHOUT_ENTITY_KEY, feature_keys=["feature_0", "feature_1"], - feature_dim=0, + feature_dim=2, + label_keys=[], expected_id_tensor=torch.tensor(range(100)), expected_feature_tensor=torch.concat( ( @@ -124,6 +147,134 @@ def tearDown(self): dim=1, ), ), + param( + "Two features with labels", + feature_spec=_FEATURE_SPEC_WITH_ENTITY_KEY, + feature_keys=["feature_0", "feature_1"], + feature_dim=2, + label_keys=["label_0"], + expected_id_tensor=torch.tensor(range(100)), + expected_feature_tensor=torch.concat( + ( + torch.tensor(range(100), dtype=torch.float32).reshape(100, 1) + * 10, # feature_0 + torch.tensor(range(100), dtype=torch.float32).reshape(100, 1) + * 0.1, # feature_1 + torch.tensor( + [i % 2 for i in range(100)], dtype=torch.float32 + ).reshape( + 100, 1 + ), # label_0 + ), + dim=1, + ), + ), + param( + "One feature with labels", + feature_spec=_FEATURE_SPEC_WITH_ENTITY_KEY, + feature_keys=["feature_0"], + feature_dim=1, + label_keys=["label_0"], + expected_id_tensor=torch.tensor(range(100)), + expected_feature_tensor=torch.concat( + ( + torch.tensor(range(100), dtype=torch.float32).reshape(100, 1) + * 10, # feature_0 + torch.tensor( + [i % 2 for i in range(100)], dtype=torch.float32 + ).reshape( + 100, 1 + ), # label_0 + ), + dim=1, + ), + ), + param( + "Only labels, no features", + feature_spec=_FEATURE_SPEC_WITH_ENTITY_KEY, + feature_keys=[], + feature_dim=0, + label_keys=["label_0"], + expected_id_tensor=torch.tensor(range(100)), + expected_feature_tensor=torch.tensor( + [i % 2 for i in range(100)], dtype=torch.float32 + ).reshape(100, 1), + ), + param( + "Two features with two labels", + feature_spec=_FEATURE_SPEC_WITH_ENTITY_KEY, + feature_keys=["feature_0", "feature_1"], + feature_dim=2, + label_keys=["label_0", "label_1"], + expected_id_tensor=torch.tensor(range(100)), + expected_feature_tensor=torch.concat( + ( + torch.tensor(range(100), dtype=torch.float32).reshape(100, 1) + * 10, # feature_0 + torch.tensor(range(100), dtype=torch.float32).reshape(100, 1) + * 0.1, # feature_1 + torch.tensor( + [i % 2 for i in range(100)], dtype=torch.float32 + ).reshape( + 100, 1 + ), # label_0 + torch.tensor( + [i % 3 for i in range(100)], dtype=torch.float32 + ).reshape( + 100, 1 + ), # label_1 + ), + dim=1, + ), + ), + param( + "One feature with two labels", + feature_spec=_FEATURE_SPEC_WITH_ENTITY_KEY, + feature_keys=["feature_0"], + feature_dim=1, + label_keys=["label_0", "label_1"], + expected_id_tensor=torch.tensor(range(100)), + expected_feature_tensor=torch.concat( + ( + torch.tensor(range(100), dtype=torch.float32).reshape(100, 1) + * 10, # feature_0 + torch.tensor( + [i % 2 for i in range(100)], dtype=torch.float32 + ).reshape( + 100, 1 + ), # label_0 + torch.tensor( + [i % 3 for i in range(100)], dtype=torch.float32 + ).reshape( + 100, 1 + ), # label_1 + ), + dim=1, + ), + ), + param( + "Only two labels, no features", + feature_spec=_FEATURE_SPEC_WITH_ENTITY_KEY, + feature_keys=[], + feature_dim=0, + label_keys=["label_0", "label_1"], + expected_id_tensor=torch.tensor(range(100)), + expected_feature_tensor=torch.concat( + ( + torch.tensor( + [i % 2 for i in range(100)], dtype=torch.float32 + ).reshape( + 100, 1 + ), # label_0 + torch.tensor( + [i % 3 for i in range(100)], dtype=torch.float32 + ).reshape( + 100, 1 + ), # label_1 + ), + dim=1, + ), + ), ] ) def test_load_as_torch_tensors( @@ -132,9 +283,11 @@ def test_load_as_torch_tensors( feature_spec: FeatureSpecDict, feature_keys: list[str], feature_dim: int, + label_keys: list[str], expected_id_tensor: torch.Tensor, expected_feature_tensor: Optional[torch.Tensor], ): + """Test TFRecordDataLoader's ability to load features and optionally labels.""" loader = TFRecordDataLoader(rank=0, world_size=1) node_ids, feature_tensor = loader.load_as_torch_tensors( serialized_tf_record_info=SerializedTFRecordInfo( @@ -143,13 +296,16 @@ def test_load_as_torch_tensors( feature_keys=feature_keys, feature_dim=feature_dim, entity_key="node_id", + label_keys=label_keys, tfrecord_uri_pattern="100.tfrecord", ), tf_dataset_options=TFDatasetOptions(deterministic=True), ) + # Verify entity IDs are loaded correctly assert_close(node_ids, expected_id_tensor) + # Verify feature tensor (which includes labels concatenated at the end if label_keys are specified) assert_close(feature_tensor, expected_feature_tensor) def test_build_dataset_for_uris(self): @@ -196,6 +352,24 @@ def test_build_dataset_for_uris(self): expected_features=torch.empty(0, 3), entity_key=("src_node_id", "dst_node_id"), ), + param( + "node_with_label_only", + feature_keys=[], + feature_dim=0, + expected_node_ids=torch.empty(0), + expected_features=torch.empty(0, 1), # 1 label + entity_key="node_id", + label_keys=["label"], + ), + param( + "node_with_features_and_label", + feature_keys=["foo_feature"], + feature_dim=1, + expected_node_ids=torch.empty(0), + expected_features=torch.empty(0, 2), # 1 feature + 1 label + entity_key="node_id", + label_keys=["label"], + ), ] ) def test_load_empty_directory( @@ -206,6 +380,7 @@ def test_load_empty_directory( expected_node_ids: torch.Tensor, expected_features: Optional[torch.Tensor], entity_key: Union[str, Tuple[str, str]], + label_keys: list[str] = [], ): temp_dir = tempfile.TemporaryDirectory() self.addCleanup(temp_dir.cleanup) @@ -218,7 +393,7 @@ def test_load_empty_directory( feature_keys=feature_keys, feature_dim=feature_dim, entity_key=entity_key, - tfrecord_uri_pattern=".tfrecord", + label_keys=label_keys, ), tf_dataset_options=TFDatasetOptions(deterministic=True), ) @@ -264,3 +439,49 @@ def test_partition( [u.uri for u in uris], [str(path / f"{i:0>2}.tfrecord") for i in expected], ) + + def test_load_labels_from_pb(self): + mocked_dataset_artifact_metadata: MockedDatasetArtifactMetadata = ( + get_mocked_dataset_artifact_metadata()[ + CORA_NODE_CLASSIFICATION_MOCKED_DATASET_INFO.name + ] + ) + gbml_config_pb_wrapper = ( + GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( + gbml_config_uri=mocked_dataset_artifact_metadata.frozen_gbml_config_uri + ) + ) + preprocessed_metadata_pb_wrapper = ( + gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper + ) + condensed_node_type = ( + gbml_config_pb_wrapper.graph_metadata_pb_wrapper.homogeneous_condensed_node_type + ) + node_metadata = preprocessed_metadata_pb_wrapper.preprocessed_metadata_pb.condensed_node_type_to_preprocessed_metadata[ + condensed_node_type + ] + loader = TFRecordDataLoader(rank=0, world_size=1) + _, feature_tensor = loader.load_as_torch_tensors( + serialized_tf_record_info=SerializedTFRecordInfo( + tfrecord_uri_prefix=UriFactory.create_uri( + node_metadata.tfrecord_uri_prefix + ), + feature_spec=preprocessed_metadata_pb_wrapper.condensed_node_type_to_feature_schema_map[ + condensed_node_type + ].feature_spec, + feature_keys=node_metadata.feature_keys, + feature_dim=node_metadata.feature_dim, + entity_key=node_metadata.node_id_key, + label_keys=node_metadata.label_keys, + tfrecord_uri_pattern=".*\.tfrecord$", + ), + tf_dataset_options=TFDatasetOptions(deterministic=True), + ) + # Ensure we have loaded data + assert feature_tensor is not None + self.assertGreater(feature_tensor.size(0), 0) + # Ensure labels have been added as an additional dimension to the features + self.assertEqual( + feature_tensor.size(1), + node_metadata.feature_dim + len(node_metadata.label_keys), + ) diff --git a/python/tests/unit/distributed/dataset_factory_test.py b/python/tests/unit/distributed/dataset_factory_test.py index 1fc7a6294..325452be5 100644 --- a/python/tests/unit/distributed/dataset_factory_test.py +++ b/python/tests/unit/distributed/dataset_factory_test.py @@ -1,15 +1,20 @@ import unittest from collections import abc +import torch from parameterized import param, parameterized -from gigl.distributed.dataset_factory import build_dataset_from_task_config_uri +from gigl.distributed.dataset_factory import ( + _get_labels_from_features, + build_dataset_from_task_config_uri, +) from gigl.distributed.dist_context import DistributedContext from gigl.src.mocking.lib.versioning import get_mocked_dataset_artifact_metadata from gigl.src.mocking.mocking_assets.mocked_datasets_for_pipeline_tests import ( CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO, ) from gigl.types.graph import DEFAULT_HOMOGENEOUS_NODE_TYPE +from tests.test_assets.distributed.utils import assert_tensor_equality # TODO(kmonte, mkolodner): Add more tests for heterogeneous datasets. @@ -61,6 +66,64 @@ def test_build_dataset_from_task_config_uri_homogeneous( dataset.val_node_ids.keys() == set([DEFAULT_HOMOGENEOUS_NODE_TYPE]) ) + @parameterized.expand( + [ + param( + "Basic test with label_dim=1", + feature_and_label_tensor=torch.tensor( + [ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0], + ] + ), + label_dim=1, + expected_features=torch.tensor( + [[1.0, 2.0, 3.0], [5.0, 6.0, 7.0], [9.0, 10.0, 11.0]] + ), + expected_labels=torch.tensor([[4.0], [8.0], [12.0]]), + ), + param( + "Test with label_dim=2", + feature_and_label_tensor=torch.tensor( + [[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]] + ), + label_dim=2, + expected_features=torch.tensor([[1.0, 2.0, 3.0], [6.0, 7.0, 8.0]]), + expected_labels=torch.tensor([[4.0, 5.0], [9.0, 10.0]]), + ), + param( + "Test with single feature column", + feature_and_label_tensor=torch.tensor( + [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] + ), + label_dim=1, + expected_features=torch.tensor([[1.0], [3.0], [5.0]]), + expected_labels=torch.tensor([[2.0], [4.0], [6.0]]), + ), + param( + "Test with no features and labels", + feature_and_label_tensor=torch.tensor([[3.0], [6.0]]), + label_dim=1, + expected_features=torch.empty((2, 0)), + expected_labels=torch.tensor([[3.0], [6.0]]), + ), + ] + ) + def test_get_labels_from_features( + self, + _, + feature_and_label_tensor: torch.Tensor, + label_dim: int, + expected_features: torch.Tensor, + expected_labels: torch.Tensor, + ): + features, labels = _get_labels_from_features( + feature_and_label_tensor, label_dim + ) + assert_tensor_equality(features, expected_features) + assert_tensor_equality(labels, expected_labels) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/unit/distributed/dataset_input_metadata_translator_test.py b/python/tests/unit/distributed/dataset_input_metadata_translator_test.py index 870f7d5f8..59af50114 100644 --- a/python/tests/unit/distributed/dataset_input_metadata_translator_test.py +++ b/python/tests/unit/distributed/dataset_input_metadata_translator_test.py @@ -152,10 +152,12 @@ def test_translator_correctness(self, _, mocked_dataset_info: MockedDatasetInfo) ].tfrecord_uri_prefix, ) self.assertEqual( - seralized_node_info.feature_keys, - preprocessed_metadata_pb_wrapper.condensed_node_type_to_feature_keys_map[ - condensed_node_type - ], + sorted(seralized_node_info.feature_keys), + ( + preprocessed_metadata_pb_wrapper.condensed_node_type_to_feature_keys_map[ + condensed_node_type + ] + ), ) self.assertEqual( seralized_node_info.feature_spec,