Skip to content

Aggregate val loss per (dataset, control_mode) outside forward + parallelize validation #373

@shuheng-liu

Description

@shuheng-liu

Background

Now that every sample carries dataset provenance in our standard data format — _TaggedDataset injects dataset_index (the norm-head row, dropout-immune) and dataset_repo_id (the deduplicated mixture-level name) into each item (src/opentau/datasets/dataset_mixture.py:83-135), and optional keys robot_type / control_mode are emitted by _emit_optional_keys (src/opentau/datasets/lerobot_dataset.py) — we can compute per-(dataset, control mode) loss breakdowns from the batch itself, instead of relying on how the validation loop happens to be batched.

Two related shortcomings in the current validation path (src/opentau/scripts/train.py:844-950) motivate this issue.

Problem 1: per-dataset breakdown is an artifact of dataloader structure, not provenance

Today the only per-dataset granularity we get comes from iterating one separate dataloader per underlying dataset:

for ds_name, ds_loader in per_dataset_val_dataloaders.items():
    ds_tracker = per_dataset_trackers[ds_name]
    for batch in ds_loader:
        losses = policy.forward(batch)
        ...

(src/opentau/scripts/train.py:868-918, dataloaders built by WeightedDatasetMixture.get_per_dataset_dataloaders() at src/opentau/datasets/dataset_mixture.py:944-977.)

Consequences:

  • The breakdown granularity is per source dataset, never per (dataset, control mode) — even though control_mode ({"joint", "ee", "mixed"}) is exactly the axis we want to slice MSE/CE along, and it is already present in the batch.
  • The loss itself (policy.forward) returns a single scalar MSE/CE reduced over the whole batch (modeling_pi05.py:740, modeling_pi06.py:623, pi07/low_level/modeling_pi07_low_level.py:931). There is no per-sample / per-group disaggregation, so we cannot break a mixed batch down by provenance at all — the only reason we get any per-dataset numbers today is that each dataloader is homogeneous by construction.
  • Crucially this disaggregation does not belong inside forward. The model shouldn't know about dataset taxonomy or own metric bookkeeping; grouping by (dataset_index, control_mode) should happen in the training/eval loop using the provenance keys already in the batch.

Problem 2: validation badly under-parallelized across ranks

get_per_dataset_dataloaders() creates one DataLoader per dataset, each iterated separately and then sharded across ranks by accelerator.prepare(). With a heterogeneous mixture of many small validation subsets (some with only a single frame after val_split_ratio), this is pathological:

  • A dataset with fewer frames than world_size leaves most ranks with an empty/padded shard — they do a no-op forward (or wait) while one rank does the real work.
  • Because we loop datasets sequentially (for ds_name, ds_loader in ...), the idle time stacks: every tiny dataset is its own under-filled collective round, and accelerator.gather_for_metrics + the trailing accelerator.wait_for_everyone() (train.py:950) force all ranks to rendezvous on each one.
  • drop_last=False (dataset_mixture.py:973) is correct for not discarding val data, but combined with tiny datasets it guarantees ragged, under-utilized batches.

Net effect: validation wall-clock scales with the number of datasets rather than the total val frames, and most ranks sit idle.

Proposed direction (for discussion, not prescriptive)

  1. Move per-group aggregation out of forward and into the loop, keyed on provenance. Have policies optionally return unreduced (per-sample) MSE/CE — or keep returning scalars but additionally expose a per-sample loss — and let the validation loop bucket them by (dataset_index, control_mode) from the batch, using dataset_index as the dropout-immune key (control_mode can be masked at train time; pair it with dataset_index or recover via compute_norm_key). Log Validation/{dataset}/{control_mode}/{MSE,CE} plus the existing mixture-weighted aggregate (_mixture_weighted_aggregate, train.py:932-946).
  2. Parallelize validation across all ranks regardless of per-dataset size. Instead of one sequential under-filled dataloader per dataset, run a single validation pass over the combined val mixture (so every rank's shard is full), and rely on the per-sample provenance keys for grouping rather than on homogeneous dataloaders. This decouples the breakdown from the batching and keeps every rank busy. (Open question: whether to keep get_per_dataset_dataloaders for any callers, or replace it.)

Acceptance

  • Validation logs MSE/CE broken down by (dataset, control_mode), computed in the loop from batch provenance — not inside any forward.
  • The grouping is keyed on dataset_index (dropout-immune) so a mixed batch can be disaggregated correctly; no reliance on one-dataset-per-dataloader homogeneity.
  • Validation keeps all ranks busy: a mixture of many 1-frame datasets no longer leaves most ranks idle. Quantify the wall-clock improvement on a representative multi-dataset config.
  • Per-step val loss remains deterministic under a fixed seed (per CLAUDE.md hard rule Fixing reward normalizer #3).

References

  • Validation loop: src/opentau/scripts/train.py:844-950
  • Per-dataset dataloaders: src/opentau/datasets/dataset_mixture.py:944-977
  • Provenance tagging (dataset_index, dataset_repo_id): src/opentau/datasets/dataset_mixture.py:83-135
  • Optional keys (robot_type, control_mode): src/opentau/datasets/lerobot_dataset.py (_emit_optional_keys)
  • Loss returned from forward (scalar MSE/CE): modeling_pi05.py:740, modeling_pi06.py:623, pi07/low_level/modeling_pi07_low_level.py:931
  • Metrics plumbing: MetricsTracker / AverageMeter in src/opentau/utils/logging_utils.py; mixture aggregate _mixture_weighted_aggregate in src/opentau/scripts/train.py

Metadata

Metadata

Assignees

Labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions