Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/opentau/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,11 @@ def create_dummy_observation(cfg, device, dtype=torch.bfloat16) -> dict:
"prompt": ["Pick up yellow lego block and put it in the bin"],
"img_is_pad": torch.zeros((1, cfg.num_cams), dtype=torch.bool, device=device),
"action_is_pad": torch.zeros((1, cfg.action_chunk), dtype=torch.bool, device=device),
# Pin per-sample normalization to dataset row 0 so checkpoints trained
# with per-dataset stats on >1 datasets don't trip
# `_resolve_dataset_index`'s "missing selector" guard. A no-op for
# single-dataset policies (the fallback would resolve to the same).
"dataset_index": torch.zeros((1,), dtype=torch.long, device=device),
}


Expand Down
22 changes: 22 additions & 0 deletions tests/utils/test_utils_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from opentau.utils.hub import HubMixin
from opentau.utils.utils import (
capture_timestamp_utc,
create_dummy_observation,
encode_accelerator_state_dict,
format_big_number,
get_channel_first_image_shape,
Expand Down Expand Up @@ -390,3 +391,24 @@ def test_another_invalid_shape_raises_value_error():
def test_encode_accelerator_state_dict(obj, expected):
output = encode_accelerator_state_dict(obj)
assert output == expected, f"Expected {expected}, but got {output} for input"


def test_create_dummy_observation_includes_dataset_index():
"""Regression for #341: the smoke-test observation must carry a
`dataset_index` so multi-dataset checkpoints don't crash inside
`_resolve_dataset_index`.
"""

@dataclass
class _Cfg:
resolution: tuple = (224, 224)
num_cams: int = 2
max_state_dim: int = 32
action_chunk: int = 10

obs = create_dummy_observation(_Cfg(), device=torch.device("cpu"), dtype=torch.float32)

assert "dataset_index" in obs
assert obs["dataset_index"].shape == (1,)
assert obs["dataset_index"].dtype == torch.long
assert torch.equal(obs["dataset_index"], torch.zeros(1, dtype=torch.long))
Loading