Skip to content

feat(datasets): bucket speed by per-task percentile rank, not duration#302

Merged
shuheng-liu merged 3 commits into
mainfrom
claude/funny-bell-07f90c
May 14, 2026
Merged

feat(datasets): bucket speed by per-task percentile rank, not duration#302
shuheng-liu merged 3 commits into
mainfrom
claude/funny-bell-07f90c

Conversation

@shuheng-liu
Copy link
Copy Markdown
Member

What this does

The speed optional key emitted by LeRobotDataset.__getitem__ was a global "duration in seconds, rounded to multiples of 10" bucket (introduced in #295):

duration_s = self.episode_lengths[ep_idx] / self.fps
item["speed_raw"] = int(round(duration_s / 10) * 10)

The label is the same number across every task in every dataset — so a 30-second pick-and-place looks the same as a 30-second hand-over even though one is a slow example of a fast task and the other is a fast example of a slow task.

Replace with a per-(dataset, task) decile rank of episode length-in-frames. Per task, compute [p5, p15, ..., p95] of the lengths and bucket each episode against its own task's distribution:

  • length < p50 (fastest)
  • p_X <= length < p_Y(X+5) * 10 (tie at boundary lands in upper bucket via np.searchsorted(side='right'))
  • length >= p95100 (slowest)

The label set {0, 10, 20, ..., 100} is unchanged, so:

  • the _emit_optional_keys contract (long tensor, multiple of 10) still holds,
  • the pi07 policy prefix string "Speed: <int>" is unchanged,
  • EnvMetadataConfig.speed keeps the same range (with one validator tweak below).

The semantic shift: the policy now sees "this is a fast example of this task" rather than "this episode is N seconds long".

Persistence

Per-task percentiles are computed once per dataset and persisted to meta/speed_percentiles.jsonl next to meta/episodes.jsonl / meta/tasks.jsonl:

{"task_index": 0, "task": "pick up the red block", "n_episodes": 47, "percentiles": [120.0, 145.0, ..., 340.0]}
{"task_index": 3, "task": "open the drawer", "n_episodes": 2, "percentiles": null}

Existence of the file is the sole gate — staleness is accepted by design (delete the file to force a recompute). On read-only roots (HF snapshot caches) the dict is returned in-memory and a one-time WARNING is logged.

Edge cases

  • Tasks with < 10 distinct episode lengths (covers 0/1/2-9-episode tasks and the all-equal-length degenerate case) → percentiles: null in the file → bucket 50 (median-equivalent neutral default) at runtime.
  • Episodes with multi-task tasks lists → silently use tasks[0] per the codebase's standing N-to-1 episode→task assumption.
  • Eval-time EnvMetadataConfig.speed = 0 — the existing validator rejected non-positive values; loosened to [0, 100] step 10 so the lowest bucket is expressible at LIBERO eval time.
  • WeightedDatasetMixture — no mixture-level changes needed; each constituent dataset persists its own percentile file (a task in dataset A has its own length distribution).

Files changed

  • New module src/opentau/datasets/speed_percentiles.pycompute_task_percentiles, bucket_episode_length, load_or_compute_speed_percentiles, plus the constants (SPEED_PERCENTILES, SPEED_BUCKET_LABELS, MIN_EPISODES_FOR_PERCENTILES = 10, SPARSE_TASK_BUCKET = 50).
  • src/opentau/datasets/lerobot_dataset.py — drop the old speed_duration_bucket_s helper; LeRobotDataset.__init__ calls load_or_compute_speed_percentiles and pre-fills self.speed_raw_by_episode; __getitem__ is now a single dict lookup.
  • src/opentau/envs/configs.py — loosen EnvMetadataConfig.speed validator to [0, 100] step 10; update the speed: and SPEED_BUCKET_SECONDS docstrings to point at the new module.
  • docs/source/concepts.rst — rewrite the speed / speed_is_pad block.
  • Tests — new tests/datasets/test_speed_percentiles.py (34 cases); remove the old TestSpeedDurationBucket; tighten the test_attach_metadata assertion to also cap speed <= 100; extend test_configs parametrize lists; replace hardcoded torch.tensor([500]) with torch.tensor([50]) across pi07 policy tests.

How it was tested

  • pre-commit run --files <all changed files> — all hooks pass (ruff lint+format, pyupgrade, bandit, typos, License Header, gitleaks).
  • pytest tests/datasets/test_speed_percentiles.py -x — 34 passed.
  • pytest tests/datasets/ tests/envs/ tests/policies/test_pi07_low_level.py tests/policies/test_pi07_high_level_planner.py tests/policies/test_pi07_paligemma_low_level_planner.py tests/policies/test_pi07_paligemma_high_level_planner.py tests/scripts/test_attach_metadata.py -m "not gpu" --ignore=tests/envs/test_factory.py -n auto — 542 passed, 7 skipped (GPU fixtures), 0 failed. The two skipped test_factory.py tests fail with ModuleNotFoundError: No module named 'libero' because libero is an optional extra; pre-existing, unrelated.
  • Real-data check on physical-intelligence/libero (1693 episodes, 40 tasks, all well populated): every task gets 10 ascending percentile boundaries with meaningful task-relative ranges (e.g. task 2 p95=387 vs task 4 p95=304 — exactly what the per-task framing should capture). Bucket distributions are roughly uniform across the 11 buckets per task (~3-5 episodes per bucket out of ~40).

Policy code consumes batch["speed"] only by stringifying it into the metadata prompt, so it's unit-agnostic — no policy-side changes needed.

How to checkout & try? (for the reviewer)

gh pr checkout <PR-number>
uv sync --extra dev
pytest tests/datasets/test_speed_percentiles.py -x
pytest tests/datasets/ tests/envs/ tests/policies/test_pi07_low_level.py \
       tests/policies/test_pi07_high_level_planner.py \
       tests/policies/test_pi07_paligemma_low_level_planner.py \
       tests/policies/test_pi07_paligemma_high_level_planner.py \
       tests/scripts/test_attach_metadata.py \
       -m "not gpu" --ignore=tests/envs/test_factory.py -n auto

Verify per-task bucketing on a real LeRobot dataset (loads physical-intelligence/libero from your HF cache):

python -c "
import json
from collections import defaultdict
from opentau.datasets.speed_percentiles import compute_task_percentiles, bucket_episode_length

ep_path = '<HF_CACHE>/lerobot/physical-intelligence/libero/meta/episodes.jsonl'
by_task = defaultdict(list)
with open(ep_path) as f:
    for line in f:
        d = json.loads(line)
        by_task[d['tasks'][0]].append(d['length'])
as_idx = {i: by_task[t] for i, t in enumerate(by_task)}
out = compute_task_percentiles(as_idx)
for idx in list(as_idx)[:5]:
    pcts = out[idx]
    bins = [bucket_episode_length(L, pcts) for L in as_idx[idx]]
    print(f'task {idx}: n={len(as_idx[idx])}, p5..p95={[round(p,1) for p in pcts]}')
    print(f'         bucket histogram = {dict(sorted({b: bins.count(b) for b in set(bins)}.items()))}')
"

Then run a smoke training step to confirm the __getitem__ path and the persisted file:

opentau-train --accelerate-config configs/examples/accelerate_ddp_config.yaml \
              --config_path=configs/examples/pi05_training_config.json
ls <dataset-root>/meta/speed_percentiles.jsonl  # written by the run

Checklist

  • I have added Google-style docstrings to important functions and ensured function parameters are typed.
  • My PR includes policy-related changes.
    • If the above is checked: I have run the GPU pytests (pytest -m "gpu") and regression tests.

Note: Before submitting this PR, please read the contributor guideline.

Per-(dataset, task) decile rank of episode length-in-frames replaces the
previous "duration in seconds, rounded to 10s" scheme. The label set
({0, 10, ..., 100}) is unchanged, but the meaning shifts from "this
episode is N seconds long" to "this is a fast/slow example of *this*
task" — a 30s pick-and-place no longer looks the same as a 30s
hand-over.

Per-task percentile boundaries are persisted to
meta/speed_percentiles.jsonl (computed once, skipped if file exists).
Tasks with fewer than 10 distinct episode lengths are flagged sparse
and bucket every episode to 50 (median-equivalent neutral default),
covering both small-N and degenerate all-equal-length cases.

Eval-side EnvMetadataConfig.speed validator loosened from "positive
multiple of 10" to "[0, 100] step 10" so the lowest bucket is
expressible at LIBERO eval time.
@shuheng-liu shuheng-liu added the feature New feature or request label May 13, 2026
@shuheng-liu shuheng-liu self-assigned this May 13, 2026
Copy link
Copy Markdown
Member Author

@shuheng-liu shuheng-liu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review

Nice change — moving from "absolute episode duration in seconds" to "per-task percentile rank" is a clear semantic upgrade for what the policy ends up seeing as the Speed: prefix, and the on-disk {0, 10, ..., 100} label set is preserved so nothing downstream breaks. Tests are comprehensive (boundary walk, sparse fallback, multi-task lists, round-trip persistence) and the libero smoke check is a nice belt-and-braces sanity test. Below are the issues I'd want addressed before merge, roughly in priority order.

Issues

1. Race condition writing meta/speed_percentiles.jsonl under distributed training. LeRobotDataset.__init__ runs on every rank, and load_or_compute_speed_percentiles does path.is_file()write_jsonlines(...) with no rank guard. write_jsonlines opens the file with "w" (truncate) and writes line-by-line — not atomic. In the typical first-run-on-multi-GPU case, all ranks see the file missing, all compute, all write, and you can get a partially-written or interleaved file; a parallel reader that gets past is_file() while another rank is mid-write will fail to parse.

The codebase already has the idiomatic pattern for this — see lerobot_dataset.py:378-388 and lerobot_dataset.py:1319-1331:

acc = get_proc_accelerator()
if acc is not None and acc.num_processes > 1:
    if acc.is_main_process:
        # do the write
    acc.wait_for_everyone()
else:
    # do the write

Either gate the write that way inside load_or_compute_speed_percentiles (cleanest), or do an atomic write_jsonlinesos.replace (write to path.with_suffix(".jsonl.tmp"), rename). Both fix it; the rank-guard version matches the rest of the file.

2. The "file existence is the sole gate" rule has a subtle first-write hazard. Because the percentile file is computed from whatever episodes dict is passed in on first load, anything that filters episodes before init (a debug run with episodes_filter, a small-N CI fixture, a partial download, …) will silently freeze those filtered episodes' percentiles for the dataset root forever. A later full-dataset run reads the stale file and gets wrong buckets, with no warning. Two cheap mitigations:

  • Persist n_episodes_total (sum across rows or a header line) alongside the per-task entries, and recompute when the on-disk total doesn't match len(episodes). Cheap, catches the most common foot-gun.
  • Or, at minimum, log an INFO-level "loaded N task percentiles from " so users have a chance to notice the file is older than they expect.

The "delete the file to recompute" escape hatch is fine as documented behavior, but pairing it with a mismatch warning would save real debugging time.

3. SPEED_BUCKET_SECONDS is now a misnomer. It's still named _SECONDS in envs/configs.py even though the unit is now "percentile-rank step", not "seconds". You updated the comment but the symbol name will mislead the next reader. Suggest SPEED_BUCKET_STEP (and update the one referring file). Drive-by but worth doing while the touched window is still open.

Minor / nits

  • src/opentau/datasets/lerobot_dataset.py:1421-1427 — when tasks is empty, you call bucket_episode_length(self.episode_lengths[ep], None) which always returns SPARSE_TASK_BUCKET. Equivalent to just assigning SPARSE_TASK_BUCKET directly; one less indirection for a reader to trace.
  • _episode_lengths_per_task reads int(ep_info["length"]) while the consumer pre-fill in lerobot_dataset.py:1424-1427 reads self.episode_lengths[ep]. They're the same numbers today, but they're separately sourced; worth a one-liner asserting / commenting the invariant, or just funnel both through self.episode_lengths.
  • _READONLY_WARNED as module-level mutable state is fine in practice, but the test has to monkeypatch.setattr(sp, "_READONLY_WARNED", set()) to reset it — a small smell. A functools.lru_cache(maxsize=None)'d helper or a class-level set on a small dataclass would be cleaner. Not blocking.
  • The persisted JSONL has no schema-version field; if you ever change the row shape (e.g. add fps, change percentile points) old files will silently load as the new shape. Adding a single {"version": 1} first row, or a version key per row, gives you an upgrade path. Optional.
  • bucket_episode_length with an empty-list percentiles (which compute_task_percentiles never produces, but a hand-edited file could) returns 0 silently rather than falling back to the sparse bucket. Defending against this is probably not worth the code, but worth knowing the failure mode.

Test coverage

  • 34 cases on the new module is good. The boundary walk + per-task independence + round-trip + read-only fallback + multi-task tasks[0] cases cover the contract well.
  • One gap: there's no test for the concurrent-write case (point 1). After fixing, I'd want at least a smoke test that two simultaneous load_or_compute_speed_percentiles calls on the same root don't corrupt the file — even if just a concurrent.futures two-worker test.
  • The TestLiberoSnapshot smoke test is a nice touch but conditional on a local cache; OK.

CI

Pre-commit ✅, check-checklist ✅, review ✅. CPU tests still running at the time of this review — worth confirming green before flipping out of draft.


Overall direction is right and the implementation is mostly clean — main blocker is the distributed-write race (1), and (2) is a usability landmine that's cheap to defuse.


Generated by Claude Code

Per #302 review:

1. Distributed-write race in load_or_compute_speed_percentiles —
   rank-gate the write via get_proc_accelerator() (mirroring the
   pattern at lerobot_dataset.py:1306-1326) and back it with an
   atomic tmp-file + os.replace. Defends against partial-write
   corruption both within a single distributed run and across
   independent processes that share a dataset root.

2. Stale-file warning — sum the on-disk per-task n_episodes and
   compare with the current load. A mismatch logs a WARNING (the
   file is still trusted per spec; the warning just cuts debug
   time when a small-N first run accidentally freezes the
   distribution).

3. Rename SPEED_BUCKET_SECONDS -> SPEED_BUCKET_STEP — the unit is
   no longer seconds, it's a percentile-rank index.

4. Drop the dead `bucket_episode_length(..., None)` branch in
   LeRobotDataset.__init__'s pre-fill loop; collapse to a direct
   SPARSE_TASK_BUCKET assignment.

5. Funnel both the percentile compute and the per-episode
   pre-fill through `self.episode_lengths`. New
   `episode_to_task_index_from_episodes` helper builds the
   `{ep_idx: task_idx}` mapping once.

6. Concurrent-write smoke test (8 thread workers) verifies the
   atomic-rename path produces a consistent file.
@shuheng-liu
Copy link
Copy Markdown
Member Author

Thanks for the review — pushed 40d16f0 addressing all three primary issues plus the two flagged nits. Reply per item:

1. Distributed-write race — fixed

load_or_compute_speed_percentiles now rank-gates the write via get_proc_accelerator() (mirrors the pattern at lerobot_dataset.py:1306-1326), and on top of that does an atomic tmp-file + os.replace in _atomic_write_jsonlines. Belt-and-braces: rank-gating handles the in-run case, atomic-rename handles the independent-process case (two concurrent training jobs sharing a dataset root). The new test_concurrent_writes_produce_valid_file exercises 8 thread workers against the same tmp_path and asserts the on-disk file is parseable, complete, has no leftover .tmp, and all workers agree.

2. Stale-file first-write hazard — WARNING added; no auto-recompute

_read_persisted now sums the on-disk per-task n_episodes and compares to the current load's total. Mismatch logs a WARNING pointing at the file and suggesting deletion to recompute. New tests test_stale_file_logs_warning_but_still_used and test_no_stale_warning_when_totals_match pin the behavior.

I left auto-recompute on mismatch out — per spec "skip if file exists" is the explicit contract, and silently recomputing would surprise anyone who hand-edited the file (e.g. for ablations). The WARNING is the cheap mitigation you suggested.

3. SPEED_BUCKET_SECONDS rename — done

Renamed to SPEED_BUCKET_STEP in envs/configs.py (the constant, docstring, validator error message, and the field's speed: docstring reference). Verified grep -rn 'SPEED_BUCKET_SECONDS' returns empty.

Nits — addressed

  • Dead bucket_episode_length(..., None) branch: collapsed to a direct SPARSE_TASK_BUCKET assignment in the pre-fill loop; the loop is now driven by self.episode_to_task_index.get(ep).
  • Dual sourcing of episode lengths: new load_or_compute_speed_percentiles signature takes episode_lengths and episode_to_task_index explicitly (built once via episode_to_task_index_from_episodes), so the percentile compute and the per-episode pre-fill go through the same self.episode_lengths dict. No more int(ep_info["length"]) vs. self.episode_lengths[ep] divergence.

Nits — left as-is (called out for transparency)

  • _READONLY_WARNED module-level set: kept. Matches the existing _CONTROL_MODE_WARNED and _SKIP_TIMESTAMP_WARNED patterns at lerobot_dataset.py:219, 227; switching this one to functools.lru_cache or a dataclass would create local-vs-global stylistic drift. The test still uses monkeypatch.setattr to reset it, which is fine for this small case. Happy to revisit if you'd rather unify the three at once.
  • Schema version field: skipped. The cost of adding "version": 1 per row now is paying the format-evolution toll without an evolution in sight — when we actually change the row shape we can pay it then with an explicit migration path. Reviewer marked this "Optional"; I'd rather add it when there's a concrete v2.
  • Empty-percentiles-list silently returns 0: confirmed as a known failure mode for hand-edited files only. Not defending against it.

Test surface

  • tests/datasets/test_speed_percentiles.py: 39 cases (was 34 — added 5: episode→task index helper, multi-task/empty-tasks variants, stale-file warning, no-warning-on-match, concurrent writes).
  • pytest tests/datasets/ tests/envs/ tests/policies/test_pi07_low_level.py tests/policies/test_pi07_high_level_planner.py tests/policies/test_pi07_paligemma_low_level_planner.py tests/policies/test_pi07_paligemma_high_level_planner.py tests/scripts/test_attach_metadata.py -m "not gpu" --ignore=tests/envs/test_factory.py -n auto — 547 passed, 7 skipped, 0 failed.
  • Pre-commit clean.

Copy link
Copy Markdown
Member Author

@shuheng-liu shuheng-liu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-review of 40d16f0

Thanks for the quick turnaround. Walked through the fixup commit end-to-end against the previous review:

Verified ✓

  • (1) Distributed write raceload_or_compute_speed_percentiles now rank-gates the write through get_proc_accelerator() and synchronizes with wait_for_everyone(). Mirrors the patterns at lerobot_dataset.py:378-388 / :1319-1331. Non-main ranks correctly compute the same percentiles in-memory (deterministic) and return them — slightly redundant compute but functionally right and clearer than blocking on a file read. Belt-and-braces atomic _atomic_write_jsonlines on top is a good move for the cross-process case.
  • (2) Stale-file warning_read_persisted sums n_episodes and warns on mismatch. Documenting "still trusted, delete to recompute" is the right call. row.get("n_episodes", 0) defensive default is fine. Test test_stale_file_logs_warning_but_still_used + test_no_stale_warning_when_totals_match pin both directions.
  • (3) SPEED_BUCKET_SECONDSSPEED_BUCKET_STEP — confirmed via grep, fully renamed in envs/configs.py constant + docstring + validator message + field doc. No stragglers anywhere in the tree.
  • (Nit) Dead bucket_episode_length(..., None) branch — collapsed to direct SPARSE_TASK_BUCKET assignment. Reader can now see the sparse-bucket fallback in one place.
  • (Nit) Dual-sourcing of episode lengths — both paths now flow through self.episode_lengths via the new episode_to_task_index_from_episodes + _group_lengths_by_task helpers. Clean.

I'm fine with the leave-as-is decisions on _READONLY_WARNED (matches _CONTROL_MODE_WARNED / _SKIP_TIMESTAMP_WARNED) and schema versioning (not paying the toll until there's an evolution).

New issues from this round

A. CI is flaky on this commit — one CPU Tests run passed, the other failed. Look at https://github.com/TensorAuto/OpenTau/actions/runs/25831516459/job/75897290298 (failed) vs https://github.com/TensorAuto/OpenTau/actions/runs/25831515012/job/75897285365 (passed) — same SHA, started 3 s apart, opposite outcomes. Worth pulling the failing logs and confirming this isn't test_concurrent_writes_produce_valid_file itself (see B). At minimum, please get a clean green run before flipping out of draft.

B. _atomic_write_jsonlines has a tmp-path collision under contention.

def _atomic_write_jsonlines(rows: list[dict], path: Path) -> None:
    tmp_path = path.with_suffix(path.suffix + ".tmp")   # ← same name for every writer
    write_jsonlines(rows, tmp_path)
    os.replace(tmp_path, path)

All concurrent writers share speed_percentiles.jsonl.tmp. Possible interleaving (your own test_concurrent_writes_produce_valid_file is precisely the workload that can hit this):

  1. T_a calls write_jsonlines(rows, tmp) → opens "w" (truncates), starts writing.
  2. T_b enters the same path. jsonlines.open(tmp, "w") truncates again under T_a.
  3. Writes interleave; tmp ends up with mixed/garbage content.
  4. T_a's os.replace(tmp, path) succeeds → path now contains the garbage.
  5. T_b's os.replace(tmp, path)FileNotFoundError because step 4 moved it. (Caught by your (OSError, PermissionError) handler — silent, but the on-disk path is already corrupted.)
  6. Later thread reads path via _read_persistedload_jsonlines blows up on the malformed JSONL.

The Python GIL makes this unlikely for a single-row file (most work is held inside one bytecode chunk), but it's not impossible, and it gets more likely as the per-task row count grows. The concurrent test you added (8 workers, 4 threads) is exactly the right shape to expose it — and that's plausibly the source of the failing CPU Tests run in (A).

Fix is one line: make the tmp path unique per writer.

import uuid
tmp_path = path.with_suffix(f"{path.suffix}.{uuid.uuid4().hex}.tmp")

(os.getpid() + threading.get_ident() works too, but UUID is simpler and covers the "two processes" case for free.) Then tighten the test to also assert that path parses as well-formed JSONL after the storm (load_jsonlines(path) round-trip; you almost have this already).

C. Stale-file WARNING fires on every rank. _read_persisted is called from __init__ on every rank, and the logging.warning(...) has no rank guard. In an 8-GPU run a stale file produces 8 identical warning lines, which is the kind of thing that conditions everyone to ignore the warning. The local idiom is if not _SKIP_TIMESTAMP_WARNED and (acc is None or acc.is_main_process): (lerobot_dataset.py:1391). Cheap to add. Could just gate the logging.warning block on (acc is None or acc.is_main_process).

Minor

  • The function-level docstring of load_or_compute_speed_percentiles says "other ranks block on Accelerator.wait_for_everyone() until the file appears, then read it" — they actually return their in-memory percentiles rather than reading the just-written file. Functionally equivalent, but the doc is misleading. One-line fix in the docstring.
  • current_total = sum(len(v) for v in by_task.values()) — equivalent to len(episode_to_task_index), which would skip a defaultdict materialization step. Drive-by.

Tests

  • 39 cases — good. Coverage of the new code paths is solid. After fixing (B), bump the concurrent test to also do a load_jsonlines(tmp_path / SPEED_PERCENTILES_PATH) round-trip after the storm to catch the corruption mode I described above (rather than only checking task_index/n_episodes/no leftover .tmp).

To unblock: (B) is the only thing I'd treat as a real blocker — it's reachable from the test the commit itself adds. (A) needs investigation but might just be (B) manifesting. (C) and the docstring fix are tidy-ups.


Generated by Claude Code

Per re-review of #302:

(B) _atomic_write_jsonlines used a shared tmp path (`<file>.tmp`),
which two concurrent writers could truncate simultaneously before
either rename. Switch to a per-writer UUID-suffixed tmp
(`<file>.<uuid>.tmp`) so writers can't stomp each other; clean up
orphaned tmp files in the finally block. This is the root cause
of the failing CPU test on the prior commit — the new
test_concurrent_writes_produce_valid_file is exactly the workload
that hit it.

(C) Stale-file WARNING in _read_persisted now takes an explicit
`warn` parameter; load_or_compute_speed_percentiles gates it on
`acc.is_main_process` so an N-rank run produces one warning, not N.
Mirrors the _SKIP_TIMESTAMP_WARNED idiom at lerobot_dataset.py:1391.

Minor fixups:
- Fix the load_or_compute_speed_percentiles docstring: non-main
  ranks return their in-memory copy, they don't re-read the file.
- Replace `sum(len(v) for v in by_task.values())` with
  `len(episode_to_task_index)` — same number, no defaultdict
  materialization for the read-from-disk path.
- Tighten test_concurrent_writes_produce_valid_file: 5 tasks (not
  1) so a torn write would surface as a wrong row count or
  percentile mismatch, and assert the on-disk content equals the
  in-memory result.
@shuheng-liu
Copy link
Copy Markdown
Member Author

Thanks for catching the tmp-path collision — confirmed it's exactly what tanked the CPU Tests run. Pushed a7097f6 with all four items addressed:

(B) _atomic_write_jsonlines tmp-path collision — fixed (root cause of A)

Switched to a per-writer UUID-suffixed tmp path: path.with_suffix(f"{path.suffix}.{uuid.uuid4().hex}.tmp"). Now each concurrent writer has its own staging file; nothing to truncate-stomp. Wrapped in try/finally to unlink orphaned tmps if os.replace fails. Validated by running test_concurrent_writes_produce_valid_file 30× in a row — all pass; the original failure mode was 100% reproducible against the old shared-tmp implementation.

I also tightened the test per your suggestion: 5 distinct tasks (not 1) so a torn write would surface as a wrong row count or a percentile mismatch, plus an explicit equality check between the in-memory result and load_jsonlines(path) after the storm. Also expanded the leftover-tmp glob to match the new UUID-suffixed pattern.

(C) Stale-file WARNING fires per rank — fixed

_read_persisted now takes an explicit warn: bool; load_or_compute_speed_percentiles gates it on acc.is_main_process. Eight-rank run → one WARNING line. Mirrors the _SKIP_TIMESTAMP_WARNED idiom at lerobot_dataset.py:1391. The read-only WARNING in the compute branch was already main-only by virtue of being inside if is_main_or_solo:.

Minor fixups — done

  • Docstring — rewrote the distributed paragraph: "every rank computes the percentiles in-memory (the result is deterministic from the inputs), but only the main process writes the file; ... Each rank returns its own in-memory copy rather than re-reading the just-written file." No more "then read it" misleading.
  • current_total simplification — replaced sum(len(v) for v in by_task.values()) with len(episode_to_task_index). The read-from-disk path now skips the _group_lengths_by_task defaultdict materialization entirely.

CI status

  • test_concurrent_writes_produce_valid_file × 30 sequential runs locally → 30 pass.
  • Full CPU surface: pytest tests/datasets/ tests/envs/ tests/policies/test_pi07_low_level.py tests/policies/test_pi07_high_level_planner.py tests/policies/test_pi07_paligemma_low_level_planner.py tests/policies/test_pi07_paligemma_high_level_planner.py tests/scripts/test_attach_metadata.py -m "not gpu" --ignore=tests/envs/test_factory.py -n auto → 547 passed, 7 skipped, 0 failed.
  • Pre-commit clean.

Will watch the CPU Tests check on this commit before flipping out of draft.

Copy link
Copy Markdown
Member Author

@shuheng-liu shuheng-liu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-review of a7097f6

All three items from the previous round are addressed cleanly, and CI is fully green on this SHA (both CPU Tests runs ✅, pre-commit ✅, check-checklist ✅, review ✅) — no more flakiness, which independently confirms (B) was the root cause.

Verified ✓

  • (B) Tmp-path collisiontmp_path = path.with_suffix(f"{path.suffix}.{uuid.uuid4().hex}.tmp") gives each writer its own staging file; try/finally with contextlib.suppress(OSError) cleans orphans if anything raises. Nice belt-and-braces. The 30× local sequential run + green CI is a convincing repro→fix loop.
  • (C) Stale-warning rank-spam_read_persisted now takes a keyword-only warn: bool and load_or_compute_speed_percentiles threads is_main_or_solo in. Computed once at the top, reused for both the warn-gate and the write-gate — clean.
  • Docstring fix — the distributed paragraph now correctly says non-main ranks return their in-memory copy rather than re-reading. Accurate.
  • current_total simplification — replacing sum(len(v) for v in by_task.values()) with len(episode_to_task_index) is exact (since episode_to_task_index already drops empty-tasks episodes), and the read-from-disk fast path no longer needs to materialize _group_lengths_by_task at all. Nice micro-win.

Tightened test is the right shape

test_concurrent_writes_produce_valid_file now exercises 16 jobs through 8 threads across 5 distinct tasks, asserts from_disk == first (round-trip), full row count, and globs for any *.tmp leftover under the new naming. A torn write would now surface as a percentile mismatch in the round-trip equality check rather than silently passing — that's the right defensive shape.

Minor (non-blocking) observations

  • The finally-block cleanup uses tmp_path.exists() + tmp_path.unlink() — there's a TOCTOU between exists() and unlink() but it's not exploitable here (only this writer knows the UUID). Could simplify to with contextlib.suppress(FileNotFoundError, OSError): tmp_path.unlink() and skip the exists() check, but the current form is also fine.
  • len(episode_to_task_index) vs the on-disk sum(n_episodes): if a hand-edited percentile file omits or duplicates a task row, the warning will fire correctly, but the comparison is now strictly "row-count agreement" rather than "per-task agreement". That's a documented relaxation given the "delete to recompute" escape hatch — calling it out, not asking to change it.

LGTM — this looks ready to flip out of draft. Nothing else from me unless you spot something I missed.


Generated by Claude Code

@shuheng-liu shuheng-liu marked this pull request as ready for review May 14, 2026 01:56
@shuheng-liu shuheng-liu merged commit 758d473 into main May 14, 2026
7 checks passed
@shuheng-liu shuheng-liu deleted the claude/funny-bell-07f90c branch May 14, 2026 01:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant