Skip to content

fix(scripts): inject dataset_index in inference smoke obs#342

Merged
shuheng-liu merged 1 commit into
mainfrom
claude/laughing-aryabhata-cc41f5
May 27, 2026
Merged

fix(scripts): inject dataset_index in inference smoke obs#342
shuheng-liu merged 1 commit into
mainfrom
claude/laughing-aryabhata-cc41f5

Conversation

@shuheng-liu
Copy link
Copy Markdown
Member

What this does

Fixes #341. The nightly Regression Tests workflow's Run Inference step has been failing against the 2-dataset CI checkpoint with:

KeyError: 'Per-dataset normalization with 2 datasets requires either
`dataset_index` (LongTensor of shape (B,)) or `dataset_repo_id` (str or
list[str] of length B) in the batch.'

After #336 landed per-dataset Normalize/Unnormalize, multi-dataset checkpoints require the batch to carry either dataset_index or dataset_repo_id. The gRPC server and eval script already inject one; create_dummy_observation (used only by src/opentau/scripts/inference.py smoke test) did not.

This PR pins the smoke-path batch to dataset row 0:

  • Single-dataset checkpoints: no behavior change — the _resolve_dataset_index fallback would have resolved to the same zeros((1,)).
  • Multi-dataset checkpoints: the smoke test runs the first dataset's stats, which is what we want for a "just exercise the forward pass" benchmark.

Real deployment callers (grpc/server.py, eval.py) continue to read their own dataset_repo_id config field — those paths already worked, this PR doesn't touch them.

🐛 Bug

How it was tested

  • Added test_create_dummy_observation_includes_dataset_index in tests/utils/test_utils_utils.py to lock in the regression.
  • Re-ran tests/utils/test_utils_utils.py and tests/policies/test_normalize_per_dataset.py locally — all pass.
  • Cannot reproduce the CI Run Inference step end-to-end locally without a GPU and a trained checkpoint; the unit test exercises the only line that was wrong, and the failure mode (_resolve_dataset_index KeyError) is purely batch-content-driven so a fresh CPU forward through a 2-dataset policy would hit the same guard.

How to checkout & try? (for the reviewer)

pytest -sx tests/utils/test_utils_utils.py::test_create_dummy_observation_includes_dataset_index

To re-verify the original failure scenario, the next nightly Regression Tests run on main (post-merge) should show the Run Inference step passing again. Alternatively, re-trigger via workflow_dispatch on .github/workflows/regression_test.yml.

Checklist

  • I have added Google-style docstrings to important functions and ensured function parameters are typed.
  • My PR includes policy-related changes.
    • If the above is checked: I have run the GPU pytests (pytest -m "gpu") and regression tests.

`create_dummy_observation` was missing the `dataset_index` (or
`dataset_repo_id`) selector that per-dataset Normalize/Unnormalize
requires for >1-dataset checkpoints. The nightly regression's
`Run Inference` step against a 2-dataset CI checkpoint tripped
`_resolve_dataset_index`'s "missing selector" KeyError on the first
`sample_actions` call. Pin to row 0 so multi-dataset checkpoints
take the first dataset's stats in the smoke path; single-dataset
checkpoints are unaffected (the zero fallback would resolve to the
same value).
@shuheng-liu shuheng-liu self-assigned this May 27, 2026
@shuheng-liu shuheng-liu added bug Something isn't working and removed 🐛 Bug labels May 27, 2026
@shuheng-liu shuheng-liu marked this pull request as ready for review May 27, 2026 18:12
Copy link
Copy Markdown
Member Author

@shuheng-liu shuheng-liu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review

Verdict: LGTM, minor non-blocking suggestion.

Summary

Minimal, well-scoped fix for #341. After #336 introduced per-dataset Normalize/Unnormalize, _resolve_dataset_index raises KeyError when a multi-dataset checkpoint sees a batch with neither dataset_index nor dataset_repo_id. scripts/inference.py's smoke path was the only inference caller missing the injection — grpc/server.py and eval.py already inject via their dataset_repo_id config field.

Correctness — verified

Cross-checked the "no behavior change for single-dataset" claim against src/opentau/policies/pretrained.py::_resolve_dataset_index:

  • Old single-dataset path (no key in batch, _stacked_num_datasets() <= 1): returns torch.zeros(batch_size or 1, dtype=torch.long, device=self._model_device()).
  • New explicit path ("dataset_index" in batch): idx.to(dtype=torch.long, device=self._model_device()).

Both produce a (1,) torch.long tensor on the policy's compute device — the claim holds. The "pin multi-dataset smoke runs to row 0" semantics is a reasonable convention for a "just exercise the forward pass" benchmark.

Also consistent with the existing pattern in src/opentau/scripts/export_to_onnx.py, which pre-injects torch.zeros(bsize, dtype=torch.long, ...) for the same reason (a traced wrapper bypassing _resolve_dataset_index).

Test

The regression test correctly:

  • Uses a minimal _Cfg dataclass with only the attributes the function reads.
  • Asserts presence + shape + dtype + value of dataset_index.
  • Stays CPU/float32 — no GPU dependency.

The acknowledged limitation (no end-to-end through a 2-dataset policy) is the right trade-off: it would require a heavy fixture, and the unit test locks in the exact line that was wrong.

Style / conventions

  • The 4-line comment is slightly long for a 1-line addition, but the WHY is genuinely non-obvious (_resolve_dataset_index's "missing selector" guard + the single-dataset no-op invariant), so it earns its keep. Leave as-is.
  • Dtype torch.long is correct — _resolve_dataset_index re-casts to long anyway, but providing it from the source skips that cast.

Minor suggestion (non-blocking)

create_dummy_observation's docstring Returns: block lists "camera observations, state, prompt, and padding flags" — consider tacking on "and dataset index" so the new key is mentioned. Trivial, fine to skip.

Risk

Low. Surgical (5 + 22 lines, single dict-key addition), no public API change, all CI green (CPU tests, pre-commit, model parallelism, checklist). The nightly Regression Tests workflow on main post-merge should confirm the Run Inference step recovers.


Generated by Claude Code

@WilliamYue37 WilliamYue37 self-requested a review May 27, 2026 18:19
@shuheng-liu shuheng-liu merged commit dd0942e into main May 27, 2026
19 checks passed
@shuheng-liu shuheng-liu deleted the claude/laughing-aryabhata-cc41f5 branch May 27, 2026 18:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Nightly regression: Run Inference fails with per-dataset normalization KeyError

2 participants