Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions src/opentau/datasets/dataset_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@

import functools
import logging
from collections import Counter
from typing import List, Optional

import numpy as np
Expand Down Expand Up @@ -363,10 +364,11 @@ def _make_dataset_names(cfg: TrainPipelineConfig, datasets: List[BaseDataset]) -
"""Derive human-readable names for each dataset in the mixture.

Uses each ``DatasetConfig``'s ``repo_id`` or ``vqa`` identifier when
available (ordered to match ``cfg.dataset_mixture.datasets``). Duplicate
identifiers are disambiguated with a ``#<idx>`` suffix. Falls back to the
dataset class name plus index when the config list cannot be lined up
with ``datasets`` (e.g. in tests that construct a mixture directly).
available (ordered to match ``cfg.dataset_mixture.datasets``). Duplicates
get a per-name sequential ``#<i>`` suffix (so ``['A','B','A']`` becomes
``['A#0','B','A#1']``). Falls back to the dataset class name plus index
when the config list cannot be lined up with ``datasets`` (e.g. in tests
that construct a mixture directly).
"""
dataset_cfgs = getattr(getattr(cfg, "dataset_mixture", None), "datasets", None)
if dataset_cfgs is None or len(dataset_cfgs) != len(datasets):
Expand All @@ -375,7 +377,17 @@ def _make_dataset_names(cfg: TrainPipelineConfig, datasets: List[BaseDataset]) -
raw_names = [
(dc.repo_id or dc.vqa or type(ds).__name__) for dc, ds in zip(dataset_cfgs, datasets, strict=True)
]
return [f"{name}#{idx}" if raw_names.count(name) > 1 else name for idx, name in enumerate(raw_names)]
counts = Counter(raw_names)
seen: dict[str, int] = {}
out: list[str] = []
for name in raw_names:
if counts[name] > 1:
i = seen.get(name, 0)
out.append(f"{name}#{i}")
seen[name] = i + 1
else:
out.append(name)
return out

def _log_dataset_info(self) -> None:
"""Log information about all datasets in the mixture."""
Expand Down
75 changes: 60 additions & 15 deletions src/opentau/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,47 @@ def update_policy(
return train_metrics


_VAL_METRIC_KEYS: tuple[str, ...] = ("loss", "mse_loss", "ce_loss", "l1_loss", "accuracy")


def _mixture_weighted_aggregate(
per_dataset_trackers: dict[str, MetricsTracker],
name_to_weight: dict[str, float],
metric_keys: tuple[str, ...] = _VAL_METRIC_KEYS,
) -> dict[str, float]:
"""Mixture-weighted average of per-dataset validation metrics.

Weights are taken from ``name_to_weight`` and renormalized over only the
names present in ``per_dataset_trackers`` (empty datasets are skipped
upstream by ``WeightedDatasetMixture.get_per_dataset_dataloaders`` and so
will be missing from the trackers). When the renormalization total is 0
-- empty trackers, or all selected datasets have weight 0 -- every metric
is returned as ``0.0``.

Args:
per_dataset_trackers: One ``MetricsTracker`` per non-empty validation
dataset, keyed by dataset name.
name_to_weight: Mapping from dataset name to its mixture weight (need
not be normalized; need not be a strict subset/superset of the
tracker keys, but must contain every tracker key).
metric_keys: The metric attribute names to aggregate.

Returns:
Dict mapping each ``metric_keys`` entry to its weighted average.
"""
weights = {name: name_to_weight[name] for name in per_dataset_trackers}
total = sum(weights.values())
if total <= 0:
return dict.fromkeys(metric_keys, 0.0)

per_dataset_dicts = {
name: tracker.to_dict(use_avg=True) for name, tracker in per_dataset_trackers.items()
}
return {
k: sum((w / total) * per_dataset_dicts[name][k] for name, w in weights.items()) for k in metric_keys
}


def _find_unused_params_from_env() -> bool:
"""Parse the ``FIND_UNUSED_PARAMS`` env var into a bool.

Expand Down Expand Up @@ -359,7 +400,6 @@ def _make_val_tracker(current_step: int = step) -> MetricsTracker:
initial_step=current_step,
)

agg_tracker = _make_val_tracker()
per_dataset_trackers: dict[str, MetricsTracker] = {
name: _make_val_tracker() for name in per_dataset_val_dataloaders
}
Expand Down Expand Up @@ -408,22 +448,13 @@ def _make_val_tracker(current_step: int = step) -> MetricsTracker:
)

if accelerator.is_main_process:
for tracker in (ds_tracker, agg_tracker):
tracker.loss = loss
tracker.mse_loss = mse_loss
tracker.ce_loss = ce_loss
tracker.l1_loss = l1_loss
tracker.accuracy = accuracy
ds_tracker.loss = loss
ds_tracker.mse_loss = mse_loss
ds_tracker.ce_loss = ce_loss
ds_tracker.l1_loss = l1_loss
ds_tracker.accuracy = accuracy

if accelerator.is_main_process:
logging.info(f"Validation/aggregate {agg_tracker}")
agg_dict = agg_tracker.to_dict(use_avg=True)
accelerator.log({"Validation/Loss": agg_dict["loss"]}, step=step)
accelerator.log({"Validation/MSE Loss": agg_dict["mse_loss"]}, step=step)
accelerator.log({"Validation/CE Loss": agg_dict["ce_loss"]}, step=step)
accelerator.log({"Validation/L1 Loss": agg_dict["l1_loss"]}, step=step)
accelerator.log({"Validation/Accuracy": agg_dict["accuracy"]}, step=step)

for ds_name, ds_tracker in per_dataset_trackers.items():
logging.info(f"Validation/{ds_name} {ds_tracker}")
ds_dict = ds_tracker.to_dict(use_avg=True)
Expand All @@ -433,6 +464,20 @@ def _make_val_tracker(current_step: int = step) -> MetricsTracker:
accelerator.log({f"Validation/{ds_name}/L1 Loss": ds_dict["l1_loss"]}, step=step)
accelerator.log({f"Validation/{ds_name}/Accuracy": ds_dict["accuracy"]}, step=step)

# Mixture-weighted aggregate across the per-dataset trackers, so the
# overall scalar reflects the training mixture rather than being
# implicitly dominated by whichever val subset has the most batches.
name_to_weight = dict(
zip(val_dataset.dataset_names, val_dataset.dataset_weights, strict=True)
)
agg = _mixture_weighted_aggregate(per_dataset_trackers, name_to_weight)
logging.info(f"Validation/aggregate {agg}")
accelerator.log({"Validation/Loss": agg["loss"]}, step=step)
accelerator.log({"Validation/MSE Loss": agg["mse_loss"]}, step=step)
accelerator.log({"Validation/CE Loss": agg["ce_loss"]}, step=step)
accelerator.log({"Validation/L1 Loss": agg["l1_loss"]}, step=step)
accelerator.log({"Validation/Accuracy": agg["accuracy"]}, step=step)

# This barrier is probably necessary to ensure
# other processes wait for the main process to finish saving
accelerator.wait_for_everyone()
Expand Down
73 changes: 73 additions & 0 deletions tests/datasets/test_dataset_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# limitations under the License.

import logging
from types import SimpleNamespace
from unittest.mock import patch

import numpy as np
Expand Down Expand Up @@ -539,6 +540,78 @@ def test_integration_logging_behavior(self, train_pipeline_config, datasets_fact
assert "DataLoader created successfully" in caplog.text


class TestMakeDatasetNames:
"""Tests for ``WeightedDatasetMixture._make_dataset_names``.

The static method does not need a real ``TrainPipelineConfig`` /
``DatasetConfig`` -- it only reads ``cfg.dataset_mixture.datasets`` and the
``repo_id`` / ``vqa`` attributes -- so we use ``SimpleNamespace`` stubs to
keep these tests focused on the naming logic.
"""

@staticmethod
def _cfg(*ds_cfgs: SimpleNamespace) -> SimpleNamespace:
return SimpleNamespace(dataset_mixture=SimpleNamespace(datasets=list(ds_cfgs)))

@staticmethod
def _ds_cfg(repo_id: str | None = None, vqa: str | None = None) -> SimpleNamespace:
return SimpleNamespace(repo_id=repo_id, vqa=vqa)

class _Stub:
"""Minimal stand-in for ``BaseDataset`` -- we only read ``type(ds).__name__``."""

def test_all_unique_repo_ids(self):
cfg = self._cfg(self._ds_cfg("lerobot/a"), self._ds_cfg("lerobot/b"))
datasets = [self._Stub(), self._Stub()]
assert WeightedDatasetMixture._make_dataset_names(cfg, datasets) == ["lerobot/a", "lerobot/b"]

def test_repeated_repo_id_uses_per_name_sequential_suffix(self):
# The previous implementation used the GLOBAL index (yielding
# ['A#0', 'B', 'A#2']); this test pins the per-name sequential behaviour.
cfg = self._cfg(
self._ds_cfg("lerobot/a"),
self._ds_cfg("lerobot/b"),
self._ds_cfg("lerobot/a"),
)
datasets = [self._Stub(), self._Stub(), self._Stub()]
names = WeightedDatasetMixture._make_dataset_names(cfg, datasets)
assert names == ["lerobot/a#0", "lerobot/b", "lerobot/a#1"]

def test_all_identical(self):
cfg = self._cfg(self._ds_cfg("x"), self._ds_cfg("x"), self._ds_cfg("x"))
datasets = [self._Stub(), self._Stub(), self._Stub()]
assert WeightedDatasetMixture._make_dataset_names(cfg, datasets) == ["x#0", "x#1", "x#2"]

def test_vqa_and_repo_id_mix(self):
cfg = self._cfg(
self._ds_cfg(repo_id="lerobot/a"),
self._ds_cfg(vqa="vqa-set-1"),
self._ds_cfg(vqa="vqa-set-1"),
)
datasets = [self._Stub(), self._Stub(), self._Stub()]
names = WeightedDatasetMixture._make_dataset_names(cfg, datasets)
assert names == ["lerobot/a", "vqa-set-1#0", "vqa-set-1#1"]

def test_neither_repo_id_nor_vqa_falls_back_to_classname(self):
cfg = self._cfg(self._ds_cfg(), self._ds_cfg())
datasets = [self._Stub(), self._Stub()]
names = WeightedDatasetMixture._make_dataset_names(cfg, datasets)
# Both share the class name, so they collide and get suffixed.
assert names == ["_Stub#0", "_Stub#1"]

def test_mismatched_cfg_length_falls_back_to_classname_index(self):
cfg = self._cfg(self._ds_cfg("lerobot/a"))
datasets = [self._Stub(), self._Stub()]
names = WeightedDatasetMixture._make_dataset_names(cfg, datasets)
assert names == ["_Stub_0", "_Stub_1"]

def test_no_dataset_mixture_attribute_falls_back_to_classname_index(self):
cfg = SimpleNamespace() # no .dataset_mixture
datasets = [self._Stub(), self._Stub()]
names = WeightedDatasetMixture._make_dataset_names(cfg, datasets)
assert names == ["_Stub_0", "_Stub_1"]


class TestDatasetMixtureOptionalKeyDropProbs:
"""Tests for the optional-key dropout probability fields on DatasetMixtureConfig."""

Expand Down
81 changes: 80 additions & 1 deletion tests/scripts/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

import pytest

from opentau.scripts.train import _find_unused_params_from_env
from opentau.scripts.train import _find_unused_params_from_env, _mixture_weighted_aggregate
from opentau.utils.logging_utils import AverageMeter, MetricsTracker


class TestFindUnusedParamsFromEnv:
Expand Down Expand Up @@ -69,5 +70,83 @@ def test_unknown_values_parse_as_false(self, monkeypatch):
assert _find_unused_params_from_env() is False, f"Expected {value!r} to parse as False, got True"


def _make_tracker(
loss: float, mse: float = 0.0, ce: float = 0.0, l1: float = 0.0, acc: float = 0.0
) -> MetricsTracker:
"""Build a ``MetricsTracker`` with one update per metric.

A single assignment per attribute means each ``AverageMeter`` ends up with
``avg == val``, which is exactly what we need to exercise the weighted
aggregation in isolation.
"""
tracker = MetricsTracker(
batch_size=8,
metrics={
"loss": AverageMeter("val_total_loss", ":.3f"),
"mse_loss": AverageMeter("val_mse_loss", ":.3f"),
"ce_loss": AverageMeter("val_ce_loss", ":.3f"),
"l1_loss": AverageMeter("val_l1_loss", ":.3f"),
"accuracy": AverageMeter("val_accuracy", ":.3f"),
},
)
tracker.loss = loss
tracker.mse_loss = mse
tracker.ce_loss = ce
tracker.l1_loss = l1
tracker.accuracy = acc
return tracker


class TestMixtureWeightedAggregate:
"""``_mixture_weighted_aggregate`` collapses per-dataset trackers using mixture weights."""

def test_equal_weights_is_simple_mean(self):
trackers = {"a": _make_tracker(loss=1.0), "b": _make_tracker(loss=3.0)}
weights = {"a": 1.0, "b": 1.0}
agg = _mixture_weighted_aggregate(trackers, weights)
assert agg["loss"] == pytest.approx(2.0)

def test_unequal_weights_renormalize(self):
# weights [3, 1] -> 0.75 * 1.0 + 0.25 * 5.0 = 2.0
trackers = {"a": _make_tracker(loss=1.0), "b": _make_tracker(loss=5.0)}
weights = {"a": 3.0, "b": 1.0}
agg = _mixture_weighted_aggregate(trackers, weights)
assert agg["loss"] == pytest.approx(2.0)

def test_renormalizes_over_present_keys_only(self):
# ``name_to_weight`` includes a name that is missing from
# ``per_dataset_trackers`` (e.g. an empty val subset). The aggregate
# should ignore it and renormalize over the present keys.
trackers = {"a": _make_tracker(loss=1.0), "b": _make_tracker(loss=3.0)}
weights = {"a": 1.0, "b": 1.0, "c_empty": 100.0}
agg = _mixture_weighted_aggregate(trackers, weights)
assert agg["loss"] == pytest.approx(2.0)

def test_aggregates_all_metric_keys(self):
trackers = {
"a": _make_tracker(loss=1.0, mse=2.0, ce=3.0, l1=4.0, acc=0.1),
"b": _make_tracker(loss=5.0, mse=6.0, ce=7.0, l1=8.0, acc=0.5),
}
weights = {"a": 1.0, "b": 3.0}
agg = _mixture_weighted_aggregate(trackers, weights)
# 0.25 * a + 0.75 * b
assert agg["loss"] == pytest.approx(0.25 * 1.0 + 0.75 * 5.0)
assert agg["mse_loss"] == pytest.approx(0.25 * 2.0 + 0.75 * 6.0)
assert agg["ce_loss"] == pytest.approx(0.25 * 3.0 + 0.75 * 7.0)
assert agg["l1_loss"] == pytest.approx(0.25 * 4.0 + 0.75 * 8.0)
assert agg["accuracy"] == pytest.approx(0.25 * 0.1 + 0.75 * 0.5)

def test_empty_trackers_returns_zeros(self):
agg = _mixture_weighted_aggregate({}, {})
assert agg == {"loss": 0.0, "mse_loss": 0.0, "ce_loss": 0.0, "l1_loss": 0.0, "accuracy": 0.0}

def test_all_zero_weights_returns_zeros(self):
trackers = {"a": _make_tracker(loss=1.0), "b": _make_tracker(loss=2.0)}
weights = {"a": 0.0, "b": 0.0}
agg = _mixture_weighted_aggregate(trackers, weights)
# Avoids div-by-zero; behaviour matches "no signal to average".
assert agg["loss"] == 0.0


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading