From 221e20079b5d20269e523fe4bcdb18930cfd80ed Mon Sep 17 00:00:00 2001 From: Karl Higley Date: Tue, 14 Mar 2023 10:29:54 -0400 Subject: [PATCH] Apply the `EMBEDDING` tag to the output of the embedding operators This will help us identify which continuous features are embeddings in the model input layers. --- .../dataloader/ops/embeddings/embedding_op.py | 4 ++-- tests/unit/dataloader/test_tf_embeddings.py | 24 +++++++++---------- .../unit/dataloader/test_torch_embeddings.py | 12 +++++----- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/merlin/dataloader/ops/embeddings/embedding_op.py b/merlin/dataloader/ops/embeddings/embedding_op.py index 8dc496a7..819dafb0 100644 --- a/merlin/dataloader/ops/embeddings/embedding_op.py +++ b/merlin/dataloader/ops/embeddings/embedding_op.py @@ -106,7 +106,7 @@ def compute_output_schema( col_schemas.append( ColumnSchema( name=self.embedding_name, - tags=[Tags.CONTINUOUS], + tags=[Tags.CONTINUOUS, Tags.EMBEDDING], dtype=self._get_dtype(self.embeddings), is_list=True, is_ragged=False, @@ -189,7 +189,7 @@ def compute_output_schema( col_schemas.append( ColumnSchema( name=self.embedding_name, - tags=[Tags.CONTINUOUS], + tags=[Tags.CONTINUOUS, Tags.EMBEDDING], dtype=self.embeddings.dtype, is_list=True, is_ragged=False, diff --git a/tests/unit/dataloader/test_tf_embeddings.py b/tests/unit/dataloader/test_tf_embeddings.py index 84d83e86..de3b2120 100644 --- a/tests/unit/dataloader/test_tf_embeddings.py +++ b/tests/unit/dataloader/test_tf_embeddings.py @@ -46,11 +46,11 @@ def test_embedding_tf_np_mmap_dl_no_lookup(tmpdir, embedding_ids, np_embeddings_ dataset = dataset.repartition(10) schema = dataset.schema for col_name in cat_names: - schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING]) dataset.schema = schema for col_name in cat_names: - schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING]) dataset.schema = schema data_loader = Loader( dataset, @@ -91,11 +91,11 @@ def test_embedding_tf_np_mmap_dl_with_lookup(tmpdir, rev_embedding_ids, np_embed dataset = dataset.repartition(10) schema = dataset.schema for col_name in cat_names: - schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING]) dataset.schema = schema for col_name in cat_names: - schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING]) dataset.schema = schema data_loader = Loader( dataset, @@ -125,11 +125,11 @@ def test_embedding_tf_np_dl_no_lookup(tmpdir, embedding_ids, embeddings_from_dat dataset = dataset.repartition(10) schema = dataset.schema for col_name in cat_names: - schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING]) dataset.schema = schema for col_name in cat_names: - schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING]) dataset.schema = schema paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*")) embeddings_ds = Dataset(paths) @@ -163,11 +163,11 @@ def test_embedding_tf_np_dl_with_lookup(tmpdir, rev_embedding_ids, embeddings_fr dataset = dataset.repartition(10) schema = dataset.schema for col_name in cat_names: - schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING]) dataset.schema = schema for col_name in cat_names: - schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING]) dataset.schema = schema paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*")) embeddings_ds = Dataset(paths) @@ -202,11 +202,11 @@ def test_embedding_tf_dl_no_lookup(tmpdir, embedding_ids, embeddings_from_datafr dataset = dataset.repartition(10) schema = dataset.schema for col_name in cat_names: - schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING]) dataset.schema = schema for col_name in cat_names: - schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING]) dataset.schema = schema paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*")) embeddings_ds = Dataset(paths) @@ -241,11 +241,11 @@ def test_embedding_tf_dl_with_lookup(tmpdir, rev_embedding_ids, embeddings_from_ dataset = dataset.repartition(10) schema = dataset.schema for col_name in cat_names: - schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING]) dataset.schema = schema for col_name in cat_names: - schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING]) dataset.schema = schema paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*")) embeddings_ds = Dataset(paths) diff --git a/tests/unit/dataloader/test_torch_embeddings.py b/tests/unit/dataloader/test_torch_embeddings.py index a735b3fc..268e658d 100644 --- a/tests/unit/dataloader/test_torch_embeddings.py +++ b/tests/unit/dataloader/test_torch_embeddings.py @@ -48,7 +48,7 @@ def test_embedding_torch_np_mmap_dl_with_lookup( dataset = dataset.repartition(10) schema = dataset.schema for col_name in cat_names: - schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING]) dataset.schema = schema data_loader = Loader( @@ -81,7 +81,7 @@ def test_embedding_torch_np_mmap_dl_no_lookup(tmpdir, embedding_ids, np_embeddin dataset = dataset.repartition(10) schema = dataset.schema for col_name in cat_names: - schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING]) dataset.schema = schema data_loader = Loader( @@ -115,7 +115,7 @@ def test_embedding_torch_np_dl_with_lookup( dataset = dataset.repartition(10) schema = dataset.schema for col_name in cat_names: - schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING]) dataset.schema = schema paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*")) embeddings_ds = Dataset(paths) @@ -150,7 +150,7 @@ def test_embedding_torch_np_dl_no_lookup(tmpdir, embedding_ids, embeddings_from_ dataset = dataset.repartition(10) schema = dataset.schema for col_name in cat_names: - schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING]) dataset.schema = schema paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*")) embeddings_ds = Dataset(paths) @@ -184,7 +184,7 @@ def test_embedding_torch_dl_with_lookup(tmpdir, rev_embedding_ids, embeddings_fr dataset = dataset.repartition(10) schema = dataset.schema for col_name in cat_names: - schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING]) dataset.schema = schema paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*")) embeddings_ds = Dataset(paths) @@ -218,7 +218,7 @@ def test_embedding_torch_dl_no_lookup(tmpdir, embedding_ids, embeddings_from_dat dataset = dataset.repartition(10) schema = dataset.schema for col_name in cat_names: - schema[col_name] = schema[col_name].with_tags(Tags.CATEGORICAL) + schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING]) dataset.schema = schema paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*")) embeddings_ds = Dataset(paths)