-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix the inference of transformer-based models trained with masked lan…
…guage modeling (#909) * add the inference fix to ReplaceMaskedEmbeddings * first solution for inference support * updates based on PR comments * Apply suggestions from code review Co-authored-by: Gabriel Moreira <gmoreira@nvidia.com> * Use merlin-dataloader package (#845) * Use merlin-dataloader package * remove torch.dataset in favor of merlin.loader.torch * update dressipi notebook * minor clean up * Completely removes models DataLoader * Installs merlin-dataloader in github actions * Adds back the stop method * dataloader can produce sparse tensors using value counts * remove data files * fix torch tests * add missing target to dlrm test * use loader.peek() * add some comments to help understand horovod tests * make sparse tensors optional * cleanup * fix spelling * fix merge * replace while loop with for loop in horovod test * use loader context mananger * Update according to dataloader changes #80 * restore tox.ini * restore gh workflow * revert generator changes * Restore documentation build (#916) - Change Python 3.9.7 to 3.8. - Update the versions of the GH actions. - Update pre-commit config file to get flake8 from GitHub instead of GitLab. * Support `tuple` return type from model `pre` and update test to use this (#890) * Support `tuple` return typee from `pre` arg to `evaluate`, `predict` * Update CLM transformer test to use `pre` instead of Loader `transform` * Update youtube dnn tests to use transform as model fit pre * Add `pre` to ModelBlock fit/evaluate * Revert "Add `pre` to ModelBlock fit/evaluate" This reverts commit 1eef7b8. * Raise exception if ragged/sparse tensors are passed at training time. * Update model_test helper to avoid passing ragged tensors to `fit` * Handle x and y in model_test * Change process_lists param to False by default * Convert to tuple in test loader * Move order of ragged tensor assertion to before train_pre call * expand dims in test_classification * pass transform as pre in test in batch negatives * Update continuous and retrieval tests * Remove test of sequence predict functions with loader * Update error message about ragged tensors for clarity * Add explanation about why the input types are restricted * Rename dataset to dataloader in model_test Co-authored-by: rnyak <ronayak@hotmail.com> * add assertion check to TransformerInferenceHiddenState Co-authored-by: Gabriel Moreira <gmoreira@nvidia.com> Co-authored-by: edknv <109497216+edknv@users.noreply.github.com> Co-authored-by: mikemckiernan <mmckiernan@nvidia.com> Co-authored-by: Oliver Holworthy <oholworthy@nvidia.com> Co-authored-by: rnyak <ronayak@hotmail.com>
- Loading branch information
1 parent
a8ab140
commit df84c81
Showing
6 changed files
with
138 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters