diff --git a/src/opentau/datasets/lerobot_dataset.py b/src/opentau/datasets/lerobot_dataset.py index 846206d9..c8140845 100644 --- a/src/opentau/datasets/lerobot_dataset.py +++ b/src/opentau/datasets/lerobot_dataset.py @@ -1704,7 +1704,20 @@ def save_episode(self, episode_data: dict | None = None) -> None: # (standardization, dataset mixing) never encounters missing keys. for key in deferred: shape = self.features[key]["shape"] - c = shape[0] if len(shape) >= 3 else 3 + names = self.features[key].get("names") or [] + # Locate channel axis; LeRobot v2.1's default image convention is + # (H, W, C) with names ["height", "width", "channel"], but some + # callers declare (C, H, W). Look it up by name, falling back to + # the CHW convention that aggregate_stats's (3, 1, 1) assertion + # ultimately expects. + if "channel" in names: + c = shape[names.index("channel")] + elif "channels" in names: + c = shape[names.index("channels")] + elif len(shape) >= 3: + c = shape[0] + else: + c = 3 ep_stats[key] = { "min": np.zeros((c, 1, 1), dtype=np.float64), "max": np.ones((c, 1, 1), dtype=np.float64), diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 61d1db25..52ad7ff5 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -582,6 +582,50 @@ def test_deferred_video_attach_video(tmp_path, empty_lerobot_dataset_factory): assert result_path == expected_path +def test_deferred_video_multi_episode_hwc_convention(tmp_path, empty_lerobot_dataset_factory): + """Saving two episodes in sequence with deferred video keys declared in + LeRobot v2.1's (H, W, C) convention must not crash. + + Regression test for a bug in the placeholder-stats fallback: it assumed + shape[0] was the channel axis, which is true for (C, H, W) but produces + shape (H, 1, 1) for (H, W, C) — aggregate_stats then rejected it on the + second save_episode because its shape-check requires (3, 1, 1) for + features with 'image' in the key. + """ + features = { + "state": {"dtype": "float32", "shape": (2,), "names": None}, + "observation.images.top": { + "dtype": "video", + "shape": (96, 128, 3), + "names": ["height", "width", "channel"], + "info": None, + }, + } + dataset = empty_lerobot_dataset_factory( + root=tmp_path / "hwc_test", + features=features, + deferred_video_keys={"observation.images.top"}, + ) + for ep in range(2): + for i in range(3): + dataset.add_frame( + { + "state": np.array([float(i), float(i + ep)], dtype=np.float32), + "task": "Dummy task", + } + ) + dataset.save_episode() + + assert dataset.meta.total_episodes == 2 + # Placeholder stats for the deferred key must use the channel axis (3), + # not the height axis (96). + stats = dataset.meta.stats["observation.images.top"] + assert stats["min"].shape == (3, 1, 1) + assert stats["max"].shape == (3, 1, 1) + assert stats["mean"].shape == (3, 1, 1) + assert stats["std"].shape == (3, 1, 1) + + def test_deferred_video_invalid_key(tmp_path, empty_lerobot_dataset_factory): """Creating a dataset with invalid deferred video keys should raise ValueError.""" features = {