diff --git a/src/opentau/datasets/lerobot_dataset.py b/src/opentau/datasets/lerobot_dataset.py index d4453d52..a117fe2d 100644 --- a/src/opentau/datasets/lerobot_dataset.py +++ b/src/opentau/datasets/lerobot_dataset.py @@ -216,6 +216,14 @@ def wrapped(self, idx): # instances within a single process (e.g., train + val constructed for the same repo). _CONTROL_MODE_WARNED: set[str] = set() +# ``skip_timestamp_check`` is a mixture-wide decision (with optional per-dataset +# override) that produces an identical warning for every dataset in the mixture. +# For a 392-dataset pretraining run on 8 ranks, the naive per-dataset emission +# floods the run log with ~3K identical lines. This flag makes the warning fire +# once per process; combined with the rank-0 gate at the call site, that's once +# per run rather than 8 × num_datasets. +_SKIP_TIMESTAMP_WARNED: bool = False + def suppress_control_mode_warning(repo_id: str) -> None: """Mark ``repo_id`` as already-warned so the missing-``control_mode`` warning @@ -1346,12 +1354,24 @@ def __init__( # Check timestamps # If transform is set, with_transform will decode all columns of a row before returning the desired column(s). if self.skip_timestamp_check: - logging.warning( - "Skipping timestamp sync check for %s (skip_timestamp_check=True). " - "Frame-to-frame spacing is NOT verified against 1/fps; downstream " - "delta_timestamps lookups may sample unintended frames.", - self.repo_id, - ) + # ``skip_timestamp_check`` is a mixture-wide decision and the + # message is identical for every dataset, so emit it once per + # process and only on the main rank. Naive per-dataset / per-rank + # logging floods the run log with ``num_processes`` × + # ``num_datasets`` copies of the same line (392 × 8 ≈ 3K for a + # wide pretraining mixture). Falls through to logging when no + # Accelerator is set (single-process dev / tests). + global _SKIP_TIMESTAMP_WARNED + acc = get_proc_accelerator() + if not _SKIP_TIMESTAMP_WARNED and (acc is None or acc.is_main_process): + _SKIP_TIMESTAMP_WARNED = True + logging.warning( + "skip_timestamp_check=True is in effect for one or more " + "datasets in this mixture (e.g. %s). Frame-to-frame " + "spacing is NOT verified against 1/fps; downstream " + "delta_timestamps lookups may sample unintended frames.", + self.repo_id, + ) else: no_transform_ds = self.hf_dataset.with_transform(None).with_format("numpy") logging.info("Checking timestamps synchronization...") diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index f79bd3ff..bb987ad3 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -1027,6 +1027,35 @@ def test_control_mode_warning_emitted_once_per_repo(tmp_path, lerobot_dataset_fa assert len(matching) == 1 +def test_skip_timestamp_warning_emitted_once_per_process(tmp_path, lerobot_dataset_factory, caplog): + """`skip_timestamp_check=True` warns exactly once per process across multiple + LeRobotDataset instances — locks in the `_SKIP_TIMESTAMP_WARNED` dedup so a + heterogeneous mixture of N datasets emits 1 line, not N.""" + from opentau.datasets import lerobot_dataset as ld_mod + + original = ld_mod._SKIP_TIMESTAMP_WARNED + ld_mod._SKIP_TIMESTAMP_WARNED = False + try: + with caplog.at_level(logging.WARNING): + ds_a = lerobot_dataset_factory( + root=tmp_path / "skip_warn_a", + repo_id="warn-once/skip-ts-a", + skip_timestamp_check=True, + ) + ds_b = lerobot_dataset_factory( + root=tmp_path / "skip_warn_b", + repo_id="warn-once/skip-ts-b", + skip_timestamp_check=True, + ) + + assert ds_a.skip_timestamp_check is True + assert ds_b.skip_timestamp_check is True + matching = [r for r in caplog.records if "skip_timestamp_check=True" in r.getMessage()] + assert len(matching) == 1 + finally: + ld_mod._SKIP_TIMESTAMP_WARNED = original + + def test_robot_type_and_control_mode_in_meta_info(tmp_path, lerobot_dataset_factory, info_factory): """robot_type and control_mode are surfaced as optional fields in the standard data format. Verify the underlying meta/info.json values that