diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py index f6389ff8..658f4f88 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py @@ -583,7 +583,7 @@ def get_grasp_poses( approach_direction: torch.Tensor, visualize_collision: bool = False, visualize_pose: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[bool, torch.Tensor, float]: """Get grasp pose given approach direction. Uses the antipodal point pairs stored in ``self._hit_point_pairs`` @@ -603,19 +603,20 @@ def get_grasp_poses( after computation. Returns: - A tuple ``(best_grasp_pose, best_open_length)`` where - ``best_grasp_pose`` is a ``(4, 4)`` homogeneous matrix and - ``best_open_length`` is a scalar. + is_success (bool): Whether a valid grasp pose is found. + best_grasp_pose (torch.Tensor): If a valid grasp pose is found, a tensor of shape (4, 4) representing the homogeneous transformation matrix of the best grasp pose in the world frame. Otherwise, an identity matrix. + best_open_length (float): If a valid grasp pose is found, a scalar representing the optimal gripper opening length. Otherwise, a zero tensor. Raises: RuntimeError: If :meth:`generate` or :meth:`annotate` has not been called yet. """ if self._hit_point_pairs is None: - raise RuntimeError( + logger.log_warning( "No antipodal point pairs available. " "Call generate() or annotate() first." ) + return False, torch.eye(4, device=self.device), 0.0 origin_points = self._hit_point_pairs[:, 0, :] hit_points = self._hit_point_pairs[:, 1, :] origin_points_ = self._apply_transform(origin_points, object_pose) @@ -632,6 +633,10 @@ def get_grasp_poses( valid_mask = ( positive_angle - torch.pi / 2 ).abs() <= self.cfg.max_deviation_angle + if valid_mask.sum() == 0: + logger.log_warning("No valid antipodal pairs after angle filtering.") + return False, torch.eye(4, device=self.device), 0.0 + valid_grasp_x = grasp_x[valid_mask] valid_centers = centers[valid_mask] @@ -650,6 +655,9 @@ def get_grasp_poses( is_visual=visualize_collision, collision_threshold=0.0, ) + if is_colliding.logical_not().sum() == 0: + logger.log_warning("No valid antipodal pairs after angle filtering.") + return False, torch.eye(4, device=self.device), 0.0 # get best grasp pose valid_grasp_poses = valid_grasp_poses[~is_colliding] valid_open_lengths = valid_open_lengths[~is_colliding] @@ -674,7 +682,7 @@ def get_grasp_poses( grasp_pose=best_grasp_pose, open_length=best_open_length.item(), ) - return best_grasp_pose, best_open_length + return True, best_grasp_pose, best_open_length @staticmethod def _grasp_pose_from_approach_direction( diff --git a/scripts/tutorials/grasp/grasp_generator.py b/scripts/tutorials/grasp/grasp_generator.py index bab09c03..16143215 100644 --- a/scripts/tutorials/grasp/grasp_generator.py +++ b/scripts/tutorials/grasp/grasp_generator.py @@ -271,11 +271,20 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso ) obj_poses = mug.get_local_pose(to_matrix=True) grasp_xpos_list = [] - for obj_pose in obj_poses: - grasp_pose, _ = grasp_generator.get_grasp_poses( + + rest_xpos = robot.compute_fk( + qpos=robot.get_qpos("arm"), name="arm", to_matrix=True + )[0] + for i, obj_pose in enumerate(obj_poses): + is_success, grasp_pose, open_length = grasp_generator.get_grasp_poses( obj_pose, approach_direction, visualize_pose=False ) - grasp_xpos_list.append(grasp_pose.unsqueeze(0)) + if is_success: + grasp_xpos_list.append(grasp_pose.unsqueeze(0)) + else: + logger.log_warning(f"No valid grasp pose found for {i}-th object.") + grasp_xpos_list.append(rest_xpos.unsqueeze(0)) + grasp_xpos = torch.cat(grasp_xpos_list, dim=0) cost_time = time.time() - start_time logger.log_info(f"Get grasp pose cost time: {cost_time:.2f} seconds")