Skip to content

Add support for online data streaming#160

Merged
yuecideng merged 31 commits intomainfrom
yueci/online-data-streaming
Mar 7, 2026
Merged

Add support for online data streaming#160
yuecideng merged 31 commits intomainfrom
yueci/online-data-streaming

Conversation

@yuecideng
Copy link
Contributor

@yuecideng yuecideng commented Mar 4, 2026

Description

TODO:

  • Add a torch Dataset style implementation for ODS. This allow model training side to use as common.
  • Add shared buffer init from config
  • Add unit test for the ODS.
  • Applied adaptation for RL training

Type of change

  • New feature (non-breaking change which adds functionality)
  • Breaking change (existing functionality will not work without user modification)
  • Documentation update

Checklist

  • I have run the black . command to format the code base.
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • Dependencies have been updated, if applicable.

Copilot AI review requested due to automatic review settings March 4, 2026 02:37
@yuecideng yuecideng added dataset gym robot learning env and its related features labels Mar 4, 2026
@yuecideng yuecideng requested review from yangchen73 and yhnsu March 4, 2026 02:37
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces online rollout data streaming support by migrating core sim/gym data structures (observations, actions, sensor outputs, episode buffers) from Python dicts to tensordict.TensorDict, adding an online data engine, and standardizing episode truncation via max_episode_steps.

Changes:

  • Add OnlineDataEngine and rollout-buffer plumbing for online episode data capture/streaming.
  • Switch environment observations/actions and sensor get_data() outputs to TensorDict, and refactor dataset recording to read from a rollout buffer.
  • Replace episode_length/action_length-style time limits with max_episode_steps (code, docs, and configs updated accordingly).

Reviewed changes

Copilot reviewed 41 out of 41 changed files in this pull request and generated 12 comments.

Show a summary per file
File Description
tests/sim/sensors/test_stereo.py Updates stereo sensor test expectations (removes dict-type assertion).
tests/sim/sensors/test_camera.py Updates camera test to expect TensorDict.
scripts/tutorials/sim/create_sensor.py Adjusts tutorial image extraction (non-inplace squeeze) and imports cv2.
pyproject.toml Adds tensordict dependency.
embodichain/utils/configclass.py Removes episode_length from configclass example snippet.
embodichain/lab/sim/types.py Changes EnvObs/EnvAction type aliases to TensorDict.
embodichain/lab/sim/sensors/stereo.py Adds TensorDict import (type alignment).
embodichain/lab/sim/sensors/contact_sensor.py Refactors contact sensor buffering to preallocated TensorDict + configurable max size.
embodichain/lab/sim/sensors/base_sensor.py Changes base sensor buffer and get_data() return type to TensorDict.
embodichain/lab/sim/robots/dexforce_w1/cfg.py Updates robot example to run multiple envs.
embodichain/lab/sim/objects/robot.py Returns proprioception as TensorDict; updates joint-id selection logic.
embodichain/lab/sim/objects/articulation.py Introduces active joint ids/names helpers and active DoF.
embodichain/lab/gym/utils/registration.py Adjusts TimeLimit wrapper behavior; removes extra wrapper injection.
embodichain/lab/gym/utils/gym_utils.py Supports TensorDict observation spaces; adds rollout-buffer initialization helper; parses max_episode_steps.
embodichain/lab/gym/envs/tasks/tableware/stack_blocks_two.py Removes demo action_length bookkeeping.
embodichain/lab/gym/envs/tasks/tableware/pour_water/pour_water.py Removes demo action_length; adjusts joint selection for EEF mimic joints.
embodichain/lab/gym/envs/tasks/tableware/blocks_ranking_rgb.py Removes demo action_length bookkeeping.
embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py Removes demo action_length bookkeeping.
embodichain/lab/gym/envs/tasks/special/simple_task.py Removes demo action_length bookkeeping.
embodichain/lab/gym/envs/tasks/rl/push_cube.py Removes timeout truncation based on episode_length.
embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py Removes timeout truncation based on episode_length.
embodichain/lab/gym/envs/rl_env.py Removes episode_length default and base timeout truncation method.
embodichain/lab/gym/envs/managers/observations.py Maps joint indices through env active joint ids for normalization.
embodichain/lab/gym/envs/managers/datasets.py Switches dataset saving pipeline to use env rollout buffer + TensorDict frames.
embodichain/lab/gym/envs/embodied_env.py Adds rollout buffer lifecycle, active joint selection, and updated dataset-save flow.
embodichain/lab/gym/envs/base_env.py Adds max_episode_steps, TensorDict observations/info, active joint ids, and rollout-step hooks.
embodichain/lab/engine/data.py Adds multiprocessing online data streaming engine for rollouts.
embodichain/lab/engine/init.py Exposes OnlineDataEngine.
docs/source/tutorial/rl.rst Removes episode_length references from RL tutorial guidance.
docs/source/overview/gym/env.md Documents max_episode_steps, control_parts, active_joint_ids, rollout buffer flags.
configs/gym/stack_cups/cobot_magic_3cam.json Updates joint id selection values.
configs/gym/stack_blocks_two/cobot_magic_3cam.json Updates joint id selection values.
configs/gym/pour_water/gym_config_simple.json Adds max_episode_steps; adjusts dataset metadata payload.
configs/gym/pour_water/gym_config.json Adds max_episode_steps; moves control_parts into env; updates joint id selection.
configs/gym/match_object_container/cobot_magic_3cam.json Updates joint id selection values.
configs/gym/blocks_ranking_size/cobot_magic_3cam.json Updates joint id selection values.
configs/gym/blocks_ranking_rgb/cobot_magic_3cam.json Updates joint id selection values.
configs/gym/agent/rearrangement_agent/fast_gym_config.json Moves control_parts into env; adjusts dataset metadata payload.
configs/gym/agent/pour_water_agent/fast_gym_config.json Moves control_parts into env; adjusts dataset metadata payload.
configs/agents/rl/push_cube/gym_config.json Adds max_episode_steps; removes episode_length extension.
configs/agents/rl/basic/cart_pole/gym_config.json Adds max_episode_steps; removes episode_length extension.
Comments suppressed due to low confidence (1)

embodichain/lab/sim/sensors/contact_sensor.py:366

  • filter_by_user_ids() applies the mask over the full preallocated _data_buffer (length max_contact_num). Entries beyond _curr_contact_num are uninitialized/stale, so this can return garbage contacts. Filter only over the current valid slice (e.g., data = self.get_data() and mask that), and similarly ensure set_contact_point_visibility() uses only the valid contacts.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 166 to 175
@@ -167,8 +172,6 @@ def get_data(self, copy: bool = True) -> Dict[str, torch.Tensor]:
Returns:
The data collected by the sensor.
"""
if copy:
return {key: value.clone() for key, value in self._data_buffer.items()}
return self._data_buffer
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The get_data docstring still documents a copy argument, but the method signature no longer accepts it. Update the docstring to match the API (or reintroduce the parameter if you still need copy semantics).

Copilot uses AI. Check for mistakes.
Comment on lines 229 to 230
if n_contact == 0:
self._data_buffer = {
"position": torch.empty((0, 3), device=self.device),
"normal": torch.empty((0, 3), device=self.device),
"friction": torch.empty((0, 3), device=self.device),
"impulse": torch.empty((0,), device=self.device),
"distance": torch.empty((0,), device=self.device),
"user_ids": torch.empty((0, 2), dtype=torch.int32, device=self.device),
"env_ids": torch.empty((0,), dtype=torch.int32, device=self.device),
}
return
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When n_contact == 0, update() returns without resetting _curr_contact_num. If the previous step had contacts, get_data() will keep returning stale contact entries. Set self._curr_contact_num = 0 before returning (and optionally clear/zero the first entries if needed).

Copilot uses AI. Check for mistakes.
"black==24.3.0",
"fvcore",
"h5py",
"tensordict"
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new tensordict dependency is added without a version pin, meaning builds will always pull the latest available release from the package index. If this package is ever compromised upstream, new installs of this project could automatically execute malicious code as part of your environment. To reduce supply-chain risk, pin tensordict to a specific, vetted version (and update intentionally over time) rather than leaving it unconstrained.

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings March 6, 2026 15:29
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 47 out of 47 changed files in this pull request and generated 10 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 229 to 231
if n_contact == 0:
self._data_buffer = {
"position": torch.empty((0, 3), device=self.device),
"normal": torch.empty((0, 3), device=self.device),
"friction": torch.empty((0, 3), device=self.device),
"impulse": torch.empty((0,), device=self.device),
"distance": torch.empty((0,), device=self.device),
"user_ids": torch.empty((0, 2), dtype=torch.int32, device=self.device),
"env_ids": torch.empty((0,), dtype=torch.int32, device=self.device),
}
return

Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When n_contact == 0, update() returns early without resetting _curr_contact_num. If the previous update had contacts, get_data() will continue returning stale contact data. Set _curr_contact_num = 0 (and optionally clear/zero relevant buffer slices) before returning on the no-contact path.

Copilot uses AI. Check for mistakes.
Comment on lines 52 to 57
# Get data from the camera
data = self.camera.get_data()

# Check if data is a dictionary
assert isinstance(data, dict), "Camera data should be a dictionary"

# Check if all expected keys are present
for key in self.camera.SUPPORTED_DATA_TYPES:
assert key in data, f"Missing key in camera data: {key}"
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_stereo.py removes the type assertion entirely, while test_camera.py now asserts the return type is TensorDict. Given BaseSensor.get_data() now returns a TensorDict, this test should also assert the expected type to avoid silently accepting regressions (e.g., accidental list/None returns that still support in checks).

Copilot uses AI. Check for mistakes.
Comment on lines +570 to +571
# Use all joints of the robot.
self.active_joint_ids = list(range(robot.dof))
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the default branch (no control_parts / active_joint_ids specified), self.active_joint_ids = list(range(robot.dof)) includes mimic joints. Since Articulation.active_joint_ids already excludes mimic joints and other codepaths explicitly use remove_mimic=True, the default here should likely be robot.active_joint_ids to avoid sending actions to mimic joints / mis-sizing action spaces.

Suggested change
# Use all joints of the robot.
self.active_joint_ids = list(range(robot.dof))
# Use all active joints of the robot (excluding mimic joints).
self.active_joint_ids = list(robot.active_joint_ids)

Copilot uses AI. Check for mistakes.
Comment on lines 441 to 461
@@ -440,12 +449,18 @@ def get_info(self, **kwargs) -> Dict[str, Any]:
Returns:
The info dictionary.
"""
info = dict(elapsed_steps=self._elapsed_steps)
info = TensorDict(
dict(elapsed_steps=self._elapsed_steps),
batch_size=[self.num_envs],
device=self.device,
)

info.update(self.evaluate(**kwargs))
evaluate = self.evaluate(**kwargs)
if evaluate:
info.update(evaluate)
return info
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_info() now returns a TensorDict, but Gymnasium’s Env.step() contract expects a plain dict for info, and common wrappers may attempt to insert arbitrary Python values into it. Consider keeping info as a dict (or converting to dict right before returning from step) to preserve wrapper/interoperability.

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings March 7, 2026 04:20
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 49 out of 49 changed files in this pull request and generated 8 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +20 to +29
from typing import Sequence, Union
from tensordict import TensorDict


Array = Union[torch.Tensor, np.ndarray, Sequence]
Device = Union[str, torch.device]

EnvObs = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]
EnvObs = TensorDict[str, Union[torch.Tensor, TensorDict[str, torch.Tensor]]]

EnvAction = Union[torch.Tensor, Dict[str, torch.Tensor]]
EnvAction = Union[torch.Tensor, TensorDict[str, torch.Tensor]]
Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EnvObs = TensorDict[...] and the other TensorDict[...] aliases are evaluated at import time. Unless TensorDict implements __class_getitem__ for those type parameters (version-dependent), this can raise TypeError: 'TensorDict' object is not subscriptable and break imports. Consider using from __future__ import annotations plus plain TensorDict/TensorDictBase, or use typing.TypeAlias with postponed evaluation.

Copilot uses AI. Check for mistakes.
Comment on lines 349 to +355
filter0_mask = torch.isin(self._data_buffer["user_ids"][:, 0], item_user_ids)
filter1_mask = torch.isin(self._data_buffer["user_ids"][:, 1], item_user_ids)
if self.cfg.filter_need_both_actor:
filter_mask = torch.logical_and(filter0_mask, filter1_mask)
else:
filter_mask = torch.logical_or(filter0_mask, filter1_mask)
return {
"position": self._data_buffer["position"][filter_mask],
"normal": self._data_buffer["normal"][filter_mask],
"friction": self._data_buffer["friction"][filter_mask],
"impulse": self._data_buffer["impulse"][filter_mask],
"distance": self._data_buffer["distance"][filter_mask],
"user_ids": self._data_buffer["user_ids"][filter_mask],
"env_ids": self._data_buffer["env_ids"][filter_mask],
}
return self._data_buffer[filter_mask]
Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filter_by_user_ids() computes the mask over the entire preallocated buffer (max_contact_num) instead of only the valid prefix [0:_curr_contact_num), so it can return uninitialized/stale rows that happen to match the user_id filter. Apply the filter on self.get_data() / self._data_buffer[: self._curr_contact_num] so only valid contacts are considered.

Copilot uses AI. Check for mistakes.
Comment on lines +436 to +447
# Sample row indices and chunk start offsets.
row_sample_idx = torch.randint(0, len(available), (batch_size,))
row_indices = available[row_sample_idx]

max_start = max_steps - chunk_size
start_indices = torch.randint(0, max_start + 1, (batch_size,))

time_offsets = torch.arange(chunk_size)
time_indices = start_indices[:, None] + time_offsets[None, :]

result = self.shared_buffer[row_indices[:, None], time_indices]

Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chunks are sampled uniformly from [0, max_episode_steps), but many create_demo_action_list() implementations generate much shorter rollouts (e.g., 100 steps vs max_episode_steps=600), leaving the remainder of each buffer row as zeros/stale. Without tracking per-episode valid lengths or a mask/done flag in the shared buffer, sample_batch() will frequently return invalid timesteps. Consider storing an episode_len / valid_mask per buffer row and constrain sampling to valid ranges (or pad + return a mask).

Copilot uses AI. Check for mistakes.
Comment on lines 95 to 107
# check for user supplied max_episode_steps during gym.make calls
if frame.f_code.co_name == "make" and "max_episode_steps" in prev_frame_locals:
if prev_frame_locals["max_episode_steps"] is not None:
max_episode_steps = prev_frame_locals["max_episode_steps"]
# do some wrapper surgery to remove the previous timelimit wrapper
# with gymnasium 0.29.1, this will remove the timelimit wrapper and nothing else.
curr_env = env
while curr_env is not None:
if isinstance(curr_env, gym.wrappers.TimeLimit):
self.env = curr_env.env
break
self._max_episode_steps = max_episode_steps
self._max_episode_steps = self.base_env.max_episode_steps

Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TimeLimitWrapper.__init__ computes/overrides the max_episode_steps argument (including honoring a user override in gym.make), but then ignores it by unconditionally setting _max_episode_steps = self.base_env.max_episode_steps. If this wrapper is meant to enforce a caller-specified time limit, _max_episode_steps should be set from the resolved max_episode_steps value (or the wrapper should be removed if BaseEnv now owns the time-limit logic).

Copilot uses AI. Check for mistakes.
Comment on lines 26 to 28
from typing import Dict, Tuple, List, Sequence
from tensordict import TensorDict

Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TensorDict is imported but never used in this module. Remove the import to avoid linting failures and keep dependencies minimal.

Copilot uses AI. Check for mistakes.
Comment on lines +232 to 233
def get_proprioception(self) -> TensorDict[str, torch.Tensor]:
"""Gets robot proprioception information, primarily for agent state representation in robot learning scenarios.
Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file does not use from __future__ import annotations, so the return annotation TensorDict[str, torch.Tensor] is evaluated at function definition time. If TensorDict is not subscriptable in the installed tensordict version, this will raise at import time. Either add postponed annotations (from __future__ import annotations) or avoid subscripting TensorDict in runtime annotations.

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings March 7, 2026 13:49
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 59 out of 59 changed files in this pull request and generated 8 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copilot AI review requested due to automatic review settings March 7, 2026 14:28
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 62 out of 62 changed files in this pull request and generated 10 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +463 to +487
logger.log_debug(f"Initializing episode for env_ids: {env_ids}", color="blue")
save_data = kwargs.get("save_data", True)

# Determine which environments to process
if env_ids is None:
env_ids_to_process = list(range(self.num_envs))
elif isinstance(env_ids, torch.Tensor):
env_ids_to_process = env_ids.cpu().tolist()
else:
env_ids_to_process = list(env_ids)
env_ids_to_process = list(range(self.num_envs)) if env_ids is None else env_ids

# Save dataset before clearing buffers for environments that are being reset
if save_data and self.dataset_manager:
if "save" in self.dataset_manager.available_modes:

# Filter to only save successful episodes
successful_env_ids = [
env_id
for env_id in env_ids_to_process
if (
self.episode_success_status.get(env_id, False)
or self._task_success[env_id].item()
)
]
successful_env_ids = self.episode_success_status | self._task_success

if successful_env_ids:
if successful_env_ids.any():

# Convert back to tensor if needed
successful_env_ids_tensor = torch.tensor(
successful_env_ids, device=self.device
)
self.dataset_manager.apply(
mode="save",
env_ids=successful_env_ids_tensor,
env_ids=successful_env_ids.nonzero(as_tuple=True)[0],
)

# Clear episode buffers and reset success status for environments being reset
for env_id in env_ids_to_process:
self.episode_obs_buffer[env_id].clear()
self.episode_action_buffer[env_id].clear()
self.episode_success_status[env_id] = False
if self.rollout_buffer is not None:
self.current_rollout_step = 0

self.episode_success_status[env_ids_to_process] = False
Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_initialize_episode() now computes successful_env_ids = self.episode_success_status | self._task_success across all envs and resets current_rollout_step = 0 globally. When env_ids is a subset (partial reset), this can (1) save episodes for envs that are not being reset and (2) wipe/overwrite rollout data for envs that are still running. Consider tracking per-env rollout step counters and only saving/clearing data for env_ids_to_process.

Copilot uses AI. Check for mistakes.
Comment on lines +507 to +593
class TestOnlineDatasetDynamicChunk(unittest.TestCase):
"""Tests for OnlineDataset with ChunkSizeSampler chunk_size."""

def setUp(self) -> None:
self.engine = _make_fake_engine()

def test_uniform_sampler_item_mode_shape(self) -> None:
"""Item mode with UniformChunkSampler: batch_size dim is absent, time dim varies."""
LOW, HIGH = 5, 15
sampler = UniformChunkSampler(low=LOW, high=HIGH)
dataset = OnlineDataset(self.engine, chunk_size=sampler)
it = iter(dataset)
for _ in range(10):
sample = next(it)
# batch_size has one element — the chunk dimension.
self.assertEqual(len(sample.batch_size), 1)
chunk_dim = sample.batch_size[0]
self.assertGreaterEqual(chunk_dim, LOW)
self.assertLessEqual(chunk_dim, HIGH)

def test_gmm_sampler_item_mode_shape(self) -> None:
"""Item mode with GMMChunkSampler: chunk dim is clamped within [low, high]."""
LOW, HIGH = 4, 20
sampler = GMMChunkSampler(
means=[8.0, 16.0], stds=[2.0, 2.0], low=LOW, high=HIGH
)
dataset = OnlineDataset(self.engine, chunk_size=sampler)
it = iter(dataset)
for _ in range(10):
sample = next(it)
chunk_dim = sample.batch_size[0]
self.assertGreaterEqual(chunk_dim, LOW)
self.assertLessEqual(chunk_dim, HIGH)

def test_uniform_sampler_batch_mode_shape(self) -> None:
"""Batch mode: per-batch chunk size is consistent across all trajectories."""
BATCH = 3
LOW, HIGH = 5, 15
sampler = UniformChunkSampler(low=LOW, high=HIGH)
dataset = OnlineDataset(self.engine, chunk_size=sampler, batch_size=BATCH)
it = iter(dataset)
for _ in range(10):
batch = next(it)
self.assertEqual(len(batch.batch_size), 2)
self.assertEqual(batch.batch_size[0], BATCH)
chunk_dim = batch.batch_size[1]
self.assertGreaterEqual(chunk_dim, LOW)
self.assertLessEqual(chunk_dim, HIGH)

def test_dynamic_chunk_sizes_vary(self) -> None:
"""Consecutive samples from a uniform sampler produce different chunk sizes."""
LOW, HIGH = 5, 30
sampler = UniformChunkSampler(low=LOW, high=HIGH)
dataset = OnlineDataset(self.engine, chunk_size=sampler)
it = iter(dataset)
sizes = {next(it).batch_size[0] for _ in range(50)}
# With a range of 26 values, drawing 50 times should yield > 1 unique size.
assert (
len(sizes) >= 1
), "Expected multiple unique chunk sizes from uniform sampler"

def test_invalid_chunk_size_type_raises(self) -> None:
"""TypeError when chunk_size is not an int or ChunkSizeSampler."""
with self.assertRaises(TypeError):
OnlineDataset(self.engine, chunk_size="large") # type: ignore[arg-type]

def test_invalid_chunk_size_int_raises(self) -> None:
"""ValueError when chunk_size is an int < 1."""
with self.assertRaises(ValueError):
OnlineDataset(self.engine, chunk_size=0)

def test_custom_sampler_subclass(self) -> None:
"""A user-defined ChunkSizeSampler subclass is accepted and called."""

class FixedSampler(ChunkSizeSampler):
def __call__(self) -> int:
return 7

dataset = OnlineDataset(self.engine, chunk_size=FixedSampler())
sample = next(iter(dataset))
self.assertEqual(sample.batch_size[0], 7)


# ---------------------------------------------------------------------------

if __name__ == "__main__":
unittest.main()
Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new test file uses unittest.TestCase classes and unittest.main(), but the repo’s testing conventions document pytest function style or plain class TestX: with setup_method/teardown_method instead (see CLAUDE.md). Consider rewriting these to the documented style for consistency.

Copilot uses AI. Check for mistakes.
Comment on lines +409 to +413
obs = TensorDict(
dict(robot=self.robot.get_proprioception()[:, self.active_joint_ids]),
batch_size=[self.num_envs],
device=self.device,
)
Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_obs() indexes the TensorDict returned by robot.get_proprioception() as td[:, self.active_joint_ids], but that TensorDict has batch_size [num_envs] (1D). This tuple-indexing is likely invalid and will raise at runtime; instead, slice each proprioception tensor by joint ids (e.g. qpos/qvel/qf) and keep the TensorDict batch_size as [num_envs].

Copilot uses AI. Check for mistakes.
Comment on lines 90 to 96
# The robot agent instance.
robot: Robot = None

active_joint_ids: List[int] = []

# The sensors used in the environment.
sensors: Dict[str, BaseSensor] = {}
Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

active_joint_ids is declared as a mutable class attribute ([]) on BaseEnv, which will be shared across instances. This can leak joint ids between environments/tests; initialize it per-instance (e.g. self.active_joint_ids = [] in __init__) instead.

Copilot uses AI. Check for mistakes.
Comment on lines +456 to +460
# Edge case: the entire valid region is locked. Fall back to
# sampling from all valid rows to avoid a hard failure.
log_error(
"[OnlineDataEngine] All valid buffer rows are currently locked. "
"Cannot sample a batch at this time.",
Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says it will "fall back to sampling from all valid rows", but log_error(..., RuntimeError) will always raise here (no fallback). Either implement an actual fallback strategy or update the comment/message to reflect that sampling fails when all rows are locked.

Suggested change
# Edge case: the entire valid region is locked. Fall back to
# sampling from all valid rows to avoid a hard failure.
log_error(
"[OnlineDataEngine] All valid buffer rows are currently locked. "
"Cannot sample a batch at this time.",
# Edge case: the entire valid region is locked. Sampling a batch
# is not possible in this state and will result in a hard failure.
log_error(
"[OnlineDataEngine] All valid buffer rows are currently locked. "
"Cannot sample a batch at this time; sampling fails because no "
"unlocked rows are available.",

Copilot uses AI. Check for mistakes.
.. literalinclude:: ../../../scripts/tutorials/gym/modular_env.py
:language: python
:start-at: @register_env("ModularEnv-v1", max_episode_steps=100, override=True)
:start-at: @register_env("ModularEnv-v1" override=True)
Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The :start-at: string is missing a comma: it currently won’t match the actual decorator line in modular_env.py, so the literalinclude will fail. Update it to @register_env("ModularEnv-v1", override=True).

Suggested change
:start-at: @register_env("ModularEnv-v1" override=True)
:start-at: @register_env("ModularEnv-v1", override=True)

Copilot uses AI. Check for mistakes.
Comment on lines +399 to +407
)
# TODO: Use a action manager to handle the action space consistency with RL.
if isinstance(action, TensorDict):
action_to_store = action["qpos"]
elif isinstance(action, torch.Tensor):
action_to_store = action
self.rollout_buffer["actions"][:, self.current_rollout_step, ...].copy_(
action_to_store.to(buffer_device), non_blocking=True
)
Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In _hook_after_sim_step, action_to_store is only set for TensorDict actions when the key 'qpos' exists. If the action TensorDict contains only 'qvel'/'qf' (supported by _step_action), this will raise KeyError/UnboundLocalError during rollout recording. Handle the other supported keys or choose a consistent action representation to store.

Copilot uses AI. Check for mistakes.
Comment on lines +168 to +192
# Draw many small batches and collect all sampled row indices.
# We cannot directly observe row indices from outside, but we can
# verify that each result slice is *not* identical to a locked row's
# data (which has a unique random fingerprint).
locked_obs = engine.shared_buffer["obs"][LOCK_START:LOCK_END] # [3, 50, 10]

for _ in range(20):
result = engine.sample_batch(batch_size=1, chunk_size=5)
sampled_obs_start = result["obs"][0, 0] # first timestep of first chunk
# Check that this does not exactly match any locked row's first timestep.
for r in range(LOCK_END - LOCK_START):
matched = torch.allclose(
sampled_obs_start, locked_obs[r, :5].mean(dim=-1, keepdim=True)
)
# The comparison above is a heuristic; the real guarantee is that
# available rows exclude locked ones. We use a direct index check:
# reconstruct which row could produce this exact obs by brute-force.
# Reconstructed check: verify available indices exclude locked rows.
all_rows = torch.arange(BUFFER_SIZE)
is_locked = (all_rows >= LOCK_START) & (all_rows < LOCK_END)
available = all_rows[~is_locked]
assert len(available) != 0, "available must be non-empty"
for row in locked_rows:
assert row not in available.tolist()

Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_sample_batch_locks_respected does not actually validate that sample_batch() avoids locked rows. The assertions only prove that available = all_rows[~is_locked] excludes the locked range (a tautology) and the matched result is never checked. Consider instrumenting sample_batch to expose chosen row indices in tests or compare sampled chunks directly against locked rows’ data to ensure none match.

Copilot uses AI. Check for mistakes.
@yuecideng yuecideng merged commit c0b2c57 into main Mar 7, 2026
6 of 10 checks passed
@yuecideng yuecideng deleted the yueci/online-data-streaming branch March 7, 2026 15:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dataset gym robot learning env and its related features

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants