Skip to content

feat: per-dataset validation loss reporting#169

Merged
shuheng-liu merged 2 commits into
mainfrom
claude/per-dataset-validation-loss-SviqL
Apr 24, 2026
Merged

feat: per-dataset validation loss reporting#169
shuheng-liu merged 2 commits into
mainfrom
claude/per-dataset-validation-loss-SviqL

Conversation

@shuheng-liu
Copy link
Copy Markdown
Member

@shuheng-liu shuheng-liu commented Apr 22, 2026

What this does

Implements per-dataset validation loss reporting so training runs with multiple datasets can track each dataset's validation metrics independently rather than only the aggregate.

  • Adds WeightedDatasetMixture.get_per_dataset_dataloaders(), which returns one sequential DataLoader per underlying dataset (empty datasets skipped).
  • In train.py, the validation step now iterates each per-dataset loader in sequence, updating a per-dataset MetricsTracker and an aggregate tracker at the same time. Each step emits both Validation/<name>/Loss (plus MSE/CE/L1/Accuracy) and the existing Validation/Loss aggregate scalars.
  • Uses each DatasetConfig's repo_id / vqa identifier as the logged dataset name (with #<idx> disambiguation when duplicated), falling back to the old ClassName_i scheme when the config list cannot be lined up with the datasets (e.g. unit tests that construct a mixture directly).

This changes the validation semantics from "sample len(concat) times via hierarchical weighted sampling with replacement" to "fully iterate each val subset once", which is the conventional choice for validation and is what makes per-dataset reporting meaningful.

How it was tested

  • Verified the name-derivation logic in isolation across unique, duplicate, VQA, mismatched-length, and missing-config cases.
  • python -c "import ast; ast.parse(...)" on both modified files.
  • Existing tests/datasets/test_dataset_mixture.py assertions on len(mixture.dataset_names) still hold (the public shape is unchanged).
  • Full GPU validation in a real run is still pending (no GPU available in this environment).

How to checkout & try? (for the reviewer)

pytest -sx tests/datasets/test_dataset_mixture.py

Then launch any training config with val_freq > 0 and multiple datasets in dataset_mixture.datasets, and confirm wandb shows both the aggregate Validation/* scalars and one Validation/<repo_id>/* group per dataset.

Checklist

  • I have added Google-style docstrings to important functions and ensured function parameters are typed.
  • My PR includes policy-related changes.
    • If the above is checked: I have run the GPU pytests (pytest -m "gpu") and regression tests.

https://claude.ai/code/session_019zJCa5a31uULbXPrFjtiTr

claude added 2 commits April 22, 2026 01:17
Validation previously used a single hierarchically-sampled DataLoader over
the mixture, yielding only aggregate metrics. Now each underlying dataset
gets its own sequential DataLoader and is validated fully, producing
per-dataset metrics (Loss, MSE/CE/L1 Loss, Accuracy) alongside the existing
aggregate Validation/* logs.

Adds WeightedDatasetMixture.get_per_dataset_dataloaders() and uses each
DatasetConfig's repo_id / vqa identifier for the logged dataset name so
wandb groupings are readable.

https://claude.ai/code/session_019zJCa5a31uULbXPrFjtiTr
- Pass step as a default arg to _make_val_tracker so ruff's B023 loop-variable
  check is satisfied; behavior is unchanged because the factory is only called
  within the same loop iteration.
- Apply ruff-format to the two touched files.

https://claude.ai/code/session_019zJCa5a31uULbXPrFjtiTr
@shuheng-liu shuheng-liu marked this pull request as ready for review April 22, 2026 01:44
@shuheng-liu shuheng-liu self-assigned this Apr 22, 2026
@shuheng-liu shuheng-liu added the feature New feature or request label Apr 22, 2026
Comment thread src/opentau/scripts/train.py
@shuheng-liu shuheng-liu merged commit 2e22da9 into main Apr 24, 2026
5 checks passed
@shuheng-liu shuheng-liu deleted the claude/per-dataset-validation-loss-SviqL branch April 24, 2026 21:31
shuheng-liu added a commit that referenced this pull request Apr 27, 2026
Conflict resolutions:

- src/opentau/scripts/train.py: kept HEAD's additions —
  ``_sync_deepspeed_gradient_accumulation_steps`` (#175) and the
  ``gradient_accumulation_steps`` entry in ``accelerator_kwargs``. main
  doesn't have this function, so taking HEAD is purely additive over the
  auto-merged surrounding edits from #176 (DDP throughput perf) and #169
  (per-dataset val loss).

- src/opentau/scripts/profile_step.py (add/add): kept HEAD's superset.
  Both branches added this file; HEAD additionally has the
  ``ATTENTION_IMPL`` / ``GRAD_CHECKPOINT`` env-var overrides and the
  ``MasterWeightOptimizer`` wrapping introduced in #187. main has no
  content beyond what HEAD already includes.

- tests/scripts/test_train.py (add/add): kept HEAD's superset. HEAD has
  the imports for ``logging``/``SimpleNamespace``/``accelerate`` plus the
  four ``test_*_deepspeed_*`` tests for
  ``_sync_deepspeed_gradient_accumulation_steps`` (#175). main's version
  only had the ``TestFindUnusedParamsFromEnv`` class which HEAD also has.

CPU tests pass: ``pytest tests/policies/test_pi05_mem.py tests/scripts/test_train.py -m 'not gpu'`` → 56 passed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants