Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/opentau/configs/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,16 @@ class EvalConfig:

recording_root: str | None = None

# Which training-time dataset's normalization stats to use when calling
# `policy.select_action` on eval observations. ``None`` (default) falls
# through to the policy's `_resolve_dataset_index` single-dataset fallback
# (works for any policy trained on exactly one dataset). Set this to one
# of the strings in `policy.config.dataset_names` when running eval
# against a multi-dataset checkpoint, otherwise the inference call will
# raise a `KeyError`. Plumbed into each `select_action` call by
# `scripts/eval.py::rollout`.
dataset_repo_id: str | None = None

def __post_init__(self):
"""Validate evaluation configuration."""
if self.batch_size > self.n_episodes:
Expand Down
7 changes: 7 additions & 0 deletions src/opentau/configs/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ class ServerConfig:
max_workers: int = 4
max_send_message_length_mb: int = 100
max_receive_message_length_mb: int = 100
# Which training-time dataset's normalization stats to use for inference
# requests. ``None`` (default) falls through to the policy's
# `_resolve_dataset_index` single-dataset fallback (works for any
# checkpoint trained on exactly one dataset). Set this to one of the
# strings in `policy.config.dataset_names` when serving a multi-dataset
# checkpoint, otherwise the inference call will raise `KeyError`.
dataset_repo_id: str | None = None

def __post_init__(self):
"""Validate server configuration parameters."""
Expand Down
13 changes: 13 additions & 0 deletions src/opentau/configs/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,19 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
use_amp: bool = False
pretrained_path: str | None = None

# When False, `_save_pretrained` strips normalize_*.buffer_* / unnormalize_*.buffer_*
# keys from the state_dict before writing model.safetensors. Reloading then requires
# the caller to pass `ds_meta=` (or `stats=`) to `make_policy` so the buffers can be
# repopulated; otherwise the inf-init assertion fires at first forward.
save_normalization_stats: bool = True

# Ordered list of dataset names this policy was trained on. Used by the per-sample
# Normalize/Unnormalize indexing path to map an inference-time
# `batch["dataset_repo_id"]` (str) into the leading dim of the stacked stats
# buffers. `None` only for policies constructed outside the standard
# `make_policy(ds_meta=...)` path (e.g. legacy single-stats fallbacks).
dataset_names: list[str] | None = None

# Deprecated: latency fields are no longer used. Kept for backward-compatible
# loading of old JSON configs. Must remain 0.0; non-zero values will raise.
cloud_vlm_latency_mean: float = 0.0
Expand Down
159 changes: 145 additions & 14 deletions src/opentau/datasets/dataset_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,60 @@

import numpy as np
import torch
from torch.utils.data import ConcatDataset, DataLoader, Sampler
from torch.utils.data import ConcatDataset, DataLoader, Dataset, Sampler

from opentau.configs.train import TrainPipelineConfig
from opentau.datasets.compute_stats import aggregate_stats
from opentau.datasets.lerobot_dataset import BaseDataset, DatasetMetadata
from opentau.datasets.standard_data_format_mapping import DATA_FEATURES_NAME_MAPPING


class _TaggedDataset(Dataset):
"""Wraps a ``BaseDataset`` so every sample carries its mixture-level identity.

The wrapped sample dict gains two keys:

- ``"dataset_repo_id"``: ``str`` — the deduplicated mixture-level name
(matches an entry in ``DatasetMixtureMetadata.dataset_names``).
Default PyTorch collate batches a per-sample ``str`` into a
``list[str]`` of length B.
- ``"dataset_index"``: ``torch.long`` scalar — the position of this
dataset in ``DatasetMixtureMetadata.dataset_names``. Default collate
stacks into a ``(B,)`` long tensor.

Policies read either key via ``PreTrainedPolicy._resolve_dataset_index``
and route per-sample into the stacked Normalize/Unnormalize buffers.

The wrapper exposes the underlying dataset's ``.meta`` so callers of
``WeightedDatasetMixture`` (e.g. metadata validation, per-dataset val
loaders) keep working unchanged. ``len`` delegates as well.
"""

def __init__(self, base: Dataset, dataset_repo_id: str, dataset_index: int):
self._base = base
self._dataset_repo_id = dataset_repo_id
# Pre-build the scalar long tensor once so __getitem__ doesn't
# allocate per call.
self._dataset_index_tensor = torch.tensor(int(dataset_index), dtype=torch.long)
# Preserve `.meta` so `WeightedDatasetMixture.__init__`'s validation
# (`hasattr(ds, "meta") and ds.meta is not None`) still works after
# wrapping. `Subset` / `random_split` produce wrappers without a
# `.meta`, so look one level deeper if needed.
meta = getattr(base, "meta", None)
if meta is None and hasattr(base, "dataset"):
meta = getattr(base.dataset, "meta", None)
self.meta = meta

def __len__(self) -> int:
return len(self._base)

def __getitem__(self, idx):
item = self._base[idx]
item["dataset_repo_id"] = self._dataset_repo_id
item["dataset_index"] = self._dataset_index_tensor
return item


def pad_vector(vector: np.ndarray, new_dim: int) -> np.ndarray:
"""Pad the last dimension of a vector to a target size with zeros.

Expand Down Expand Up @@ -114,21 +160,90 @@ def _apply_data_feature_name_mapping_overrides(


class DatasetMixtureMetadata:
"""A class to hold metadata for a mixture of datasets.

This is used to aggregate metadata from multiple datasets into a single object.
"""Per-dataset metadata for a mixture (no cross-dataset stat aggregation).

Each underlying dataset's stats are normalised into the standard data
format (feature renaming, state/action padding to ``cfg.max_state_dim`` /
``cfg.max_action_dim``, missing-camera zero placeholders) and kept
*separate* on ``self.per_dataset_stats``. The policy's Normalize /
Unnormalize layers stack these along a new leading dim and use a
per-sample ``dataset_index`` to select the right row.

Attributes:
per_dataset_stats: ``list[dict[str, dict[str, np.ndarray]]]`` ordered
to match ``dataset_names``. ``per_dataset_stats[i]`` is the
standardised stats dict for dataset ``dataset_names[i]``.
dataset_names: Ordered deduplicated mixture-level names (matches
``WeightedDatasetMixture._make_dataset_names`` output).
dataset_name_to_index: ``{name: i}`` reverse lookup for O(1)
inference-time resolution.
"""

def __init__(
self, cfg: TrainPipelineConfig, metadatas: List[DatasetMetadata], dataset_weights: List[float]
self,
cfg: TrainPipelineConfig,
metadatas: List[DatasetMetadata],
dataset_weights: List[float],
dataset_names: List[str] | None = None,
):
self.cfg = cfg
self._dataset_weights = list(dataset_weights)

# convert each metadata stats to the standard data format
for metadata in metadatas:
metadata.stats = self._to_standard_data_format(metadata.repo_id, metadata.stats)

self.stats = aggregate_stats([metadata.stats for metadata in metadatas], weights=dataset_weights)
# Per-dataset stats (no cross-dataset aggregation). Policies stack
# these along a leading axis and index per-sample.
self.per_dataset_stats: list[dict[str, dict[str, np.ndarray]]] = [m.stats for m in metadatas]

# Names default to repo_id when WeightedDatasetMixture didn't supply
# deduplicated names (e.g. tests that instantiate the metadata
# directly). Duplicates would break the str -> index mapping; surface
# them rather than silently keeping the last one.
if dataset_names is None:
dataset_names = [m.repo_id for m in metadatas]
if len(dataset_names) != len(metadatas):
raise ValueError(f"dataset_names ({len(dataset_names)}) must match metadatas ({len(metadatas)}).")
if len(set(dataset_names)) != len(dataset_names):
dups = [n for n in dataset_names if dataset_names.count(n) > 1]
raise ValueError(
f"dataset_names must be unique; got duplicates {sorted(set(dups))}. "
"Use WeightedDatasetMixture's `_make_dataset_names` which appends "
"`#N` suffixes for repeated repo ids."
)
self.dataset_names: list[str] = list(dataset_names)
self.dataset_name_to_index: dict[str, int] = {n: i for i, n in enumerate(self.dataset_names)}

def aggregated_action_stats(self) -> dict[str, np.ndarray]:
"""Single mixture-wide action stats (mean/std/min/max/count).

Backwards-compat helper for the rare consumers that genuinely need a
single set of action stats across the whole mixture — currently only
``fit_fast_tokenizer.py``, which fits one BPE codec over a global
action range. Most callers should consume ``per_dataset_stats`` /
``dataset_names`` directly.
"""
# Co-iterate stats and weights so a non-trailing dataset lacking
# ``actions`` keeps the weight alignment correct. Slicing
# ``self._dataset_weights[:N]`` would silently misalign if e.g.
# `per_dataset_stats[1]` lacked actions — the BPE codec would then
# fit a weighted mean using dataset 0's weight applied to dataset 2's
# action distribution.
filtered: list[tuple[dict[str, np.ndarray], float]] = [
({"actions": s["actions"]}, w)
for s, w in zip(self.per_dataset_stats, self._dataset_weights, strict=True)
if "actions" in s
]
if not filtered:
raise ValueError(
"No dataset in the mixture exposes 'actions' stats; aggregated_action_stats() is undefined."
)
agg = aggregate_stats(
[s for s, _ in filtered],
weights=[w for _, w in filtered],
)
return agg["actions"]

def _to_standard_data_format(
self, repo_id: str, stats: dict[str, dict[str, np.ndarray]]
Expand Down Expand Up @@ -372,16 +487,30 @@ def __init__(
)

self.cfg = cfg
self.datasets = datasets
self.dataset_weights = dataset_weights
# Common resample rate (Hz); None = mixed-frequency (native fps per dataset).
self.action_freq: Optional[float] = action_freq
self.dataset_names = self._make_dataset_names(cfg, datasets) # For logging
self.dataset_weights = dataset_weights
self.dataset_names = self._make_dataset_names(cfg, datasets) # For logging + tagging

# Validate meta presence on the UN-wrapped inputs (so the error
# message points to the real source) before wrapping.
if not all(hasattr(ds, "meta") and ds.meta is not None for ds in datasets):
raise ValueError("All datasets must have a 'meta' attribute with valid metadata.")

# Wrap every underlying dataset so __getitem__ injects
# `dataset_repo_id: str` and `dataset_index: torch.long`. The policy's
# Normalize/Unnormalize uses these to gather per-sample stats.
# `_TaggedDataset` preserves `.meta` so `get_per_dataset_dataloaders`
# and other consumers keep working.
self.datasets = [
_TaggedDataset(ds, name, idx)
for idx, (ds, name) in enumerate(zip(datasets, self.dataset_names, strict=True))
]

logging.info("Initializing WeightedDatasetMixture...")
self._log_dataset_info()

self.concatenated_dataset: ConcatDataset = ConcatDataset(datasets)
self.concatenated_dataset: ConcatDataset = ConcatDataset(self.datasets)
logging.info(f"Total length of concatenated dataset: {len(self.concatenated_dataset)}")

self.sample_weights: torch.Tensor = self._calculate_sample_weights()
Expand All @@ -394,10 +523,12 @@ def __init__(
)
logging.info("-" * 30)

# aggregate metadata
if not all(hasattr(ds, "meta") and ds.meta is not None for ds in datasets):
raise ValueError("All datasets must have a 'meta' attribute with valid metadata.")
self.meta = DatasetMixtureMetadata(cfg, [ds.meta for ds in datasets], dataset_weights)
self.meta = DatasetMixtureMetadata(
cfg,
[ds.meta for ds in datasets],
dataset_weights,
dataset_names=self.dataset_names,
)

@staticmethod
def _make_dataset_names(cfg: TrainPipelineConfig, datasets: List[BaseDataset]) -> List[str]:
Expand Down
40 changes: 38 additions & 2 deletions src/opentau/policies/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,17 +215,44 @@ def make_policy(
policy_cls = get_policy_class(cfg.type)

kwargs = {}
per_dataset_stats: list[dict[str, dict[str, np.ndarray]]] | None = None
dataset_names: list[str] | None = None

if ds_meta is not None:
features = dataset_to_policy_features(ds_meta.features)
kwargs["dataset_stats"] = ds_meta.stats
# `DatasetMixtureMetadata` exposes per-dataset stats + names; a bare
# `LeRobotDatasetMetadata` (single-dataset path, e.g. some scripts /
# tests) exposes only `.stats` — wrap it into a singleton list so the
# policy ctor sees a uniform shape.
if hasattr(ds_meta, "per_dataset_stats") and hasattr(ds_meta, "dataset_names"):
per_dataset_stats = list(ds_meta.per_dataset_stats)
dataset_names = list(ds_meta.dataset_names)
else:
per_dataset_stats = [ds_meta.stats]
dataset_names = [getattr(ds_meta, "repo_id", "default")]

if not features_already_set:
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}

if stats is not None:
kwargs["dataset_stats"] = stats
# External per-dataset stats override. Accept either a single
# dict-of-features (wrapped into a singleton list) or an already-
# listed `per_dataset_stats`.
if isinstance(stats, dict):
per_dataset_stats = [stats]
dataset_names = dataset_names or ["external"]
else:
per_dataset_stats = list(stats)
dataset_names = dataset_names or [f"external_{i}" for i in range(len(per_dataset_stats))]

if per_dataset_stats is not None:
kwargs["per_dataset_stats"] = per_dataset_stats
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocker — cfg.policy.type = "value" immediately raises TypeError.

The PR description says the value policy is "out of scope and unchanged", but the factory now unconditionally passes per_dataset_stats= and dataset_names= to every policy class. ValueFunction.__init__ (policies/value/modeling_value.py:124-138) still has the old signature:

def __init__(
    self,
    config: ValueConfig,
    dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):

So make_policy(cfg, ds_meta=...) with cfg.type='value' triggers:

TypeError: ValueFunction.__init__() got an unexpected keyword argument 'per_dataset_stats'

Even if that were fixed, ValueFunction.__init__ still calls Normalize(config.input_features, config.normalization_mapping, dataset_stats)dataset_stats is the 3rd positional arg, which is now per_dataset_stats: list[...]. Passing a dict there would either crash in create_stats_buffers (which iterates len(per_dataset_stats) as if it's a list) or worse, silently produce nonsense buffer shapes. And self.normalize_inputs(batch) inside forward is now missing the dataset_index positional arg.

Either: (a) update ValueFunction to the new API so "unchanged" becomes true, (b) carve value out of the per_dataset_stats kwarg passing in this factory and keep the legacy path, or (c) explicitly raise NotImplementedError when cfg.type=='value' lands in this branch so it fails loudly with a useful message instead of a generic TypeError.


Generated by Claude Code

kwargs["dataset_names"] = dataset_names
# Persist the ordered name list into the policy config so it survives
# save -> load. Inference-time `batch["dataset_repo_id"]` strings are
# resolved against this list.
cfg.dataset_names = list(dataset_names)

if execution_target is not None:
kwargs["execution_target"] = execution_target
Expand All @@ -243,6 +270,15 @@ def make_policy(

assert isinstance(policy, nn.Module)

# If the checkpoint was saved with save_normalization_stats=False the
# buffers loaded back as +inf — repopulate from the caller's stats if we
# have them, otherwise raise a clear error rather than letting the first
# forward fail mid-step.
if cfg.pretrained_path and per_dataset_stats is not None:
policy._inject_stats(per_dataset_stats, dataset_names=dataset_names)
elif cfg.pretrained_path:
policy._check_norm_stats_loaded()

# policy = torch.compile(policy, mode="reduce-overhead")

return policy
Loading
Loading