From a9555dac287d86453acdfacda8431f2eeb1b7fe1 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Wed, 27 May 2026 10:44:32 -0700 Subject: [PATCH] fix(scripts): inject dataset_index in inference smoke obs (#341) `create_dummy_observation` was missing the `dataset_index` (or `dataset_repo_id`) selector that per-dataset Normalize/Unnormalize requires for >1-dataset checkpoints. The nightly regression's `Run Inference` step against a 2-dataset CI checkpoint tripped `_resolve_dataset_index`'s "missing selector" KeyError on the first `sample_actions` call. Pin to row 0 so multi-dataset checkpoints take the first dataset's stats in the smoke path; single-dataset checkpoints are unaffected (the zero fallback would resolve to the same value). --- src/opentau/utils/utils.py | 5 +++++ tests/utils/test_utils_utils.py | 22 ++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/src/opentau/utils/utils.py b/src/opentau/utils/utils.py index a8c96cfa..f7e93f13 100644 --- a/src/opentau/utils/utils.py +++ b/src/opentau/utils/utils.py @@ -402,6 +402,11 @@ def create_dummy_observation(cfg, device, dtype=torch.bfloat16) -> dict: "prompt": ["Pick up yellow lego block and put it in the bin"], "img_is_pad": torch.zeros((1, cfg.num_cams), dtype=torch.bool, device=device), "action_is_pad": torch.zeros((1, cfg.action_chunk), dtype=torch.bool, device=device), + # Pin per-sample normalization to dataset row 0 so checkpoints trained + # with per-dataset stats on >1 datasets don't trip + # `_resolve_dataset_index`'s "missing selector" guard. A no-op for + # single-dataset policies (the fallback would resolve to the same). + "dataset_index": torch.zeros((1,), dtype=torch.long, device=device), } diff --git a/tests/utils/test_utils_utils.py b/tests/utils/test_utils_utils.py index 88bf72e3..5b131be8 100644 --- a/tests/utils/test_utils_utils.py +++ b/tests/utils/test_utils_utils.py @@ -25,6 +25,7 @@ from opentau.utils.hub import HubMixin from opentau.utils.utils import ( capture_timestamp_utc, + create_dummy_observation, encode_accelerator_state_dict, format_big_number, get_channel_first_image_shape, @@ -390,3 +391,24 @@ def test_another_invalid_shape_raises_value_error(): def test_encode_accelerator_state_dict(obj, expected): output = encode_accelerator_state_dict(obj) assert output == expected, f"Expected {expected}, but got {output} for input" + + +def test_create_dummy_observation_includes_dataset_index(): + """Regression for #341: the smoke-test observation must carry a + `dataset_index` so multi-dataset checkpoints don't crash inside + `_resolve_dataset_index`. + """ + + @dataclass + class _Cfg: + resolution: tuple = (224, 224) + num_cams: int = 2 + max_state_dim: int = 32 + action_chunk: int = 10 + + obs = create_dummy_observation(_Cfg(), device=torch.device("cpu"), dtype=torch.float32) + + assert "dataset_index" in obs + assert obs["dataset_index"].shape == (1,) + assert obs["dataset_index"].dtype == torch.long + assert torch.equal(obs["dataset_index"], torch.zeros(1, dtype=torch.long))