From daefd68eadeb858cce725a4fa95088fa951aac0b Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Thu, 26 Feb 2026 16:39:08 -0800 Subject: [PATCH 1/8] feat: add script to segmentize lerobot dataset --- .../scripts/segment_lerobot_dataset.py | 271 ++++++++++++++++++ .../datasets/test_segment_lerobot_dataset.py | 186 ++++++++++++ 2 files changed, 457 insertions(+) create mode 100644 src/opentau/scripts/segment_lerobot_dataset.py create mode 100644 tests/datasets/test_segment_lerobot_dataset.py diff --git a/src/opentau/scripts/segment_lerobot_dataset.py b/src/opentau/scripts/segment_lerobot_dataset.py new file mode 100644 index 00000000..187c7f4c --- /dev/null +++ b/src/opentau/scripts/segment_lerobot_dataset.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python +# +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Create a segmented LeRobot v2.1 dataset from a source episode. + +This script builds a brand-new dataset where each output episode corresponds to one +`[start, end)` frame segment from a selected source episode. + +Accepted input formats: LeRobot v2.0 and v2.1. +Output format: always LeRobot v2.1. + +Example: + python segment_lerobot_dataset.py ./input_dataset ./output_dataset \ + --episode-id 0 \ + --segment 0:100 \ + --segment 120:220 +""" + +import argparse +import math +import shutil +from copy import deepcopy +from pathlib import Path +from typing import Any, cast + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq + +from opentau.datasets.compute_stats import compute_episode_stats +from opentau.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata +from opentau.datasets.utils import ( + EPISODES_PATH, + EPISODES_STATS_PATH, + TASKS_PATH, + append_jsonlines, + write_episode_stats, + write_json, +) + + +def _parse_segment(text: str) -> tuple[int, int]: + parts = text.split(":") + if len(parts) != 2: + raise argparse.ArgumentTypeError(f"Invalid segment '{text}'. Expected START:END with integer values.") + try: + start = int(parts[0]) + end = int(parts[1]) + except ValueError as exc: + raise argparse.ArgumentTypeError( + f"Invalid segment '{text}'. START and END must be integers." + ) from exc + if start < 0 or end < 0: + raise argparse.ArgumentTypeError(f"Invalid segment '{text}'. START and END must be non-negative.") + if end <= start: + raise argparse.ArgumentTypeError( + f"Invalid segment '{text}'. END must be strictly greater than START." + ) + return start, end + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Create a segmented LeRobot v2.1 dataset from one source episode.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("input_root", type=Path, help="Path to source LeRobot dataset root.") + parser.add_argument("output_root", type=Path, help="Path to output dataset root (must not exist).") + parser.add_argument( + "--episode-id", + type=int, + default=0, + help="Source episode index to segment.", + ) + parser.add_argument( + "--segment", + type=_parse_segment, + action="append", + required=True, + help="Segment as START:END in frame index domain. Repeat this argument for multiple segments.", + ) + return parser.parse_args() + + +def _to_numpy_for_stats(column: pa.ChunkedArray) -> np.ndarray: + # For list/fixed-size-list numeric columns, `to_pylist` keeps the nested shape, + # and `np.asarray` reconstructs the expected dense ndarray. + return np.asarray(column.to_pylist()) + + +def segment_dataset( + input_root: Path, + output_root: Path, + episode_id: int, + segments: list[tuple[int, int]], +) -> None: + input_root = input_root.resolve() + output_root = output_root.resolve() + + if not input_root.is_dir(): + raise ValueError(f"Input dataset root does not exist: {input_root}") + if output_root.exists(): + raise ValueError(f"Output dataset root already exists: {output_root}") + if not segments: + raise ValueError("At least one segment must be provided.") + + source_meta = LeRobotDatasetMetadata(repo_id=input_root.name, root=input_root) + source_version = str(source_meta._version) + if not (source_version.startswith("2.0") or source_version.startswith("2.1")): + raise ValueError( + "Only LeRobot dataset format v2.0 and v2.1 are supported as input by this script. " + f"Found codebase_version={source_version}." + ) + if episode_id not in source_meta.episodes: + raise ValueError(f"Episode {episode_id} not found in source dataset.") + + source_episode = source_meta.episodes[episode_id] + source_length = int(source_episode["length"]) + for start, end in segments: + if end > source_length: + raise ValueError( + f"Segment ({start}, {end}) is out of bounds for source episode length {source_length}." + ) + + source_parquet_path = input_root / source_meta.get_data_file_path(episode_id) + if not source_parquet_path.is_file(): + raise ValueError(f"Missing source parquet file for episode {episode_id}: {source_parquet_path}") + source_table = pq.read_table(source_parquet_path) + if source_table.num_rows != source_length: + raise ValueError( + f"Source metadata length ({source_length}) does not match parquet row count ({source_table.num_rows})." + ) + + output_root.mkdir(parents=True, exist_ok=False) + chunks_size = int(source_meta.chunks_size) + global_index_offset = 0 + total_frames = 0 + output_episodes: list[dict] = [] + + # Write tasks as-is so existing task_index values remain valid. + for task_index, task in sorted(source_meta.tasks.items(), key=lambda x: x[0]): + append_jsonlines({"task_index": task_index, "task": task}, output_root / TASKS_PATH) + + source_episode_stats = source_meta.episodes_stats.get(cast(Any, episode_id), {}) + visual_keys = [k for k, ft in source_meta.features.items() if ft["dtype"] in ["image", "video"]] + + for output_episode_index, (start, end) in enumerate(segments): + seg_len = end - start + seg_table = source_table.slice(start, seg_len) + + replacement_arrays = { + "episode_index": pa.array(np.full(seg_len, output_episode_index, dtype=np.int64)), + "frame_index": pa.array(np.arange(seg_len, dtype=np.int64)), + "index": pa.array(np.arange(global_index_offset, global_index_offset + seg_len, dtype=np.int64)), + } + for key, arr in replacement_arrays.items(): + col_idx = seg_table.schema.get_field_index(key) + if col_idx >= 0: + seg_table = seg_table.set_column(col_idx, key, arr) + + episode_chunk = output_episode_index // chunks_size + output_parquet_path = output_root / source_meta.data_path.format( + episode_chunk=episode_chunk, + episode_index=output_episode_index, + ) + output_parquet_path.parent.mkdir(parents=True, exist_ok=True) + pq.write_table(seg_table, output_parquet_path) + + # Build tasks list for this segment from task_index values present in rows. + task_indices = seg_table.column("task_index").to_pylist() + seen_task_indices = set() + episode_tasks = [] + for task_index in task_indices: + if task_index in seen_task_indices: + continue + seen_task_indices.add(task_index) + episode_tasks.append(source_meta.tasks[int(task_index)]) + + output_episodes.append( + { + "episode_index": output_episode_index, + "tasks": episode_tasks, + "length": seg_len, + } + ) + + stats_features = { + key: feature + for key, feature in source_meta.features.items() + if feature["dtype"] not in ["image", "video", "string"] and key in seg_table.column_names + } + stats_data: dict[str, list[str] | np.ndarray] = { + key: _to_numpy_for_stats(seg_table.column(key)) for key in stats_features + } + episode_stats = compute_episode_stats(stats_data, stats_features) + + # Keep visual keys in stats for downstream compatibility. + for key in visual_keys: + if key in source_episode_stats: + source_key_stats = cast(dict[str, np.ndarray], source_episode_stats[key]) + copied = {metric: np.array(val, copy=True) for metric, val in source_key_stats.items()} + if "count" in copied: + copied["count"] = np.array([seg_len], dtype=np.int64) + episode_stats[key] = copied + + write_episode_stats(output_episode_index, episode_stats, output_root) + + global_index_offset += seg_len + total_frames += seg_len + + # Copy source video episode for each requested segment episode, if any. + video_path_template = source_meta.video_path + if source_meta.video_keys and video_path_template is None: + raise ValueError("Source dataset declares video keys but has no video_path template in metadata.") + video_path_template_str = cast(str, video_path_template) if video_path_template is not None else "" + + for video_key in source_meta.video_keys: + src_video_path = input_root / source_meta.get_video_file_path(episode_id, video_key) + if not src_video_path.is_file(): + raise ValueError(f"Missing source video for key '{video_key}': {src_video_path}") + for output_episode_index in range(len(segments)): + episode_chunk = output_episode_index // chunks_size + dst_video_path = output_root / video_path_template_str.format( + episode_chunk=episode_chunk, + video_key=video_key, + episode_index=output_episode_index, + ) + dst_video_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(src_video_path, dst_video_path) + + for episode in output_episodes: + append_jsonlines(episode, output_root / EPISODES_PATH) + + total_episodes = len(segments) + info = deepcopy(source_meta.info) + info["codebase_version"] = CODEBASE_VERSION + info["total_episodes"] = total_episodes + info["total_frames"] = total_frames + info["total_chunks"] = int(math.ceil(total_episodes / chunks_size)) if total_episodes > 0 else 0 + info["total_videos"] = total_episodes * len(source_meta.video_keys) + info["splits"] = {"train": f"0:{total_episodes}"} + write_json(info, output_root / "meta" / "info.json") + + # Ensure expected meta files exist and are explicit outputs. + _ = output_root / EPISODES_STATS_PATH + + +def main() -> None: + args = parse_args() + segment_dataset( + input_root=args.input_root, + output_root=args.output_root, + episode_id=args.episode_id, + segments=args.segment, + ) + + +if __name__ == "__main__": + main() diff --git a/tests/datasets/test_segment_lerobot_dataset.py b/tests/datasets/test_segment_lerobot_dataset.py new file mode 100644 index 00000000..0683d331 --- /dev/null +++ b/tests/datasets/test_segment_lerobot_dataset.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python + +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pyarrow.parquet as pq + +from opentau.datasets.lerobot_dataset import LeRobotDatasetMetadata +from opentau.datasets.utils import load_episodes, load_episodes_stats, load_info, write_json, write_stats +from opentau.scripts.segment_lerobot_dataset import segment_dataset + + +def test_segment_lerobot_v21_dataset(tmp_path, empty_lerobot_dataset_factory): + input_root = tmp_path / "source_dataset" + output_root = tmp_path / "segmented_dataset" + + features = { + "state": {"dtype": "float32", "shape": (1,), "names": None}, + "actions": {"dtype": "float32", "shape": (1,), "names": None}, + } + dataset = empty_lerobot_dataset_factory(root=input_root, features=features, use_videos=False) + + for i in range(10): + dataset.add_frame( + { + "state": np.array([float(i)], dtype=np.float32), + "actions": np.array([float(i) + 0.5], dtype=np.float32), + "task": "pick object", + } + ) + dataset.save_episode() + + segment_dataset( + input_root=input_root, + output_root=output_root, + episode_id=0, + segments=[(2, 5), (5, 10)], + ) + + info = load_info(output_root) + episodes = load_episodes(output_root) + episodes_stats = load_episodes_stats(output_root) + output_meta = LeRobotDatasetMetadata(repo_id=output_root.name, root=output_root) + + assert info["codebase_version"] == "v2.1" + assert info["total_episodes"] == 2 + assert info["total_frames"] == 8 + assert info["total_videos"] == 0 + assert info["splits"] == {"train": "0:2"} + + assert episodes[0]["length"] == 3 + assert episodes[1]["length"] == 5 + assert len(episodes_stats) == 2 + assert int(episodes_stats[0]["state"]["count"][0]) == 3 + assert int(episodes_stats[1]["state"]["count"][0]) == 5 + + ep0_table = pq.read_table(output_root / output_meta.get_data_file_path(0)) + ep1_table = pq.read_table(output_root / output_meta.get_data_file_path(1)) + ep0 = ep0_table.to_pydict() + ep1 = ep1_table.to_pydict() + + assert ep0["frame_index"] == [0, 1, 2] + assert ep0["episode_index"] == [0, 0, 0] + assert ep0["index"] == [0, 1, 2] + assert [float(x) for x in ep0["state"]] == [2.0, 3.0, 4.0] + + assert ep1["frame_index"] == [0, 1, 2, 3, 4] + assert ep1["episode_index"] == [1, 1, 1, 1, 1] + assert ep1["index"] == [3, 4, 5, 6, 7] + assert [float(x) for x in ep1["state"]] == [5.0, 6.0, 7.0, 8.0, 9.0] + + +def test_segment_lerobot_v20_input_outputs_v21(tmp_path, empty_lerobot_dataset_factory): + input_root = tmp_path / "source_v20_dataset" + output_root = tmp_path / "segmented_from_v20" + + features = { + "state": {"dtype": "float32", "shape": (1,), "names": None}, + "actions": {"dtype": "float32", "shape": (1,), "names": None}, + } + dataset = empty_lerobot_dataset_factory(root=input_root, features=features, use_videos=False) + for i in range(8): + dataset.add_frame( + { + "state": np.array([float(i)], dtype=np.float32), + "actions": np.array([float(i) + 1.0], dtype=np.float32), + "task": "stack blocks", + } + ) + dataset.save_episode() + + # Convert source metadata to a v2.0-style dataset: + # - set codebase_version to 2.0 + # - write legacy global stats.json used by v2.0 loaders + source_info = load_info(input_root) + source_info["codebase_version"] = "v2.0" + write_json(source_info, input_root / "meta" / "info.json") + write_stats(dataset.meta.stats, input_root) + + segment_dataset( + input_root=input_root, + output_root=output_root, + episode_id=0, + segments=[(0, 3), (3, 8)], + ) + + out_info = load_info(output_root) + out_episodes = load_episodes(output_root) + out_stats = load_episodes_stats(output_root) + + assert out_info["codebase_version"] == "v2.1" + assert out_info["total_episodes"] == 2 + assert out_info["total_frames"] == 8 + assert out_episodes[0]["length"] == 3 + assert out_episodes[1]["length"] == 5 + assert len(out_stats) == 2 + + +def test_segment_lerobot_non_consecutive_and_overlapping_ranges(tmp_path, empty_lerobot_dataset_factory): + input_root = tmp_path / "source_edge_case_dataset" + output_root = tmp_path / "segmented_edge_case_dataset" + + features = { + "state": {"dtype": "float32", "shape": (1,), "names": None}, + "actions": {"dtype": "float32", "shape": (1,), "names": None}, + } + dataset = empty_lerobot_dataset_factory(root=input_root, features=features, use_videos=False) + for i in range(25): + dataset.add_frame( + { + "state": np.array([float(i)], dtype=np.float32), + "actions": np.array([float(i) + 10.0], dtype=np.float32), + "task": "edge-case task", + } + ) + dataset.save_episode() + + # Non-consecutive + overlapping segments: + # - (0, 10) and (5, 15) overlap on [5..9] + # - (18, 23) is non-consecutive with both + segment_dataset( + input_root=input_root, + output_root=output_root, + episode_id=0, + segments=[(0, 10), (18, 23), (5, 15)], + ) + + info = load_info(output_root) + episodes = load_episodes(output_root) + output_meta = LeRobotDatasetMetadata(repo_id=output_root.name, root=output_root) + + assert info["codebase_version"] == "v2.1" + assert info["total_episodes"] == 3 + assert info["total_frames"] == 25 + assert [episodes[i]["length"] for i in sorted(episodes)] == [10, 5, 10] + + ep0 = pq.read_table(output_root / output_meta.get_data_file_path(0)).to_pydict() + ep1 = pq.read_table(output_root / output_meta.get_data_file_path(1)).to_pydict() + ep2 = pq.read_table(output_root / output_meta.get_data_file_path(2)).to_pydict() + + # Local frame indexing resets per output episode. + assert ep0["frame_index"] == list(range(10)) + assert ep1["frame_index"] == list(range(5)) + assert ep2["frame_index"] == list(range(10)) + + # Global index remains contiguous across output episodes. + assert ep0["index"] == list(range(0, 10)) + assert ep1["index"] == list(range(10, 15)) + assert ep2["index"] == list(range(15, 25)) + + # Data slices match requested source windows. + assert [float(x) for x in ep0["state"]] == [float(i) for i in range(0, 10)] + assert [float(x) for x in ep1["state"]] == [float(i) for i in range(18, 23)] + assert [float(x) for x in ep2["state"]] == [float(i) for i in range(5, 15)] From b05378ce194502e5893218a4a2e03c92aadc3da6 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Thu, 26 Feb 2026 16:46:32 -0800 Subject: [PATCH 2/8] docs: add Google-style docs to episode segment utils --- .../scripts/segment_lerobot_dataset.py | 38 +++++++++++++++++++ .../datasets/test_segment_lerobot_dataset.py | 29 ++++++++++++-- 2 files changed, 64 insertions(+), 3 deletions(-) diff --git a/src/opentau/scripts/segment_lerobot_dataset.py b/src/opentau/scripts/segment_lerobot_dataset.py index 187c7f4c..5187a0bb 100644 --- a/src/opentau/scripts/segment_lerobot_dataset.py +++ b/src/opentau/scripts/segment_lerobot_dataset.py @@ -52,6 +52,17 @@ def _parse_segment(text: str) -> tuple[int, int]: + """Parse one CLI segment token. + + Args: + text: Segment string formatted as ``START:END``. + + Returns: + A 2-tuple ``(start, end)`` where ``start >= 0`` and ``end > start``. + + Raises: + argparse.ArgumentTypeError: If the value is malformed or out of range. + """ parts = text.split(":") if len(parts) != 2: raise argparse.ArgumentTypeError(f"Invalid segment '{text}'. Expected START:END with integer values.") @@ -72,6 +83,12 @@ def _parse_segment(text: str) -> tuple[int, int]: def parse_args() -> argparse.Namespace: + """Parse command-line arguments for dataset segmentation. + + Returns: + Parsed CLI namespace containing input/output roots, source episode id, + and the list of frame-range segments. + """ parser = argparse.ArgumentParser( description="Create a segmented LeRobot v2.1 dataset from one source episode.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -95,6 +112,14 @@ def parse_args() -> argparse.Namespace: def _to_numpy_for_stats(column: pa.ChunkedArray) -> np.ndarray: + """Convert an Arrow chunked column to a NumPy array for stats. + + Args: + column: Arrow chunked column extracted from an episode parquet table. + + Returns: + NumPy representation of the column values. + """ # For list/fixed-size-list numeric columns, `to_pylist` keeps the nested shape, # and `np.asarray` reconstructs the expected dense ndarray. return np.asarray(column.to_pylist()) @@ -106,6 +131,18 @@ def segment_dataset( episode_id: int, segments: list[tuple[int, int]], ) -> None: + """Create a new segmented dataset from a source episode. + + Args: + input_root: Source LeRobot dataset directory (v2.0 or v2.1). + output_root: Destination directory for the new dataset. Must not exist. + episode_id: Source episode index to slice. + segments: List of ``(start, end)`` frame ranges in ``[start, end)`` form. + + Raises: + ValueError: If inputs are invalid, source files are missing, or segment + ranges are out of bounds. + """ input_root = input_root.resolve() output_root = output_root.resolve() @@ -258,6 +295,7 @@ def segment_dataset( def main() -> None: + """CLI entry point.""" args = parse_args() segment_dataset( input_root=args.input_root, diff --git a/tests/datasets/test_segment_lerobot_dataset.py b/tests/datasets/test_segment_lerobot_dataset.py index 0683d331..2e423bba 100644 --- a/tests/datasets/test_segment_lerobot_dataset.py +++ b/tests/datasets/test_segment_lerobot_dataset.py @@ -14,6 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path +from typing import Any + import numpy as np import pyarrow.parquet as pq @@ -22,7 +25,13 @@ from opentau.scripts.segment_lerobot_dataset import segment_dataset -def test_segment_lerobot_v21_dataset(tmp_path, empty_lerobot_dataset_factory): +def test_segment_lerobot_v21_dataset(tmp_path: Path, empty_lerobot_dataset_factory: Any) -> None: + """Validate baseline segmentation behavior for v2.1 input. + + Args: + tmp_path: Temporary directory fixture provided by pytest. + empty_lerobot_dataset_factory: Fixture that creates a writable dataset. + """ input_root = tmp_path / "source_dataset" output_root = tmp_path / "segmented_dataset" @@ -82,7 +91,13 @@ def test_segment_lerobot_v21_dataset(tmp_path, empty_lerobot_dataset_factory): assert [float(x) for x in ep1["state"]] == [5.0, 6.0, 7.0, 8.0, 9.0] -def test_segment_lerobot_v20_input_outputs_v21(tmp_path, empty_lerobot_dataset_factory): +def test_segment_lerobot_v20_input_outputs_v21(tmp_path: Path, empty_lerobot_dataset_factory: Any) -> None: + """Ensure v2.0 input is accepted and output stays v2.1. + + Args: + tmp_path: Temporary directory fixture provided by pytest. + empty_lerobot_dataset_factory: Fixture that creates a writable dataset. + """ input_root = tmp_path / "source_v20_dataset" output_root = tmp_path / "segmented_from_v20" @@ -128,7 +143,15 @@ def test_segment_lerobot_v20_input_outputs_v21(tmp_path, empty_lerobot_dataset_f assert len(out_stats) == 2 -def test_segment_lerobot_non_consecutive_and_overlapping_ranges(tmp_path, empty_lerobot_dataset_factory): +def test_segment_lerobot_non_consecutive_and_overlapping_ranges( + tmp_path: Path, empty_lerobot_dataset_factory: Any +) -> None: + """Cover non-consecutive and overlapping segment ranges. + + Args: + tmp_path: Temporary directory fixture provided by pytest. + empty_lerobot_dataset_factory: Fixture that creates a writable dataset. + """ input_root = tmp_path / "source_edge_case_dataset" output_root = tmp_path / "segmented_edge_case_dataset" From 6933b5dfe43470f6ddfe19f7fb970bbe1c5ec3be Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Thu, 26 Feb 2026 17:10:53 -0800 Subject: [PATCH 3/8] fix: properly segment videos and images --- .../scripts/segment_lerobot_dataset.py | 157 +++++++++++++++++- .../datasets/test_segment_lerobot_dataset.py | 107 ++++++++++++ 2 files changed, 258 insertions(+), 6 deletions(-) diff --git a/src/opentau/scripts/segment_lerobot_dataset.py b/src/opentau/scripts/segment_lerobot_dataset.py index 5187a0bb..08a8e04c 100644 --- a/src/opentau/scripts/segment_lerobot_dataset.py +++ b/src/opentau/scripts/segment_lerobot_dataset.py @@ -31,6 +31,7 @@ import argparse import math import shutil +import subprocess from copy import deepcopy from pathlib import Path from typing import Any, cast @@ -42,8 +43,8 @@ from opentau.datasets.compute_stats import compute_episode_stats from opentau.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata from opentau.datasets.utils import ( + DEFAULT_IMAGE_PATH, EPISODES_PATH, - EPISODES_STATS_PATH, TASKS_PATH, append_jsonlines, write_episode_stats, @@ -125,6 +126,127 @@ def _to_numpy_for_stats(column: pa.ChunkedArray) -> np.ndarray: return np.asarray(column.to_pylist()) +def _trim_video_segment(src_video_path: Path, dst_video_path: Path, start_frame: int, end_frame: int) -> None: + """Trim a source video to the requested frame interval. + + Args: + src_video_path: Source episode video path. + dst_video_path: Output path for the trimmed segment video. + start_frame: Inclusive start frame index. + end_frame: Exclusive end frame index. + + Raises: + RuntimeError: If ffmpeg is unavailable or the trim command fails. + """ + if shutil.which("ffmpeg") is None: + raise RuntimeError("ffmpeg is required to trim segmented videos but was not found in PATH.") + + # Trim by exact frame indices and reset timeline to start at zero. + vf = f"trim=start_frame={start_frame}:end_frame={end_frame},setpts=PTS-STARTPTS" + cmd = [ + "ffmpeg", + "-hide_banner", + "-loglevel", + "error", + "-y", + "-i", + str(src_video_path), + "-vf", + vf, + "-an", + str(dst_video_path), + ] + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError( + f"Failed to trim video segment {start_frame}:{end_frame} from '{src_video_path}'. " + f"ffmpeg stderr: {result.stderr.strip()}" + ) + + +def _copy_segment_images_and_rewrite_column( + image_cells: list[Any], + input_root: Path, + output_root: Path, + image_key: str, + output_episode_index: int, + source_episode_index: int, + source_segment_start: int, +) -> list[Any]: + """Copy image files for a segment and rewrite per-row image references. + + Args: + image_cells: Image column values from the sliced source table. + input_root: Source dataset root path. + output_root: Output dataset root path. + image_key: Feature key for this image stream. + output_episode_index: Output episode index receiving this segment. + source_episode_index: Source episode index for image path fallback. + source_segment_start: Start frame index of this segment in source episode. + + Returns: + New image column values with updated file paths for copied images. + + Raises: + FileNotFoundError: If a referenced source image file does not exist. + """ + rewritten_cells: list[Any] = [] + for frame_index, cell in enumerate(image_cells): + rel_dst = DEFAULT_IMAGE_PATH.format( + image_key=image_key, + episode_index=output_episode_index, + frame_index=frame_index, + ) + dst_path = output_root / rel_dst + dst_path.parent.mkdir(parents=True, exist_ok=True) + + if isinstance(cell, dict): + image_bytes = cell.get("bytes") + if isinstance(image_bytes, (bytes, bytearray)) and len(image_bytes) > 0: + dst_path.write_bytes(bytes(image_bytes)) + new_cell = dict(cell) + new_cell["path"] = str(dst_path) + rewritten_cells.append(new_cell) + continue + + src_path: Path | None = None + if isinstance(cell, str): + src_path = Path(cell) + elif isinstance(cell, dict): + path_val = cell.get("path") + if isinstance(path_val, str) and path_val: + src_path = Path(path_val) + + # Embedded-image rows may not require copying when path is empty. + if src_path is None: + rewritten_cells.append(cell) + continue + + if not src_path.is_absolute(): + src_path = input_root / src_path + if not src_path.is_file(): + # Fallback to canonical image location under input root. + source_frame_index = source_segment_start + frame_index + src_path = input_root / DEFAULT_IMAGE_PATH.format( + image_key=image_key, + episode_index=source_episode_index, + frame_index=source_frame_index, + ) + if not src_path.is_file(): + raise FileNotFoundError(f"Missing source image for key '{image_key}': {src_path}") + + shutil.copy2(src_path, dst_path) + + if isinstance(cell, str): + rewritten_cells.append(str(dst_path)) + else: + new_cell = dict(cell) + new_cell["path"] = str(dst_path) + rewritten_cells.append(new_cell) + + return rewritten_cells + + def segment_dataset( input_root: Path, output_root: Path, @@ -139,6 +261,12 @@ def segment_dataset( episode_id: Source episode index to slice. segments: List of ``(start, end)`` frame ranges in ``[start, end)`` form. + Notes: + For visual features (``dtype`` in ``{"image", "video"}``), per-episode + statistics (``min``, ``max``, ``mean``, ``std``) are inherited from the + source episode statistics and only the ``count`` is updated to the segment + length. They are not recomputed from the segmented visual data. + Raises: ValueError: If inputs are invalid, source files are missing, or segment ranges are out of bounds. @@ -207,6 +335,26 @@ def segment_dataset( if col_idx >= 0: seg_table = seg_table.set_column(col_idx, key, arr) + # For image-based datasets, copy only the segment frames and rewrite image references. + image_keys = [k for k, ft in source_meta.features.items() if ft["dtype"] == "image"] + for image_key in image_keys: + if image_key not in seg_table.column_names: + continue + col_idx = seg_table.schema.get_field_index(image_key) + image_cells = seg_table.column(image_key).to_pylist() + rewritten = _copy_segment_images_and_rewrite_column( + image_cells=image_cells, + input_root=input_root, + output_root=output_root, + image_key=image_key, + output_episode_index=output_episode_index, + source_episode_index=episode_id, + source_segment_start=start, + ) + seg_table = seg_table.set_column( + col_idx, image_key, pa.array(rewritten, type=seg_table.schema.field(image_key).type) + ) + episode_chunk = output_episode_index // chunks_size output_parquet_path = output_root / source_meta.data_path.format( episode_chunk=episode_chunk, @@ -267,7 +415,7 @@ def segment_dataset( src_video_path = input_root / source_meta.get_video_file_path(episode_id, video_key) if not src_video_path.is_file(): raise ValueError(f"Missing source video for key '{video_key}': {src_video_path}") - for output_episode_index in range(len(segments)): + for output_episode_index, (start, end) in enumerate(segments): episode_chunk = output_episode_index // chunks_size dst_video_path = output_root / video_path_template_str.format( episode_chunk=episode_chunk, @@ -275,7 +423,7 @@ def segment_dataset( episode_index=output_episode_index, ) dst_video_path.parent.mkdir(parents=True, exist_ok=True) - shutil.copy2(src_video_path, dst_video_path) + _trim_video_segment(src_video_path, dst_video_path, start, end) for episode in output_episodes: append_jsonlines(episode, output_root / EPISODES_PATH) @@ -290,9 +438,6 @@ def segment_dataset( info["splits"] = {"train": f"0:{total_episodes}"} write_json(info, output_root / "meta" / "info.json") - # Ensure expected meta files exist and are explicit outputs. - _ = output_root / EPISODES_STATS_PATH - def main() -> None: """CLI entry point.""" diff --git a/tests/datasets/test_segment_lerobot_dataset.py b/tests/datasets/test_segment_lerobot_dataset.py index 2e423bba..f0a8e592 100644 --- a/tests/datasets/test_segment_lerobot_dataset.py +++ b/tests/datasets/test_segment_lerobot_dataset.py @@ -16,6 +16,7 @@ from pathlib import Path from typing import Any +from unittest.mock import patch import numpy as np import pyarrow.parquet as pq @@ -25,6 +26,24 @@ from opentau.scripts.segment_lerobot_dataset import segment_dataset +def _extract_image_path(cell: Any) -> str | None: + """Extract image path from a parquet image cell. + + Args: + cell: A parquet image value (string path or dict with `path`/`bytes`). + + Returns: + Image path string if present, otherwise None. + """ + if isinstance(cell, str): + return cell + if isinstance(cell, dict): + path = cell.get("path") + if isinstance(path, str) and path: + return path + return None + + def test_segment_lerobot_v21_dataset(tmp_path: Path, empty_lerobot_dataset_factory: Any) -> None: """Validate baseline segmentation behavior for v2.1 input. @@ -207,3 +226,91 @@ def test_segment_lerobot_non_consecutive_and_overlapping_ranges( assert [float(x) for x in ep0["state"]] == [float(i) for i in range(0, 10)] assert [float(x) for x in ep1["state"]] == [float(i) for i in range(18, 23)] assert [float(x) for x in ep2["state"]] == [float(i) for i in range(5, 15)] + + +def test_segment_lerobot_copies_image_files_for_segments( + tmp_path: Path, empty_lerobot_dataset_factory: Any +) -> None: + """Ensure segmented datasets copy and rewrite image file references. + + Args: + tmp_path: Temporary directory fixture provided by pytest. + empty_lerobot_dataset_factory: Fixture that creates a writable dataset. + """ + input_root = tmp_path / "source_image_dataset" + output_root = tmp_path / "segmented_image_dataset" + image_key = "observation.images.camera" + + features = { + "state": {"dtype": "float32", "shape": (1,), "names": None}, + "actions": {"dtype": "float32", "shape": (1,), "names": None}, + image_key: {"dtype": "image", "shape": (3, 8, 8), "names": ["channel", "height", "width"]}, + } + dataset = empty_lerobot_dataset_factory(root=input_root, features=features, use_videos=False) + for i in range(8): + dataset.add_frame( + { + "state": np.array([float(i)], dtype=np.float32), + "actions": np.array([float(i) + 1.0], dtype=np.float32), + "observation.images.camera": np.full((8, 8, 3), i / 8.0, dtype=np.float32), + "task": "image task", + } + ) + dataset.save_episode() + + segment_dataset( + input_root=input_root, + output_root=output_root, + episode_id=0, + segments=[(1, 4), (4, 8)], + ) + + out_meta = LeRobotDatasetMetadata(repo_id=output_root.name, root=output_root) + ep0 = pq.read_table(output_root / out_meta.get_data_file_path(0)).to_pydict() + ep1 = pq.read_table(output_root / out_meta.get_data_file_path(1)).to_pydict() + + ep0_paths = [_extract_image_path(cell) for cell in ep0[image_key]] + ep1_paths = [_extract_image_path(cell) for cell in ep1[image_key]] + assert all(path is not None for path in ep0_paths) + assert all(path is not None for path in ep1_paths) + + for frame_idx, path in enumerate(ep0_paths): + assert path is not None + expected = output_root / f"images/{image_key}/episode_000000/frame_{frame_idx:06d}.png" + assert Path(path) == expected + assert expected.is_file() + + for frame_idx, path in enumerate(ep1_paths): + assert path is not None + expected = output_root / f"images/{image_key}/episode_000001/frame_{frame_idx:06d}.png" + assert Path(path) == expected + assert expected.is_file() + + +def test_trim_video_segment_uses_frame_range_filter(tmp_path: Path) -> None: + """Ensure ffmpeg trim command uses frame-range segmentation. + + Args: + tmp_path: Temporary directory fixture provided by pytest. + """ + src = tmp_path / "src.mp4" + dst = tmp_path / "dst.mp4" + src.write_bytes(b"fake") + + with ( + patch("opentau.scripts.segment_lerobot_dataset.shutil.which", return_value="/usr/bin/ffmpeg"), + patch("opentau.scripts.segment_lerobot_dataset.subprocess.run") as run_mock, + ): + run_mock.return_value.returncode = 0 + run_mock.return_value.stderr = "" + + from opentau.scripts.segment_lerobot_dataset import _trim_video_segment + + _trim_video_segment(src, dst, 5, 15) + + assert run_mock.call_count == 1 + cmd = run_mock.call_args.args[0] + assert "ffmpeg" in cmd[0] + assert "-vf" in cmd + vf_expr = cmd[cmd.index("-vf") + 1] + assert "trim=start_frame=5:end_frame=15" in vf_expr From 0d359bffbfa177619511f8c647e9f6b2b502bb44 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Thu, 26 Feb 2026 17:22:42 -0800 Subject: [PATCH 4/8] feat: fix timestamp index --- src/opentau/scripts/segment_lerobot_dataset.py | 10 ++++++++++ tests/datasets/test_segment_lerobot_dataset.py | 9 +++++++++ 2 files changed, 19 insertions(+) diff --git a/src/opentau/scripts/segment_lerobot_dataset.py b/src/opentau/scripts/segment_lerobot_dataset.py index 08a8e04c..6c06f276 100644 --- a/src/opentau/scripts/segment_lerobot_dataset.py +++ b/src/opentau/scripts/segment_lerobot_dataset.py @@ -335,6 +335,16 @@ def segment_dataset( if col_idx >= 0: seg_table = seg_table.set_column(col_idx, key, arr) + # Recompute timestamps from local frame_index to avoid subtraction drift. + if "timestamp" in seg_table.column_names: + ts_idx = seg_table.schema.get_field_index("timestamp") + recomputed_ts = np.arange(seg_len, dtype=np.float64) / float(source_meta.fps) + seg_table = seg_table.set_column( + ts_idx, + "timestamp", + pa.array(recomputed_ts, type=seg_table.schema.field("timestamp").type), + ) + # For image-based datasets, copy only the segment frames and rewrite image references. image_keys = [k for k, ft in source_meta.features.items() if ft["dtype"] == "image"] for image_key in image_keys: diff --git a/tests/datasets/test_segment_lerobot_dataset.py b/tests/datasets/test_segment_lerobot_dataset.py index f0a8e592..772381f1 100644 --- a/tests/datasets/test_segment_lerobot_dataset.py +++ b/tests/datasets/test_segment_lerobot_dataset.py @@ -102,11 +102,15 @@ def test_segment_lerobot_v21_dataset(tmp_path: Path, empty_lerobot_dataset_facto assert ep0["frame_index"] == [0, 1, 2] assert ep0["episode_index"] == [0, 0, 0] assert ep0["index"] == [0, 1, 2] + assert np.isclose(float(ep0["timestamp"][0]), 0.0) + assert np.allclose(np.diff(np.asarray(ep0["timestamp"], dtype=np.float64)), 1.0 / dataset.fps) assert [float(x) for x in ep0["state"]] == [2.0, 3.0, 4.0] assert ep1["frame_index"] == [0, 1, 2, 3, 4] assert ep1["episode_index"] == [1, 1, 1, 1, 1] assert ep1["index"] == [3, 4, 5, 6, 7] + assert np.isclose(float(ep1["timestamp"][0]), 0.0) + assert np.allclose(np.diff(np.asarray(ep1["timestamp"], dtype=np.float64)), 1.0 / dataset.fps) assert [float(x) for x in ep1["state"]] == [5.0, 6.0, 7.0, 8.0, 9.0] @@ -222,6 +226,11 @@ def test_segment_lerobot_non_consecutive_and_overlapping_ranges( assert ep1["index"] == list(range(10, 15)) assert ep2["index"] == list(range(15, 25)) + # Timestamps are rebased per output episode. + assert np.isclose(float(ep0["timestamp"][0]), 0.0) + assert np.isclose(float(ep1["timestamp"][0]), 0.0) + assert np.isclose(float(ep2["timestamp"][0]), 0.0) + # Data slices match requested source windows. assert [float(x) for x in ep0["state"]] == [float(i) for i in range(0, 10)] assert [float(x) for x in ep1["state"]] == [float(i) for i in range(18, 23)] From 1abe173223576a1a296a6c4070e148bf12ed8af3 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Mon, 2 Mar 2026 11:58:06 -0800 Subject: [PATCH 5/8] feat: allow skipping video stats when saving ep --- src/opentau/datasets/compute_stats.py | 55 +++++++++++++++++-------- src/opentau/datasets/lerobot_dataset.py | 12 ++++-- 2 files changed, 47 insertions(+), 20 deletions(-) diff --git a/src/opentau/datasets/compute_stats.py b/src/opentau/datasets/compute_stats.py index 13eb0c31..6ddfcf58 100644 --- a/src/opentau/datasets/compute_stats.py +++ b/src/opentau/datasets/compute_stats.py @@ -192,40 +192,61 @@ def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[st } -def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict: +def compute_episode_stats( + episode_data: dict[str, list[str] | np.ndarray], + features: dict, + skip_video_stats: bool = False, +) -> dict: """Compute statistics for a single episode. - For image/video features, samples and downsamples images before computing stats. + For image/video features, samples and downsamples images before computing stats + (unless skip_video_stats is True, in which case placeholder stats are used). For other features, computes stats directly on the array data. Args: episode_data: Dictionary mapping feature names to their data (arrays or image paths). features: Dictionary of feature specifications with 'dtype' keys. + skip_video_stats: If True, do not compute real stats for image/video features; + instead use placeholder stats (min=0, max=1, mean=0.5, std=0.5, count from data) + so the output format remains valid. Returns: Dictionary mapping feature names to their statistics (min, max, mean, std, count). - Image statistics are normalized to [0, 1] range. + Image statistics are normalized to [0, 1] range (or placeholders when skip_video_stats). """ ep_stats = {} for key, data in episode_data.items(): if features[key]["dtype"] == "string": continue # HACK: we should receive np.arrays of strings elif features[key]["dtype"] in ["image", "video"]: - ep_ft_array = sample_images(data) # data is a list of image paths - axes_to_reduce = (0, 2, 3) # keep channel dim - keepdims = True + if skip_video_stats: + # Placeholder stats: shape (3, 1, 1) for min/max/mean/std, count from length + n_frames = len(data) if isinstance(data, list) else data.shape[0] + shape = features[key]["shape"] + # Expected shape for video is (C, H, W) e.g. (3, H, W) + c = shape[0] if len(shape) >= 3 else 3 + ep_stats[key] = { + "min": np.zeros((c, 1, 1), dtype=np.float64), + "max": np.ones((c, 1, 1), dtype=np.float64), + "mean": np.full((c, 1, 1), 0.5, dtype=np.float64), + "std": np.full((c, 1, 1), 0.5, dtype=np.float64), + "count": np.array([n_frames]), + } + else: + image_paths = data.tolist() if isinstance(data, np.ndarray) else data + ep_ft_array = sample_images(image_paths) # image_paths is list[str] + axes_to_reduce = (0, 2, 3) # keep channel dim + keepdims = True + ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims) + # normalize and remove batch dim for images + ep_stats[key] = { + k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items() + } else: - ep_ft_array = data # data is already a np.ndarray - axes_to_reduce = 0 # compute stats over the first axis - keepdims = data.ndim == 1 # keep as np.array - - ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims) - - # finally, we normalize and remove batch dim for images - if features[key]["dtype"] in ["image", "video"]: - ep_stats[key] = { - k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items() - } + ep_ft_array = data if isinstance(data, np.ndarray) else np.asarray(data) + axes_to_reduce = (0,) # compute stats over the first axis + keepdims = ep_ft_array.ndim == 1 # keep as np.array + ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims) return ep_stats diff --git a/src/opentau/datasets/lerobot_dataset.py b/src/opentau/datasets/lerobot_dataset.py index c136c32f..451fa749 100644 --- a/src/opentau/datasets/lerobot_dataset.py +++ b/src/opentau/datasets/lerobot_dataset.py @@ -1670,7 +1670,9 @@ def save_episode(self, episode_data: dict | None = None) -> None: self._wait_image_writer() self._save_episode_table(episode_buffer, episode_index) - ep_stats = compute_episode_stats(episode_buffer, self.features) + ep_stats = compute_episode_stats( + episode_buffer, self.features, skip_video_stats=getattr(self, "skip_video_stats", False) + ) if len(self.meta.video_keys) > 0: video_paths = self.encode_episode_videos(episode_index) @@ -1682,9 +1684,11 @@ def save_episode(self, episode_data: dict | None = None) -> None: ep_data_index, _ = get_episode_data_index(self.meta.episodes, [episode_index]) ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()} + timestamps = np.asarray(episode_buffer["timestamp"]).reshape(-1) + episode_indices = np.full(episode_length, episode_index) check_timestamps_sync( - episode_buffer["timestamp"], - episode_buffer["episode_index"], + timestamps, + episode_indices, ep_data_index_np, self.fps, self.tolerance_s, @@ -1870,6 +1874,7 @@ def create( image_resample_strategy: str = "nearest", vector_resample_strategy: str = "nearest", standardize: bool = True, + skip_video_stats: bool = False, ) -> "LeRobotDataset": """Create a LeRobot Dataset from scratch in order to record data.""" obj = cls.__new__(cls) @@ -1903,5 +1908,6 @@ def create( obj.image_resample_strategy = image_resample_strategy obj.vector_resample_strategy = vector_resample_strategy obj.standardize = standardize + obj.skip_video_stats = skip_video_stats obj.episode_data_index, obj.epi2idx = get_episode_data_index(obj.meta.episodes, obj.episodes) return obj From d5025b44a09fc5143eebac0261f710f03645b187 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Mon, 2 Mar 2026 12:18:00 -0800 Subject: [PATCH 6/8] feat: require JSON file for specifying episodes and segments --- .../scripts/segment_lerobot_dataset.py | 405 ++++++++++-------- .../datasets/test_segment_lerobot_dataset.py | 126 +++++- 2 files changed, 337 insertions(+), 194 deletions(-) diff --git a/src/opentau/scripts/segment_lerobot_dataset.py b/src/opentau/scripts/segment_lerobot_dataset.py index 6c06f276..f3629b3d 100644 --- a/src/opentau/scripts/segment_lerobot_dataset.py +++ b/src/opentau/scripts/segment_lerobot_dataset.py @@ -16,19 +16,17 @@ """Create a segmented LeRobot v2.1 dataset from a source episode. This script builds a brand-new dataset where each output episode corresponds to one -`[start, end)` frame segment from a selected source episode. +`[start, end)` frame segment from source episodes defined in a JSON plan. Accepted input formats: LeRobot v2.0 and v2.1. Output format: always LeRobot v2.1. Example: - python segment_lerobot_dataset.py ./input_dataset ./output_dataset \ - --episode-id 0 \ - --segment 0:100 \ - --segment 120:220 + python segment_lerobot_dataset.py ./input_dataset ./output_dataset ./segments.json """ import argparse +import json import math import shutil import subprocess @@ -52,43 +50,12 @@ ) -def _parse_segment(text: str) -> tuple[int, int]: - """Parse one CLI segment token. - - Args: - text: Segment string formatted as ``START:END``. - - Returns: - A 2-tuple ``(start, end)`` where ``start >= 0`` and ``end > start``. - - Raises: - argparse.ArgumentTypeError: If the value is malformed or out of range. - """ - parts = text.split(":") - if len(parts) != 2: - raise argparse.ArgumentTypeError(f"Invalid segment '{text}'. Expected START:END with integer values.") - try: - start = int(parts[0]) - end = int(parts[1]) - except ValueError as exc: - raise argparse.ArgumentTypeError( - f"Invalid segment '{text}'. START and END must be integers." - ) from exc - if start < 0 or end < 0: - raise argparse.ArgumentTypeError(f"Invalid segment '{text}'. START and END must be non-negative.") - if end <= start: - raise argparse.ArgumentTypeError( - f"Invalid segment '{text}'. END must be strictly greater than START." - ) - return start, end - - def parse_args() -> argparse.Namespace: """Parse command-line arguments for dataset segmentation. Returns: - Parsed CLI namespace containing input/output roots, source episode id, - and the list of frame-range segments. + Parsed CLI namespace containing input/output roots and a JSON path + mapping source episode ids to frame-range segments. """ parser = argparse.ArgumentParser( description="Create a segmented LeRobot v2.1 dataset from one source episode.", @@ -96,22 +63,77 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument("input_root", type=Path, help="Path to source LeRobot dataset root.") parser.add_argument("output_root", type=Path, help="Path to output dataset root (must not exist).") - parser.add_argument( - "--episode-id", - type=int, - default=0, - help="Source episode index to segment.", - ) - parser.add_argument( - "--segment", - type=_parse_segment, - action="append", - required=True, - help="Segment as START:END in frame index domain. Repeat this argument for multiple segments.", - ) + parser.add_argument("segments_json", type=Path, help="Path to JSON segmentation plan.") return parser.parse_args() +def _load_segments_by_episode(segments_json: Path) -> dict[int, list[tuple[int, int]]]: + """Load and validate segmentation plan from JSON. + + Args: + segments_json: Path to JSON object mapping episode id strings to arrays + of ``[start, end]`` pairs. + + Returns: + Mapping from integer episode id to list of ``(start, end)`` tuples. + + Raises: + ValueError: If JSON content is missing or malformed. + """ + if not segments_json.is_file(): + raise ValueError(f"Segments JSON file does not exist: {segments_json}") + + with open(segments_json) as f: + data = json.load(f) + + if not isinstance(data, dict): + raise ValueError("Segments JSON must be an object mapping episode ids to segment lists.") + + segments_by_episode: dict[int, list[tuple[int, int]]] = {} + for episode_key, segment_list in data.items(): + if not isinstance(episode_key, str): + raise ValueError("Episode ids in segments JSON must be strings.") + try: + episode_id = int(episode_key) + except ValueError as exc: + raise ValueError(f"Episode id '{episode_key}' is not a valid integer string.") from exc + + if not isinstance(segment_list, list) or len(segment_list) == 0: + raise ValueError(f"Episode '{episode_key}' must map to a non-empty list of [start, end] pairs.") + + validated_segments: list[tuple[int, int]] = [] + for seg_idx, pair in enumerate(segment_list): + if not isinstance(pair, (list, tuple)) or len(pair) != 2: + raise ValueError( + f"Segment #{seg_idx} for episode '{episode_key}' must be a pair [start, end]." + ) + start_raw, end_raw = pair + if not isinstance(start_raw, int) or isinstance(start_raw, bool): + raise ValueError( + f"Segment start must be an integer for episode '{episode_key}', got {start_raw}." + ) + if not isinstance(end_raw, int) or isinstance(end_raw, bool): + raise ValueError( + f"Segment end must be an integer for episode '{episode_key}', got {end_raw}." + ) + start, end = int(start_raw), int(end_raw) + if start < 0 or end < 0: + raise ValueError( + f"Segment [{start}, {end}] for episode '{episode_key}' must be non-negative." + ) + if end <= start: + raise ValueError( + f"Segment [{start}, {end}] for episode '{episode_key}' must satisfy end > start." + ) + validated_segments.append((start, end)) + + segments_by_episode[episode_id] = validated_segments + + if not segments_by_episode: + raise ValueError("Segments JSON produced an empty segmentation plan.") + return segments_by_episode + + def _to_numpy_for_stats(column: pa.ChunkedArray) -> np.ndarray: """Convert an Arrow chunked column to a NumPy array for stats. @@ -205,6 +227,8 @@ def _copy_segment_images_and_rewrite_column( if isinstance(image_bytes, (bytes, bytearray)) and len(image_bytes) > 0: dst_path.write_bytes(bytes(image_bytes)) new_cell = dict(cell) + # Avoid duplicating image payload in parquet after materializing to disk. + new_cell["bytes"] = None new_cell["path"] = str(dst_path) rewritten_cells.append(new_cell) continue @@ -241,6 +265,7 @@ def _copy_segment_images_and_rewrite_column( rewritten_cells.append(str(dst_path)) else: new_cell = dict(cell) + new_cell["bytes"] = None new_cell["path"] = str(dst_path) rewritten_cells.append(new_cell) @@ -250,16 +275,15 @@ def _copy_segment_images_and_rewrite_column( def segment_dataset( input_root: Path, output_root: Path, - episode_id: int, - segments: list[tuple[int, int]], + segments_by_episode: dict[int, list[tuple[int, int]]], ) -> None: """Create a new segmented dataset from a source episode. Args: input_root: Source LeRobot dataset directory (v2.0 or v2.1). output_root: Destination directory for the new dataset. Must not exist. - episode_id: Source episode index to slice. - segments: List of ``(start, end)`` frame ranges in ``[start, end)`` form. + segments_by_episode: Mapping from source episode id to list of + ``(start, end)`` frame ranges in ``[start, end)`` form. Notes: For visual features (``dtype`` in ``{"image", "video"}``), per-episode @@ -278,8 +302,8 @@ def segment_dataset( raise ValueError(f"Input dataset root does not exist: {input_root}") if output_root.exists(): raise ValueError(f"Output dataset root already exists: {output_root}") - if not segments: - raise ValueError("At least one segment must be provided.") + if not segments_by_episode: + raise ValueError("At least one episode with segment ranges must be provided.") source_meta = LeRobotDatasetMetadata(repo_id=input_root.name, root=input_root) source_version = str(source_meta._version) @@ -288,25 +312,18 @@ def segment_dataset( "Only LeRobot dataset format v2.0 and v2.1 are supported as input by this script. " f"Found codebase_version={source_version}." ) - if episode_id not in source_meta.episodes: - raise ValueError(f"Episode {episode_id} not found in source dataset.") - - source_episode = source_meta.episodes[episode_id] - source_length = int(source_episode["length"]) - for start, end in segments: - if end > source_length: - raise ValueError( - f"Segment ({start}, {end}) is out of bounds for source episode length {source_length}." - ) - - source_parquet_path = input_root / source_meta.get_data_file_path(episode_id) - if not source_parquet_path.is_file(): - raise ValueError(f"Missing source parquet file for episode {episode_id}: {source_parquet_path}") - source_table = pq.read_table(source_parquet_path) - if source_table.num_rows != source_length: - raise ValueError( - f"Source metadata length ({source_length}) does not match parquet row count ({source_table.num_rows})." - ) + for source_episode_id, episode_segments in segments_by_episode.items(): + if source_episode_id not in source_meta.episodes: + raise ValueError(f"Episode {source_episode_id} not found in source dataset.") + if not episode_segments: + raise ValueError(f"Episode {source_episode_id} has no segment ranges.") + source_length = int(source_meta.episodes[source_episode_id]["length"]) + for start, end in episode_segments: + if not (0 <= start < end <= source_length): + raise ValueError( + f"Segment ({start}, {end}) is invalid for source episode length {source_length}. " + "Expected 0 <= start < end <= source_length." + ) output_root.mkdir(parents=True, exist_ok=False) chunks_size = int(source_meta.chunks_size) @@ -318,127 +335,145 @@ def segment_dataset( for task_index, task in sorted(source_meta.tasks.items(), key=lambda x: x[0]): append_jsonlines({"task_index": task_index, "task": task}, output_root / TASKS_PATH) - source_episode_stats = source_meta.episodes_stats.get(cast(Any, episode_id), {}) visual_keys = [k for k, ft in source_meta.features.items() if ft["dtype"] in ["image", "video"]] - - for output_episode_index, (start, end) in enumerate(segments): - seg_len = end - start - seg_table = source_table.slice(start, seg_len) - - replacement_arrays = { - "episode_index": pa.array(np.full(seg_len, output_episode_index, dtype=np.int64)), - "frame_index": pa.array(np.arange(seg_len, dtype=np.int64)), - "index": pa.array(np.arange(global_index_offset, global_index_offset + seg_len, dtype=np.int64)), - } - for key, arr in replacement_arrays.items(): - col_idx = seg_table.schema.get_field_index(key) - if col_idx >= 0: - seg_table = seg_table.set_column(col_idx, key, arr) - - # Recompute timestamps from local frame_index to avoid subtraction drift. - if "timestamp" in seg_table.column_names: - ts_idx = seg_table.schema.get_field_index("timestamp") - recomputed_ts = np.arange(seg_len, dtype=np.float64) / float(source_meta.fps) - seg_table = seg_table.set_column( - ts_idx, - "timestamp", - pa.array(recomputed_ts, type=seg_table.schema.field("timestamp").type), - ) - - # For image-based datasets, copy only the segment frames and rewrite image references. - image_keys = [k for k, ft in source_meta.features.items() if ft["dtype"] == "image"] - for image_key in image_keys: - if image_key not in seg_table.column_names: - continue - col_idx = seg_table.schema.get_field_index(image_key) - image_cells = seg_table.column(image_key).to_pylist() - rewritten = _copy_segment_images_and_rewrite_column( - image_cells=image_cells, - input_root=input_root, - output_root=output_root, - image_key=image_key, - output_episode_index=output_episode_index, - source_episode_index=episode_id, - source_segment_start=start, - ) - seg_table = seg_table.set_column( - col_idx, image_key, pa.array(rewritten, type=seg_table.schema.field(image_key).type) - ) - - episode_chunk = output_episode_index // chunks_size - output_parquet_path = output_root / source_meta.data_path.format( - episode_chunk=episode_chunk, - episode_index=output_episode_index, - ) - output_parquet_path.parent.mkdir(parents=True, exist_ok=True) - pq.write_table(seg_table, output_parquet_path) - - # Build tasks list for this segment from task_index values present in rows. - task_indices = seg_table.column("task_index").to_pylist() - seen_task_indices = set() - episode_tasks = [] - for task_index in task_indices: - if task_index in seen_task_indices: - continue - seen_task_indices.add(task_index) - episode_tasks.append(source_meta.tasks[int(task_index)]) - - output_episodes.append( - { - "episode_index": output_episode_index, - "tasks": episode_tasks, - "length": seg_len, - } - ) - - stats_features = { - key: feature - for key, feature in source_meta.features.items() - if feature["dtype"] not in ["image", "video", "string"] and key in seg_table.column_names - } - stats_data: dict[str, list[str] | np.ndarray] = { - key: _to_numpy_for_stats(seg_table.column(key)) for key in stats_features - } - episode_stats = compute_episode_stats(stats_data, stats_features) - - # Keep visual keys in stats for downstream compatibility. - for key in visual_keys: - if key in source_episode_stats: - source_key_stats = cast(dict[str, np.ndarray], source_episode_stats[key]) - copied = {metric: np.array(val, copy=True) for metric, val in source_key_stats.items()} - if "count" in copied: - copied["count"] = np.array([seg_len], dtype=np.int64) - episode_stats[key] = copied - - write_episode_stats(output_episode_index, episode_stats, output_root) - - global_index_offset += seg_len - total_frames += seg_len - - # Copy source video episode for each requested segment episode, if any. + image_keys = [k for k, ft in source_meta.features.items() if ft["dtype"] == "image"] video_path_template = source_meta.video_path if source_meta.video_keys and video_path_template is None: raise ValueError("Source dataset declares video keys but has no video_path template in metadata.") video_path_template_str = cast(str, video_path_template) if video_path_template is not None else "" + source_tables: dict[int, pa.Table] = {} + output_episode_index = 0 + + for episode_id, segments in segments_by_episode.items(): + source_episode_stats = source_meta.episodes_stats.get(cast(Any, episode_id), {}) + if episode_id not in source_tables: + source_parquet_path = input_root / source_meta.get_data_file_path(episode_id) + if not source_parquet_path.is_file(): + raise ValueError( + f"Missing source parquet file for episode {episode_id}: {source_parquet_path}" + ) + source_table = pq.read_table(source_parquet_path) + source_length = int(source_meta.episodes[episode_id]["length"]) + if source_table.num_rows != source_length: + raise ValueError( + f"Source metadata length ({source_length}) does not match parquet row count ({source_table.num_rows})." + ) + source_tables[episode_id] = source_table + source_table = source_tables[episode_id] + + for start, end in segments: + seg_len = end - start + seg_table = source_table.slice(start, seg_len) + + replacement_arrays = { + "episode_index": pa.array(np.full(seg_len, output_episode_index, dtype=np.int64)), + "frame_index": pa.array(np.arange(seg_len, dtype=np.int64)), + "index": pa.array( + np.arange(global_index_offset, global_index_offset + seg_len, dtype=np.int64) + ), + } + for key, arr in replacement_arrays.items(): + col_idx = seg_table.schema.get_field_index(key) + if col_idx >= 0: + seg_table = seg_table.set_column(col_idx, key, arr) + + # Recompute timestamps from local frame_index to avoid subtraction drift. + if "timestamp" in seg_table.column_names: + ts_idx = seg_table.schema.get_field_index("timestamp") + recomputed_ts = np.arange(seg_len, dtype=np.float64) / float(source_meta.fps) + seg_table = seg_table.set_column( + ts_idx, + "timestamp", + pa.array(recomputed_ts, type=seg_table.schema.field("timestamp").type), + ) + + # For image-based datasets, copy only the segment frames and rewrite image references. + for image_key in image_keys: + if image_key not in seg_table.column_names: + continue + col_idx = seg_table.schema.get_field_index(image_key) + image_cells = seg_table.column(image_key).to_pylist() + rewritten = _copy_segment_images_and_rewrite_column( + image_cells=image_cells, + input_root=input_root, + output_root=output_root, + image_key=image_key, + output_episode_index=output_episode_index, + source_episode_index=episode_id, + source_segment_start=start, + ) + seg_table = seg_table.set_column( + col_idx, image_key, pa.array(rewritten, type=seg_table.schema.field(image_key).type) + ) - for video_key in source_meta.video_keys: - src_video_path = input_root / source_meta.get_video_file_path(episode_id, video_key) - if not src_video_path.is_file(): - raise ValueError(f"Missing source video for key '{video_key}': {src_video_path}") - for output_episode_index, (start, end) in enumerate(segments): episode_chunk = output_episode_index // chunks_size - dst_video_path = output_root / video_path_template_str.format( + output_parquet_path = output_root / source_meta.data_path.format( episode_chunk=episode_chunk, - video_key=video_key, episode_index=output_episode_index, ) - dst_video_path.parent.mkdir(parents=True, exist_ok=True) - _trim_video_segment(src_video_path, dst_video_path, start, end) + output_parquet_path.parent.mkdir(parents=True, exist_ok=True) + pq.write_table(seg_table, output_parquet_path) + + # Build tasks list for this segment from task_index values present in rows. + task_indices = seg_table.column("task_index").to_pylist() + seen_task_indices = set() + episode_tasks = [] + for task_index in task_indices: + if task_index in seen_task_indices: + continue + seen_task_indices.add(task_index) + episode_tasks.append(source_meta.tasks[int(task_index)]) + + output_episodes.append( + { + "episode_index": output_episode_index, + "tasks": episode_tasks, + "length": seg_len, + } + ) + + stats_features = { + key: feature + for key, feature in source_meta.features.items() + if feature["dtype"] not in ["image", "video", "string"] and key in seg_table.column_names + } + stats_data: dict[str, list[str] | np.ndarray] = { + key: _to_numpy_for_stats(seg_table.column(key)) for key in stats_features + } + episode_stats = compute_episode_stats(stats_data, stats_features) + + # Keep visual keys in stats for downstream compatibility. + for key in visual_keys: + if key in source_episode_stats: + source_key_stats = cast(dict[str, np.ndarray], source_episode_stats[key]) + copied = {metric: np.array(val, copy=True) for metric, val in source_key_stats.items()} + if "count" in copied: + copied["count"] = np.array([seg_len], dtype=np.int64) + episode_stats[key] = copied + + write_episode_stats(output_episode_index, episode_stats, output_root) + + # Copy and trim source videos for this segment, if any. + for video_key in source_meta.video_keys: + src_video_path = input_root / source_meta.get_video_file_path(episode_id, video_key) + if not src_video_path.is_file(): + raise ValueError(f"Missing source video for key '{video_key}': {src_video_path}") + dst_video_path = output_root / video_path_template_str.format( + episode_chunk=episode_chunk, + video_key=video_key, + episode_index=output_episode_index, + ) + dst_video_path.parent.mkdir(parents=True, exist_ok=True) + _trim_video_segment(src_video_path, dst_video_path, start, end) + + global_index_offset += seg_len + total_frames += seg_len + output_episode_index += 1 for episode in output_episodes: append_jsonlines(episode, output_root / EPISODES_PATH) - total_episodes = len(segments) + total_episodes = output_episode_index info = deepcopy(source_meta.info) info["codebase_version"] = CODEBASE_VERSION info["total_episodes"] = total_episodes @@ -452,11 +487,11 @@ def segment_dataset( def main() -> None: """CLI entry point.""" args = parse_args() + segments_by_episode = _load_segments_by_episode(args.segments_json) segment_dataset( input_root=args.input_root, output_root=args.output_root, - episode_id=args.episode_id, - segments=args.segment, + segments_by_episode=segments_by_episode, ) diff --git a/tests/datasets/test_segment_lerobot_dataset.py b/tests/datasets/test_segment_lerobot_dataset.py index 772381f1..db6d6de1 100644 --- a/tests/datasets/test_segment_lerobot_dataset.py +++ b/tests/datasets/test_segment_lerobot_dataset.py @@ -20,10 +20,11 @@ import numpy as np import pyarrow.parquet as pq +import pytest from opentau.datasets.lerobot_dataset import LeRobotDatasetMetadata from opentau.datasets.utils import load_episodes, load_episodes_stats, load_info, write_json, write_stats -from opentau.scripts.segment_lerobot_dataset import segment_dataset +from opentau.scripts.segment_lerobot_dataset import _load_segments_by_episode, segment_dataset def _extract_image_path(cell: Any) -> str | None: @@ -73,8 +74,7 @@ def test_segment_lerobot_v21_dataset(tmp_path: Path, empty_lerobot_dataset_facto segment_dataset( input_root=input_root, output_root=output_root, - episode_id=0, - segments=[(2, 5), (5, 10)], + segments_by_episode={0: [(2, 5), (5, 10)]}, ) info = load_info(output_root) @@ -150,8 +150,7 @@ def test_segment_lerobot_v20_input_outputs_v21(tmp_path: Path, empty_lerobot_dat segment_dataset( input_root=input_root, output_root=output_root, - episode_id=0, - segments=[(0, 3), (3, 8)], + segments_by_episode={0: [(0, 3), (3, 8)]}, ) out_info = load_info(output_root) @@ -199,8 +198,7 @@ def test_segment_lerobot_non_consecutive_and_overlapping_ranges( segment_dataset( input_root=input_root, output_root=output_root, - episode_id=0, - segments=[(0, 10), (18, 23), (5, 15)], + segments_by_episode={0: [(0, 10), (18, 23), (5, 15)]}, ) info = load_info(output_root) @@ -270,8 +268,7 @@ def test_segment_lerobot_copies_image_files_for_segments( segment_dataset( input_root=input_root, output_root=output_root, - episode_id=0, - segments=[(1, 4), (4, 8)], + segments_by_episode={0: [(1, 4), (4, 8)]}, ) out_meta = LeRobotDatasetMetadata(repo_id=output_root.name, root=output_root) @@ -282,6 +279,9 @@ def test_segment_lerobot_copies_image_files_for_segments( ep1_paths = [_extract_image_path(cell) for cell in ep1[image_key]] assert all(path is not None for path in ep0_paths) assert all(path is not None for path in ep1_paths) + for cell in ep0[image_key] + ep1[image_key]: + if isinstance(cell, dict): + assert cell.get("bytes") in (None, b"") for frame_idx, path in enumerate(ep0_paths): assert path is not None @@ -296,6 +296,114 @@ def test_segment_lerobot_copies_image_files_for_segments( assert expected.is_file() +def test_load_segments_by_episode_from_json(tmp_path: Path) -> None: + """Validate JSON segmentation-plan parsing. + + Args: + tmp_path: Temporary directory fixture provided by pytest. + """ + plan_path = tmp_path / "segments.json" + plan_path.write_text('{"0": [[0, 3], [5, 8]], "2": [[10, 12]]}') + parsed = _load_segments_by_episode(plan_path) + assert parsed == {0: [(0, 3), (5, 8)], 2: [(10, 12)]} + + +def test_segment_lerobot_json_plan_with_two_source_episodes( + tmp_path: Path, empty_lerobot_dataset_factory: Any +) -> None: + """Ensure JSON plan can segment multiple source episodes in one run. + + Args: + tmp_path: Temporary directory fixture provided by pytest. + empty_lerobot_dataset_factory: Fixture that creates a writable dataset. + """ + input_root = tmp_path / "source_two_episode_dataset" + output_root = tmp_path / "segmented_two_episode_dataset" + + features = { + "state": {"dtype": "float32", "shape": (1,), "names": None}, + "actions": {"dtype": "float32", "shape": (1,), "names": None}, + } + dataset = empty_lerobot_dataset_factory(root=input_root, features=features, use_videos=False) + + # Episode 0 + for i in range(6): + dataset.add_frame( + { + "state": np.array([float(i)], dtype=np.float32), + "actions": np.array([float(i) + 100.0], dtype=np.float32), + "task": "ep0", + } + ) + dataset.save_episode() + + # Episode 1 + for i in range(6): + dataset.add_frame( + { + "state": np.array([float(i + 10)], dtype=np.float32), + "actions": np.array([float(i) + 200.0], dtype=np.float32), + "task": "ep1", + } + ) + dataset.save_episode() + + plan_path = tmp_path / "segments_two_episodes.json" + plan_path.write_text('{"0": [[1, 4]], "1": [[2, 6]]}') + segments_by_episode = _load_segments_by_episode(plan_path) + segment_dataset( + input_root=input_root, + output_root=output_root, + segments_by_episode=segments_by_episode, + ) + + info = load_info(output_root) + episodes = load_episodes(output_root) + out_meta = LeRobotDatasetMetadata(repo_id=output_root.name, root=output_root) + + assert info["total_episodes"] == 2 + assert [episodes[i]["length"] for i in sorted(episodes)] == [3, 4] + + ep0 = pq.read_table(output_root / out_meta.get_data_file_path(0)).to_pydict() + ep1 = pq.read_table(output_root / out_meta.get_data_file_path(1)).to_pydict() + assert [float(x) for x in ep0["state"]] == [1.0, 2.0, 3.0] + assert [float(x) for x in ep1["state"]] == [12.0, 13.0, 14.0, 15.0] + + +def test_segment_lerobot_rejects_invalid_segment_bounds( + tmp_path: Path, empty_lerobot_dataset_factory: Any +) -> None: + """Ensure segment bounds are validated in segment_dataset(). + + Args: + tmp_path: Temporary directory fixture provided by pytest. + empty_lerobot_dataset_factory: Fixture that creates a writable dataset. + """ + input_root = tmp_path / "source_invalid_segment_dataset" + output_root = tmp_path / "segmented_invalid_segment_dataset" + features = { + "state": {"dtype": "float32", "shape": (1,), "names": None}, + "actions": {"dtype": "float32", "shape": (1,), "names": None}, + } + dataset = empty_lerobot_dataset_factory(root=input_root, features=features, use_videos=False) + for i in range(5): + dataset.add_frame( + { + "state": np.array([float(i)], dtype=np.float32), + "actions": np.array([float(i)], dtype=np.float32), + "task": "invalid-bounds", + } + ) + dataset.save_episode() + + with pytest.raises(ValueError, match="Expected 0 <= start < end <= source_length"): + segment_dataset( + input_root=input_root, + output_root=output_root, + segments_by_episode={0: [(-1, 2)]}, + ) + + def test_trim_video_segment_uses_frame_range_filter(tmp_path: Path) -> None: """Ensure ffmpeg trim command uses frame-range segmentation. From 76d04c864beb0802343326329a1750045603bcce Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Mon, 2 Mar 2026 12:36:20 -0800 Subject: [PATCH 7/8] docs: update documentation on usage of the segment script --- src/opentau/scripts/segment_lerobot_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/opentau/scripts/segment_lerobot_dataset.py b/src/opentau/scripts/segment_lerobot_dataset.py index f3629b3d..ce997a05 100644 --- a/src/opentau/scripts/segment_lerobot_dataset.py +++ b/src/opentau/scripts/segment_lerobot_dataset.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Create a segmented LeRobot v2.1 dataset from a source episode. +"""Create a segmented LeRobot v2.1 dataset from source episodes. This script builds a brand-new dataset where each output episode corresponds to one `[start, end)` frame segment from source episodes defined in a JSON plan. @@ -58,7 +58,7 @@ def parse_args() -> argparse.Namespace: mapping source episode ids to frame-range segments. """ parser = argparse.ArgumentParser( - description="Create a segmented LeRobot v2.1 dataset from one source episode.", + description="Create a segmented LeRobot v2.1 dataset from one or more source episodes.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("input_root", type=Path, help="Path to source LeRobot dataset root.") @@ -277,7 +277,7 @@ def segment_dataset( output_root: Path, segments_by_episode: dict[int, list[tuple[int, int]]], ) -> None: - """Create a new segmented dataset from a source episode. + """Create a new segmented dataset from one or more source episodes. Args: input_root: Source LeRobot dataset directory (v2.0 or v2.1). From 594e1afe49d35980f6951fbb319b0fb75d597328 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Mon, 2 Mar 2026 12:47:50 -0800 Subject: [PATCH 8/8] feat: add an example of segments.json --- configs/examples/segments.json | 18 ++++++++++++++++++ src/opentau/scripts/segment_lerobot_dataset.py | 2 ++ 2 files changed, 20 insertions(+) create mode 100644 configs/examples/segments.json diff --git a/configs/examples/segments.json b/configs/examples/segments.json new file mode 100644 index 00000000..114f3a2a --- /dev/null +++ b/configs/examples/segments.json @@ -0,0 +1,18 @@ +{ + "0": [ + [ + 5, + 15 + ], + [ + 10, + 20 + ] + ], + "1": [ + [ + 0, + 30 + ] + ] +} diff --git a/src/opentau/scripts/segment_lerobot_dataset.py b/src/opentau/scripts/segment_lerobot_dataset.py index ce997a05..b8185197 100644 --- a/src/opentau/scripts/segment_lerobot_dataset.py +++ b/src/opentau/scripts/segment_lerobot_dataset.py @@ -23,6 +23,8 @@ Example: python segment_lerobot_dataset.py ./input_dataset ./output_dataset ./segments.json + +An example of segments.json can be found in `configs/examples/segments.json`. """ import argparse