Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces online rollout data streaming support by migrating core sim/gym data structures (observations, actions, sensor outputs, episode buffers) from Python dicts to tensordict.TensorDict, adding an online data engine, and standardizing episode truncation via max_episode_steps.
Changes:
- Add
OnlineDataEngineand rollout-buffer plumbing for online episode data capture/streaming. - Switch environment observations/actions and sensor
get_data()outputs toTensorDict, and refactor dataset recording to read from a rollout buffer. - Replace
episode_length/action_length-style time limits withmax_episode_steps(code, docs, and configs updated accordingly).
Reviewed changes
Copilot reviewed 41 out of 41 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/sim/sensors/test_stereo.py | Updates stereo sensor test expectations (removes dict-type assertion). |
| tests/sim/sensors/test_camera.py | Updates camera test to expect TensorDict. |
| scripts/tutorials/sim/create_sensor.py | Adjusts tutorial image extraction (non-inplace squeeze) and imports cv2. |
| pyproject.toml | Adds tensordict dependency. |
| embodichain/utils/configclass.py | Removes episode_length from configclass example snippet. |
| embodichain/lab/sim/types.py | Changes EnvObs/EnvAction type aliases to TensorDict. |
| embodichain/lab/sim/sensors/stereo.py | Adds TensorDict import (type alignment). |
| embodichain/lab/sim/sensors/contact_sensor.py | Refactors contact sensor buffering to preallocated TensorDict + configurable max size. |
| embodichain/lab/sim/sensors/base_sensor.py | Changes base sensor buffer and get_data() return type to TensorDict. |
| embodichain/lab/sim/robots/dexforce_w1/cfg.py | Updates robot example to run multiple envs. |
| embodichain/lab/sim/objects/robot.py | Returns proprioception as TensorDict; updates joint-id selection logic. |
| embodichain/lab/sim/objects/articulation.py | Introduces active joint ids/names helpers and active DoF. |
| embodichain/lab/gym/utils/registration.py | Adjusts TimeLimit wrapper behavior; removes extra wrapper injection. |
| embodichain/lab/gym/utils/gym_utils.py | Supports TensorDict observation spaces; adds rollout-buffer initialization helper; parses max_episode_steps. |
| embodichain/lab/gym/envs/tasks/tableware/stack_blocks_two.py | Removes demo action_length bookkeeping. |
| embodichain/lab/gym/envs/tasks/tableware/pour_water/pour_water.py | Removes demo action_length; adjusts joint selection for EEF mimic joints. |
| embodichain/lab/gym/envs/tasks/tableware/blocks_ranking_rgb.py | Removes demo action_length bookkeeping. |
| embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py | Removes demo action_length bookkeeping. |
| embodichain/lab/gym/envs/tasks/special/simple_task.py | Removes demo action_length bookkeeping. |
| embodichain/lab/gym/envs/tasks/rl/push_cube.py | Removes timeout truncation based on episode_length. |
| embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py | Removes timeout truncation based on episode_length. |
| embodichain/lab/gym/envs/rl_env.py | Removes episode_length default and base timeout truncation method. |
| embodichain/lab/gym/envs/managers/observations.py | Maps joint indices through env active joint ids for normalization. |
| embodichain/lab/gym/envs/managers/datasets.py | Switches dataset saving pipeline to use env rollout buffer + TensorDict frames. |
| embodichain/lab/gym/envs/embodied_env.py | Adds rollout buffer lifecycle, active joint selection, and updated dataset-save flow. |
| embodichain/lab/gym/envs/base_env.py | Adds max_episode_steps, TensorDict observations/info, active joint ids, and rollout-step hooks. |
| embodichain/lab/engine/data.py | Adds multiprocessing online data streaming engine for rollouts. |
| embodichain/lab/engine/init.py | Exposes OnlineDataEngine. |
| docs/source/tutorial/rl.rst | Removes episode_length references from RL tutorial guidance. |
| docs/source/overview/gym/env.md | Documents max_episode_steps, control_parts, active_joint_ids, rollout buffer flags. |
| configs/gym/stack_cups/cobot_magic_3cam.json | Updates joint id selection values. |
| configs/gym/stack_blocks_two/cobot_magic_3cam.json | Updates joint id selection values. |
| configs/gym/pour_water/gym_config_simple.json | Adds max_episode_steps; adjusts dataset metadata payload. |
| configs/gym/pour_water/gym_config.json | Adds max_episode_steps; moves control_parts into env; updates joint id selection. |
| configs/gym/match_object_container/cobot_magic_3cam.json | Updates joint id selection values. |
| configs/gym/blocks_ranking_size/cobot_magic_3cam.json | Updates joint id selection values. |
| configs/gym/blocks_ranking_rgb/cobot_magic_3cam.json | Updates joint id selection values. |
| configs/gym/agent/rearrangement_agent/fast_gym_config.json | Moves control_parts into env; adjusts dataset metadata payload. |
| configs/gym/agent/pour_water_agent/fast_gym_config.json | Moves control_parts into env; adjusts dataset metadata payload. |
| configs/agents/rl/push_cube/gym_config.json | Adds max_episode_steps; removes episode_length extension. |
| configs/agents/rl/basic/cart_pole/gym_config.json | Adds max_episode_steps; removes episode_length extension. |
Comments suppressed due to low confidence (1)
embodichain/lab/sim/sensors/contact_sensor.py:366
filter_by_user_ids()applies the mask over the full preallocated_data_buffer(lengthmax_contact_num). Entries beyond_curr_contact_numare uninitialized/stale, so this can return garbage contacts. Filter only over the current valid slice (e.g.,data = self.get_data()and mask that), and similarly ensureset_contact_point_visibility()uses only the valid contacts.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -167,8 +172,6 @@ def get_data(self, copy: bool = True) -> Dict[str, torch.Tensor]: | |||
| Returns: | |||
| The data collected by the sensor. | |||
| """ | |||
| if copy: | |||
| return {key: value.clone() for key, value in self._data_buffer.items()} | |||
| return self._data_buffer | |||
There was a problem hiding this comment.
The get_data docstring still documents a copy argument, but the method signature no longer accepts it. Update the docstring to match the API (or reintroduce the parameter if you still need copy semantics).
| if n_contact == 0: | ||
| self._data_buffer = { | ||
| "position": torch.empty((0, 3), device=self.device), | ||
| "normal": torch.empty((0, 3), device=self.device), | ||
| "friction": torch.empty((0, 3), device=self.device), | ||
| "impulse": torch.empty((0,), device=self.device), | ||
| "distance": torch.empty((0,), device=self.device), | ||
| "user_ids": torch.empty((0, 2), dtype=torch.int32, device=self.device), | ||
| "env_ids": torch.empty((0,), dtype=torch.int32, device=self.device), | ||
| } | ||
| return |
There was a problem hiding this comment.
When n_contact == 0, update() returns without resetting _curr_contact_num. If the previous step had contacts, get_data() will keep returning stale contact entries. Set self._curr_contact_num = 0 before returning (and optionally clear/zero the first entries if needed).
| "black==24.3.0", | ||
| "fvcore", | ||
| "h5py", | ||
| "tensordict" |
There was a problem hiding this comment.
The new tensordict dependency is added without a version pin, meaning builds will always pull the latest available release from the package index. If this package is ever compromised upstream, new installs of this project could automatically execute malicious code as part of your environment. To reduce supply-chain risk, pin tensordict to a specific, vetted version (and update intentionally over time) rather than leaving it unconstrained.
…orce/EmbodiChain into yueci/online-data-streaming
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 47 out of 47 changed files in this pull request and generated 10 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if n_contact == 0: | ||
| self._data_buffer = { | ||
| "position": torch.empty((0, 3), device=self.device), | ||
| "normal": torch.empty((0, 3), device=self.device), | ||
| "friction": torch.empty((0, 3), device=self.device), | ||
| "impulse": torch.empty((0,), device=self.device), | ||
| "distance": torch.empty((0,), device=self.device), | ||
| "user_ids": torch.empty((0, 2), dtype=torch.int32, device=self.device), | ||
| "env_ids": torch.empty((0,), dtype=torch.int32, device=self.device), | ||
| } | ||
| return | ||
|
|
There was a problem hiding this comment.
When n_contact == 0, update() returns early without resetting _curr_contact_num. If the previous update had contacts, get_data() will continue returning stale contact data. Set _curr_contact_num = 0 (and optionally clear/zero relevant buffer slices) before returning on the no-contact path.
| # Get data from the camera | ||
| data = self.camera.get_data() | ||
|
|
||
| # Check if data is a dictionary | ||
| assert isinstance(data, dict), "Camera data should be a dictionary" | ||
|
|
||
| # Check if all expected keys are present | ||
| for key in self.camera.SUPPORTED_DATA_TYPES: | ||
| assert key in data, f"Missing key in camera data: {key}" |
There was a problem hiding this comment.
test_stereo.py removes the type assertion entirely, while test_camera.py now asserts the return type is TensorDict. Given BaseSensor.get_data() now returns a TensorDict, this test should also assert the expected type to avoid silently accepting regressions (e.g., accidental list/None returns that still support in checks).
| # Use all joints of the robot. | ||
| self.active_joint_ids = list(range(robot.dof)) |
There was a problem hiding this comment.
In the default branch (no control_parts / active_joint_ids specified), self.active_joint_ids = list(range(robot.dof)) includes mimic joints. Since Articulation.active_joint_ids already excludes mimic joints and other codepaths explicitly use remove_mimic=True, the default here should likely be robot.active_joint_ids to avoid sending actions to mimic joints / mis-sizing action spaces.
| # Use all joints of the robot. | |
| self.active_joint_ids = list(range(robot.dof)) | |
| # Use all active joints of the robot (excluding mimic joints). | |
| self.active_joint_ids = list(robot.active_joint_ids) |
| @@ -440,12 +449,18 @@ def get_info(self, **kwargs) -> Dict[str, Any]: | |||
| Returns: | |||
| The info dictionary. | |||
| """ | |||
| info = dict(elapsed_steps=self._elapsed_steps) | |||
| info = TensorDict( | |||
| dict(elapsed_steps=self._elapsed_steps), | |||
| batch_size=[self.num_envs], | |||
| device=self.device, | |||
| ) | |||
|
|
|||
| info.update(self.evaluate(**kwargs)) | |||
| evaluate = self.evaluate(**kwargs) | |||
| if evaluate: | |||
| info.update(evaluate) | |||
| return info | |||
There was a problem hiding this comment.
get_info() now returns a TensorDict, but Gymnasium’s Env.step() contract expects a plain dict for info, and common wrappers may attempt to insert arbitrary Python values into it. Consider keeping info as a dict (or converting to dict right before returning from step) to preserve wrapper/interoperability.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 49 out of 49 changed files in this pull request and generated 8 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| from typing import Sequence, Union | ||
| from tensordict import TensorDict | ||
|
|
||
|
|
||
| Array = Union[torch.Tensor, np.ndarray, Sequence] | ||
| Device = Union[str, torch.device] | ||
|
|
||
| EnvObs = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]] | ||
| EnvObs = TensorDict[str, Union[torch.Tensor, TensorDict[str, torch.Tensor]]] | ||
|
|
||
| EnvAction = Union[torch.Tensor, Dict[str, torch.Tensor]] | ||
| EnvAction = Union[torch.Tensor, TensorDict[str, torch.Tensor]] |
There was a problem hiding this comment.
EnvObs = TensorDict[...] and the other TensorDict[...] aliases are evaluated at import time. Unless TensorDict implements __class_getitem__ for those type parameters (version-dependent), this can raise TypeError: 'TensorDict' object is not subscriptable and break imports. Consider using from __future__ import annotations plus plain TensorDict/TensorDictBase, or use typing.TypeAlias with postponed evaluation.
| filter0_mask = torch.isin(self._data_buffer["user_ids"][:, 0], item_user_ids) | ||
| filter1_mask = torch.isin(self._data_buffer["user_ids"][:, 1], item_user_ids) | ||
| if self.cfg.filter_need_both_actor: | ||
| filter_mask = torch.logical_and(filter0_mask, filter1_mask) | ||
| else: | ||
| filter_mask = torch.logical_or(filter0_mask, filter1_mask) | ||
| return { | ||
| "position": self._data_buffer["position"][filter_mask], | ||
| "normal": self._data_buffer["normal"][filter_mask], | ||
| "friction": self._data_buffer["friction"][filter_mask], | ||
| "impulse": self._data_buffer["impulse"][filter_mask], | ||
| "distance": self._data_buffer["distance"][filter_mask], | ||
| "user_ids": self._data_buffer["user_ids"][filter_mask], | ||
| "env_ids": self._data_buffer["env_ids"][filter_mask], | ||
| } | ||
| return self._data_buffer[filter_mask] |
There was a problem hiding this comment.
filter_by_user_ids() computes the mask over the entire preallocated buffer (max_contact_num) instead of only the valid prefix [0:_curr_contact_num), so it can return uninitialized/stale rows that happen to match the user_id filter. Apply the filter on self.get_data() / self._data_buffer[: self._curr_contact_num] so only valid contacts are considered.
| # Sample row indices and chunk start offsets. | ||
| row_sample_idx = torch.randint(0, len(available), (batch_size,)) | ||
| row_indices = available[row_sample_idx] | ||
|
|
||
| max_start = max_steps - chunk_size | ||
| start_indices = torch.randint(0, max_start + 1, (batch_size,)) | ||
|
|
||
| time_offsets = torch.arange(chunk_size) | ||
| time_indices = start_indices[:, None] + time_offsets[None, :] | ||
|
|
||
| result = self.shared_buffer[row_indices[:, None], time_indices] | ||
|
|
There was a problem hiding this comment.
Chunks are sampled uniformly from [0, max_episode_steps), but many create_demo_action_list() implementations generate much shorter rollouts (e.g., 100 steps vs max_episode_steps=600), leaving the remainder of each buffer row as zeros/stale. Without tracking per-episode valid lengths or a mask/done flag in the shared buffer, sample_batch() will frequently return invalid timesteps. Consider storing an episode_len / valid_mask per buffer row and constrain sampling to valid ranges (or pad + return a mask).
| # check for user supplied max_episode_steps during gym.make calls | ||
| if frame.f_code.co_name == "make" and "max_episode_steps" in prev_frame_locals: | ||
| if prev_frame_locals["max_episode_steps"] is not None: | ||
| max_episode_steps = prev_frame_locals["max_episode_steps"] | ||
| # do some wrapper surgery to remove the previous timelimit wrapper | ||
| # with gymnasium 0.29.1, this will remove the timelimit wrapper and nothing else. | ||
| curr_env = env | ||
| while curr_env is not None: | ||
| if isinstance(curr_env, gym.wrappers.TimeLimit): | ||
| self.env = curr_env.env | ||
| break | ||
| self._max_episode_steps = max_episode_steps | ||
| self._max_episode_steps = self.base_env.max_episode_steps | ||
|
|
There was a problem hiding this comment.
TimeLimitWrapper.__init__ computes/overrides the max_episode_steps argument (including honoring a user override in gym.make), but then ignores it by unconditionally setting _max_episode_steps = self.base_env.max_episode_steps. If this wrapper is meant to enforce a caller-specified time limit, _max_episode_steps should be set from the resolved max_episode_steps value (or the wrapper should be removed if BaseEnv now owns the time-limit logic).
| from typing import Dict, Tuple, List, Sequence | ||
| from tensordict import TensorDict | ||
|
|
There was a problem hiding this comment.
TensorDict is imported but never used in this module. Remove the import to avoid linting failures and keep dependencies minimal.
| def get_proprioception(self) -> TensorDict[str, torch.Tensor]: | ||
| """Gets robot proprioception information, primarily for agent state representation in robot learning scenarios. |
There was a problem hiding this comment.
This file does not use from __future__ import annotations, so the return annotation TensorDict[str, torch.Tensor] is evaluated at function definition time. If TensorDict is not subscriptable in the installed tensordict version, this will raise at import time. Either add postponed annotations (from __future__ import annotations) or avoid subscripting TensorDict in runtime annotations.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 59 out of 59 changed files in this pull request and generated 8 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 62 out of 62 changed files in this pull request and generated 10 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| logger.log_debug(f"Initializing episode for env_ids: {env_ids}", color="blue") | ||
| save_data = kwargs.get("save_data", True) | ||
|
|
||
| # Determine which environments to process | ||
| if env_ids is None: | ||
| env_ids_to_process = list(range(self.num_envs)) | ||
| elif isinstance(env_ids, torch.Tensor): | ||
| env_ids_to_process = env_ids.cpu().tolist() | ||
| else: | ||
| env_ids_to_process = list(env_ids) | ||
| env_ids_to_process = list(range(self.num_envs)) if env_ids is None else env_ids | ||
|
|
||
| # Save dataset before clearing buffers for environments that are being reset | ||
| if save_data and self.dataset_manager: | ||
| if "save" in self.dataset_manager.available_modes: | ||
|
|
||
| # Filter to only save successful episodes | ||
| successful_env_ids = [ | ||
| env_id | ||
| for env_id in env_ids_to_process | ||
| if ( | ||
| self.episode_success_status.get(env_id, False) | ||
| or self._task_success[env_id].item() | ||
| ) | ||
| ] | ||
| successful_env_ids = self.episode_success_status | self._task_success | ||
|
|
||
| if successful_env_ids: | ||
| if successful_env_ids.any(): | ||
|
|
||
| # Convert back to tensor if needed | ||
| successful_env_ids_tensor = torch.tensor( | ||
| successful_env_ids, device=self.device | ||
| ) | ||
| self.dataset_manager.apply( | ||
| mode="save", | ||
| env_ids=successful_env_ids_tensor, | ||
| env_ids=successful_env_ids.nonzero(as_tuple=True)[0], | ||
| ) | ||
|
|
||
| # Clear episode buffers and reset success status for environments being reset | ||
| for env_id in env_ids_to_process: | ||
| self.episode_obs_buffer[env_id].clear() | ||
| self.episode_action_buffer[env_id].clear() | ||
| self.episode_success_status[env_id] = False | ||
| if self.rollout_buffer is not None: | ||
| self.current_rollout_step = 0 | ||
|
|
||
| self.episode_success_status[env_ids_to_process] = False |
There was a problem hiding this comment.
_initialize_episode() now computes successful_env_ids = self.episode_success_status | self._task_success across all envs and resets current_rollout_step = 0 globally. When env_ids is a subset (partial reset), this can (1) save episodes for envs that are not being reset and (2) wipe/overwrite rollout data for envs that are still running. Consider tracking per-env rollout step counters and only saving/clearing data for env_ids_to_process.
| class TestOnlineDatasetDynamicChunk(unittest.TestCase): | ||
| """Tests for OnlineDataset with ChunkSizeSampler chunk_size.""" | ||
|
|
||
| def setUp(self) -> None: | ||
| self.engine = _make_fake_engine() | ||
|
|
||
| def test_uniform_sampler_item_mode_shape(self) -> None: | ||
| """Item mode with UniformChunkSampler: batch_size dim is absent, time dim varies.""" | ||
| LOW, HIGH = 5, 15 | ||
| sampler = UniformChunkSampler(low=LOW, high=HIGH) | ||
| dataset = OnlineDataset(self.engine, chunk_size=sampler) | ||
| it = iter(dataset) | ||
| for _ in range(10): | ||
| sample = next(it) | ||
| # batch_size has one element — the chunk dimension. | ||
| self.assertEqual(len(sample.batch_size), 1) | ||
| chunk_dim = sample.batch_size[0] | ||
| self.assertGreaterEqual(chunk_dim, LOW) | ||
| self.assertLessEqual(chunk_dim, HIGH) | ||
|
|
||
| def test_gmm_sampler_item_mode_shape(self) -> None: | ||
| """Item mode with GMMChunkSampler: chunk dim is clamped within [low, high].""" | ||
| LOW, HIGH = 4, 20 | ||
| sampler = GMMChunkSampler( | ||
| means=[8.0, 16.0], stds=[2.0, 2.0], low=LOW, high=HIGH | ||
| ) | ||
| dataset = OnlineDataset(self.engine, chunk_size=sampler) | ||
| it = iter(dataset) | ||
| for _ in range(10): | ||
| sample = next(it) | ||
| chunk_dim = sample.batch_size[0] | ||
| self.assertGreaterEqual(chunk_dim, LOW) | ||
| self.assertLessEqual(chunk_dim, HIGH) | ||
|
|
||
| def test_uniform_sampler_batch_mode_shape(self) -> None: | ||
| """Batch mode: per-batch chunk size is consistent across all trajectories.""" | ||
| BATCH = 3 | ||
| LOW, HIGH = 5, 15 | ||
| sampler = UniformChunkSampler(low=LOW, high=HIGH) | ||
| dataset = OnlineDataset(self.engine, chunk_size=sampler, batch_size=BATCH) | ||
| it = iter(dataset) | ||
| for _ in range(10): | ||
| batch = next(it) | ||
| self.assertEqual(len(batch.batch_size), 2) | ||
| self.assertEqual(batch.batch_size[0], BATCH) | ||
| chunk_dim = batch.batch_size[1] | ||
| self.assertGreaterEqual(chunk_dim, LOW) | ||
| self.assertLessEqual(chunk_dim, HIGH) | ||
|
|
||
| def test_dynamic_chunk_sizes_vary(self) -> None: | ||
| """Consecutive samples from a uniform sampler produce different chunk sizes.""" | ||
| LOW, HIGH = 5, 30 | ||
| sampler = UniformChunkSampler(low=LOW, high=HIGH) | ||
| dataset = OnlineDataset(self.engine, chunk_size=sampler) | ||
| it = iter(dataset) | ||
| sizes = {next(it).batch_size[0] for _ in range(50)} | ||
| # With a range of 26 values, drawing 50 times should yield > 1 unique size. | ||
| assert ( | ||
| len(sizes) >= 1 | ||
| ), "Expected multiple unique chunk sizes from uniform sampler" | ||
|
|
||
| def test_invalid_chunk_size_type_raises(self) -> None: | ||
| """TypeError when chunk_size is not an int or ChunkSizeSampler.""" | ||
| with self.assertRaises(TypeError): | ||
| OnlineDataset(self.engine, chunk_size="large") # type: ignore[arg-type] | ||
|
|
||
| def test_invalid_chunk_size_int_raises(self) -> None: | ||
| """ValueError when chunk_size is an int < 1.""" | ||
| with self.assertRaises(ValueError): | ||
| OnlineDataset(self.engine, chunk_size=0) | ||
|
|
||
| def test_custom_sampler_subclass(self) -> None: | ||
| """A user-defined ChunkSizeSampler subclass is accepted and called.""" | ||
|
|
||
| class FixedSampler(ChunkSizeSampler): | ||
| def __call__(self) -> int: | ||
| return 7 | ||
|
|
||
| dataset = OnlineDataset(self.engine, chunk_size=FixedSampler()) | ||
| sample = next(iter(dataset)) | ||
| self.assertEqual(sample.batch_size[0], 7) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
There was a problem hiding this comment.
This new test file uses unittest.TestCase classes and unittest.main(), but the repo’s testing conventions document pytest function style or plain class TestX: with setup_method/teardown_method instead (see CLAUDE.md). Consider rewriting these to the documented style for consistency.
| obs = TensorDict( | ||
| dict(robot=self.robot.get_proprioception()[:, self.active_joint_ids]), | ||
| batch_size=[self.num_envs], | ||
| device=self.device, | ||
| ) |
There was a problem hiding this comment.
get_obs() indexes the TensorDict returned by robot.get_proprioception() as td[:, self.active_joint_ids], but that TensorDict has batch_size [num_envs] (1D). This tuple-indexing is likely invalid and will raise at runtime; instead, slice each proprioception tensor by joint ids (e.g. qpos/qvel/qf) and keep the TensorDict batch_size as [num_envs].
| # The robot agent instance. | ||
| robot: Robot = None | ||
|
|
||
| active_joint_ids: List[int] = [] | ||
|
|
||
| # The sensors used in the environment. | ||
| sensors: Dict[str, BaseSensor] = {} |
There was a problem hiding this comment.
active_joint_ids is declared as a mutable class attribute ([]) on BaseEnv, which will be shared across instances. This can leak joint ids between environments/tests; initialize it per-instance (e.g. self.active_joint_ids = [] in __init__) instead.
embodichain/agents/engine/data.py
Outdated
| # Edge case: the entire valid region is locked. Fall back to | ||
| # sampling from all valid rows to avoid a hard failure. | ||
| log_error( | ||
| "[OnlineDataEngine] All valid buffer rows are currently locked. " | ||
| "Cannot sample a batch at this time.", |
There was a problem hiding this comment.
The comment says it will "fall back to sampling from all valid rows", but log_error(..., RuntimeError) will always raise here (no fallback). Either implement an actual fallback strategy or update the comment/message to reflect that sampling fails when all rows are locked.
| # Edge case: the entire valid region is locked. Fall back to | |
| # sampling from all valid rows to avoid a hard failure. | |
| log_error( | |
| "[OnlineDataEngine] All valid buffer rows are currently locked. " | |
| "Cannot sample a batch at this time.", | |
| # Edge case: the entire valid region is locked. Sampling a batch | |
| # is not possible in this state and will result in a hard failure. | |
| log_error( | |
| "[OnlineDataEngine] All valid buffer rows are currently locked. " | |
| "Cannot sample a batch at this time; sampling fails because no " | |
| "unlocked rows are available.", |
docs/source/tutorial/modular_env.rst
Outdated
| .. literalinclude:: ../../../scripts/tutorials/gym/modular_env.py | ||
| :language: python | ||
| :start-at: @register_env("ModularEnv-v1", max_episode_steps=100, override=True) | ||
| :start-at: @register_env("ModularEnv-v1" override=True) |
There was a problem hiding this comment.
The :start-at: string is missing a comma: it currently won’t match the actual decorator line in modular_env.py, so the literalinclude will fail. Update it to @register_env("ModularEnv-v1", override=True).
| :start-at: @register_env("ModularEnv-v1" override=True) | |
| :start-at: @register_env("ModularEnv-v1", override=True) |
| ) | ||
| # TODO: Use a action manager to handle the action space consistency with RL. | ||
| if isinstance(action, TensorDict): | ||
| action_to_store = action["qpos"] | ||
| elif isinstance(action, torch.Tensor): | ||
| action_to_store = action | ||
| self.rollout_buffer["actions"][:, self.current_rollout_step, ...].copy_( | ||
| action_to_store.to(buffer_device), non_blocking=True | ||
| ) |
There was a problem hiding this comment.
In _hook_after_sim_step, action_to_store is only set for TensorDict actions when the key 'qpos' exists. If the action TensorDict contains only 'qvel'/'qf' (supported by _step_action), this will raise KeyError/UnboundLocalError during rollout recording. Handle the other supported keys or choose a consistent action representation to store.
| # Draw many small batches and collect all sampled row indices. | ||
| # We cannot directly observe row indices from outside, but we can | ||
| # verify that each result slice is *not* identical to a locked row's | ||
| # data (which has a unique random fingerprint). | ||
| locked_obs = engine.shared_buffer["obs"][LOCK_START:LOCK_END] # [3, 50, 10] | ||
|
|
||
| for _ in range(20): | ||
| result = engine.sample_batch(batch_size=1, chunk_size=5) | ||
| sampled_obs_start = result["obs"][0, 0] # first timestep of first chunk | ||
| # Check that this does not exactly match any locked row's first timestep. | ||
| for r in range(LOCK_END - LOCK_START): | ||
| matched = torch.allclose( | ||
| sampled_obs_start, locked_obs[r, :5].mean(dim=-1, keepdim=True) | ||
| ) | ||
| # The comparison above is a heuristic; the real guarantee is that | ||
| # available rows exclude locked ones. We use a direct index check: | ||
| # reconstruct which row could produce this exact obs by brute-force. | ||
| # Reconstructed check: verify available indices exclude locked rows. | ||
| all_rows = torch.arange(BUFFER_SIZE) | ||
| is_locked = (all_rows >= LOCK_START) & (all_rows < LOCK_END) | ||
| available = all_rows[~is_locked] | ||
| assert len(available) != 0, "available must be non-empty" | ||
| for row in locked_rows: | ||
| assert row not in available.tolist() | ||
|
|
There was a problem hiding this comment.
test_sample_batch_locks_respected does not actually validate that sample_batch() avoids locked rows. The assertions only prove that available = all_rows[~is_locked] excludes the locked range (a tautology) and the matched result is never checked. Consider instrumenting sample_batch to expose chosen row indices in tests or compare sampled chunks directly against locked rows’ data to ensure none match.
Description
TODO:
Type of change
Checklist
black .command to format the code base.