Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions embodichain/lab/gym/envs/managers/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,22 +358,19 @@ def compute_semantic_mask(
else:
mask = obs["sensor"][entity_cfg.uid]["mask"]

left_robot_uids = torch.cat(
[
env.robot.get_user_ids(link_name)
for link_name in env.robot.link_names
if link_name.startswith("left_")
],
-1,
)
right_robot_uids = torch.cat(
[
env.robot.get_user_ids(link_name)
for link_name in env.robot.link_names
if link_name.startswith("right_")
],
-1,
)
left_link_indices = [
i
for i, link_name in enumerate(env.robot.link_names)
if link_name.startswith("left_")
]
left_robot_uids = env.robot.user_ids[:, left_link_indices]
Comment on lines +361 to +366
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The updated left/right user ID indexing fixes a per-environment shape issue, but there is no unit test covering compute_semantic_mask. Please add a test (in tests/gym/envs/managers/test_observation_functors.py, where other observation functors are tested) that uses >1 env with distinct robot.user_ids per env and verifies the resulting left/right robot mask channels are computed per-env (not flattened/concatenated across envs).

Copilot uses AI. Check for mistakes.

right_link_indices = [
i
for i, link_name in enumerate(env.robot.link_names)
if link_name.startswith("right_")
]
right_robot_uids = env.robot.user_ids[:, right_link_indices]

mask_exp = mask.unsqueeze(-1)

Expand Down
146 changes: 146 additions & 0 deletions tests/gym/envs/managers/test_observation_functors.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,16 @@ def __init__(self, num_envs: int = 4, num_joints: int = 6):
self.num_joints = num_joints
self.device = torch.device("cpu")
self.joint_names = [f"joint_{i}" for i in range(num_joints)]
self.link_names = [
"base",
"left_shoulder",
"left_elbow",
"right_shoulder",
"right_elbow",
]
self._qpos = torch.zeros(num_envs, num_joints)
self._qvel = torch.zeros(num_envs, num_joints)
self.user_ids = torch.tensor([[0, 1, 2, 3, 4]] * num_envs, dtype=torch.int32)

# Mock body_data
self.body_data = Mock()
Expand Down Expand Up @@ -757,3 +765,141 @@ def test_reset_clears_cache(self):
# Should fetch again
functor(env, obs, entity_cfg=MagicMock(uid="robot"))
assert env.sim._robots["robot"].get_joint_drive.call_count == 2


class TestComputeSemanticMask:
"""Tests for compute_semantic_mask functor."""

# Layout of the synthetic 4x4 mask used in tests:
# [[ 1, 2, 3, 4],
# [ 0, 0, 10, 10],
# [ 1, 0, 3, 10],
# [ 0, 0, 0, 0]]
#
# Robot user_ids = [[0, 1, 2, 3, 4]] per env
# left_link_indices = [1, 2] → uids 1, 2
# right_link_indices = [3, 4] → uids 3, 4
# Foreground object uid = 10

def _make_env_and_obs(self, num_envs: int = 2):
"""Create a mock env with a synthetic 4x4 semantic mask."""
env = MockEnv(num_envs=num_envs)

# Create a foreground object with user_id 10
fg_obj = MockRigidObject("fg_object", num_envs)
fg_obj.get_user_ids = lambda: torch.full((num_envs,), 10, dtype=torch.int32)
env.sim.add_rigid_object(fg_obj)

single = torch.tensor(
[[1, 2, 3, 4], [0, 0, 10, 10], [1, 0, 3, 10], [0, 0, 0, 0]],
dtype=torch.int32,
)
mask = single.unsqueeze(0).repeat(num_envs, 1, 1)

obs = {"sensor": {"camera": {"mask": mask}}}
return env, obs

def test_returns_correct_shape(self):
"""Test that compute_semantic_mask returns shape (B, H, W, 4)."""
env, obs = self._make_env_and_obs(num_envs=2)

result = compute_semantic_mask(
env,
obs,
entity_cfg=MagicMock(uid="camera"),
foreground_uids=["fg_object"],
)

assert result.shape == (2, 4, 4, 4)

def test_correct_channel_assignment(self):
"""Test that each semantic channel has 1s only in expected pixels."""
env, obs = self._make_env_and_obs(num_envs=2)

result = compute_semantic_mask(
env,
obs,
entity_cfg=MagicMock(uid="camera"),
foreground_uids=["fg_object"],
)

# SemanticMask channels: BACKGROUND=0, FOREGROUND=1, ROBOT_LEFT=2, ROBOT_RIGHT=3
bg = result[0, :, :, 0]
fg = result[0, :, :, 1]
left = result[0, :, :, 2]
right = result[0, :, :, 3]

# Left robot (uids 1, 2): pixels (0,0), (0,1), (2,0)
assert left[0, 0] == 1
assert left[0, 1] == 1
assert left[2, 0] == 1
assert left.sum().item() == 3

# Right robot (uids 3, 4): pixels (0,2), (0,3), (2,2)
assert right[0, 2] == 1
assert right[0, 3] == 1
assert right[2, 2] == 1
assert right.sum().item() == 3

# Foreground (uid 10): pixels (1,2), (1,3), (2,3)
assert fg[1, 2] == 1
assert fg[1, 3] == 1
assert fg[2, 3] == 1
assert fg.sum().item() == 3

# Background: 7 pixels with uid 0
assert bg.sum().item() == 7

def test_background_is_negation_of_foreground_and_robot(self):
"""Test that background == ~(left | right | foreground)."""
env, obs = self._make_env_and_obs(num_envs=2)

result = compute_semantic_mask(
env,
obs,
entity_cfg=MagicMock(uid="camera"),
foreground_uids=["fg_object"],
)

bg = result[0, :, :, 0]
fg = result[0, :, :, 1]
left = result[0, :, :, 2]
right = result[0, :, :, 3]

expected_bg = ~(left.bool() | right.bool() | fg.bool())
torch.testing.assert_close(bg.bool(), expected_bg)

def test_with_foreground_uids_not_in_assets(self):
"""Test that foreground UIDs not in asset_uids are silently ignored."""
env, obs = self._make_env_and_obs(num_envs=2)

result = compute_semantic_mask(
env,
obs,
entity_cfg=MagicMock(uid="camera"),
foreground_uids=["fg_object", "nonexistent_object"],
)

# Foreground should still only match uid 10 (from fg_object)
fg = result[0, :, :, 1]
assert fg[1, 2] == 1
assert fg[1, 3] == 1
assert fg[2, 3] == 1
assert fg.sum().item() == 3

def test_different_num_envs(self):
"""Test that compute_semantic_mask works with different batch sizes."""
num_envs = 5
env, obs = self._make_env_and_obs(num_envs=num_envs)

result = compute_semantic_mask(
env,
obs,
entity_cfg=MagicMock(uid="camera"),
foreground_uids=["fg_object"],
)

assert result.shape == (num_envs, 4, 4, 4)
# All envs have identical mask data, so results should match
for i in range(1, num_envs):
assert result[i].equal(result[0])
Loading