[lhotse] Added support for re-weighting datasets with temperature on the fly.#15200
Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces temperature-based dataset reweighting to the Lhotse dataloader, enabling dynamic adjustment of sampling probabilities across datasets through a configurable temperature parameter. This eliminates the need for manual weight recalculation when combining datasets with different sizes.
Key changes:
- New
temperature_reweighting()function that applies temperature scaling to weights using the formula(w_i ^ temp) / sum(w_j ^ temp) - New
reweight_temperatureconfiguration option that supports hierarchical temperature application across nested dataset groups - Comprehensive test suite with 19 tests covering various input types and edge cases
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
nemo/collections/common/data/lhotse/cutset.py |
Adds the temperature_reweighting() function and integrates it into the dataset loading pipeline with temperature propagation through nested groups |
examples/tts/conf/magpietts/magpietts_lhotse.yaml |
Adds example configuration demonstrating hierarchical temperature usage with [1.0, 0.0] |
tests/collections/common/test_lhotse_temperature_reweighting.py |
Adds comprehensive unit and integration tests for the temperature reweighting functionality |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| shuffle: true | ||
| num_workers: 6 | ||
| pin_memory: true | ||
| reweight_temperature: [1.0, 0.0] # Temperature for re-weighting datasets. 1 is a neutral value. Lower temperature over-samples smaller datasets, and vice versa. |
There was a problem hiding this comment.
I think you also need to add this to dataloader.py in LhotseDataLoadingConfig otherwise it won't get propagated to cutset.py functions. We should also have a test that uses this option through get_lhotse_dataloader_from_config to check it works.
There was a problem hiding this comment.
good point. indeed, my test using get_lhotse_dataloader_from_config failed to pass this new param to propagate_attrs.
Fixed now.
8c81ce9 to
f524347
Compare
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
…onfig. Otherwise, it would not be passed to propagate_attrs. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
…t length could be shorter or longer than the max depth of recursion group. added tests. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
…smatches Replace silent auto-correction (extend/trim with warning) with ValueError when reweight_temperature list length does not match input_cfg nesting depth. Only scalar (broadcast) and exact-length list are now accepted. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
…reweight remperature. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
…ction Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
4283a0f to
2d1d1ce
Compare
…the fly. (NVIDIA-NeMo#15200) * [lhotse] reweighting datasets with temperature on the fly based on weights predeifined in train_ds.dataset.input_cfg YAML configs. This feature would save the effort of flattening the dataset distribution every time when adding new datasets. Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * Update tests/collections/common/test_lhotse_temperature_reweighting.py Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Apply suggestion from @XuesongYang Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Apply suggestion from @XuesongYang Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * added safeguard and updated unit tests. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Apply suggestion from @XuesongYang Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * added tests to ensure dataloader applied reweighting as expected. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * bugfix: added reweighting_temperature into default LhotseDataLoadingConfig. Otherwise, it would not be passed to propagate_attrs. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * added support of flexible formats of temperatures (scalar, list). List length could be shorter or longer than the max depth of recursion group. added tests. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * added documentations Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Update nemo/collections/common/data/lhotse/cutset.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Enforce strict reweight_temperature validation: reject list length mismatches Replace silent auto-correction (extend/trim with warning) with ValueError when reweight_temperature list length does not match input_cfg nesting depth. Only scalar (broadcast) and exact-length list are now accepted. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * fixed type hint for config_list in parse_and_combine_datasets(). Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * added deterministic_rng fixture to 5 stochastic integration tests of reweight remperature. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Refactor: extract reweight_temperature validation into standalone function Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --------- Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Summary
Add on-the-fly temperature-based dataset reweighting to the Lhotse dataloader, allowing users to control sampling distributions across nested
input_cfggroups without manually recalculating weights each time datasets are added or removed.The feature applies the formula
w_i^τ / Σ w_j^τat each nesting level, whereτis configurable per level. Validation is strict:reweight_temperaturemust be either a scalar (broadcast to all levels) or a list whose length exactly matches theinput_cfgnesting depth.ŵᵢ = wᵢ^τ / Σⱼ wⱼ^τ.
τ = 1.0: Preserves original weight ratios (neutral)τ = 0.0: Equalizes all datasets regardless of original weights0 < τ < 1.0: Over-samples smaller datasets relative to larger onesChanges
New functionality (
cutset.py)temperature_reweighting(weights, temperature): applies temperature scaling and normalizes weights.count_input_cfg_levels(config): computes maximuminput_cfgnesting depth.read_dataset_config(): scalar is broadcast with a logged warning; list length mismatch raisesValueError.parse_and_combine_datasets()pops the current-level temperature and passes remaining values to child groups viapropagate_attrs.Config (
dataloader.py)reweight_temperaturefield toLhotseDataLoadingConfig(defaultNone)Tests
temperature_reweighting()andcount_input_cfg_levels()Documentation (
docs/source/audio/configs.rst)Example configs
magpietts_lhotse.yaml/magpietts_lhotse_moe.yaml: addedreweight_temperature: 1.0(neutral default) with comment showing list alternative