feat: per-dataset validation loss reporting#169
Merged
Conversation
Validation previously used a single hierarchically-sampled DataLoader over the mixture, yielding only aggregate metrics. Now each underlying dataset gets its own sequential DataLoader and is validated fully, producing per-dataset metrics (Loss, MSE/CE/L1 Loss, Accuracy) alongside the existing aggregate Validation/* logs. Adds WeightedDatasetMixture.get_per_dataset_dataloaders() and uses each DatasetConfig's repo_id / vqa identifier for the logged dataset name so wandb groupings are readable. https://claude.ai/code/session_019zJCa5a31uULbXPrFjtiTr
- Pass step as a default arg to _make_val_tracker so ruff's B023 loop-variable check is satisfied; behavior is unchanged because the factory is only called within the same loop iteration. - Apply ruff-format to the two touched files. https://claude.ai/code/session_019zJCa5a31uULbXPrFjtiTr
akshay18iitg
requested changes
Apr 24, 2026
akshay18iitg
approved these changes
Apr 24, 2026
This was referenced Apr 25, 2026
shuheng-liu
added a commit
that referenced
this pull request
Apr 27, 2026
Conflict resolutions: - src/opentau/scripts/train.py: kept HEAD's additions — ``_sync_deepspeed_gradient_accumulation_steps`` (#175) and the ``gradient_accumulation_steps`` entry in ``accelerator_kwargs``. main doesn't have this function, so taking HEAD is purely additive over the auto-merged surrounding edits from #176 (DDP throughput perf) and #169 (per-dataset val loss). - src/opentau/scripts/profile_step.py (add/add): kept HEAD's superset. Both branches added this file; HEAD additionally has the ``ATTENTION_IMPL`` / ``GRAD_CHECKPOINT`` env-var overrides and the ``MasterWeightOptimizer`` wrapping introduced in #187. main has no content beyond what HEAD already includes. - tests/scripts/test_train.py (add/add): kept HEAD's superset. HEAD has the imports for ``logging``/``SimpleNamespace``/``accelerate`` plus the four ``test_*_deepspeed_*`` tests for ``_sync_deepspeed_gradient_accumulation_steps`` (#175). main's version only had the ``TestFindUnusedParamsFromEnv`` class which HEAD also has. CPU tests pass: ``pytest tests/policies/test_pi05_mem.py tests/scripts/test_train.py -m 'not gpu'`` → 56 passed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What this does
Implements per-dataset validation loss reporting so training runs with multiple datasets can track each dataset's validation metrics independently rather than only the aggregate.
WeightedDatasetMixture.get_per_dataset_dataloaders(), which returns one sequentialDataLoaderper underlying dataset (empty datasets skipped).train.py, the validation step now iterates each per-dataset loader in sequence, updating a per-datasetMetricsTrackerand an aggregate tracker at the same time. Each step emits bothValidation/<name>/Loss(plus MSE/CE/L1/Accuracy) and the existingValidation/Lossaggregate scalars.DatasetConfig'srepo_id/vqaidentifier as the logged dataset name (with#<idx>disambiguation when duplicated), falling back to the oldClassName_ischeme when the config list cannot be lined up with the datasets (e.g. unit tests that construct a mixture directly).This changes the validation semantics from "sample
len(concat)times via hierarchical weighted sampling with replacement" to "fully iterate each val subset once", which is the conventional choice for validation and is what makes per-dataset reporting meaningful.How it was tested
python -c "import ast; ast.parse(...)"on both modified files.tests/datasets/test_dataset_mixture.pyassertions onlen(mixture.dataset_names)still hold (the public shape is unchanged).How to checkout & try? (for the reviewer)
Then launch any training config with
val_freq > 0and multiple datasets indataset_mixture.datasets, and confirm wandb shows both the aggregateValidation/*scalars and oneValidation/<repo_id>/*group per dataset.Checklist
https://claude.ai/code/session_019zJCa5a31uULbXPrFjtiTr