fix(scripts): inject dataset_index in inference smoke obs#342
Conversation
`create_dummy_observation` was missing the `dataset_index` (or `dataset_repo_id`) selector that per-dataset Normalize/Unnormalize requires for >1-dataset checkpoints. The nightly regression's `Run Inference` step against a 2-dataset CI checkpoint tripped `_resolve_dataset_index`'s "missing selector" KeyError on the first `sample_actions` call. Pin to row 0 so multi-dataset checkpoints take the first dataset's stats in the smoke path; single-dataset checkpoints are unaffected (the zero fallback would resolve to the same value).
shuheng-liu
left a comment
There was a problem hiding this comment.
Review
Verdict: LGTM, minor non-blocking suggestion.
Summary
Minimal, well-scoped fix for #341. After #336 introduced per-dataset Normalize/Unnormalize, _resolve_dataset_index raises KeyError when a multi-dataset checkpoint sees a batch with neither dataset_index nor dataset_repo_id. scripts/inference.py's smoke path was the only inference caller missing the injection — grpc/server.py and eval.py already inject via their dataset_repo_id config field.
Correctness — verified
Cross-checked the "no behavior change for single-dataset" claim against src/opentau/policies/pretrained.py::_resolve_dataset_index:
- Old single-dataset path (no key in batch,
_stacked_num_datasets() <= 1): returnstorch.zeros(batch_size or 1, dtype=torch.long, device=self._model_device()). - New explicit path (
"dataset_index" in batch):idx.to(dtype=torch.long, device=self._model_device()).
Both produce a (1,) torch.long tensor on the policy's compute device — the claim holds. The "pin multi-dataset smoke runs to row 0" semantics is a reasonable convention for a "just exercise the forward pass" benchmark.
Also consistent with the existing pattern in src/opentau/scripts/export_to_onnx.py, which pre-injects torch.zeros(bsize, dtype=torch.long, ...) for the same reason (a traced wrapper bypassing _resolve_dataset_index).
Test
The regression test correctly:
- Uses a minimal
_Cfgdataclass with only the attributes the function reads. - Asserts presence + shape + dtype + value of
dataset_index. - Stays CPU/float32 — no GPU dependency.
The acknowledged limitation (no end-to-end through a 2-dataset policy) is the right trade-off: it would require a heavy fixture, and the unit test locks in the exact line that was wrong.
Style / conventions
- The 4-line comment is slightly long for a 1-line addition, but the WHY is genuinely non-obvious (
_resolve_dataset_index's "missing selector" guard + the single-dataset no-op invariant), so it earns its keep. Leave as-is. - Dtype
torch.longis correct —_resolve_dataset_indexre-casts to long anyway, but providing it from the source skips that cast.
Minor suggestion (non-blocking)
create_dummy_observation's docstring Returns: block lists "camera observations, state, prompt, and padding flags" — consider tacking on "and dataset index" so the new key is mentioned. Trivial, fine to skip.
Risk
Low. Surgical (5 + 22 lines, single dict-key addition), no public API change, all CI green (CPU tests, pre-commit, model parallelism, checklist). The nightly Regression Tests workflow on main post-merge should confirm the Run Inference step recovers.
Generated by Claude Code
What this does
Fixes #341. The nightly Regression Tests workflow's
Run Inferencestep has been failing against the 2-dataset CI checkpoint with:After #336 landed per-dataset Normalize/Unnormalize, multi-dataset checkpoints require the batch to carry either
dataset_indexordataset_repo_id. The gRPC server and eval script already inject one;create_dummy_observation(used only bysrc/opentau/scripts/inference.pysmoke test) did not.This PR pins the smoke-path batch to dataset row 0:
_resolve_dataset_indexfallback would have resolved to the samezeros((1,)).Real deployment callers (
grpc/server.py,eval.py) continue to read their owndataset_repo_idconfig field — those paths already worked, this PR doesn't touch them.🐛 Bug
How it was tested
test_create_dummy_observation_includes_dataset_indexintests/utils/test_utils_utils.pyto lock in the regression.tests/utils/test_utils_utils.pyandtests/policies/test_normalize_per_dataset.pylocally — all pass.Run Inferencestep end-to-end locally without a GPU and a trained checkpoint; the unit test exercises the only line that was wrong, and the failure mode (_resolve_dataset_indexKeyError) is purely batch-content-driven so a fresh CPU forward through a 2-dataset policy would hit the same guard.How to checkout & try? (for the reviewer)
To re-verify the original failure scenario, the next nightly
Regression Testsrun onmain(post-merge) should show theRun Inferencestep passing again. Alternatively, re-trigger viaworkflow_dispatchon.github/workflows/regression_test.yml.Checklist