Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/opentau/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1704,7 +1704,20 @@ def save_episode(self, episode_data: dict | None = None) -> None:
# (standardization, dataset mixing) never encounters missing keys.
for key in deferred:
shape = self.features[key]["shape"]
c = shape[0] if len(shape) >= 3 else 3
names = self.features[key].get("names") or []
# Locate channel axis; LeRobot v2.1's default image convention is
# (H, W, C) with names ["height", "width", "channel"], but some
# callers declare (C, H, W). Look it up by name, falling back to
# the CHW convention that aggregate_stats's (3, 1, 1) assertion
# ultimately expects.
if "channel" in names:
c = shape[names.index("channel")]
elif "channels" in names:
c = shape[names.index("channels")]
elif len(shape) >= 3:
c = shape[0]
else:
c = 3
ep_stats[key] = {
"min": np.zeros((c, 1, 1), dtype=np.float64),
"max": np.ones((c, 1, 1), dtype=np.float64),
Expand Down
44 changes: 44 additions & 0 deletions tests/datasets/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,50 @@ def test_deferred_video_attach_video(tmp_path, empty_lerobot_dataset_factory):
assert result_path == expected_path


def test_deferred_video_multi_episode_hwc_convention(tmp_path, empty_lerobot_dataset_factory):
"""Saving two episodes in sequence with deferred video keys declared in
LeRobot v2.1's (H, W, C) convention must not crash.

Regression test for a bug in the placeholder-stats fallback: it assumed
shape[0] was the channel axis, which is true for (C, H, W) but produces
shape (H, 1, 1) for (H, W, C) — aggregate_stats then rejected it on the
second save_episode because its shape-check requires (3, 1, 1) for
features with 'image' in the key.
"""
features = {
"state": {"dtype": "float32", "shape": (2,), "names": None},
"observation.images.top": {
"dtype": "video",
"shape": (96, 128, 3),
"names": ["height", "width", "channel"],
"info": None,
},
}
dataset = empty_lerobot_dataset_factory(
root=tmp_path / "hwc_test",
features=features,
deferred_video_keys={"observation.images.top"},
)
for ep in range(2):
for i in range(3):
dataset.add_frame(
{
"state": np.array([float(i), float(i + ep)], dtype=np.float32),
"task": "Dummy task",
}
)
dataset.save_episode()

assert dataset.meta.total_episodes == 2
# Placeholder stats for the deferred key must use the channel axis (3),
# not the height axis (96).
stats = dataset.meta.stats["observation.images.top"]
assert stats["min"].shape == (3, 1, 1)
assert stats["max"].shape == (3, 1, 1)
assert stats["mean"].shape == (3, 1, 1)
assert stats["std"].shape == (3, 1, 1)


def test_deferred_video_invalid_key(tmp_path, empty_lerobot_dataset_factory):
"""Creating a dataset with invalid deferred video keys should raise ValueError."""
features = {
Expand Down
Loading