Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
1135b66
fix(datasets): guard get_safe_version with is_valid_version on SHA re…
shuheng-liu May 14, 2026
ad609f1
fix(datasets): fall back to inferred features when info.json cast fails
shuheng-liu May 14, 2026
0e34b65
fix(datasets): use HF Hub filelock instead of rank-coordinated metada…
shuheng-liu May 14, 2026
40f2a34
fix(datasets): always barrier in load_or_compute_speed_percentiles
shuheng-liu May 14, 2026
25d50cc
build(deps): add py-spy to dev extras for deadlock debugging
shuheng-liu May 14, 2026
808bd4d
Revert "build(deps): add py-spy to dev extras for deadlock debugging"
shuheng-liu May 14, 2026
db438c4
Revert "fix(datasets): use HF Hub filelock instead of rank-coordinate…
shuheng-liu May 14, 2026
65dfad7
fix(datasets): skip episodes with unresolvable task labels in speed b…
shuheng-liu May 14, 2026
59e7d08
fix(datasets): also fall back on ValueError Keys mismatch, not just T…
shuheng-liu May 14, 2026
747e4cc
fix(datasets): fall back on any typed-Dataset schema error, truncate log
shuheng-liu May 14, 2026
c9810cc
fix(datasets): stream parquet to mmap'd Arrow file instead of full in…
shuheng-liu May 14, 2026
0d12744
fix(datasets): write Arrow IPC stream format for Dataset.from_file mmap
shuheng-liu May 14, 2026
4bf9277
fix(datasets): use load_dataset(parquet) for genuine mmap, drop hand-…
shuheng-liu May 14, 2026
f5cc691
fix(datasets): download episode files directly, skip slow snapshot_do…
shuheng-liu May 14, 2026
0f6f249
fix(datasets): route whole-repo pulls to snapshot_download, skip on-d…
shuheng-liu May 14, 2026
1e28204
test(datasets): mock hf_hub_download in lerobot_dataset_factory
shuheng-liu May 14, 2026
9e1f8ec
refactor(datasets): address #304 review feedback
shuheng-liu May 14, 2026
b6bc39d
docs(datasets): clarify load_hf_dataset schema-inference docstring
shuheng-liu May 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 112 additions & 50 deletions src/opentau/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,22 +87,21 @@
import json
import logging
import math
import re
import shutil
import traceback
from abc import abstractmethod
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Callable

import datasets
import numpy as np
import packaging.version
import PIL.Image
import pyarrow.dataset as pa_ds
import torch
import torch.nn.functional as F # noqa: N812
import torch.utils
from datasets import Dataset, DatasetInfo, concatenate_datasets
from datasets import concatenate_datasets, load_dataset
from einops import rearrange
from huggingface_hub import HfApi, hf_hub_download, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME
Expand Down Expand Up @@ -219,6 +218,12 @@ def wrapped(self, idx):

CODEBASE_VERSION = "v2.1"

# Thread-pool width for the per-file `hf_hub_download` fan-out in `download_files`.
# The work is network-I/O-bound (the GIL is released during each request), so a
# width well above the core count is fine; 16 keeps enough round-trips in flight
# without hammering the Hub.
_DOWNLOAD_MAX_WORKERS = 16

# Set of repo_ids for which we've already emitted the "missing control_mode" warning.
# Keyed at module level so duplicates are suppressed across multiple LeRobotDataset
# instances within a single process (e.g., train + val constructed for the same repo).
Expand Down Expand Up @@ -1255,6 +1260,13 @@ def __init__(
# episode_data_index["from"/"to"] (built in self.episodes order in
# get_episode_data_index). Mismatched order would silently return
# rows from the wrong episode for callers that pass an unsorted list.
#
# `self.episodes` is backfilled with the full episode list further down
# when it is None (so downstream indexing always has a concrete list),
# which destroys the "no subset requested" signal. Capture it now so
# `download_episodes` can still distinguish a whole-repo pull from a
# subset pull.
self._episodes_were_specified = episodes is not None
self.episodes = sorted(episodes) if episodes is not None else None
self.tolerance_s = tolerance_s
self.skip_timestamp_check = skip_timestamp_check
Expand Down Expand Up @@ -1358,7 +1370,8 @@ def __init__(
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
self.hf_dataset = self.load_hf_dataset()
except (AssertionError, FileNotFoundError, NotADirectoryError):
self.revision = get_safe_version(self.repo_id, self.revision)
if is_valid_version(self.revision):
self.revision = get_safe_version(self.repo_id, self.revision)
self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset()

Expand Down Expand Up @@ -1537,20 +1550,73 @@ def pull_from_repo(
ignore_patterns=ignore_patterns,
)

@on_accelerate_main_proc(local=True, _sync=True)
def download_files(self, files: list[str]) -> None:
"""Fetch the files from `files` that are missing on disk, one `hf_hub_download` each.

`snapshot_download(allow_patterns=<thousands of explicit paths>)` spends
many minutes GIL-held and I/O-idle inside `filter_repo_objects`, whose
fnmatch loop is O(repo_files x patterns) — long enough to trip the NCCL
watchdog while sibling ranks wait at the `_sync` broadcast.
`hf_hub_download` targets each file by exact path, skipping the filter
entirely; a thread pool overlaps the network round-trips (the GIL is
released during I/O).
"""
if not files:
return
# `hf_hub_download` issues a network metadata request per file even
# when the file is already in `local_dir`. Calling it for an
# already-complete episode set burns one HF API request per file and
# trips the 3000 req / 5 min rate limit (429). Skip files already on
# disk — when the selected episodes were pre-downloaded this is a
# no-op that makes zero requests.
missing = [f for f in files if not (self.root / f).is_file()]
if not missing:
return
logging.info(
"%s: %d/%d episode files absent on disk, downloading them",
self.repo_id,
len(missing),
len(files),
)

def _fetch(fpath: str) -> None:
hf_hub_download(
repo_id=self.repo_id,
filename=fpath,
repo_type="dataset",
revision=self.revision,
local_dir=self.root,
)

with ThreadPoolExecutor(max_workers=_DOWNLOAD_MAX_WORKERS) as pool:
# list() forces the lazy map so any per-file failure propagates.
list(pool.map(_fetch, missing))

def download_episodes(self, download_videos: bool = True) -> None:
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
in 'local_dir', they won't be downloaded again.
dataset will be downloaded. Already-present files in 'local_dir' are not re-downloaded.
"""
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
files = None
ignore_patterns = None if download_videos else "videos/"
if self.episodes is not None:
files = self.get_episodes_file_paths()

self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
if not self._episodes_were_specified:
# Whole-dataset download: snapshot_download with no allow_patterns
# has nothing to filter, so there is no filter_repo_objects blowup,
# and it lists the repo tree in O(1) API calls instead of one
# metadata request per file. `self.episodes` has been backfilled to
# the full list by now, so branch on the construction-time flag.
ignore_patterns = None if download_videos else "videos/"
self.pull_from_repo(ignore_patterns=ignore_patterns)
return
# Episode subset: download exactly the needed files directly. Passing
# the explicit per-episode path list to snapshot_download as
# allow_patterns triggers a pathologically slow filter_repo_objects
# scan (see download_files).
files = self.get_episodes_file_paths()
if not download_videos:
files = [f for f in files if not f.startswith("videos/")]
self.download_files(files)

def get_episodes_file_paths(self) -> list[str]:
"""Get file paths for all selected episodes.
Expand All @@ -1576,44 +1642,40 @@ def get_episodes_file_paths(self) -> list[str]:
return fpaths

def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
# Derive the parquet glob from the meta data_path template so that
# datasets with a non-default `info["data_path"]` (deeper nesting,
# flat layout, etc.) keep working. Default template is
# "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
# which yields the glob "data/chunk-*/episode_*.parquet". Assumes the
# template uses simple `{name}` / `{name:fmt}` placeholders and no
# literal `{{`/`}}` escapes — true for every in-repo writer.
glob_pattern = re.sub(r"\{[^}]+\}", "*", self.meta.data_path)
paths = sorted(self.root.glob(glob_pattern))
if not paths:
raise FileNotFoundError(f"No parquet files matching {glob_pattern!r} under {self.root}")
features = get_hf_features_from_features(self.meta.features)
# Read parquet directly via pyarrow.dataset and wrap the resulting
# pa.Table in a HF Dataset. Going through `load_dataset("parquet", ...)`
# or `Dataset.from_parquet(...)` both route through ParquetDatasetBuilder,
# which rewrites the parquet bytes into an uncompressed Arrow cache at
# $HF_HOME/datasets/parquet/ — 1-5x the source size (compression-dependent)
# and one cache entry per distinct (paths, filter) combo. Issue #277 has
# the empirical numbers; verified on physical-intelligence/libero.
#
# Trade-off: `to_table(filter=...)` materializes the filtered rows into
# RAM rather than mmapping a disk-backed Arrow cache. RAM cost scales
# with `len(filtered rows) × avg-row-size`; concretely:
# ~350 MB for physical-intelligence/libero with episodes=[0..9],
# ~46 GB for humanoid-everyday-A-overlay with episodes=None (full corpus).
# Narrow `episodes=` picks are fine; an episodes=None load on a multi-GB
# image-heavy repo will OOM on a small dev box — pass a manageable
# subset, or restore a mmap'd Arrow cache via tmp pa.ipc files if RAM
# ever becomes the binding constraint.
#
# The `Dataset(table, info=DatasetInfo(features=features))` constructor
# signature has been stable since datasets 2.x; the project pin is
# `datasets>=2.19.0`, so we're well inside the supported window.
pa_dataset = pa_ds.dataset(list(map(str, paths)), format="parquet")
filter_expr = pa_ds.field("episode_index").isin(self.episodes) if self.episodes is not None else None
table = pa_dataset.to_table(filter=filter_expr)
hf_dataset = Dataset(table, info=DatasetInfo(features=features))
"""hf_dataset contains all the observations, states, actions, rewards, etc.

Loads the per-episode parquet files via `load_dataset("parquet", ...)`,
which builds a memory-mapped Arrow cache under `$HF_HOME/datasets/`. The
cache costs disk (~1-5x the parquet, compression-dependent — see #277)
and nothing prunes it, so a multi-hundred-GB mixture needs that much
extra disk provisioned. In exchange the loaded dataset is genuinely
memory-mapped (resident pages are file-backed and reclaimable), so RAM
stays bounded by the OS page cache rather than the dataset size. That is
essential here: 8 ranks each load the full mixture, and the
multi-hundred-GB video repos would otherwise OOM the node. A hand-rolled
`pa_ds.to_table()` + `Dataset(table)` (or streaming to a self-written
Arrow IPC file + `Dataset.from_file`) was tried and both materialised
into anonymous RAM instead of mmapping — `load_dataset` routes through
HF's ParquetDatasetBuilder, whose Arrow cache `Dataset.from_file` *does*
memory-map correctly.

Schema is inferred from the parquet files themselves; it is intentionally
not validated against `meta/info.json`. So a parquet/`info.json` mismatch
now loads silently — the parquet's own schema wins and `info.json` is no
longer authoritative. A mismatch *between* the parquet files of a single
dataset still fails, but as a `load_dataset` concatenation error rather
than the old explicit feature-cast error.

Files are passed in sorted-episode order so the hf_dataset row layout
stays aligned with `episode_data_index["from"/"to"]` — see the
`sorted(episodes)` note in __init__. `self.episodes` is always a concrete
list here: __init__ backfills it with the full episode list before this
method is ever called.
"""
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
if not files:
raise FileNotFoundError(f"No parquet files for {self.repo_id} under {self.root}")
hf_dataset = load_dataset("parquet", data_files=files, split="train")
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset

Expand Down
108 changes: 71 additions & 37 deletions src/opentau/datasets/speed_percentiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@
# ``_CONTROL_MODE_WARNED`` pattern in ``lerobot_dataset.py``.
_READONLY_WARNED: set[str] = set()

# Module-level set of unresolved episode task labels we've already warned
# about, so episodes.jsonl / tasks.jsonl drift in a dataset only logs once
# per distinct bad label per process instead of once per episode per rank.
_UNRESOLVED_TASK_WARNED: set[str] = set()


def compute_task_percentiles(
episode_lengths_per_task: dict[int, list[int]],
Expand Down Expand Up @@ -143,13 +148,34 @@ def episode_to_task_index_from_episodes(
``tasks[0]`` — the codebase assumes an N-to-1 episode-to-task
relationship even though the field is structurally a list. Episodes
with an empty / missing ``tasks`` list are skipped.

Episodes whose ``tasks[0]`` is absent from ``task_to_task_index`` are
also skipped (with a deduped warning) rather than raising. This guards
against ``episodes.jsonl`` / ``tasks.jsonl`` drift in dataset metadata
— a `tasks.jsonl` with stale/incomplete entries, or an
``episodes.jsonl`` that stores integer task indices instead of task
strings. Skipped episodes are still trained on; they just fall back to
``SPARSE_TASK_BUCKET`` in the downstream speed-bucket lookup, which
already tolerates a missing ``episode_to_task_index`` entry.
"""
out: dict[int, int] = {}
for ep_idx, ep_info in episodes.items():
tasks = ep_info.get("tasks") or []
if not tasks:
continue
out[ep_idx] = task_to_task_index[tasks[0]]
task_idx = task_to_task_index.get(tasks[0])
if task_idx is None:
key = str(tasks[0])
if key not in _UNRESOLVED_TASK_WARNED:
_UNRESOLVED_TASK_WARNED.add(key)
logging.warning(
"Episode task label %r is not present in tasks.jsonl; episode(s) "
"using it fall back to the sparse speed bucket. This indicates "
"episodes.jsonl / tasks.jsonl drift in the dataset metadata.",
tasks[0],
)
continue
out[ep_idx] = task_idx
return out


Expand Down Expand Up @@ -271,39 +297,47 @@ def load_or_compute_speed_percentiles(
distributed = acc is not None and acc.num_processes > 1
is_main_or_solo = (not distributed) or acc.is_main_process

if path.is_file():
# `episode_to_task_index` already drops episodes with empty tasks,
# so its length is the per-task episode count we want to compare
# against the on-disk sum.
return _read_persisted(path, len(episode_to_task_index), warn=is_main_or_solo)

by_task = _group_lengths_by_task(episode_lengths, episode_to_task_index)
index_to_task = {idx: task for task, idx in task_to_task_index.items()}
percentiles = compute_task_percentiles(by_task)
rows = [
{
"task_index": task_idx,
"task": index_to_task.get(task_idx, ""),
"n_episodes": len(by_task.get(task_idx, [])),
"percentiles": percentiles[task_idx],
}
for task_idx in sorted(percentiles)
]

if is_main_or_solo:
try:
_atomic_write_jsonlines(rows, path)
except (OSError, PermissionError) as e:
root_key = str(root)
if root_key not in _READONLY_WARNED:
_READONLY_WARNED.add(root_key)
logging.warning(
"Could not write speed percentiles to %s (%s); using in-memory "
"values for this run. The compute will repeat on every load until "
"the file can be written.",
path,
e,
)
if distributed:
acc.wait_for_everyone()
return percentiles
# NB: the barrier at the end of this function must run on every code path,
# not just the compute path. Otherwise a rank that arrives *after* rank 0
# has finished writing the file takes the early-return branch (file now
# exists), skips the barrier, and silently desyncs the collective counter
# for every subsequent collective in the mixture-load loop — manifesting
# as a NCCL hang at a much later (and entirely unrelated) sync point.
try:
if path.is_file():
# `episode_to_task_index` already drops episodes with empty tasks,
# so its length is the per-task episode count we want to compare
# against the on-disk sum.
return _read_persisted(path, len(episode_to_task_index), warn=is_main_or_solo)

by_task = _group_lengths_by_task(episode_lengths, episode_to_task_index)
index_to_task = {idx: task for task, idx in task_to_task_index.items()}
percentiles = compute_task_percentiles(by_task)
rows = [
{
"task_index": task_idx,
"task": index_to_task.get(task_idx, ""),
"n_episodes": len(by_task.get(task_idx, [])),
"percentiles": percentiles[task_idx],
}
for task_idx in sorted(percentiles)
]

if is_main_or_solo:
try:
_atomic_write_jsonlines(rows, path)
except (OSError, PermissionError) as e:
root_key = str(root)
if root_key not in _READONLY_WARNED:
_READONLY_WARNED.add(root_key)
logging.warning(
"Could not write speed percentiles to %s (%s); using in-memory "
"values for this run. The compute will repeat on every load until "
"the file can be written.",
path,
e,
)
return percentiles
finally:
if distributed:
acc.wait_for_everyone()
23 changes: 23 additions & 0 deletions tests/datasets/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,29 @@ def test_dataset_no_episodes_loads_all(tmp_path, lerobot_dataset_factory):
_assert_episode_row_alignment(dataset)


def test_download_files_skips_present_files(tmp_path, lerobot_dataset_factory):
"""download_files must not call hf_hub_download for files already on disk.

This is the core of the 429-avoidance fix: a pre-downloaded episode set
should make download_files a no-op with zero Hub requests. Constructing
the dataset already places every selected-episode file on disk, so a
second download_files pass over the same paths must fetch nothing.
"""
dataset = lerobot_dataset_factory(
root=tmp_path / "test",
repo_id=DUMMY_REPO_ID,
total_episodes=10,
total_frames=400,
episodes=[2, 5, 6],
)
files = dataset.get_episodes_file_paths()
assert files, "expected a non-empty file list for the test to be meaningful"
assert all((dataset.root / f).is_file() for f in files), "fixture should pre-place all files"
with patch("opentau.datasets.lerobot_dataset.hf_hub_download") as mock_hf_hub_download:
dataset.download_files(files)
mock_hf_hub_download.assert_not_called()


def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
Expand Down
Loading
Loading