Skip to content

Commit

Permalink
Use merlin-dataloader package (#845)
Browse files Browse the repository at this point in the history
* Use merlin-dataloader package

* remove torch.dataset in favor of merlin.loader.torch

* update dressipi notebook

* minor clean up

* Completely removes models DataLoader

* Installs merlin-dataloader in github actions

* Adds back the stop method

* dataloader can produce sparse tensors using value counts

* remove data files

* fix torch tests

* add missing target to dlrm test

* use loader.peek()

* add some comments to help understand horovod tests

* make sparse tensors optional

* cleanup

* fix spelling

* fix merge

* replace while loop with for loop in horovod test

* use loader context mananger

* Update according to dataloader changes #80

* restore tox.ini

* restore gh workflow

* revert generator changes
  • Loading branch information
edknv committed Dec 9, 2022
1 parent e08a72c commit 60a9ca1
Show file tree
Hide file tree
Showing 37 changed files with 271 additions and 1,557 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/tensorflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ jobs:
fi
pip install "pandas>=1.2.0,<1.4.0dev0"
pip install "NVTabular@git+https://github.com/NVIDIA-Merlin/NVTabular.git@$branch"
pip install "merlin-dataloader@git+https://github.com/NVIDIA-Merlin/dataloader.git@$branch"
pip install "merlin-core@git+https://github.com/NVIDIA-Merlin/core.git@$branch"
- name: Install dependencies
run: |
Expand Down Expand Up @@ -108,6 +109,7 @@ jobs:
fi
pip install "pandas>=1.2.0,<1.4.0dev0"
pip install "NVTabular@git+https://github.com/NVIDIA-Merlin/NVTabular.git@$branch"
pip install "merlin-dataloader@git+https://github.com/NVIDIA-Merlin/dataloader.git@$branch"
pip install "merlin-core@git+https://github.com/NVIDIA-Merlin/core.git@$branch"
- name: Install dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion examples/05-Retrieval-Model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,7 @@
}
],
"source": [
"eval_loader = mm.Loader(valid, batch_size=1024, transform=mm.ToTarget(schema, \"item_id\"))\n",
"eval_loader = mm.Loader(valid, batch_size=1024).map(mm.ToTarget(schema, \"item_id\"))\n",
"\n",
"metrics = topk_model.evaluate(eval_loader, return_dict=True)\n",
"metrics"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -968,8 +968,8 @@
"metadata": {},
"outputs": [],
"source": [
"loader = mm.Loader(train, batch_size=BATCH_SIZE, transform=mm.ToTarget(train.schema, \"purchase_id_first\", one_hot=True), shuffle = False)\n",
"val_loader = mm.Loader(valid, batch_size=BATCH_SIZE, transform=mm.ToTarget(train.schema, \"purchase_id_first\", one_hot=True), shuffle=False)"
"loader = mm.Loader(train, batch_size=BATCH_SIZE, shuffle=False).map(mm.ToTarget(train.schema, \"purchase_id_first\", one_hot=True))\n",
"val_loader = mm.Loader(valid, batch_size=BATCH_SIZE, shuffle=False).map(mm.ToTarget(train.schema, \"purchase_id_first\", one_hot=True))"
]
},
{
Expand Down
17 changes: 16 additions & 1 deletion merlin/datasets/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import merlin.io
from merlin.models.utils import schema_utils
from merlin.schema import Schema, Tags
from merlin.schema import ColumnSchema, Schema, Tags
from merlin.schema.io.tensorflow_metadata import TensorflowMetadata

LOG = logging.getLogger("merlin-models")
Expand Down Expand Up @@ -116,6 +116,21 @@ def generate_data(
else:
raise ValueError(f"Unknown input type: {type(input)}")

for col in schema.column_names:
if not schema[col].is_list:
continue
new_properties = schema[col].properties
new_properties["value_count"] = {"min": min_session_length}
if max_session_length:
new_properties["value_count"]["max"] = max_session_length
schema[col] = ColumnSchema(
name=schema[col].name,
tags=schema[col].tags,
properties=new_properties,
dtype=schema[col].dtype,
is_list=True,
)

df = generate_user_item_interactions(
schema, num_rows, min_session_length, max_session_length, device=device
)
Expand Down

0 comments on commit 60a9ca1

Please sign in to comment.