diff --git a/src/opentau/scripts/fit_fast_tokenizer.py b/src/opentau/scripts/fit_fast_tokenizer.py index 0de8dafb..f9a45e9b 100644 --- a/src/opentau/scripts/fit_fast_tokenizer.py +++ b/src/opentau/scripts/fit_fast_tokenizer.py @@ -41,13 +41,28 @@ resampled to ``mixture.action_freq`` (or to each dataset's native fps when ``action_freq is None`` -- mixed-frequency mixtures) and right-padded to ``max_action_dim``. - 4. Min-max-normalize each chunk to ``[-1, 1]`` using - ``mixture.meta.aggregated_action_stats()`` -- the same global - min/max the BPE codec is fit over. (The training policy itself - normalizes per-dataset via - ``Normalize({"ACTION": NormalizationMode.MIN_MAX})``; the - tokenizer collapses to one global range so the discrete vocab - is shared across the mixture.) + 4. Min-max-normalize each chunk to ``[-1, 1]`` using the per- + ``(robot_type, control_mode)`` norm-head stats by default + (``--per-head-norm``) -- the same stats the training policy + applies via ``Normalize({"ACTION": NormalizationMode.MIN_MAX})`` + with per-head stacked buffers (PR #347). Per-dataset raw stats + are zero-padded (matching ``pad_vector`` in + ``_to_standard_data_format``) then pooled across head members + with ``nanmin``/``nanmax`` (with ``±Inf`` masked first, like + ``aggregate_stats``). Failure modes mirror training: any + dataset whose stats fail to load -> ``RuntimeError`` (training + would crash too via ``_to_standard_data_format``); + ``require_non_empty_robot_type/control_mode`` -> same + ``ValueError`` as ``datasets.factory._validate_metadata_requirements``. + Pass ``--no-per-head-norm`` for the legacy global aggregate + path, which is what older fits (pre-#347) used; the global + path under-spreads each head's distribution and produces + shorter fit-time chunks than the policy actually feeds the + tokenizer at training, so token-length analysis on the global + fit systematically underestimates the truncation rate. + Passing ``--use-mixture-dataloader`` silently degrades to + global normalization with a warning (the dataloader path + doesn't surface per-sample dataset_index here yet). 5. Call ``UniversalActionProcessor.fit(...)`` (DCT + Rust BpeTrainer) and ``save_pretrained`` the result. The upstream remote-code source ``processing_action_tokenizer.py`` is copied alongside so @@ -63,7 +78,8 @@ --chunk-size 50 \\ [--total-chunks 1000000] [--action-dim 32] \\ [--vocab-size 2048] [--scale 10] [--seed 0] [--num-workers 8] \\ - [--dataloader-batch-size 256] [--pilot] + [--dataloader-batch-size 256] [--pilot] \\ + [--no-per-head-norm] # legacy global normalization (default: per-head) The mixture JSON may use ``$ref`` includes -- see ``opentau.configs.refs.resolve_refs``. @@ -88,6 +104,7 @@ from opentau.configs.default import DatasetMixtureConfig from opentau.configs.refs import resolve_refs_to_tempfile +from opentau.datasets.dataset_mixture import compute_norm_key logger = logging.getLogger(__name__) @@ -201,6 +218,24 @@ def parse_args() -> argparse.Namespace: "fps convention." ), ) + p.add_argument( + "--per-head-norm", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "Normalize each chunk to [-1, 1] using its own " + "(robot_type, control_mode) norm head's pooled min/max -- " + "matches what the policy does at training time after PR #347. " + "Pass --no-per-head-norm to fall back to the legacy global " + "aggregate (only correct for single-head mixtures). The " + "global path systematically under-spreads each head's " + "distribution, so fit-time token-length analysis " + "underestimates the actual policy-side truncation rate. " + "Only supported on the default (manual sampler) path; " + "passing --use-mixture-dataloader silently falls back to " + "global normalization with a warning." + ), + ) p.add_argument( "--pilot", action="store_true", @@ -271,15 +306,29 @@ def _resolve_native_action_key( return "action" -def _load_dataset_stats( +def _load_metadata_stats_and_info( item: tuple[int, Any], -) -> tuple[int, str, dict | None, str | None]: - """Worker: read action min/max from a dataset's LeRobotDatasetMetadata.""" +) -> tuple[int, str, dict[str, Any], dict[str, np.ndarray] | None, str | None]: + """Shared worker: load a dataset's ``LeRobotDatasetMetadata``, return + ``(idx, repo_id, info_with_overrides, stats_or_None, err_or_None)``. + + ``info_with_overrides`` is a copy of ``meta.info`` with + ``DatasetConfig.{robot_type,control_mode}`` overrides applied -- mirrors + ``factory._apply_metadata_overrides`` so the caller can derive the same + norm key the training policy does. On full construction failure, the + returned ``info`` is empty. + """ idx, cfg = item + repo_id = cfg.repo_id or "" try: from opentau.datasets.lerobot_dataset import LeRobotDatasetMetadata meta = LeRobotDatasetMetadata(cfg.repo_id, root=cfg.root, revision=cfg.revision) + info = dict(getattr(meta, "info", {}) or {}) + if cfg.robot_type is not None: + info["robot_type"] = cfg.robot_type + if cfg.control_mode is not None: + info["control_mode"] = cfg.control_mode key = _resolve_native_action_key( cfg.repo_id, cfg.data_features_name_mapping, @@ -288,14 +337,16 @@ def _load_dataset_stats( if not meta.stats or key not in meta.stats: return ( idx, - cfg.repo_id, + repo_id, + info, None, f"key {key!r} missing from stats (keys={sorted(meta.stats or [])})", ) s = meta.stats[key] return ( idx, - cfg.repo_id, + repo_id, + info, { "min": np.asarray(s["min"], dtype=np.float64).ravel(), "max": np.asarray(s["max"], dtype=np.float64).ravel(), @@ -303,7 +354,15 @@ def _load_dataset_stats( None, ) except Exception as e: # noqa: BLE001 - return idx, cfg.repo_id or "", None, f"{type(e).__name__}: {e}" + return idx, repo_id, {}, None, f"{type(e).__name__}: {e}" + + +def _load_dataset_stats( + item: tuple[int, Any], +) -> tuple[int, str, dict | None, str | None]: + """Worker: read action min/max only (used by the legacy global-norm path).""" + idx, repo_id, _info, stats, err = _load_metadata_stats_and_info(item) + return idx, repo_id, stats, err def _aggregate_stats_manual( @@ -394,6 +453,344 @@ def _aggregate_stats_manual( return action_min, action_max, per_dataset +def _aggregate_stats_per_head( + mixture_cfg: DatasetMixtureConfig, action_dim: int, num_workers: int +) -> tuple[list[np.ndarray], list[np.ndarray], list[str], dict[str, np.ndarray]]: + """Aggregate per-``(robot_type, control_mode)`` action stats. + + Mirrors what ``DatasetMixtureMetadata._build_norm_heads`` does for the + actions stat at training time: each dataset gets the pooled min/max of + its norm head, so chunks normalized at fit time match what + ``Normalize({"ACTION": NormalizationMode.MIN_MAX})`` produces from the + same data at training time. The pooling for min/max is straight + ``nanmin``/``nanmax`` across head members -- count-weighted aggregation + (which ``aggregate_stats`` does in dataset_mixture) only matters for + ``mean``/``std`` and is a no-op for ``min``/``max``. + + To match the production ``Normalize`` path exactly, per-dataset stats are + zero-padded (not NaN-padded) to ``action_dim`` before pooling -- this + mirrors ``pad_vector`` (zero-pad) applied in + ``DatasetMixtureMetadata._to_standard_data_format``. Trailing slots + therefore pool to ``min=max=0``, which makes the production + ``(x - min) / (max - min + EPS) * 2 - 1`` evaluate to ``-1`` for the + zero-padded action suffix at both fit and training time. + + Failure modes (matching training-time behaviour): + + - If any dataset's action stats fail to load, raise ``RuntimeError``. + Training would also crash on such a dataset (``_to_standard_data_format`` + raises ``KeyError``), so silently producing a tokenizer that the policy + will refuse to consume is sneaky. Drop the offending dataset from the + mixture (or use ``--no-per-head-norm`` to fall back to the legacy + global path) before retrying. + - If ``mixture_cfg.require_non_empty_robot_type`` / + ``require_non_empty_control_mode`` is set and any dataset still has + an empty value after overrides, raise the same ``ValueError`` that + ``datasets.factory._validate_metadata_requirements`` raises. + + Args: + mixture_cfg: Mixture config. + action_dim: Padded action dim (chunks are padded to this width + before normalization, so the returned ``min``/``max`` arrays + are sized ``(action_dim,)``). + num_workers: ProcessPool size for parallel stats loading. + + Returns: + ``(per_ds_min, per_ds_max, per_ds_key, per_head_stats)``: + + - ``per_ds_min[i]``/``per_ds_max[i]``: ``(action_dim,)`` float32 + arrays for dataset ``i`` (the i-th entry in + ``mixture_cfg.datasets``). Datasets sharing a non-fallback + ``(robot_type, control_mode)`` get identical pooled arrays. + - ``per_ds_key[i]``: the norm key for dataset ``i`` (always a str + after the all-stats-required guard above). + - ``per_head_stats``: deduplicated head -> ``{"min": ..., "max": ...}`` + map (diagnostics + report). + """ + from concurrent.futures import ProcessPoolExecutor, as_completed + + n = len(mixture_cfg.datasets) + raw_min: list[np.ndarray | None] = [None] * n + raw_max: list[np.ndarray | None] = [None] * n + # Use list comprehensions (not `[{}] * n`) for mutable defaults so a future + # `per_ds_info[idx]["override_X"] = ...` doesn't silently fan out across + # all slots. The `[None] * n` / `[""] * n` patterns above are + # safe because their element types are immutable. + per_ds_info: list[dict[str, Any]] = [{} for _ in range(n)] + per_ds_native_dim: list[int | None] = [None] * n + per_ds_repo: list[str] = [""] * n + failures: list[tuple[int, str, str]] = [] + + logger.info( + "Loading per-(rt, cm) action stats for %d datasets (workers=%d)", + n, + num_workers, + ) + t0 = time.perf_counter() + work = list(enumerate(mixture_cfg.datasets)) + with ProcessPoolExecutor(max_workers=num_workers) as ex: + futs = {ex.submit(_load_metadata_stats_and_info, item): item[0] for item in work} + for fut in as_completed(futs): + idx, repo_id, info, stats, err = fut.result() + per_ds_repo[idx] = repo_id + per_ds_info[idx] = info + if stats is None: + failures.append((idx, repo_id, err or "")) + continue + native_dim = int(stats["min"].shape[0]) + per_ds_native_dim[idx] = native_dim + # Zero-pad to action_dim -- matches `pad_vector` in + # `_to_standard_data_format`. Trailing slots get min=max=0, so the + # production `(0 - 0) / (EPS) * 2 - 1 = -1` matches our fit-time + # output bit-for-bit (in float64) for the padded suffix. + mn = np.zeros(action_dim, dtype=np.float64) + mx = np.zeros(action_dim, dtype=np.float64) + clip = min(action_dim, native_dim) + mn[:clip] = stats["min"][:clip] + mx[:clip] = stats["max"][:clip] + raw_min[idx] = mn + raw_max[idx] = mx + + if failures: + sample = ", ".join(f"{repo}: {err}" for _i, repo, err in failures[:5]) + raise RuntimeError( + f"Per-head normalization requires all datasets' action stats to load, " + f"but {len(failures)}/{n} failed. Training would also crash on these " + f"datasets (`_to_standard_data_format` raises on missing stats). " + f"Drop them from the mixture or fix their stats. Sample: {sample}" + ) + + # Match `_validate_metadata_requirements` in datasets.factory: if the + # mixture demands non-empty robot_type / control_mode, surface that at + # fit time so the operator doesn't burn 90s on a fit that the very next + # training launch refuses to start. + require_robot = bool(getattr(mixture_cfg, "require_non_empty_robot_type", False)) + require_control = bool(getattr(mixture_cfg, "require_non_empty_control_mode", False)) + if require_robot or require_control: + bad: list[str] = [] + for i in range(n): + info = per_ds_info[i] + if require_robot and not (info.get("robot_type") or "").strip(): + bad.append(f"{per_ds_repo[i]}: robot_type is empty") + if require_control and not (info.get("control_mode") or "").strip(): + bad.append(f"{per_ds_repo[i]}: control_mode is empty") + if bad: + raise ValueError( + "DatasetMixtureConfig requires non-empty metadata fields, but the " + f"following {len(bad)} datasets are missing values after overrides:\n - " + + "\n - ".join(bad) + + "\nSet `DatasetConfig.robot_type` / `DatasetConfig.control_mode` " + "on the offending dataset(s) to provide an override." + ) + + # Derive norm keys (now that all stats loaded). Track fallback-fired + # datasets and surface them like `_build_norm_heads` does -- otherwise the + # operator silently gets singleton-per-dataset heads instead of pooled + # ones, which is almost never what they want. + per_ds_key: list[str] = [""] * n + fallback_datasets: list[str] = [] + for i in range(n): + info = per_ds_info[i] + key, fallback_fired = compute_norm_key( + info.get("robot_type"), + info.get("control_mode"), + per_ds_repo[i], + ) + per_ds_key[i] = key + if fallback_fired: + fallback_datasets.append(per_ds_repo[i]) + if fallback_datasets: + shown = fallback_datasets[:10] + suffix = f", ... and {len(fallback_datasets) - 10} more" if len(fallback_datasets) > 10 else "" + logger.warning( + "%d/%d datasets lack non-empty robot_type / control_mode and were " + "given a per-dataset fallback norm head (one singleton head each). " + "Set `DatasetConfig.robot_type` / `DatasetConfig.control_mode` to " + "pool them into shared heads. Affected: %s%s", + len(fallback_datasets), + n, + shown, + suffix, + ) + # Tighter signal for the specific divergence the deferred fallback-name + # dedup finding flagged: when a fallback-keyed repo_id appears more + # than once in the mixture, fit-time pools them under one shared key + # (since `compute_norm_key` returns the same string both times) while + # training keeps them as separate singleton heads (since + # `_make_dataset_names` deduplicates with `#N` suffixes). Flag the + # divergence explicitly so the operator can fix it before it ships. + from collections import Counter + + fallback_counts = Counter(fallback_datasets) + duplicates_in_fallback = {k: v for k, v in fallback_counts.items() if v > 1} + if duplicates_in_fallback: + dup_preview = dict(list(duplicates_in_fallback.items())[:10]) + logger.warning( + "%d fallback-keyed repo_id values appear in the mixture more " + "than once. Fit-time POOLS these under one shared head per " + "repo_id, but training (via `_make_dataset_names`'s `#N` dedup) " + "keeps them as separate singleton heads -- the fit-time " + "normalization will diverge from training. Set " + "`DatasetConfig.robot_type` / `DatasetConfig.control_mode` on " + "each entry, or deduplicate the mixture, to close the gap. " + "Duplicates (repo_id -> count): %s", + len(duplicates_in_fallback), + dup_preview, + ) + + # Restore the over-dim warning from the global path: datasets whose + # native action dim exceeds --action-dim will have high dims silently + # dropped. The mixture won't include them at training either if + # max_action_dim < native, but the warning is still load-bearing because + # the fit produces a tokenizer with no token coverage for the dropped + # dims at inference. + over_dim = [ + (i, per_ds_native_dim[i], per_ds_repo[i]) + for i in range(n) + if per_ds_native_dim[i] is not None and per_ds_native_dim[i] > action_dim + ] + if over_dim: + max_over = max(d for _, d, _ in over_dim if d is not None) + logger.warning( + "%d/%d datasets have native action_dim > --action-dim=%d (max %d). " + "Their extra dims will be silently dropped. Confirm this matches " + "the production policy's max_action_dim (mismatch => the fitted " + "BPE doesn't cover those dims at inference). Sample: %s", + len(over_dim), + n, + action_dim, + max_over, + [r for _, _, r in over_dim[:5]], + ) + + per_ds_min, per_ds_max, per_head_stats = _pool_per_head_stats(raw_min, raw_max, per_ds_key, action_dim) + logger.info( + "Per-(rt, cm) stats aggregated in %.1fs: %d heads across %d datasets.", + time.perf_counter() - t0, + len(per_head_stats), + n, + ) + return per_ds_min, per_ds_max, per_ds_key, per_head_stats + + +def _pool_per_head_stats( + raw_min: list[np.ndarray | None], + raw_max: list[np.ndarray | None], + per_ds_key: list[str | None], + action_dim: int, +) -> tuple[list[np.ndarray], list[np.ndarray], dict[str, dict[str, np.ndarray]]]: + """Pure pooling: aggregate per-dataset raw (min, max) into per-head stats, + then broadcast back per-dataset so each dataset gets the stats of its + head. Fallback keys (a dataset's repo_id used in lieu of a real + ``(rt, cm)`` pair) are singletons by construction, so they get their + own stats verbatim. + + Pooling for ``min``/``max`` is ``nanmin``/``nanmax`` across the head's + members; for those two fields ``aggregate_stats``'s count-weighted + aggregation reduces to the unweighted nanmin/nanmax (the weights only + matter for ``mean``/``std``). ``aggregate_stats`` masks ``±Inf`` to + ``NaN`` first so they don't poison the reduction; we mirror that. + + Preconditions (enforced by ``_aggregate_stats_per_head``): every entry + in ``per_ds_key`` is a non-None string, and every entry in + ``raw_min``/``raw_max`` is a non-None ``(action_dim,)`` array. + + Returns ``(per_ds_min, per_ds_max, per_head_stats)`` where the third + element is the deduplicated per-head stats map + ``{key: {"min": ..., "max": ...}}``. + """ + n = len(per_ds_key) + from collections import defaultdict + + key_to_indices: dict[str, list[int]] = defaultdict(list) + for i, k in enumerate(per_ds_key): + # Invariant from `_aggregate_stats_per_head`: every per_ds_key is a + # non-empty string and every raw_min/raw_max entry is set. + assert k is not None and raw_min[i] is not None and raw_max[i] is not None, ( + f"_pool_per_head_stats: row {i} has stale None (key={k!r}, " + f"raw_min={'None' if raw_min[i] is None else 'set'}). " + "Callers must populate every row before pooling." + ) + key_to_indices[k].append(i) + + per_head_stats: dict[str, dict[str, np.ndarray]] = {} + for key, indices in key_to_indices.items(): + stacked_min = np.stack([raw_min[i] for i in indices]) + stacked_max = np.stack([raw_max[i] for i in indices]) + # Mirror `aggregate_stats` (compute_stats.py:350): mask `±Inf` to + # `NaN` before reduction so a single Inf entry doesn't poison the + # pool (-Inf would still survive nanmin; +Inf would still survive + # nanmax). NaN is then skipped by nanmin/nanmax. + stacked_min = np.where(np.isfinite(stacked_min), stacked_min, np.nan) + stacked_max = np.where(np.isfinite(stacked_max), stacked_max, np.nan) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=RuntimeWarning) + mn = np.nanmin(stacked_min, axis=0) + mx = np.nanmax(stacked_max, axis=0) + # Defensive: a head where *every* member had Inf in some slot would + # leave NaN here. Fill with [-1, 1] so downstream `Normalize` doesn't + # see NaN. (Should not happen on healthy data; assert and continue.) + nan_dims = np.where(~np.isfinite(mn) | ~np.isfinite(mx))[0] + if nan_dims.size > 0: + logger.warning( + "Norm head %r has dims with no finite stats after Inf-mask " + "+ nanmin/nanmax: %s. Filling with [-1, 1]; expect a model " + "that ignores those output dims.", + key, + nan_dims.tolist(), + ) + mn[nan_dims] = -1.0 + mx[nan_dims] = 1.0 + per_head_stats[key] = {"min": mn, "max": mx} + + per_ds_min: list[np.ndarray] = [] + per_ds_max: list[np.ndarray] = [] + for i in range(n): + head = per_head_stats[per_ds_key[i]] + per_ds_min.append(head["min"].astype(np.float32)) + per_ds_max.append(head["max"].astype(np.float32)) + return per_ds_min, per_ds_max, per_head_stats + + +def _normalize_chunks_per_head( + stacked: np.ndarray, + per_dataset_chunks: list[int], + per_ds_min: list[np.ndarray], + per_ds_max: list[np.ndarray], +) -> np.ndarray: + """Apply per-dataset min/max to chunks in dataset-config order. + + ``_sample_via_manual`` concatenates per-dataset chunk buckets in + mixture-config order, so a cumulative sum over ``per_dataset_chunks`` + tells us which slice of ``stacked`` belongs to each dataset. Each slice + gets normalized with its own ``(min, max)``; datasets that share a norm + head get identical stats by construction (see ``_aggregate_stats_per_head``). + """ + total = int(sum(per_dataset_chunks)) + # Defensive: a future refactor to `_sample_via_manual` that changes the + # concatenation order or drops chunks must not silently corrupt the BPE + # corpus by misaligning the per-dataset normalization windows. + assert total == stacked.shape[0], ( + f"_normalize_chunks_per_head: per_dataset_chunks sums to {total} but " + f"stacked has {stacked.shape[0]} rows. Sampler/normalizer drifted " + "out of sync; review `_sample_via_manual` concatenation order." + ) + assert len(per_dataset_chunks) == len(per_ds_min) == len(per_ds_max), ( + "per_dataset_chunks, per_ds_min, per_ds_max must be 1:1; got " + f"{len(per_dataset_chunks)} / {len(per_ds_min)} / {len(per_ds_max)}." + ) + out = np.zeros_like(stacked) + offset = 0 + for i, count in enumerate(per_dataset_chunks): + if count <= 0: + continue + out[offset : offset + count] = _normalize_chunks( + stacked[offset : offset + count], per_ds_min[i], per_ds_max[i] + ) + offset += count + return out + + def _compute_budgets_weighted(mixture_cfg: DatasetMixtureConfig, total_chunks: int) -> list[int]: """Pure weight-proportional budget per dataset (no clamps). @@ -1016,6 +1413,23 @@ def main() -> int: args = parse_args() _setup_logging(args.log_level) + # The --use-mixture-dataloader path uses ``_extract_action_stats(mixture.meta, ...)`` + # which returns global aggregates; threading per-head normalization through + # the dataloader-drained chunks would need per-sample dataset_index from + # the batch and isn't done yet. Warn-and-fall-back so existing invocations + # that pass `--use-mixture-dataloader` without explicitly setting + # `--per-head-norm` keep working; the per-head default still applies on + # the manual sampler path (which is what most fits use). + if args.use_mixture_dataloader and args.per_head_norm: + logger.warning( + "--per-head-norm is not yet implemented on the --use-mixture-dataloader " + "path; falling back to global aggregated_action_stats() for this run. " + "Drop --use-mixture-dataloader to use the default manual sampler " + "(which does support per-head), or pass --no-per-head-norm explicitly " + "to silence this warning." + ) + args.per_head_norm = False + out_dir = args.out_dir if args.pilot: out_dir = out_dir / "pilot" @@ -1048,6 +1462,10 @@ def main() -> int: build_time = 0.0 sample_errors: dict[int, str] = {} per_dataset_chunks: list[int] = [0] * len(mixture_cfg.datasets) + action_min: np.ndarray | None = None + action_max: np.ndarray | None = None + per_ds_norm_keys: list[str] | None = None + per_head_stats_report: dict[str, dict[str, list[float]]] | None = None if args.use_mixture_dataloader: # Slow path: full WeightedDatasetMixture dataloader. logger.info( @@ -1089,11 +1507,25 @@ def main() -> int: # Use mixture-computed stats. action_min, action_max = _extract_action_stats(mixture.meta, args.action_dim) else: - # Fast path: aggregate stats manually, sample chunks per-dataset with - # scipy interp for FPS resampling. - action_min, action_max, _ = _aggregate_stats_manual(mixture_cfg, args.action_dim, args.num_workers) - logger.info("Global action_min: %s", np.round(action_min, 3).tolist()) - logger.info("Global action_max: %s", np.round(action_max, 3).tolist()) + # Fast path: aggregate stats manually (per-head or global), sample + # chunks per-dataset with scipy interp for FPS resampling. + if args.per_head_norm: + per_ds_min, per_ds_max, per_ds_norm_keys, per_head_stats = _aggregate_stats_per_head( + mixture_cfg, args.action_dim, args.num_workers + ) + per_head_stats_report = { + key: { + "min": [round(float(x), 6) for x in s["min"].tolist()], + "max": [round(float(x), 6) for x in s["max"].tolist()], + } + for key, s in per_head_stats.items() + } + else: + action_min, action_max, _ = _aggregate_stats_manual( + mixture_cfg, args.action_dim, args.num_workers + ) + logger.info("Global action_min: %s", np.round(action_min, 3).tolist()) + logger.info("Global action_max: %s", np.round(action_max, 3).tolist()) t_drain0 = time.perf_counter() raw_chunks, per_dataset_chunks, sample_errors = _sample_via_manual( mixture_cfg, @@ -1124,12 +1556,16 @@ def main() -> int: for i, c in enumerate(raw_chunks): clip = min(c.shape[1], dim) stacked[i, :, :clip] = c[:, :clip] - all_chunks = _normalize_chunks(stacked, action_min, action_max) + if args.per_head_norm: + all_chunks = _normalize_chunks_per_head(stacked, per_dataset_chunks, per_ds_min, per_ds_max) + else: + all_chunks = _normalize_chunks(stacked, action_min, action_max) drain_time = time.perf_counter() - t_drain0 logger.info( - "Manual path: %d chunks normalized in %.1fs", + "Manual path: %d chunks normalized in %.1fs (norm=%s)", all_chunks.shape[0], drain_time, + "per_head" if args.per_head_norm else "global", ) fit_time = 0.0 @@ -1168,10 +1604,14 @@ def main() -> int: "dataloader_batch_size": args.dataloader_batch_size, "num_workers": args.num_workers, "sampler": "mixture_dataloader" if args.use_mixture_dataloader else "manual_parquet_interp", + "normalization": "per_robot_type_control_mode" if args.per_head_norm else "global", "per_dataset_chunks": per_dataset_chunks, "sample_errors": sample_errors, - "global_action_min": action_min.tolist(), - "global_action_max": action_max.tolist(), + "global_action_min": action_min.tolist() if action_min is not None else None, + "global_action_max": action_max.tolist() if action_max is not None else None, + "per_dataset_norm_keys": per_ds_norm_keys, + "per_head_action_stats": per_head_stats_report, + "n_norm_heads": (len(per_head_stats_report) if per_head_stats_report is not None else None), "timings_seconds": { "build_mixture": build_time, "drain": drain_time, diff --git a/tests/scripts/test_fit_fast_tokenizer.py b/tests/scripts/test_fit_fast_tokenizer.py index f9ceb1fe..ee75198b 100644 --- a/tests/scripts/test_fit_fast_tokenizer.py +++ b/tests/scripts/test_fit_fast_tokenizer.py @@ -28,7 +28,12 @@ import numpy as np from opentau.configs.default import DatasetConfig -from opentau.scripts.fit_fast_tokenizer import _sample_chunks_for_dataset_manual +from opentau.scripts.fit_fast_tokenizer import ( + _normalize_chunks, + _normalize_chunks_per_head, + _pool_per_head_stats, + _sample_chunks_for_dataset_manual, +) from tests.fixtures.constants import DEFAULT_FPS, DUMMY_REPO_ID @@ -81,3 +86,296 @@ def test_none_action_freq_matches_explicit_native_fps(self, tmp_path, lerobot_da # bit-equal output. A future divergence would surface here immediately. for c_none, c_native in zip(chunks_none, chunks_native, strict=True): np.testing.assert_array_equal(c_none, c_native) + + +class TestPoolPerHeadStats: + """``_pool_per_head_stats`` pools per-dataset action stats per norm key. + + These cases pin the contract that ``_aggregate_stats_per_head`` relies on + -- the script normalizes each chunk with the *pooled* head's + (min, max), and that pool MUST match what the training policy's + ``DatasetMixtureMetadata._build_norm_heads`` computes from the same raw + stats. If this drifts the fit normalization stops matching the + training-time normalization and the BPE chunk-length distribution at + inference will diverge from what the fit script saw -- exactly the + failure mode that motivated this code path. + """ + + def test_two_datasets_one_head_pool_min_max(self): + """Two datasets sharing a (rt, cm) get the elementwise nanmin/nanmax.""" + action_dim = 4 + # Real per-head fits zero-pad trailing slots (matches `pad_vector` in + # `_to_standard_data_format`); the pure pooler is dim-agnostic, so this + # test uses real numbers throughout and relies on a separate test for + # the production zero-pad behaviour. + raw_min = [ + np.array([-1.0, -2.0, -3.0, -1.5], dtype=np.float64), + np.array([-0.5, -3.0, -2.0, -1.0], dtype=np.float64), + ] + raw_max = [ + np.array([1.0, 2.0, 3.0, 1.5], dtype=np.float64), + np.array([2.0, 1.5, 4.0, 1.0], dtype=np.float64), + ] + per_ds_key = ["robotA::ee", "robotA::ee"] + + per_ds_min, per_ds_max, per_head_stats = _pool_per_head_stats( + raw_min, raw_max, per_ds_key, action_dim + ) + + # One head, two datasets -- both datasets get the same pooled (min, max). + assert len(per_head_stats) == 1 + assert "robotA::ee" in per_head_stats + np.testing.assert_array_equal(per_ds_min[0], per_ds_min[1]) + np.testing.assert_array_equal(per_ds_max[0], per_ds_max[1]) + # Pooled values: elementwise nanmin / nanmax across the two datasets. + np.testing.assert_array_equal(per_ds_min[0], np.array([-1.0, -3.0, -3.0, -1.5], dtype=np.float32)) + np.testing.assert_array_equal(per_ds_max[0], np.array([2.0, 2.0, 4.0, 1.5], dtype=np.float32)) + + def test_distinct_heads_kept_separate(self): + """Different (rt, cm) pairs do NOT pool, even if dims overlap.""" + action_dim = 3 + raw_min = [ + np.array([-1.0, -1.0, -1.0], dtype=np.float64), + np.array([-10.0, -10.0, -10.0], dtype=np.float64), + ] + raw_max = [ + np.array([1.0, 1.0, 1.0], dtype=np.float64), + np.array([10.0, 10.0, 10.0], dtype=np.float64), + ] + per_ds_key = ["robotA::ee", "robotB::joint"] + + per_ds_min, per_ds_max, per_head_stats = _pool_per_head_stats( + raw_min, raw_max, per_ds_key, action_dim + ) + + assert len(per_head_stats) == 2 + # Each dataset keeps its own stats verbatim (singleton heads). + np.testing.assert_array_equal(per_ds_min[0], np.array([-1.0, -1.0, -1.0], dtype=np.float32)) + np.testing.assert_array_equal(per_ds_min[1], np.array([-10.0, -10.0, -10.0], dtype=np.float32)) + np.testing.assert_array_equal(per_ds_max[0], np.array([1.0, 1.0, 1.0], dtype=np.float32)) + np.testing.assert_array_equal(per_ds_max[1], np.array([10.0, 10.0, 10.0], dtype=np.float32)) + + def test_inf_masked_before_pool(self): + """``±Inf`` in raw stats are masked to NaN before nanmin/nanmax. + + Mirrors ``aggregate_stats`` (compute_stats.py:350) -- without the + mask, a single ``+Inf`` poisons nanmax for the whole head and + ``-Inf`` poisons nanmin. Both would make the production + ``(x - min) / (max - min)`` evaluate to 0 in float32 (the divisor + wins), breaking the chunk's normalization silently. + """ + action_dim = 2 + raw_min = [ + np.array([-1.0, -np.inf], dtype=np.float64), # one corrupted dim + np.array([-2.0, -2.0], dtype=np.float64), + ] + raw_max = [ + np.array([np.inf, 1.0], dtype=np.float64), # corrupted dim + np.array([2.0, 2.0], dtype=np.float64), + ] + per_ds_key = ["robotA::ee", "robotA::ee"] + + per_ds_min, per_ds_max, _ = _pool_per_head_stats(raw_min, raw_max, per_ds_key, action_dim) + # The Inf-corrupted entries get masked to NaN; the other peer's finite + # values dominate the pool. + np.testing.assert_array_equal(per_ds_min[0], np.array([-2.0, -2.0], dtype=np.float32)) + np.testing.assert_array_equal(per_ds_max[0], np.array([2.0, 2.0], dtype=np.float32)) + + def test_all_inf_dim_falls_back_to_minus_one_one(self): + """A dim where every head member is ±Inf gets the [-1, 1] fallback.""" + action_dim = 2 + raw_min = [ + np.array([-1.0, -np.inf], dtype=np.float64), + np.array([-2.0, -np.inf], dtype=np.float64), + ] + raw_max = [ + np.array([1.0, np.inf], dtype=np.float64), + np.array([2.0, np.inf], dtype=np.float64), + ] + per_ds_key = ["robotA::ee", "robotA::ee"] + + per_ds_min, per_ds_max, _ = _pool_per_head_stats(raw_min, raw_max, per_ds_key, action_dim) + # Dim 0 pools normally; dim 1 is Inf across all members -> [-1, 1] fill. + np.testing.assert_array_equal(per_ds_min[0], np.array([-2.0, -1.0], dtype=np.float32)) + np.testing.assert_array_equal(per_ds_max[0], np.array([2.0, 1.0], dtype=np.float32)) + + +class TestNormalizeChunksPerHead: + """``_normalize_chunks_per_head`` normalizes each dataset's slice + independently using that dataset's (min, max). + + ``_sample_via_manual`` concatenates per-dataset chunk buckets in + mixture-config order, so the helper relies on ``per_dataset_chunks`` + (the count per dataset) to know which row range belongs to each + dataset. This test pins that contract: chunks from dataset 0 get + normalized with ``per_ds_min[0]/max[0]``, chunks from dataset 1 with + ``per_ds_min[1]/max[1]``, etc., regardless of how datasets with zero + chunks are interleaved. + """ + + def test_per_dataset_slice_normalization(self): + action_dim = 2 + chunk_size = 3 + # 3 datasets, the middle one had 0 chunks (degenerate, common + # outcome of very-small total_chunks budgets after rounding). + per_dataset_chunks = [2, 0, 1] + # Dataset 0: chunks all equal to its raw max, so after [-1, 1] + # normalization they should land at +1 on every dim. + # Dataset 2: chunks all equal to its raw min, so should land at -1. + per_ds_min = [ + np.array([-2.0, -2.0], dtype=np.float32), # ds 0 + np.array([-1.0, -1.0], dtype=np.float32), # ds 1 (unused, 0 chunks) + np.array([-5.0, -10.0], dtype=np.float32), # ds 2 + ] + per_ds_max = [ + np.array([2.0, 2.0], dtype=np.float32), + np.array([1.0, 1.0], dtype=np.float32), + np.array([5.0, 10.0], dtype=np.float32), + ] + + # Build the stacked chunks accordingly. + stacked = np.zeros((3, chunk_size, action_dim), dtype=np.float32) + stacked[0:2] = 2.0 # ds 0 at raw_max + stacked[2] = np.array([-5.0, -10.0], dtype=np.float32) # ds 2 at raw_min + + out = _normalize_chunks_per_head(stacked, per_dataset_chunks, per_ds_min, per_ds_max) + + # First two rows -> all +1 (ds 0 at its raw_max) + np.testing.assert_allclose(out[0:2], 1.0, atol=1e-7) + # Third row -> all -1 (ds 2 at its raw_min) + np.testing.assert_allclose(out[2], -1.0, atol=1e-7) + + def test_matches_global_when_all_datasets_share_one_head(self): + """Per-head normalization with one shared (rt, cm) === global.""" + rng = np.random.default_rng(0) + action_dim = 4 + n_chunks = 7 + chunk_size = 5 + stacked = rng.uniform(-1, 1, size=(n_chunks, chunk_size, action_dim)).astype(np.float32) + + # Two datasets that share the same pooled stats. + per_dataset_chunks = [3, 4] + shared_min = np.array([-2.0, -2.0, -2.0, -2.0], dtype=np.float32) + shared_max = np.array([2.0, 2.0, 2.0, 2.0], dtype=np.float32) + + out_per_head = _normalize_chunks_per_head( + stacked, per_dataset_chunks, [shared_min, shared_min], [shared_max, shared_max] + ) + out_global = _normalize_chunks(stacked, shared_min, shared_max) + + # Bit-exact: both paths apply the same affine transform. + np.testing.assert_array_equal(out_per_head, out_global) + + def test_chunk_count_mismatch_raises_assert(self): + """Sampler/normalizer drift is caught at the boundary, not silently.""" + action_dim = 2 + stacked = np.zeros((3, 4, action_dim), dtype=np.float32) + per_dataset_chunks = [2, 2] # sum=4 but stacked has 3 rows + per_ds_min = [np.array([-1.0, -1.0], dtype=np.float32)] * 2 + per_ds_max = [np.array([1.0, 1.0], dtype=np.float32)] * 2 + + import pytest + + with pytest.raises(AssertionError, match="sums to"): + _normalize_chunks_per_head(stacked, per_dataset_chunks, per_ds_min, per_ds_max) + + +class TestNormalizeEquivalenceVsProduction: + """``_normalize_chunks_per_head`` matches production ``Normalize`` byte-for-byte. + + Pins the invariant the PR claims: a chunk normalized at fit time produces + the same byte sequence the policy feeds to the FAST tokenizer at training + time. The production path is ``Normalize({"ACTION": MIN_MAX}).forward`` + with per-head stacked stats buffers (PR #347) and per-sample + ``dataset_index`` lookups. Our manual path applies the same per-dataset + (min, max) directly. Any drift between these two -- e.g. forgetting the + ``* 2 - 1`` shift, swapping EPS conventions, or accidentally truncating + the trailing-dim slot -- means the BPE corpus the fit operates on no + longer matches what training sees, and the published token-length + distribution becomes uncalibrated guidance. + """ + + def test_per_head_normalize_matches_production_normalize(self): + """Synthesize 2 datasets, 2 heads; verify per-dataset output matches + ``Normalize`` driven by per-sample ``dataset_index``. + """ + import torch + + from opentau.configs.types import FeatureType, NormalizationMode, PolicyFeature + from opentau.policies.normalize import Normalize + + action_dim = 4 + chunk_size = 3 + # Dataset 0: head "robotA::ee", min=[-2, -2, 0, 0], max=[2, 2, 0, 0] + # Dataset 1: head "robotB::joint", min=[-1, -5, -10, 0], max=[3, 5, 10, 0] + # Trailing zeros simulate `pad_vector` zero-pad in + # `_to_standard_data_format`. Production `(x - 0) / (0 - 0 + EPS) * 2 - 1` + # evaluates to -1 for those slots, and we must too. + per_ds_min_f32 = [ + np.array([-2.0, -2.0, 0.0, 0.0], dtype=np.float32), + np.array([-1.0, -5.0, -10.0, 0.0], dtype=np.float32), + ] + per_ds_max_f32 = [ + np.array([2.0, 2.0, 0.0, 0.0], dtype=np.float32), + np.array([3.0, 5.0, 10.0, 0.0], dtype=np.float32), + ] + + # Per-dataset chunk counts -- mixed sizes to exercise the offset arithmetic. + per_dataset_chunks = [3, 4] + n_total = sum(per_dataset_chunks) + + rng = np.random.default_rng(42) + # Real signal on dims 0-1 for ds 0, dims 0-2 for ds 1; zero on padded tail. + stacked = np.zeros((n_total, chunk_size, action_dim), dtype=np.float32) + stacked[0:3, :, 0:2] = rng.uniform(-2, 2, size=(3, chunk_size, 2)).astype(np.float32) + stacked[3:7, :, 0:3] = rng.uniform(-5, 5, size=(4, chunk_size, 3)).astype(np.float32) + + out_manual = _normalize_chunks_per_head(stacked, per_dataset_chunks, per_ds_min_f32, per_ds_max_f32) + + # Build the production `Normalize` layer with per-dataset stats buffers. + # The contract: pass `per_dataset_stats` as a list aligned to + # `dataset_names`; `Normalize.forward(batch, dataset_index)` then + # picks each sample's stats via `dataset_index[i]`. + features = {"action": PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))} + norm_map = {FeatureType.ACTION: NormalizationMode.MIN_MAX} + per_dataset_stats = [ + { + "action": { + "min": torch.from_numpy(per_ds_min_f32[0]), + "max": torch.from_numpy(per_ds_max_f32[0]), + } + }, + { + "action": { + "min": torch.from_numpy(per_ds_min_f32[1]), + "max": torch.from_numpy(per_ds_max_f32[1]), + } + }, + ] + normalize = Normalize( + features=features, + norm_map=norm_map, + per_dataset_stats=per_dataset_stats, + dataset_names=["dsA", "dsB"], + ) + # Production processes one chunk at a time; flatten to (N*T, D) so + # each sample carries its own dataset_index and the buffer gather + # exactly mirrors training. + n, t, d = stacked.shape + dataset_index = torch.from_numpy( + np.concatenate([np.full(per_dataset_chunks[k] * t, k, dtype=np.int64) for k in range(2)]) + ) + batch = {"action": torch.from_numpy(stacked.reshape(n * t, d))} + out_prod = normalize(batch, dataset_index)["action"].numpy().reshape(n, t, d) + + # Both paths apply `(x - min) / (max - min + EPS) * 2 - 1` with EPS=1e-8. + # Production runs the whole expression in float32; the manual path + # computes in float64 then casts to float32 at the end, so low-bit + # rounding differs by ~1 ULP (1.19e-7) on a few entries. The DCT scale + # the BPE codec sees is O(1) so this is well below the threshold that + # would change the BPE token-id sequence. Padded slots become -1 in + # both (zero data, zero stats, +EPS divisor -> -1). + np.testing.assert_allclose(out_manual, out_prod, rtol=0, atol=2e-7) + # Sanity: the padded suffix actually is -1 in both outputs. + np.testing.assert_allclose(out_manual[0:3, :, 2:4], -1.0, atol=1e-7) + np.testing.assert_allclose(out_manual[3:7, :, 3:4], -1.0, atol=1e-7)