Skip to content

fit_fast_tokenizer: dedup duplicate repo_id in fallback heads to match training #349

@shuheng-liu

Description

@shuheng-liu

Context

Deferred follow-up from #348 (per-(robot_type, control_mode) normalization in fit_fast_tokenizer). Tracked here so the PR can land without bundling unrelated plumbing.

Problem

When a dataset entry in the mixture has empty (robot_type, control_mode) after overrides, compute_norm_key falls back to using cfg.repo_id as the head's key. If two such entries share the same repo_id, the two diverge between fit and training:

  • Fit time (current _aggregate_stats_per_head in src/opentau/scripts/fit_fast_tokenizer.py): both rows compute the same fallback key repo_id, so they are POOLED into one shared head.
  • Training time (DatasetMixtureMetadata._build_norm_heads in src/opentau/datasets/dataset_mixture.py:340-348): the per-dataset name is _make_dataset_names's deduplicated form (X, X#0, X#1, ...), so the fallback keys differ and the rows stay as separate singleton heads.

Net: the BPE corpus the tokenizer is fit on doesn't match what the policy normalizes at runtime for that subset of rows.

Mitigation already in place (PR #348, commit c755cb8)

A Counter-based warning at fit time flags duplicate repo_id values that fell through to fallback. So an operator hitting this configuration sees an explicit signal:

N fallback-keyed repo_id values appear in the mixture more than once. Fit-time POOLS these under one shared head per repo_id, but training (via _make_dataset_names's #N dedup) keeps them as separate singleton heads...

This is enough to keep the misalignment from shipping silently, but doesn't actually fix the underlying divergence.

Proposed fix

Replicate _make_dataset_names's dedup at fit time inside _aggregate_stats_per_head so the fallback key passed to compute_norm_key is the deduplicated name, not the raw repo_id. Roughly:

raw_names = [(dc.repo_id or dc.vqa or "<no-name>") for dc in mixture_cfg.datasets]
counts = Counter(raw_names)
seen: dict[str, int] = {}
deduplicated_names: list[str] = []
for name in raw_names:
    if counts[name] > 1:
        i = seen.get(name, 0)
        deduplicated_names.append(f"{name}#{i}")
        seen[name] = i + 1
    else:
        deduplicated_names.append(name)

Then pass deduplicated_names[i] (not per_ds_repo[i]) as the third arg to compute_norm_key. This makes the fit-time fallback keys match training's exactly.

Why deferred

The dedup logic is currently a static method on WeightedDatasetMixture that takes TrainPipelineConfig and a list of BaseDataset. Lifting just the part that handles dedup is straightforward (the snippet above doesn't need either), but "the right thing" is probably to factor _make_dataset_names so both call sites share. That refactor felt out of scope for #348's fit-time normalization correctness fixes; tracking here.

Acceptance

  • Fit-time fallback keys match training's (a synthetic test with duplicate repo_id shows both paths producing the same (per_ds_key, per_norm_key_stats) shape).
  • The c755cb8 dup-detection warning still fires (for visibility), but the underlying divergence is gone.

Metadata

Metadata

Assignees

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions