fix: mixture-weighted validation aggregate + per-name dup suffix#189
Merged
Conversation
Restores meaningful overall Validation/Loss after PR #169 by computing the aggregate as a mixture-weighted mean of per-dataset metrics rather than implicitly weighting by val-subset size. Also fixes disambiguation of duplicate repo_ids to use per-name sequential suffix (['A','B','A'] -> ['A#0','B','A#1'] instead of ['A#0','B','A#2']).
akshay18iitg
approved these changes
Apr 27, 2026
shuheng-liu
added a commit
that referenced
this pull request
Apr 27, 2026
Conflict resolutions: - src/opentau/scripts/train.py: auto-merged. - tests/scripts/test_train.py: union resolution. HEAD added the ``test_*_deepspeed_*`` tests (#175) and main added the ``TestMixtureWeightedAggregate`` class (#189). Both exercise different helpers so they coexist; the import block was combined.
3 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What this does
Follow-up to #169. Two small fixes around per-dataset validation reporting:
Validation/Loss. PR feat: per-dataset validation loss reporting #169 kept the aggregate scalar but, because it accumulated each batch into a parallelagg_trackerwithn=1, the meaning silently changed: large val subsets dominate. This drops the parallel tracker and computes the aggregate from the per-dataset trackers usingWeightedDatasetMixture.dataset_weights, so the overallValidation/{Loss,MSE Loss,CE Loss,L1 Loss,Accuracy}reflects the training mixture again. Per-dataset scalars are unchanged.repo_idcollides. When twoDatasetConfigs share the samerepo_id(e.g. same HF repo, differentepisodesslices),_make_dataset_namespreviously used the global index, so['A','B','A']became['A#0','B','A#2']. Now it uses a per-name counter:['A#0','B','A#1'].Helper extracted to top-level
_mixture_weighted_aggregate(per_dataset_trackers, name_to_weight, metric_keys)so it can be unit-tested. Renormalizes weights over the keys actually present in the trackers (matchesget_per_dataset_dataloadersskipping empty datasets), and returns0.0for every metric when there is no signal (empty trackers / all-zero weights).Out of scope:
_get_worker_name_mapping_overridesstill keys off barerepo_idand would collide between two datasets that sharerepo_idbut have differentdata_features_name_mapping. That is a feature-mapping bug rather than a reporting one and is left for a separate PR.How it was tested
tests/scripts/test_train.py::TestMixtureWeightedAggregatecovers equal weights / unequal weights / renormalizing over present keys only / all five metrics / empty-trackers / all-zero-weights edge cases.tests/datasets/test_dataset_mixture.py::TestMakeDatasetNamespins the disambiguation behaviour: all-unique, repeatedrepo_id, all-identical, mixedvqa/repo_id, missingrepo_id+vqa(class-name fallback), mismatched cfg length, and absentdataset_mixtureattribute.pre-commit runclean on all touched files.['A#0','B','A#1']ordering.How to checkout & try? (for the reviewer)
Then launch any training config with
val_freq > 0and at least two datasets with differentdataset_weightsindataset_mixture.datasets. Confirm in wandb thatValidation/Lossnow matchessum(w_i * Validation/<name_i>/Loss) / sum(w_i)instead of being dominated by the largest val subset. To exercise the duplicate-name path, point twoDatasetConfigentries at the samerepo_id(with differentepisodes) and confirm the per-dataset W&B groups appear as<repo_id>#0and<repo_id>#1.Checklist
Generated by Claude Code