Skip to content

[lhotse] Added support for re-weighting datasets with temperature on the fly.#15200

Merged
pzelasko merged 16 commits into
NVIDIA-NeMo:mainfrom
XuesongYang:xueyang/pr-reweight-otf
Mar 12, 2026
Merged

[lhotse] Added support for re-weighting datasets with temperature on the fly.#15200
pzelasko merged 16 commits into
NVIDIA-NeMo:mainfrom
XuesongYang:xueyang/pr-reweight-otf

Conversation

@XuesongYang
Copy link
Copy Markdown
Collaborator

@XuesongYang XuesongYang commented Dec 17, 2025

Summary

Add on-the-fly temperature-based dataset reweighting to the Lhotse dataloader, allowing users to control sampling distributions across nested input_cfg groups without manually recalculating weights each time datasets are added or removed.

The feature applies the formula w_i^τ / Σ w_j^τ at each nesting level, where τ is configurable per level. Validation is strict: reweight_temperature must be either a scalar (broadcast to all levels) or a list whose length exactly matches the input_cfg nesting depth.

ŵᵢ = wᵢ^τ / Σⱼ wⱼ^τ.

  • τ = 1.0: Preserves original weight ratios (neutral)
  • τ = 0.0: Equalizes all datasets regardless of original weights
  • 0 < τ < 1.0: Over-samples smaller datasets relative to larger ones

Changes

New functionality (cutset.py)

  • temperature_reweighting(weights, temperature): applies temperature scaling and normalizes weights.
  • count_input_cfg_levels(config): computes maximum input_cfg nesting depth.
  • Validation in read_dataset_config(): scalar is broadcast with a logged warning; list length mismatch raises ValueError.
  • parse_and_combine_datasets() pops the current-level temperature and passes remaining values to child groups via propagate_attrs.

Config (dataloader.py)

  • Added reweight_temperature field to LhotseDataLoadingConfig (default None)

Tests

  • Unit tests for temperature_reweighting() and count_input_cfg_levels()
  • Integration tests for dataloader with various temperature configurations
  • Validation tests for scalar/list normalization behavior

Documentation (docs/source/audio/configs.rst)

  • New "Dataset Reweighting with Temperature" section with formula, configuration examples (scalar and list), nesting depth explanation, and a multi-task-group walkthrough.

Example configs

  • magpietts_lhotse.yaml / magpietts_lhotse_moe.yaml: added reweight_temperature: 1.0 (neutral default) with comment showing list alternative
  • Example of preserving top-level ratios but equalize within sub-groups:
reweight_temperature: [1.0, 0.0]  # Level 1: preserve, Level 2: equalize

Copilot AI review requested due to automatic review settings December 17, 2025 00:30
Comment thread examples/tts/conf/magpietts/magpietts_lhotse.yaml Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces temperature-based dataset reweighting to the Lhotse dataloader, enabling dynamic adjustment of sampling probabilities across datasets through a configurable temperature parameter. This eliminates the need for manual weight recalculation when combining datasets with different sizes.

Key changes:

  • New temperature_reweighting() function that applies temperature scaling to weights using the formula (w_i ^ temp) / sum(w_j ^ temp)
  • New reweight_temperature configuration option that supports hierarchical temperature application across nested dataset groups
  • Comprehensive test suite with 19 tests covering various input types and edge cases

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
nemo/collections/common/data/lhotse/cutset.py Adds the temperature_reweighting() function and integrates it into the dataset loading pipeline with temperature propagation through nested groups
examples/tts/conf/magpietts/magpietts_lhotse.yaml Adds example configuration demonstrating hierarchical temperature usage with [1.0, 0.0]
tests/collections/common/test_lhotse_temperature_reweighting.py Adds comprehensive unit and integration tests for the temperature reweighting functionality

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread examples/tts/conf/magpietts/magpietts_lhotse.yaml Outdated
Comment thread nemo/collections/common/data/lhotse/cutset.py Outdated
Comment thread nemo/collections/common/data/lhotse/cutset.py Outdated
Comment thread tests/collections/common/test_lhotse_temperature_reweighting.py
Comment thread tests/collections/common/test_lhotse_temperature_reweighting.py Fixed
Comment thread nemo/collections/common/data/lhotse/cutset.py Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tests/collections/common/test_lhotse_temperature_reweighting.py
Comment thread nemo/collections/common/data/lhotse/cutset.py
Comment thread tests/collections/common/test_lhotse_temperature_reweighting.py
Comment thread examples/tts/conf/magpietts/magpietts_lhotse.yaml Outdated
Comment thread nemo/collections/common/data/lhotse/cutset.py
shuffle: true
num_workers: 6
pin_memory: true
reweight_temperature: [1.0, 0.0] # Temperature for re-weighting datasets. 1 is a neutral value. Lower temperature over-samples smaller datasets, and vice versa.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you also need to add this to dataloader.py in LhotseDataLoadingConfig otherwise it won't get propagated to cutset.py functions. We should also have a test that uses this option through get_lhotse_dataloader_from_config to check it works.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. indeed, my test using get_lhotse_dataloader_from_config failed to pass this new param to propagate_attrs.

Fixed now.

Comment thread nemo/collections/common/data/lhotse/dataloader.py Fixed
@XuesongYang XuesongYang requested a review from Copilot January 13, 2026 05:23
XuesongYang and others added 11 commits March 11, 2026 11:48
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
    Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
…onfig. Otherwise, it would not be passed to propagate_attrs.

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
…t length could be shorter or longer than the max depth of recursion group. added tests.

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
…smatches

Replace silent auto-correction (extend/trim with warning) with ValueError
when reweight_temperature list length does not match input_cfg nesting depth.
Only scalar (broadcast) and exact-length list are now accepted.

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
…reweight remperature.

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
…ction

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
@pzelasko pzelasko enabled auto-merge (squash) March 11, 2026 18:54
@pzelasko pzelasko merged commit e137894 into NVIDIA-NeMo:main Mar 12, 2026
133 checks passed
nune-tadevosyan pushed a commit to nune-tadevosyan/NeMo that referenced this pull request Mar 13, 2026
…the fly. (NVIDIA-NeMo#15200)

* [lhotse] reweighting datasets with temperature on the fly based on weights predeifined in train_ds.dataset.input_cfg YAML configs. This feature would save the effort of flattening the dataset distribution every time when adding new datasets.

Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com>

* Update tests/collections/common/test_lhotse_temperature_reweighting.py

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>

* Apply suggestion from @XuesongYang

    Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com>

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>

* Apply suggestion from @XuesongYang

    Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com>

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>

* added safeguard and updated unit tests.

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>

* Apply suggestion from @XuesongYang

    Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>

* added tests to ensure dataloader applied reweighting as expected.

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>

* bugfix: added reweighting_temperature into default LhotseDataLoadingConfig. Otherwise, it would not be passed to propagate_attrs.

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>

* added support of flexible formats of temperatures (scalar, list). List length could be shorter or longer than the max depth of recursion group. added tests.

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>

* added documentations

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>

* Update nemo/collections/common/data/lhotse/cutset.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>

* Enforce strict reweight_temperature validation: reject list length mismatches

Replace silent auto-correction (extend/trim with warning) with ValueError
when reweight_temperature list length does not match input_cfg nesting depth.
Only scalar (broadcast) and exact-length list are now accepted.

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>

* fixed type hint for config_list in parse_and_combine_datasets().

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>

* added deterministic_rng fixture to 5 stochastic integration tests of reweight remperature.

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>

* Refactor: extract reweight_temperature validation into standalone function

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>

---------

Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@XuesongYang XuesongYang deleted the xueyang/pr-reweight-otf branch March 24, 2026 03:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants