diff --git a/embodichain/lab/gym/envs/managers/observations.py b/embodichain/lab/gym/envs/managers/observations.py index a71e2827..50724cea 100644 --- a/embodichain/lab/gym/envs/managers/observations.py +++ b/embodichain/lab/gym/envs/managers/observations.py @@ -358,22 +358,19 @@ def compute_semantic_mask( else: mask = obs["sensor"][entity_cfg.uid]["mask"] - left_robot_uids = torch.cat( - [ - env.robot.get_user_ids(link_name) - for link_name in env.robot.link_names - if link_name.startswith("left_") - ], - -1, - ) - right_robot_uids = torch.cat( - [ - env.robot.get_user_ids(link_name) - for link_name in env.robot.link_names - if link_name.startswith("right_") - ], - -1, - ) + left_link_indices = [ + i + for i, link_name in enumerate(env.robot.link_names) + if link_name.startswith("left_") + ] + left_robot_uids = env.robot.user_ids[:, left_link_indices] + + right_link_indices = [ + i + for i, link_name in enumerate(env.robot.link_names) + if link_name.startswith("right_") + ] + right_robot_uids = env.robot.user_ids[:, right_link_indices] mask_exp = mask.unsqueeze(-1) diff --git a/tests/gym/envs/managers/test_observation_functors.py b/tests/gym/envs/managers/test_observation_functors.py index 3dc08a28..a9238d90 100644 --- a/tests/gym/envs/managers/test_observation_functors.py +++ b/tests/gym/envs/managers/test_observation_functors.py @@ -31,8 +31,16 @@ def __init__(self, num_envs: int = 4, num_joints: int = 6): self.num_joints = num_joints self.device = torch.device("cpu") self.joint_names = [f"joint_{i}" for i in range(num_joints)] + self.link_names = [ + "base", + "left_shoulder", + "left_elbow", + "right_shoulder", + "right_elbow", + ] self._qpos = torch.zeros(num_envs, num_joints) self._qvel = torch.zeros(num_envs, num_joints) + self.user_ids = torch.tensor([[0, 1, 2, 3, 4]] * num_envs, dtype=torch.int32) # Mock body_data self.body_data = Mock() @@ -757,3 +765,141 @@ def test_reset_clears_cache(self): # Should fetch again functor(env, obs, entity_cfg=MagicMock(uid="robot")) assert env.sim._robots["robot"].get_joint_drive.call_count == 2 + + +class TestComputeSemanticMask: + """Tests for compute_semantic_mask functor.""" + + # Layout of the synthetic 4x4 mask used in tests: + # [[ 1, 2, 3, 4], + # [ 0, 0, 10, 10], + # [ 1, 0, 3, 10], + # [ 0, 0, 0, 0]] + # + # Robot user_ids = [[0, 1, 2, 3, 4]] per env + # left_link_indices = [1, 2] → uids 1, 2 + # right_link_indices = [3, 4] → uids 3, 4 + # Foreground object uid = 10 + + def _make_env_and_obs(self, num_envs: int = 2): + """Create a mock env with a synthetic 4x4 semantic mask.""" + env = MockEnv(num_envs=num_envs) + + # Create a foreground object with user_id 10 + fg_obj = MockRigidObject("fg_object", num_envs) + fg_obj.get_user_ids = lambda: torch.full((num_envs,), 10, dtype=torch.int32) + env.sim.add_rigid_object(fg_obj) + + single = torch.tensor( + [[1, 2, 3, 4], [0, 0, 10, 10], [1, 0, 3, 10], [0, 0, 0, 0]], + dtype=torch.int32, + ) + mask = single.unsqueeze(0).repeat(num_envs, 1, 1) + + obs = {"sensor": {"camera": {"mask": mask}}} + return env, obs + + def test_returns_correct_shape(self): + """Test that compute_semantic_mask returns shape (B, H, W, 4).""" + env, obs = self._make_env_and_obs(num_envs=2) + + result = compute_semantic_mask( + env, + obs, + entity_cfg=MagicMock(uid="camera"), + foreground_uids=["fg_object"], + ) + + assert result.shape == (2, 4, 4, 4) + + def test_correct_channel_assignment(self): + """Test that each semantic channel has 1s only in expected pixels.""" + env, obs = self._make_env_and_obs(num_envs=2) + + result = compute_semantic_mask( + env, + obs, + entity_cfg=MagicMock(uid="camera"), + foreground_uids=["fg_object"], + ) + + # SemanticMask channels: BACKGROUND=0, FOREGROUND=1, ROBOT_LEFT=2, ROBOT_RIGHT=3 + bg = result[0, :, :, 0] + fg = result[0, :, :, 1] + left = result[0, :, :, 2] + right = result[0, :, :, 3] + + # Left robot (uids 1, 2): pixels (0,0), (0,1), (2,0) + assert left[0, 0] == 1 + assert left[0, 1] == 1 + assert left[2, 0] == 1 + assert left.sum().item() == 3 + + # Right robot (uids 3, 4): pixels (0,2), (0,3), (2,2) + assert right[0, 2] == 1 + assert right[0, 3] == 1 + assert right[2, 2] == 1 + assert right.sum().item() == 3 + + # Foreground (uid 10): pixels (1,2), (1,3), (2,3) + assert fg[1, 2] == 1 + assert fg[1, 3] == 1 + assert fg[2, 3] == 1 + assert fg.sum().item() == 3 + + # Background: 7 pixels with uid 0 + assert bg.sum().item() == 7 + + def test_background_is_negation_of_foreground_and_robot(self): + """Test that background == ~(left | right | foreground).""" + env, obs = self._make_env_and_obs(num_envs=2) + + result = compute_semantic_mask( + env, + obs, + entity_cfg=MagicMock(uid="camera"), + foreground_uids=["fg_object"], + ) + + bg = result[0, :, :, 0] + fg = result[0, :, :, 1] + left = result[0, :, :, 2] + right = result[0, :, :, 3] + + expected_bg = ~(left.bool() | right.bool() | fg.bool()) + torch.testing.assert_close(bg.bool(), expected_bg) + + def test_with_foreground_uids_not_in_assets(self): + """Test that foreground UIDs not in asset_uids are silently ignored.""" + env, obs = self._make_env_and_obs(num_envs=2) + + result = compute_semantic_mask( + env, + obs, + entity_cfg=MagicMock(uid="camera"), + foreground_uids=["fg_object", "nonexistent_object"], + ) + + # Foreground should still only match uid 10 (from fg_object) + fg = result[0, :, :, 1] + assert fg[1, 2] == 1 + assert fg[1, 3] == 1 + assert fg[2, 3] == 1 + assert fg.sum().item() == 3 + + def test_different_num_envs(self): + """Test that compute_semantic_mask works with different batch sizes.""" + num_envs = 5 + env, obs = self._make_env_and_obs(num_envs=num_envs) + + result = compute_semantic_mask( + env, + obs, + entity_cfg=MagicMock(uid="camera"), + foreground_uids=["fg_object"], + ) + + assert result.shape == (num_envs, 4, 4, 4) + # All envs have identical mask data, so results should match + for i in range(1, num_envs): + assert result[i].equal(result[0])