Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions embodichain/lab/gym/envs/embodied_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,16 +644,20 @@ def _step_action(self, action: EnvAction) -> EnvAction:
# Support multiple control modes simultaneously
if "qpos" in action:
self.robot.set_qpos(
qpos=action["qpos"], joint_ids=self.active_joint_ids
qpos=action["qpos"].to(self.device), joint_ids=self.active_joint_ids
)
if "qvel" in action:
self.robot.set_qvel(
qvel=action["qvel"], joint_ids=self.active_joint_ids
qvel=action["qvel"].to(self.device), joint_ids=self.active_joint_ids
)
if "qf" in action:
self.robot.set_qf(qf=action["qf"], joint_ids=self.active_joint_ids)
self.robot.set_qf(
qf=action["qf"].to(self.device), joint_ids=self.active_joint_ids
)
elif isinstance(action, torch.Tensor):
self.robot.set_qpos(qpos=action, joint_ids=self.active_joint_ids)
self.robot.set_qpos(
qpos=action.to(self.device), joint_ids=self.active_joint_ids
)
else:
logger.log_error(f"Unsupported action type: {type(action)}")

Expand Down
8 changes: 5 additions & 3 deletions embodichain/lab/sim/objects/robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,8 @@ def compute_fk(
f"Joint positions shape mismatch. Expected {solver.dof} joints, got {qpos.shape[1]}."
)

result_matrix = solver.get_fk(qpos=qpos)
qpos_ = qpos.to(self.device)
result_matrix = solver.get_fk(qpos=qpos_)

base_pose = self.get_link_pose(
link_name=solver.root_link_name, env_ids=local_env_ids, to_matrix=True
Expand Down Expand Up @@ -633,8 +634,9 @@ def compute_batch_fk(
f"Joint positions shape mismatch. Expected {solver.dof} joints, got {qpos.shape[1]}."
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In compute_batch_fk, the shape check uses qpos.shape[2] != solver.dof but the error message reports qpos.shape[1]. This will misreport the number of joints; it should report the last dimension (shape[2]) to match the validation.

Suggested change
f"Joint positions shape mismatch. Expected {solver.dof} joints, got {qpos.shape[1]}."
f"Joint positions shape mismatch. Expected {solver.dof} joints, got {qpos.shape[2]}."

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In compute_batch_fk, the dof check compares qpos.shape[2] to solver.dof, but the error message reports got {qpos.shape[1]} (batch dimension) instead of the joint dimension. This can be misleading when debugging shape issues; the message should report qpos.shape[2] (or the actual joint-axis size being validated).

Suggested change
f"Joint positions shape mismatch. Expected {solver.dof} joints, got {qpos.shape[1]}."
f"Joint positions shape mismatch. Expected {solver.dof} joints, got {qpos.shape[2]}."

Copilot uses AI. Check for mistakes.
)

n_batch = qpos.shape[1]
qpos_batch = qpos.reshape(-1, solver.dof)
qpos_ = qpos.to(self.device)
n_batch = qpos_.shape[1]
qpos_batch = qpos_.reshape(-1, solver.dof)
xpos_batch = solver.get_fk(qpos=qpos_batch)

# get xpos from link root
Expand Down
Loading