From fa1e2b14f42ce992e0ab60c658335da81ce6be84 Mon Sep 17 00:00:00 2001 From: chenjian Date: Wed, 8 Apr 2026 12:16:48 +0800 Subject: [PATCH 1/3] update --- embodichain/lab/sim/objects/robot.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/embodichain/lab/sim/objects/robot.py b/embodichain/lab/sim/objects/robot.py index fc38dae6..e959f9aa 100644 --- a/embodichain/lab/sim/objects/robot.py +++ b/embodichain/lab/sim/objects/robot.py @@ -498,7 +498,12 @@ def compute_fk( f"Joint positions shape mismatch. Expected {solver.dof} joints, got {qpos.shape[1]}." ) - result_matrix = solver.get_fk(qpos=qpos) + if qpos.device != self.device: + qpos_ = qpos.to(self.device) + else: + qpos_ = qpos + + result_matrix = solver.get_fk(qpos=qpos_) base_pose = self.get_link_pose( link_name=solver.root_link_name, env_ids=local_env_ids, to_matrix=True @@ -633,8 +638,13 @@ def compute_batch_fk( f"Joint positions shape mismatch. Expected {solver.dof} joints, got {qpos.shape[1]}." ) - n_batch = qpos.shape[1] - qpos_batch = qpos.reshape(-1, solver.dof) + if qpos.device != self.device: + qpos_ = qpos.to(self.device) + else: + qpos_ = qpos + + n_batch = qpos_.shape[1] + qpos_batch = qpos_.reshape(-1, solver.dof) xpos_batch = solver.get_fk(qpos=qpos_batch) # get xpos from link root From 8fcf4e409b5f4b07340d27cb4a39acff7723d472 Mon Sep 17 00:00:00 2001 From: chenjian Date: Wed, 8 Apr 2026 14:39:42 +0800 Subject: [PATCH 2/3] update --- embodichain/lab/gym/envs/embodied_env.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index da3ae9b7..3e699620 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -644,16 +644,20 @@ def _step_action(self, action: EnvAction) -> EnvAction: # Support multiple control modes simultaneously if "qpos" in action: self.robot.set_qpos( - qpos=action["qpos"], joint_ids=self.active_joint_ids + qpos=action["qpos"].to(self.device), joint_ids=self.active_joint_ids ) if "qvel" in action: self.robot.set_qvel( - qvel=action["qvel"], joint_ids=self.active_joint_ids + qvel=action["qvel"].to(self.device), joint_ids=self.active_joint_ids ) if "qf" in action: - self.robot.set_qf(qf=action["qf"], joint_ids=self.active_joint_ids) + self.robot.set_qf( + qf=action["qf"].to(self.device), joint_ids=self.active_joint_ids + ) elif isinstance(action, torch.Tensor): - self.robot.set_qpos(qpos=action, joint_ids=self.active_joint_ids) + self.robot.set_qpos( + qpos=action.to(self.device), joint_ids=self.active_joint_ids + ) else: logger.log_error(f"Unsupported action type: {type(action)}") From eec6b41c565f04e1a09ad0b0ed7ea9c5dbd70d78 Mon Sep 17 00:00:00 2001 From: chenjian Date: Wed, 8 Apr 2026 14:41:38 +0800 Subject: [PATCH 3/3] update --- embodichain/lab/sim/objects/robot.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/embodichain/lab/sim/objects/robot.py b/embodichain/lab/sim/objects/robot.py index e959f9aa..e6dac158 100644 --- a/embodichain/lab/sim/objects/robot.py +++ b/embodichain/lab/sim/objects/robot.py @@ -498,11 +498,7 @@ def compute_fk( f"Joint positions shape mismatch. Expected {solver.dof} joints, got {qpos.shape[1]}." ) - if qpos.device != self.device: - qpos_ = qpos.to(self.device) - else: - qpos_ = qpos - + qpos_ = qpos.to(self.device) result_matrix = solver.get_fk(qpos=qpos_) base_pose = self.get_link_pose( @@ -638,11 +634,7 @@ def compute_batch_fk( f"Joint positions shape mismatch. Expected {solver.dof} joints, got {qpos.shape[1]}." ) - if qpos.device != self.device: - qpos_ = qpos.to(self.device) - else: - qpos_ = qpos - + qpos_ = qpos.to(self.device) n_batch = qpos_.shape[1] qpos_batch = qpos_.reshape(-1, solver.dof) xpos_batch = solver.get_fk(qpos=qpos_batch)