Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions src/opentau/configs/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
- Evaluation settings and parameters
"""

import warnings
from dataclasses import dataclass, field

import draccus
Expand Down Expand Up @@ -105,10 +106,16 @@ class DatasetConfig:
robot_type: str | None = None
control_mode: str | None = None

# Ratio of the dataset to be used for validation. Please specify a value.
# If `val_freq` is set to 0, a validation dataset will not be created and this value will be ignored.
# Defaults to 0.05.
val_split_ratio: float = 0.05
# DEPRECATED. Set `val_split_ratio` on `DatasetMixtureConfig` instead — the
# mixture-level value is the single source of truth and is applied uniformly
# to every dataset in the mixture. This per-dataset field is retained only
# so that pre-existing JSON configs continue to parse; setting it here has
# no effect on the actual split. The default is `None` (sentinel meaning
# "user did not set this") so that
# `DatasetMixtureConfig.__post_init__` can distinguish a real per-dataset
# override from the unset default and only emit a `DeprecationWarning` in
# the former case.
val_split_ratio: float | None = None

def __post_init__(self):
"""Validate dataset configuration and register custom mappings if provided."""
Expand Down Expand Up @@ -278,9 +285,22 @@ def __post_init__(self):
if not 0.0 <= value <= 1.0:
raise ValueError(f"`{name}` must be in [0, 1], got {value}.")

# set the val_split_ratio for all datasets in the mixture
# `DatasetConfig.val_split_ratio` is deprecated — the mixture-level
# value is the single source of truth (read by `factory.make_dataset`).
# The per-dataset field defaults to `None`; warn only when the user
# actually set a value there, since that's the case where their input
# is being silently ignored.
for dataset_cfg in self.datasets:
dataset_cfg.val_split_ratio = self.val_split_ratio
if dataset_cfg.val_split_ratio is not None:
warnings.warn(
"`DatasetConfig.val_split_ratio` is deprecated and ignored; "
"set `val_split_ratio` on `DatasetMixtureConfig` instead. "
f"Got dataset value {dataset_cfg.val_split_ratio} "
f"vs. mixture value {self.val_split_ratio}; the mixture "
"value will be used.",
DeprecationWarning,
stacklevel=2,
)
Comment thread
shuheng-liu marked this conversation as resolved.


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions src/opentau/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def make_dataset(

A train and validation dataset are returned if `train_cfg.val_freq` is greater than 0.
The validation dataset is a subset of the train dataset, and is used for evaluation during training.
The validation dataset is created by splitting the train dataset into train and validation sets based on `cfg.val_split_ratio`.
The validation dataset is created by splitting the train dataset into train and validation sets based on `train_cfg.dataset_mixture.val_split_ratio`.

Args:
cfg (DatasetConfig): A DatasetConfig used to create a LeRobotDataset.
Expand Down Expand Up @@ -243,7 +243,7 @@ def make_dataset(
dataset.meta.stats[key][stats_type] = np.array(stats, dtype=np.float32)

if train_cfg.val_freq > 0:
val_size = int(len(dataset) * cfg.val_split_ratio)
val_size = int(len(dataset) * train_cfg.dataset_mixture.val_split_ratio)
train_size = len(dataset) - val_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_dataset.meta = copy.deepcopy(dataset.meta) # type: ignore[assignment]
Expand Down
39 changes: 39 additions & 0 deletions tests/configs/test_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

import pytest

from opentau.configs.default import DatasetConfig, DatasetMixtureConfig
Expand Down Expand Up @@ -97,6 +99,43 @@ def test_invalid_vector_resample_strategy_raises_error():
DatasetMixtureConfig(vector_resample_strategy=strategy)


def test_val_split_ratio_no_warning_when_only_mixture_customized():
"""Setting only the mixture-level `val_split_ratio` must not warn.

This is the common path users follow after the deprecation; previously
a per-dataset default of 0.05 made every customized mixture trip a
false-positive `DeprecationWarning` because every child still had its
default value.
"""
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
DatasetMixtureConfig(
datasets=[DatasetConfig(repo_id="foo/bar"), DatasetConfig(repo_id="baz/qux")],
val_split_ratio=0.1,
)
val_split_warnings = [
w
for w in caught
if issubclass(w.category, DeprecationWarning) and "val_split_ratio" in str(w.message)
]
assert not val_split_warnings, (
f"Unexpected val_split_ratio DeprecationWarning(s): {[str(w.message) for w in val_split_warnings]}"
)


def test_val_split_ratio_warns_when_child_overrides():
"""Setting `val_split_ratio` on a child `DatasetConfig` must emit a DeprecationWarning."""
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
DatasetMixtureConfig(
datasets=[DatasetConfig(repo_id="foo/bar", val_split_ratio=0.2)],
val_split_ratio=0.1,
)
assert any(
issubclass(w.category, DeprecationWarning) and "val_split_ratio" in str(w.message) for w in caught
)


class TestDatasetConfigDataMapping:
"""Test class for DatasetConfig data mapping functionality."""

Expand Down
Loading