diff --git a/src/opentau/utils/utils.py b/src/opentau/utils/utils.py index a8c96cfa..f7e93f13 100644 --- a/src/opentau/utils/utils.py +++ b/src/opentau/utils/utils.py @@ -402,6 +402,11 @@ def create_dummy_observation(cfg, device, dtype=torch.bfloat16) -> dict: "prompt": ["Pick up yellow lego block and put it in the bin"], "img_is_pad": torch.zeros((1, cfg.num_cams), dtype=torch.bool, device=device), "action_is_pad": torch.zeros((1, cfg.action_chunk), dtype=torch.bool, device=device), + # Pin per-sample normalization to dataset row 0 so checkpoints trained + # with per-dataset stats on >1 datasets don't trip + # `_resolve_dataset_index`'s "missing selector" guard. A no-op for + # single-dataset policies (the fallback would resolve to the same). + "dataset_index": torch.zeros((1,), dtype=torch.long, device=device), } diff --git a/tests/utils/test_utils_utils.py b/tests/utils/test_utils_utils.py index 88bf72e3..5b131be8 100644 --- a/tests/utils/test_utils_utils.py +++ b/tests/utils/test_utils_utils.py @@ -25,6 +25,7 @@ from opentau.utils.hub import HubMixin from opentau.utils.utils import ( capture_timestamp_utc, + create_dummy_observation, encode_accelerator_state_dict, format_big_number, get_channel_first_image_shape, @@ -390,3 +391,24 @@ def test_another_invalid_shape_raises_value_error(): def test_encode_accelerator_state_dict(obj, expected): output = encode_accelerator_state_dict(obj) assert output == expected, f"Expected {expected}, but got {output} for input" + + +def test_create_dummy_observation_includes_dataset_index(): + """Regression for #341: the smoke-test observation must carry a + `dataset_index` so multi-dataset checkpoints don't crash inside + `_resolve_dataset_index`. + """ + + @dataclass + class _Cfg: + resolution: tuple = (224, 224) + num_cams: int = 2 + max_state_dim: int = 32 + action_chunk: int = 10 + + obs = create_dummy_observation(_Cfg(), device=torch.device("cpu"), dtype=torch.float32) + + assert "dataset_index" in obs + assert obs["dataset_index"].shape == (1,) + assert obs["dataset_index"].dtype == torch.long + assert torch.equal(obs["dataset_index"], torch.zeros(1, dtype=torch.long))