Skip to content

Commit

Permalink
Fix SyntheticData.read_schema with proto text files (#191)
Browse files Browse the repository at this point in the history
As pointed out #184 (comment)
there was some issues with the SyntheticData.read_schema method. Fix
and add a basic unittest that would have caught this
  • Loading branch information
benfred committed Feb 22, 2022
1 parent 3133717 commit 4f96693
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
10 changes: 6 additions & 4 deletions merlin_models/data/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,13 @@ def from_schema(

@classmethod
def read_schema(cls, path: Union[str, Path]) -> Schema:
_schema_path = (
os.path.join(str(path), "schema.json") if os.path.isdir(str(path)) else str(path)
)
path = str(path)
_schema_path = os.path.join(path, "schema.json") if os.path.isdir(path) else path

if _schema_path.endswith(".pb") or _schema_path.endswith(".pbtxt"):
TensorflowMetadata.from_from_proto_text(_schema_path).to_merlin_schema()
return TensorflowMetadata.from_proto_text_file(
os.path.dirname(_schema_path), os.path.basename(_schema_path)
).to_merlin_schema()

return tensorflow_metadata_json_to_schema(_schema_path)

Expand Down
8 changes: 8 additions & 0 deletions tests/data/testing/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#
import pytest
from merlin.schema import Tags
from merlin.schema.io.tensorflow_metadata import TensorflowMetadata

from merlin_models.data.synthetic import SyntheticData
from merlin_models.utils.schema import filter_dict_by_schema
Expand Down Expand Up @@ -76,3 +77,10 @@ def test_torch_tensors_generation_cpu():
for val in filter_dict_by_schema(tensors, schema.select_by_tag(Tags.CATEGORICAL)).values():
assert val.dtype == torch.int64
assert val.max() < 52000


def test_synthetic_read_proto_text(tmpdir):
schema = SyntheticData("music_streaming").schema
TensorflowMetadata.from_merlin_schema(schema).to_proto_text_file(tmpdir, "schema.pbtxt")
reloaded = SyntheticData.read_schema(tmpdir / "schema.pbtxt")
assert schema == reloaded

0 comments on commit 4f96693

Please sign in to comment.