Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions python/gigl/src/common/types/pb_wrappers/gbml_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,9 @@ def should_populate_predictions_path(self) -> bool:
"""
Allows access to should_populate_predictions_path under GbmlConfig

This flag defaults to False, but can be set to True to populate the predictions path in the InferenceOutput for each entity type.
Comment thread
mkolodner-sc marked this conversation as resolved.
We default this to False since config populator currently does not populate the predictions path for link prediction tasks.

This flag is a temporary workaround to populate the extra embeddings for the same entity type

Returns:
Expand All @@ -481,3 +484,23 @@ def should_populate_predictions_path(self) -> bool:
)
)
)

@property
def should_populate_embeddings_path(self) -> bool:
"""
Allows access to should_populate_embeddings_path under GbmlConfig

This flag defaults to True, but can be set to False to skip populating embedding paths in InferenceOutput.
Comment thread
mkolodner-sc marked this conversation as resolved.
We default this to True since config populator currently always by default populates the embeddings path
for both link prediction and node classification tasks.

Returns:
bool: Whether to populate embeddings path in the InferenceOutput for each entity type
"""
return bool(
strtobool(
dict(self.gbml_config_pb.feature_flags).get(
"should_populate_embeddings_path", "True"
)
)
)
13 changes: 8 additions & 5 deletions python/gigl/src/config_populator/config_populator.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,18 +336,21 @@ def __populate_inference_metadata_pb(
gbml_config_pb=self.template_gbml_config
)
for node_type in inferencer_node_types:
embeddings_path = bq_constants.get_embeddings_table(
applied_task_identifier=self.applied_task_identifier,
node_type=node_type,
)
embeddings_path: Optional[str] = None
predictions_path: Optional[str] = None

if template_gbml_config_pb_wrapper.should_populate_embeddings_path:
embeddings_path = bq_constants.get_embeddings_table(
applied_task_identifier=self.applied_task_identifier,
node_type=node_type,
)

if (
self.task_metadata_pb_wrapper.task_metadata_type
== TaskMetadataType.NODE_BASED_TASK
or template_gbml_config_pb_wrapper.should_populate_predictions_path
):
# TODO: currently, we are overloading the predictions path to store extra embeddings.
# TODO: currently, we are overloading the predictions path to store extra embeddings for link prediction.
# consider extending InferenceOutput's definition for this purpose.
predictions_path = bq_constants.get_predictions_table(
applied_task_identifier=self.applied_task_identifier,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import unittest

from parameterized import param, parameterized

from gigl.common.logger import Logger
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.types.pb_wrappers.dataset_metadata import DatasetMetadataPbWrapper
Expand Down Expand Up @@ -126,12 +128,27 @@ def test_sgs_config_population_is_accurate(
"",
)

def test_glt_config_population_is_accurate(self):
@parameterized.expand(
[
param(
should_populate_embeddings_path=True,
),
param(
should_populate_embeddings_path=False,
),
]
)
def test_glt_config_population_is_accurate(
self, should_populate_embeddings_path: bool
):
config_populator = ConfigPopulator()
frozen_gbml_config_pb = config_populator._populate_frozen_gbml_config_pb(
applied_task_identifier=self.applied_task_identifier,
template_gbml_config_pb=_TEMPLATE_CONFIG_FOR_GLT,
)
frozen_gbml_config_pb.feature_flags["should_populate_embeddings_path"] = str(
should_populate_embeddings_path
)
gbml_config_pb_wrapper = GbmlConfigPbWrapper(
gbml_config_pb=frozen_gbml_config_pb
)
Expand All @@ -150,6 +167,10 @@ def test_glt_config_population_is_accurate(self):

# GBML Config Pb Wrapper Checks
self.assertTrue(gbml_config_pb_wrapper.should_use_glt_backend)
self.assertEqual(
gbml_config_pb_wrapper.should_populate_embeddings_path,
should_populate_embeddings_path,
)
with self.assertRaises(ValueError):
# We should expect to throw an error when accessing the flattened graph metadata pb wrapper when it is not set
gbml_config_pb_wrapper.flattened_graph_metadata_pb_wrapper
Expand Down Expand Up @@ -178,12 +199,20 @@ def test_glt_config_population_is_accurate(self):
isinstance(inference_metadata_pb, inference_metadata_pb2.InferenceMetadata)
)
for node_type in inference_metadata_pb.node_type_to_inferencer_output_info_map:
self.assertNotEqual(
inference_metadata_pb.node_type_to_inferencer_output_info_map[
node_type
].embeddings_path,
"",
)
if should_populate_embeddings_path:
self.assertNotEqual(
inference_metadata_pb.node_type_to_inferencer_output_info_map[
node_type
].embeddings_path,
"",
)
else:
self.assertEqual(
inference_metadata_pb.node_type_to_inferencer_output_info_map[
node_type
].embeddings_path,
"",
)
if (
gbml_config_pb_wrapper.task_metadata_pb_wrapper.task_metadata_type
== TaskMetadataType.NODE_BASED_TASK
Expand Down