diff --git a/src/opentau/datasets/dataset_mixture.py b/src/opentau/datasets/dataset_mixture.py index ceadab7c..c2899c1c 100644 --- a/src/opentau/datasets/dataset_mixture.py +++ b/src/opentau/datasets/dataset_mixture.py @@ -62,6 +62,7 @@ import functools import logging +from collections import Counter from typing import List, Optional import numpy as np @@ -363,10 +364,11 @@ def _make_dataset_names(cfg: TrainPipelineConfig, datasets: List[BaseDataset]) - """Derive human-readable names for each dataset in the mixture. Uses each ``DatasetConfig``'s ``repo_id`` or ``vqa`` identifier when - available (ordered to match ``cfg.dataset_mixture.datasets``). Duplicate - identifiers are disambiguated with a ``#`` suffix. Falls back to the - dataset class name plus index when the config list cannot be lined up - with ``datasets`` (e.g. in tests that construct a mixture directly). + available (ordered to match ``cfg.dataset_mixture.datasets``). Duplicates + get a per-name sequential ``#`` suffix (so ``['A','B','A']`` becomes + ``['A#0','B','A#1']``). Falls back to the dataset class name plus index + when the config list cannot be lined up with ``datasets`` (e.g. in tests + that construct a mixture directly). """ dataset_cfgs = getattr(getattr(cfg, "dataset_mixture", None), "datasets", None) if dataset_cfgs is None or len(dataset_cfgs) != len(datasets): @@ -375,7 +377,17 @@ def _make_dataset_names(cfg: TrainPipelineConfig, datasets: List[BaseDataset]) - raw_names = [ (dc.repo_id or dc.vqa or type(ds).__name__) for dc, ds in zip(dataset_cfgs, datasets, strict=True) ] - return [f"{name}#{idx}" if raw_names.count(name) > 1 else name for idx, name in enumerate(raw_names)] + counts = Counter(raw_names) + seen: dict[str, int] = {} + out: list[str] = [] + for name in raw_names: + if counts[name] > 1: + i = seen.get(name, 0) + out.append(f"{name}#{i}") + seen[name] = i + 1 + else: + out.append(name) + return out def _log_dataset_info(self) -> None: """Log information about all datasets in the mixture.""" diff --git a/src/opentau/scripts/train.py b/src/opentau/scripts/train.py index 8bf99eda..09bbfd51 100644 --- a/src/opentau/scripts/train.py +++ b/src/opentau/scripts/train.py @@ -118,6 +118,47 @@ def update_policy( return train_metrics +_VAL_METRIC_KEYS: tuple[str, ...] = ("loss", "mse_loss", "ce_loss", "l1_loss", "accuracy") + + +def _mixture_weighted_aggregate( + per_dataset_trackers: dict[str, MetricsTracker], + name_to_weight: dict[str, float], + metric_keys: tuple[str, ...] = _VAL_METRIC_KEYS, +) -> dict[str, float]: + """Mixture-weighted average of per-dataset validation metrics. + + Weights are taken from ``name_to_weight`` and renormalized over only the + names present in ``per_dataset_trackers`` (empty datasets are skipped + upstream by ``WeightedDatasetMixture.get_per_dataset_dataloaders`` and so + will be missing from the trackers). When the renormalization total is 0 + -- empty trackers, or all selected datasets have weight 0 -- every metric + is returned as ``0.0``. + + Args: + per_dataset_trackers: One ``MetricsTracker`` per non-empty validation + dataset, keyed by dataset name. + name_to_weight: Mapping from dataset name to its mixture weight (need + not be normalized; need not be a strict subset/superset of the + tracker keys, but must contain every tracker key). + metric_keys: The metric attribute names to aggregate. + + Returns: + Dict mapping each ``metric_keys`` entry to its weighted average. + """ + weights = {name: name_to_weight[name] for name in per_dataset_trackers} + total = sum(weights.values()) + if total <= 0: + return dict.fromkeys(metric_keys, 0.0) + + per_dataset_dicts = { + name: tracker.to_dict(use_avg=True) for name, tracker in per_dataset_trackers.items() + } + return { + k: sum((w / total) * per_dataset_dicts[name][k] for name, w in weights.items()) for k in metric_keys + } + + def _find_unused_params_from_env() -> bool: """Parse the ``FIND_UNUSED_PARAMS`` env var into a bool. @@ -359,7 +400,6 @@ def _make_val_tracker(current_step: int = step) -> MetricsTracker: initial_step=current_step, ) - agg_tracker = _make_val_tracker() per_dataset_trackers: dict[str, MetricsTracker] = { name: _make_val_tracker() for name in per_dataset_val_dataloaders } @@ -408,22 +448,13 @@ def _make_val_tracker(current_step: int = step) -> MetricsTracker: ) if accelerator.is_main_process: - for tracker in (ds_tracker, agg_tracker): - tracker.loss = loss - tracker.mse_loss = mse_loss - tracker.ce_loss = ce_loss - tracker.l1_loss = l1_loss - tracker.accuracy = accuracy + ds_tracker.loss = loss + ds_tracker.mse_loss = mse_loss + ds_tracker.ce_loss = ce_loss + ds_tracker.l1_loss = l1_loss + ds_tracker.accuracy = accuracy if accelerator.is_main_process: - logging.info(f"Validation/aggregate {agg_tracker}") - agg_dict = agg_tracker.to_dict(use_avg=True) - accelerator.log({"Validation/Loss": agg_dict["loss"]}, step=step) - accelerator.log({"Validation/MSE Loss": agg_dict["mse_loss"]}, step=step) - accelerator.log({"Validation/CE Loss": agg_dict["ce_loss"]}, step=step) - accelerator.log({"Validation/L1 Loss": agg_dict["l1_loss"]}, step=step) - accelerator.log({"Validation/Accuracy": agg_dict["accuracy"]}, step=step) - for ds_name, ds_tracker in per_dataset_trackers.items(): logging.info(f"Validation/{ds_name} {ds_tracker}") ds_dict = ds_tracker.to_dict(use_avg=True) @@ -433,6 +464,20 @@ def _make_val_tracker(current_step: int = step) -> MetricsTracker: accelerator.log({f"Validation/{ds_name}/L1 Loss": ds_dict["l1_loss"]}, step=step) accelerator.log({f"Validation/{ds_name}/Accuracy": ds_dict["accuracy"]}, step=step) + # Mixture-weighted aggregate across the per-dataset trackers, so the + # overall scalar reflects the training mixture rather than being + # implicitly dominated by whichever val subset has the most batches. + name_to_weight = dict( + zip(val_dataset.dataset_names, val_dataset.dataset_weights, strict=True) + ) + agg = _mixture_weighted_aggregate(per_dataset_trackers, name_to_weight) + logging.info(f"Validation/aggregate {agg}") + accelerator.log({"Validation/Loss": agg["loss"]}, step=step) + accelerator.log({"Validation/MSE Loss": agg["mse_loss"]}, step=step) + accelerator.log({"Validation/CE Loss": agg["ce_loss"]}, step=step) + accelerator.log({"Validation/L1 Loss": agg["l1_loss"]}, step=step) + accelerator.log({"Validation/Accuracy": agg["accuracy"]}, step=step) + # This barrier is probably necessary to ensure # other processes wait for the main process to finish saving accelerator.wait_for_everyone() diff --git a/tests/datasets/test_dataset_mixture.py b/tests/datasets/test_dataset_mixture.py index 07d0e1d1..a7c8b80c 100644 --- a/tests/datasets/test_dataset_mixture.py +++ b/tests/datasets/test_dataset_mixture.py @@ -16,6 +16,7 @@ # limitations under the License. import logging +from types import SimpleNamespace from unittest.mock import patch import numpy as np @@ -539,6 +540,78 @@ def test_integration_logging_behavior(self, train_pipeline_config, datasets_fact assert "DataLoader created successfully" in caplog.text +class TestMakeDatasetNames: + """Tests for ``WeightedDatasetMixture._make_dataset_names``. + + The static method does not need a real ``TrainPipelineConfig`` / + ``DatasetConfig`` -- it only reads ``cfg.dataset_mixture.datasets`` and the + ``repo_id`` / ``vqa`` attributes -- so we use ``SimpleNamespace`` stubs to + keep these tests focused on the naming logic. + """ + + @staticmethod + def _cfg(*ds_cfgs: SimpleNamespace) -> SimpleNamespace: + return SimpleNamespace(dataset_mixture=SimpleNamespace(datasets=list(ds_cfgs))) + + @staticmethod + def _ds_cfg(repo_id: str | None = None, vqa: str | None = None) -> SimpleNamespace: + return SimpleNamespace(repo_id=repo_id, vqa=vqa) + + class _Stub: + """Minimal stand-in for ``BaseDataset`` -- we only read ``type(ds).__name__``.""" + + def test_all_unique_repo_ids(self): + cfg = self._cfg(self._ds_cfg("lerobot/a"), self._ds_cfg("lerobot/b")) + datasets = [self._Stub(), self._Stub()] + assert WeightedDatasetMixture._make_dataset_names(cfg, datasets) == ["lerobot/a", "lerobot/b"] + + def test_repeated_repo_id_uses_per_name_sequential_suffix(self): + # The previous implementation used the GLOBAL index (yielding + # ['A#0', 'B', 'A#2']); this test pins the per-name sequential behaviour. + cfg = self._cfg( + self._ds_cfg("lerobot/a"), + self._ds_cfg("lerobot/b"), + self._ds_cfg("lerobot/a"), + ) + datasets = [self._Stub(), self._Stub(), self._Stub()] + names = WeightedDatasetMixture._make_dataset_names(cfg, datasets) + assert names == ["lerobot/a#0", "lerobot/b", "lerobot/a#1"] + + def test_all_identical(self): + cfg = self._cfg(self._ds_cfg("x"), self._ds_cfg("x"), self._ds_cfg("x")) + datasets = [self._Stub(), self._Stub(), self._Stub()] + assert WeightedDatasetMixture._make_dataset_names(cfg, datasets) == ["x#0", "x#1", "x#2"] + + def test_vqa_and_repo_id_mix(self): + cfg = self._cfg( + self._ds_cfg(repo_id="lerobot/a"), + self._ds_cfg(vqa="vqa-set-1"), + self._ds_cfg(vqa="vqa-set-1"), + ) + datasets = [self._Stub(), self._Stub(), self._Stub()] + names = WeightedDatasetMixture._make_dataset_names(cfg, datasets) + assert names == ["lerobot/a", "vqa-set-1#0", "vqa-set-1#1"] + + def test_neither_repo_id_nor_vqa_falls_back_to_classname(self): + cfg = self._cfg(self._ds_cfg(), self._ds_cfg()) + datasets = [self._Stub(), self._Stub()] + names = WeightedDatasetMixture._make_dataset_names(cfg, datasets) + # Both share the class name, so they collide and get suffixed. + assert names == ["_Stub#0", "_Stub#1"] + + def test_mismatched_cfg_length_falls_back_to_classname_index(self): + cfg = self._cfg(self._ds_cfg("lerobot/a")) + datasets = [self._Stub(), self._Stub()] + names = WeightedDatasetMixture._make_dataset_names(cfg, datasets) + assert names == ["_Stub_0", "_Stub_1"] + + def test_no_dataset_mixture_attribute_falls_back_to_classname_index(self): + cfg = SimpleNamespace() # no .dataset_mixture + datasets = [self._Stub(), self._Stub()] + names = WeightedDatasetMixture._make_dataset_names(cfg, datasets) + assert names == ["_Stub_0", "_Stub_1"] + + class TestDatasetMixtureOptionalKeyDropProbs: """Tests for the optional-key dropout probability fields on DatasetMixtureConfig.""" diff --git a/tests/scripts/test_train.py b/tests/scripts/test_train.py index bc1751fa..de5f0b0f 100644 --- a/tests/scripts/test_train.py +++ b/tests/scripts/test_train.py @@ -24,7 +24,8 @@ import pytest -from opentau.scripts.train import _find_unused_params_from_env +from opentau.scripts.train import _find_unused_params_from_env, _mixture_weighted_aggregate +from opentau.utils.logging_utils import AverageMeter, MetricsTracker class TestFindUnusedParamsFromEnv: @@ -69,5 +70,83 @@ def test_unknown_values_parse_as_false(self, monkeypatch): assert _find_unused_params_from_env() is False, f"Expected {value!r} to parse as False, got True" +def _make_tracker( + loss: float, mse: float = 0.0, ce: float = 0.0, l1: float = 0.0, acc: float = 0.0 +) -> MetricsTracker: + """Build a ``MetricsTracker`` with one update per metric. + + A single assignment per attribute means each ``AverageMeter`` ends up with + ``avg == val``, which is exactly what we need to exercise the weighted + aggregation in isolation. + """ + tracker = MetricsTracker( + batch_size=8, + metrics={ + "loss": AverageMeter("val_total_loss", ":.3f"), + "mse_loss": AverageMeter("val_mse_loss", ":.3f"), + "ce_loss": AverageMeter("val_ce_loss", ":.3f"), + "l1_loss": AverageMeter("val_l1_loss", ":.3f"), + "accuracy": AverageMeter("val_accuracy", ":.3f"), + }, + ) + tracker.loss = loss + tracker.mse_loss = mse + tracker.ce_loss = ce + tracker.l1_loss = l1 + tracker.accuracy = acc + return tracker + + +class TestMixtureWeightedAggregate: + """``_mixture_weighted_aggregate`` collapses per-dataset trackers using mixture weights.""" + + def test_equal_weights_is_simple_mean(self): + trackers = {"a": _make_tracker(loss=1.0), "b": _make_tracker(loss=3.0)} + weights = {"a": 1.0, "b": 1.0} + agg = _mixture_weighted_aggregate(trackers, weights) + assert agg["loss"] == pytest.approx(2.0) + + def test_unequal_weights_renormalize(self): + # weights [3, 1] -> 0.75 * 1.0 + 0.25 * 5.0 = 2.0 + trackers = {"a": _make_tracker(loss=1.0), "b": _make_tracker(loss=5.0)} + weights = {"a": 3.0, "b": 1.0} + agg = _mixture_weighted_aggregate(trackers, weights) + assert agg["loss"] == pytest.approx(2.0) + + def test_renormalizes_over_present_keys_only(self): + # ``name_to_weight`` includes a name that is missing from + # ``per_dataset_trackers`` (e.g. an empty val subset). The aggregate + # should ignore it and renormalize over the present keys. + trackers = {"a": _make_tracker(loss=1.0), "b": _make_tracker(loss=3.0)} + weights = {"a": 1.0, "b": 1.0, "c_empty": 100.0} + agg = _mixture_weighted_aggregate(trackers, weights) + assert agg["loss"] == pytest.approx(2.0) + + def test_aggregates_all_metric_keys(self): + trackers = { + "a": _make_tracker(loss=1.0, mse=2.0, ce=3.0, l1=4.0, acc=0.1), + "b": _make_tracker(loss=5.0, mse=6.0, ce=7.0, l1=8.0, acc=0.5), + } + weights = {"a": 1.0, "b": 3.0} + agg = _mixture_weighted_aggregate(trackers, weights) + # 0.25 * a + 0.75 * b + assert agg["loss"] == pytest.approx(0.25 * 1.0 + 0.75 * 5.0) + assert agg["mse_loss"] == pytest.approx(0.25 * 2.0 + 0.75 * 6.0) + assert agg["ce_loss"] == pytest.approx(0.25 * 3.0 + 0.75 * 7.0) + assert agg["l1_loss"] == pytest.approx(0.25 * 4.0 + 0.75 * 8.0) + assert agg["accuracy"] == pytest.approx(0.25 * 0.1 + 0.75 * 0.5) + + def test_empty_trackers_returns_zeros(self): + agg = _mixture_weighted_aggregate({}, {}) + assert agg == {"loss": 0.0, "mse_loss": 0.0, "ce_loss": 0.0, "l1_loss": 0.0, "accuracy": 0.0} + + def test_all_zero_weights_returns_zeros(self): + trackers = {"a": _make_tracker(loss=1.0), "b": _make_tracker(loss=2.0)} + weights = {"a": 0.0, "b": 0.0} + agg = _mixture_weighted_aggregate(trackers, weights) + # Avoids div-by-zero; behaviour matches "no signal to average". + assert agg["loss"] == 0.0 + + if __name__ == "__main__": pytest.main([__file__, "-v"])