feat(datasets): bucket speed by per-task percentile rank, not duration#302
Conversation
Per-(dataset, task) decile rank of episode length-in-frames replaces the
previous "duration in seconds, rounded to 10s" scheme. The label set
({0, 10, ..., 100}) is unchanged, but the meaning shifts from "this
episode is N seconds long" to "this is a fast/slow example of *this*
task" — a 30s pick-and-place no longer looks the same as a 30s
hand-over.
Per-task percentile boundaries are persisted to
meta/speed_percentiles.jsonl (computed once, skipped if file exists).
Tasks with fewer than 10 distinct episode lengths are flagged sparse
and bucket every episode to 50 (median-equivalent neutral default),
covering both small-N and degenerate all-equal-length cases.
Eval-side EnvMetadataConfig.speed validator loosened from "positive
multiple of 10" to "[0, 100] step 10" so the lowest bucket is
expressible at LIBERO eval time.
shuheng-liu
left a comment
There was a problem hiding this comment.
Review
Nice change — moving from "absolute episode duration in seconds" to "per-task percentile rank" is a clear semantic upgrade for what the policy ends up seeing as the Speed: prefix, and the on-disk {0, 10, ..., 100} label set is preserved so nothing downstream breaks. Tests are comprehensive (boundary walk, sparse fallback, multi-task lists, round-trip persistence) and the libero smoke check is a nice belt-and-braces sanity test. Below are the issues I'd want addressed before merge, roughly in priority order.
Issues
1. Race condition writing meta/speed_percentiles.jsonl under distributed training. LeRobotDataset.__init__ runs on every rank, and load_or_compute_speed_percentiles does path.is_file() → write_jsonlines(...) with no rank guard. write_jsonlines opens the file with "w" (truncate) and writes line-by-line — not atomic. In the typical first-run-on-multi-GPU case, all ranks see the file missing, all compute, all write, and you can get a partially-written or interleaved file; a parallel reader that gets past is_file() while another rank is mid-write will fail to parse.
The codebase already has the idiomatic pattern for this — see lerobot_dataset.py:378-388 and lerobot_dataset.py:1319-1331:
acc = get_proc_accelerator()
if acc is not None and acc.num_processes > 1:
if acc.is_main_process:
# do the write
acc.wait_for_everyone()
else:
# do the writeEither gate the write that way inside load_or_compute_speed_percentiles (cleanest), or do an atomic write_jsonlines → os.replace (write to path.with_suffix(".jsonl.tmp"), rename). Both fix it; the rank-guard version matches the rest of the file.
2. The "file existence is the sole gate" rule has a subtle first-write hazard. Because the percentile file is computed from whatever episodes dict is passed in on first load, anything that filters episodes before init (a debug run with episodes_filter, a small-N CI fixture, a partial download, …) will silently freeze those filtered episodes' percentiles for the dataset root forever. A later full-dataset run reads the stale file and gets wrong buckets, with no warning. Two cheap mitigations:
- Persist
n_episodes_total(sum across rows or a header line) alongside the per-task entries, and recompute when the on-disk total doesn't matchlen(episodes). Cheap, catches the most common foot-gun. - Or, at minimum, log an INFO-level "loaded N task percentiles from " so users have a chance to notice the file is older than they expect.
The "delete the file to recompute" escape hatch is fine as documented behavior, but pairing it with a mismatch warning would save real debugging time.
3. SPEED_BUCKET_SECONDS is now a misnomer. It's still named _SECONDS in envs/configs.py even though the unit is now "percentile-rank step", not "seconds". You updated the comment but the symbol name will mislead the next reader. Suggest SPEED_BUCKET_STEP (and update the one referring file). Drive-by but worth doing while the touched window is still open.
Minor / nits
src/opentau/datasets/lerobot_dataset.py:1421-1427— whentasksis empty, you callbucket_episode_length(self.episode_lengths[ep], None)which always returnsSPARSE_TASK_BUCKET. Equivalent to just assigningSPARSE_TASK_BUCKETdirectly; one less indirection for a reader to trace._episode_lengths_per_taskreadsint(ep_info["length"])while the consumer pre-fill inlerobot_dataset.py:1424-1427readsself.episode_lengths[ep]. They're the same numbers today, but they're separately sourced; worth a one-liner asserting / commenting the invariant, or just funnel both throughself.episode_lengths._READONLY_WARNEDas module-level mutable state is fine in practice, but the test has tomonkeypatch.setattr(sp, "_READONLY_WARNED", set())to reset it — a small smell. Afunctools.lru_cache(maxsize=None)'d helper or a class-levelseton a small dataclass would be cleaner. Not blocking.- The persisted JSONL has no schema-version field; if you ever change the row shape (e.g. add
fps, change percentile points) old files will silently load as the new shape. Adding a single{"version": 1}first row, or aversionkey per row, gives you an upgrade path. Optional. bucket_episode_lengthwith an empty-listpercentiles(whichcompute_task_percentilesnever produces, but a hand-edited file could) returns0silently rather than falling back to the sparse bucket. Defending against this is probably not worth the code, but worth knowing the failure mode.
Test coverage
- 34 cases on the new module is good. The boundary walk + per-task independence + round-trip + read-only fallback + multi-task
tasks[0]cases cover the contract well. - One gap: there's no test for the concurrent-write case (point 1). After fixing, I'd want at least a smoke test that two simultaneous
load_or_compute_speed_percentilescalls on the same root don't corrupt the file — even if just aconcurrent.futurestwo-worker test. - The
TestLiberoSnapshotsmoke test is a nice touch but conditional on a local cache; OK.
CI
Pre-commit ✅, check-checklist ✅, review ✅. CPU tests still running at the time of this review — worth confirming green before flipping out of draft.
Overall direction is right and the implementation is mostly clean — main blocker is the distributed-write race (1), and (2) is a usability landmine that's cheap to defuse.
Generated by Claude Code
Per #302 review: 1. Distributed-write race in load_or_compute_speed_percentiles — rank-gate the write via get_proc_accelerator() (mirroring the pattern at lerobot_dataset.py:1306-1326) and back it with an atomic tmp-file + os.replace. Defends against partial-write corruption both within a single distributed run and across independent processes that share a dataset root. 2. Stale-file warning — sum the on-disk per-task n_episodes and compare with the current load. A mismatch logs a WARNING (the file is still trusted per spec; the warning just cuts debug time when a small-N first run accidentally freezes the distribution). 3. Rename SPEED_BUCKET_SECONDS -> SPEED_BUCKET_STEP — the unit is no longer seconds, it's a percentile-rank index. 4. Drop the dead `bucket_episode_length(..., None)` branch in LeRobotDataset.__init__'s pre-fill loop; collapse to a direct SPARSE_TASK_BUCKET assignment. 5. Funnel both the percentile compute and the per-episode pre-fill through `self.episode_lengths`. New `episode_to_task_index_from_episodes` helper builds the `{ep_idx: task_idx}` mapping once. 6. Concurrent-write smoke test (8 thread workers) verifies the atomic-rename path produces a consistent file.
|
Thanks for the review — pushed 40d16f0 addressing all three primary issues plus the two flagged nits. Reply per item: 1. Distributed-write race — fixed
2. Stale-file first-write hazard — WARNING added; no auto-recompute
I left auto-recompute on mismatch out — per spec "skip if file exists" is the explicit contract, and silently recomputing would surprise anyone who hand-edited the file (e.g. for ablations). The WARNING is the cheap mitigation you suggested. 3.
|
shuheng-liu
left a comment
There was a problem hiding this comment.
Re-review of 40d16f0
Thanks for the quick turnaround. Walked through the fixup commit end-to-end against the previous review:
Verified ✓
- (1) Distributed write race —
load_or_compute_speed_percentilesnow rank-gates the write throughget_proc_accelerator()and synchronizes withwait_for_everyone(). Mirrors the patterns atlerobot_dataset.py:378-388/:1319-1331. Non-main ranks correctly compute the same percentiles in-memory (deterministic) and return them — slightly redundant compute but functionally right and clearer than blocking on a file read. Belt-and-braces atomic_atomic_write_jsonlineson top is a good move for the cross-process case. - (2) Stale-file warning —
_read_persistedsumsn_episodesand warns on mismatch. Documenting "still trusted, delete to recompute" is the right call.row.get("n_episodes", 0)defensive default is fine. Testtest_stale_file_logs_warning_but_still_used+test_no_stale_warning_when_totals_matchpin both directions. - (3)
SPEED_BUCKET_SECONDS→SPEED_BUCKET_STEP— confirmed via grep, fully renamed inenvs/configs.pyconstant + docstring + validator message + field doc. No stragglers anywhere in the tree. - (Nit) Dead
bucket_episode_length(..., None)branch — collapsed to directSPARSE_TASK_BUCKETassignment. Reader can now see the sparse-bucket fallback in one place. - (Nit) Dual-sourcing of episode lengths — both paths now flow through
self.episode_lengthsvia the newepisode_to_task_index_from_episodes+_group_lengths_by_taskhelpers. Clean.
I'm fine with the leave-as-is decisions on _READONLY_WARNED (matches _CONTROL_MODE_WARNED / _SKIP_TIMESTAMP_WARNED) and schema versioning (not paying the toll until there's an evolution).
New issues from this round
A. CI is flaky on this commit — one CPU Tests run passed, the other failed. Look at https://github.com/TensorAuto/OpenTau/actions/runs/25831516459/job/75897290298 (failed) vs https://github.com/TensorAuto/OpenTau/actions/runs/25831515012/job/75897285365 (passed) — same SHA, started 3 s apart, opposite outcomes. Worth pulling the failing logs and confirming this isn't test_concurrent_writes_produce_valid_file itself (see B). At minimum, please get a clean green run before flipping out of draft.
B. _atomic_write_jsonlines has a tmp-path collision under contention.
def _atomic_write_jsonlines(rows: list[dict], path: Path) -> None:
tmp_path = path.with_suffix(path.suffix + ".tmp") # ← same name for every writer
write_jsonlines(rows, tmp_path)
os.replace(tmp_path, path)All concurrent writers share speed_percentiles.jsonl.tmp. Possible interleaving (your own test_concurrent_writes_produce_valid_file is precisely the workload that can hit this):
- T_a calls
write_jsonlines(rows, tmp)→ opens"w"(truncates), starts writing. - T_b enters the same path.
jsonlines.open(tmp, "w")truncates again under T_a. - Writes interleave; tmp ends up with mixed/garbage content.
- T_a's
os.replace(tmp, path)succeeds →pathnow contains the garbage. - T_b's
os.replace(tmp, path)→FileNotFoundErrorbecause step 4 moved it. (Caught by your(OSError, PermissionError)handler — silent, but the on-diskpathis already corrupted.) - Later thread reads
pathvia_read_persisted→load_jsonlinesblows up on the malformed JSONL.
The Python GIL makes this unlikely for a single-row file (most work is held inside one bytecode chunk), but it's not impossible, and it gets more likely as the per-task row count grows. The concurrent test you added (8 workers, 4 threads) is exactly the right shape to expose it — and that's plausibly the source of the failing CPU Tests run in (A).
Fix is one line: make the tmp path unique per writer.
import uuid
tmp_path = path.with_suffix(f"{path.suffix}.{uuid.uuid4().hex}.tmp")(os.getpid() + threading.get_ident() works too, but UUID is simpler and covers the "two processes" case for free.) Then tighten the test to also assert that path parses as well-formed JSONL after the storm (load_jsonlines(path) round-trip; you almost have this already).
C. Stale-file WARNING fires on every rank. _read_persisted is called from __init__ on every rank, and the logging.warning(...) has no rank guard. In an 8-GPU run a stale file produces 8 identical warning lines, which is the kind of thing that conditions everyone to ignore the warning. The local idiom is if not _SKIP_TIMESTAMP_WARNED and (acc is None or acc.is_main_process): (lerobot_dataset.py:1391). Cheap to add. Could just gate the logging.warning block on (acc is None or acc.is_main_process).
Minor
- The function-level docstring of
load_or_compute_speed_percentilessays "other ranks block onAccelerator.wait_for_everyone()until the file appears, then read it" — they actually return their in-memorypercentilesrather than reading the just-written file. Functionally equivalent, but the doc is misleading. One-line fix in the docstring. current_total = sum(len(v) for v in by_task.values())— equivalent tolen(episode_to_task_index), which would skip a defaultdict materialization step. Drive-by.
Tests
- 39 cases — good. Coverage of the new code paths is solid. After fixing (B), bump the concurrent test to also do a
load_jsonlines(tmp_path / SPEED_PERCENTILES_PATH)round-trip after the storm to catch the corruption mode I described above (rather than only checkingtask_index/n_episodes/no leftover.tmp).
To unblock: (B) is the only thing I'd treat as a real blocker — it's reachable from the test the commit itself adds. (A) needs investigation but might just be (B) manifesting. (C) and the docstring fix are tidy-ups.
Generated by Claude Code
Per re-review of #302: (B) _atomic_write_jsonlines used a shared tmp path (`<file>.tmp`), which two concurrent writers could truncate simultaneously before either rename. Switch to a per-writer UUID-suffixed tmp (`<file>.<uuid>.tmp`) so writers can't stomp each other; clean up orphaned tmp files in the finally block. This is the root cause of the failing CPU test on the prior commit — the new test_concurrent_writes_produce_valid_file is exactly the workload that hit it. (C) Stale-file WARNING in _read_persisted now takes an explicit `warn` parameter; load_or_compute_speed_percentiles gates it on `acc.is_main_process` so an N-rank run produces one warning, not N. Mirrors the _SKIP_TIMESTAMP_WARNED idiom at lerobot_dataset.py:1391. Minor fixups: - Fix the load_or_compute_speed_percentiles docstring: non-main ranks return their in-memory copy, they don't re-read the file. - Replace `sum(len(v) for v in by_task.values())` with `len(episode_to_task_index)` — same number, no defaultdict materialization for the read-from-disk path. - Tighten test_concurrent_writes_produce_valid_file: 5 tasks (not 1) so a torn write would surface as a wrong row count or percentile mismatch, and assert the on-disk content equals the in-memory result.
|
Thanks for catching the tmp-path collision — confirmed it's exactly what tanked the CPU Tests run. Pushed a7097f6 with all four items addressed: (B)
|
shuheng-liu
left a comment
There was a problem hiding this comment.
Re-review of a7097f6
All three items from the previous round are addressed cleanly, and CI is fully green on this SHA (both CPU Tests runs ✅, pre-commit ✅, check-checklist ✅, review ✅) — no more flakiness, which independently confirms (B) was the root cause.
Verified ✓
- (B) Tmp-path collision —
tmp_path = path.with_suffix(f"{path.suffix}.{uuid.uuid4().hex}.tmp")gives each writer its own staging file;try/finallywithcontextlib.suppress(OSError)cleans orphans if anything raises. Nice belt-and-braces. The 30× local sequential run + green CI is a convincing repro→fix loop. - (C) Stale-warning rank-spam —
_read_persistednow takes a keyword-onlywarn: boolandload_or_compute_speed_percentilesthreadsis_main_or_soloin. Computed once at the top, reused for both the warn-gate and the write-gate — clean. - Docstring fix — the distributed paragraph now correctly says non-main ranks return their in-memory copy rather than re-reading. Accurate.
current_totalsimplification — replacingsum(len(v) for v in by_task.values())withlen(episode_to_task_index)is exact (sinceepisode_to_task_indexalready drops empty-tasks episodes), and the read-from-disk fast path no longer needs to materialize_group_lengths_by_taskat all. Nice micro-win.
Tightened test is the right shape
test_concurrent_writes_produce_valid_file now exercises 16 jobs through 8 threads across 5 distinct tasks, asserts from_disk == first (round-trip), full row count, and globs for any *.tmp leftover under the new naming. A torn write would now surface as a percentile mismatch in the round-trip equality check rather than silently passing — that's the right defensive shape.
Minor (non-blocking) observations
- The
finally-block cleanup usestmp_path.exists()+tmp_path.unlink()— there's a TOCTOU betweenexists()andunlink()but it's not exploitable here (only this writer knows the UUID). Could simplify towith contextlib.suppress(FileNotFoundError, OSError): tmp_path.unlink()and skip theexists()check, but the current form is also fine. len(episode_to_task_index)vs the on-disksum(n_episodes): if a hand-edited percentile file omits or duplicates a task row, the warning will fire correctly, but the comparison is now strictly "row-count agreement" rather than "per-task agreement". That's a documented relaxation given the "delete to recompute" escape hatch — calling it out, not asking to change it.
LGTM — this looks ready to flip out of draft. Nothing else from me unless you spot something I missed.
Generated by Claude Code
What this does
The
speedoptional key emitted byLeRobotDataset.__getitem__was a global "duration in seconds, rounded to multiples of 10" bucket (introduced in #295):The label is the same number across every task in every dataset — so a 30-second pick-and-place looks the same as a 30-second hand-over even though one is a slow example of a fast task and the other is a fast example of a slow task.
Replace with a per-(dataset, task) decile rank of episode length-in-frames. Per task, compute
[p5, p15, ..., p95]of the lengths and bucket each episode against its own task's distribution:length < p5→0(fastest)p_X <= length < p_Y→(X+5) * 10(tie at boundary lands in upper bucket vianp.searchsorted(side='right'))length >= p95→100(slowest)The label set
{0, 10, 20, ..., 100}is unchanged, so:_emit_optional_keyscontract (long tensor, multiple of 10) still holds,"Speed: <int>"is unchanged,EnvMetadataConfig.speedkeeps the same range (with one validator tweak below).The semantic shift: the policy now sees "this is a fast example of this task" rather than "this episode is N seconds long".
Persistence
Per-task percentiles are computed once per dataset and persisted to
meta/speed_percentiles.jsonlnext tometa/episodes.jsonl/meta/tasks.jsonl:{"task_index": 0, "task": "pick up the red block", "n_episodes": 47, "percentiles": [120.0, 145.0, ..., 340.0]} {"task_index": 3, "task": "open the drawer", "n_episodes": 2, "percentiles": null}Existence of the file is the sole gate — staleness is accepted by design (delete the file to force a recompute). On read-only roots (HF snapshot caches) the dict is returned in-memory and a one-time WARNING is logged.
Edge cases
percentiles: nullin the file → bucket50(median-equivalent neutral default) at runtime.taskslists → silently usetasks[0]per the codebase's standing N-to-1 episode→task assumption.EnvMetadataConfig.speed = 0— the existing validator rejected non-positive values; loosened to[0, 100]step 10 so the lowest bucket is expressible at LIBERO eval time.WeightedDatasetMixture— no mixture-level changes needed; each constituent dataset persists its own percentile file (a task in dataset A has its own length distribution).Files changed
src/opentau/datasets/speed_percentiles.py—compute_task_percentiles,bucket_episode_length,load_or_compute_speed_percentiles, plus the constants (SPEED_PERCENTILES,SPEED_BUCKET_LABELS,MIN_EPISODES_FOR_PERCENTILES = 10,SPARSE_TASK_BUCKET = 50).src/opentau/datasets/lerobot_dataset.py— drop the oldspeed_duration_bucket_shelper;LeRobotDataset.__init__callsload_or_compute_speed_percentilesand pre-fillsself.speed_raw_by_episode;__getitem__is now a single dict lookup.src/opentau/envs/configs.py— loosenEnvMetadataConfig.speedvalidator to[0, 100]step 10; update thespeed:andSPEED_BUCKET_SECONDSdocstrings to point at the new module.docs/source/concepts.rst— rewrite thespeed/speed_is_padblock.tests/datasets/test_speed_percentiles.py(34 cases); remove the oldTestSpeedDurationBucket; tighten thetest_attach_metadataassertion to also capspeed <= 100; extendtest_configsparametrize lists; replace hardcodedtorch.tensor([500])withtorch.tensor([50])across pi07 policy tests.How it was tested
pre-commit run --files <all changed files>— all hooks pass (ruff lint+format, pyupgrade, bandit, typos, License Header, gitleaks).pytest tests/datasets/test_speed_percentiles.py -x— 34 passed.pytest tests/datasets/ tests/envs/ tests/policies/test_pi07_low_level.py tests/policies/test_pi07_high_level_planner.py tests/policies/test_pi07_paligemma_low_level_planner.py tests/policies/test_pi07_paligemma_high_level_planner.py tests/scripts/test_attach_metadata.py -m "not gpu" --ignore=tests/envs/test_factory.py -n auto— 542 passed, 7 skipped (GPU fixtures), 0 failed. The two skippedtest_factory.pytests fail withModuleNotFoundError: No module named 'libero'because libero is an optional extra; pre-existing, unrelated.physical-intelligence/libero(1693 episodes, 40 tasks, all well populated): every task gets 10 ascending percentile boundaries with meaningful task-relative ranges (e.g. task 2 p95=387 vs task 4 p95=304 — exactly what the per-task framing should capture). Bucket distributions are roughly uniform across the 11 buckets per task (~3-5 episodes per bucket out of ~40).Policy code consumes
batch["speed"]only by stringifying it into the metadata prompt, so it's unit-agnostic — no policy-side changes needed.How to checkout & try? (for the reviewer)
Verify per-task bucketing on a real LeRobot dataset (loads
physical-intelligence/liberofrom your HF cache):Then run a smoke training step to confirm the
__getitem__path and the persisted file:opentau-train --accelerate-config configs/examples/accelerate_ddp_config.yaml \ --config_path=configs/examples/pi05_training_config.json ls <dataset-root>/meta/speed_percentiles.jsonl # written by the runChecklist
Note: Before submitting this PR, please read the contributor guideline.