In [15]:
import argparse

import gymnasium as gym
import numpy as np

from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.utils.wrappers import RecordEpisode
from mani_skill.utils.structs.pose import Pose
from mani_skill.utils.geometry import rotation_conversions as rot_utils

In [29]:
# Args
env_id = "TestBench-v1"

num_envs = 1
seed = 10
quiet = False
render_mode = "human"
parallel_in_single_scene = False
obs_mode = "none"
reward_mode = None
control_mode = "pd_ee_target_delta_pose"
shader_dir = "default"
sim_backend = "auto"

pause = False

In [30]:
np.set_printoptions(suppress=True, precision=3)
verbose = True

env: BaseEnv = gym.make(
        env_id,
        obs_mode=obs_mode,
        reward_mode=reward_mode,
        control_mode=control_mode,
        render_mode=render_mode,
        shader_dir=shader_dir,
        num_envs=num_envs,
        sim_backend=sim_backend,
        parallel_in_single_scene=parallel_in_single_scene,
    )

In [31]:
print("Observation space", env.observation_space)
print("Action space", env.action_space)
print("Control mode", env.unwrapped.control_mode)
print("Reward mode", env.unwrapped.reward_mode)

Observation space Dict()
Action space Box(-1.0, 1.0, (6,), float32)
Control mode pd_ee_target_delta_pose
Reward mode normalized_dense


In [32]:
obs, _ = env.reset(seed=seed)
env.action_space.seed(seed)

[10]

In [33]:
if render_mode is not None:
    viewer = env.render()
    viewer.paused = pause
    env.render()

In [34]:
print(env.agent.controller.controllers['arm'].ee_pose)
print(env.agent.controller.controllers['arm'].ee_pose_at_base)
print(env.agent.controller.controllers['arm']._target_pose)


start_pose = env.agent.controller.controllers['arm'].ee_pose_at_base
BASE_pose = env.agent.robot.get_links()[0].pose

print(BASE_pose)

Pose(raw_pose=tensor([[-0.3423,  0.0039,  0.4254,  0.4899,  0.5024,  0.5114,  0.4960]]))
Pose(raw_pose=tensor([[0.4577, 0.0039, 0.4254, 0.4899, 0.5024, 0.5114, 0.4960]]))
Pose(raw_pose=tensor([[0.4577, 0.0039, 0.4254, 0.4899, 0.5024, 0.5114, 0.4960]]))
Pose(raw_pose=tensor([[-8.0000e-01, -1.3097e-10,  7.4506e-09,  1.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00]]))


In [22]:
def interpolate_pose(pose1, pose2, alpha: float):
    """
    Interpolates between two Pose objects given an interpolation factor alpha (0 <= alpha <= 1).
    
    :param pose1: The starting Pose object.
    :param pose2: The ending Pose object.
    :param alpha: Interpolation factor (0 = pose1, 1 = pose2).
    :return: A new Pose object that is the interpolated result.
    """
    # Linearly interpolate positions
    interp_position = (1 - alpha) * pose1.p + alpha * pose2.p
    
    # Perform SLERP for quaternion interpolation
    dot_product = torch.sum(pose1.q * pose2.q, dim=-1, keepdim=True)
    if torch.all(dot_product < 0):
        pose2.q = -pose2.q
        dot_product = -dot_product

    dot_product = torch.clamp(dot_product, -1.0, 1.0)

    theta_0 = torch.acos(dot_product)
    sin_theta_0 = torch.sin(theta_0)

    if torch.all(sin_theta_0 > 1e-6):
        sin_theta = torch.sin(alpha * theta_0)
        sin_theta_1 = torch.sin((1 - alpha) * theta_0)
        interp_quat = (sin_theta_1 / sin_theta_0) * pose1.q + (sin_theta / sin_theta_0) * pose2.q
    else:
        interp_quat = pose1.q

    # Normalize the resulting quaternion to avoid numerical drift
    interp_quat = torch.nn.functional.normalize(interp_quat)

    # Create and return the new interpolated Pose
    return Pose.create_from_pq(interp_position, interp_quat)


def generate_trajectory(poses, durations, control_frequency):
    """
    Generates a smooth trajectory to control the end-effector.
    
    :param poses: List of poses, where each pose is [x, y, z, qw, qx, qy, qz]
    :param durations: List of durations between consecutive poses
    :param control_frequency: Scalar frequency to determine the number of internal poses per second
    :return: List of interpolated poses (smooth trajectory)
    """
    trajectory = []
    num_segments = len(poses) - 1

    for i in range(num_segments):
        start_pose = poses[i]
        end_pose = poses[i + 1]
        duration = durations[i]
        
        num_steps = int(duration * control_frequency)
        for step in range(num_steps):
            alpha = step / num_steps
            interp_pose = interpolate_pose(start_pose, end_pose, alpha)
            trajectory.append(interp_pose)
    
    # Add the last pose to ensure the trajectory ends exactly at the final pose
    trajectory.append(poses[-1])
    
    return trajectory


In [35]:
import torch
p = torch.tensor([0.3, -0.3, 0.0])
q = rot_utils.axis_angle_to_quaternion(torch.tensor([0, 0, torch.pi/2]))
target_pose = Pose.create_from_pq(p, q)
print(target_pose.shape)

torch.Size([1, 7])


In [36]:
traj = generate_trajectory([start_pose, target_pose], [2], env.agent.controller.control_freq)

In [37]:
print("start:", start_pose.p, start_pose.q)
for ttt in traj:
    print(ttt.p, ttt.q)
print("end:", target_pose.p, target_pose.q)


start: tensor([[0.4577, 0.0039, 0.4254]]) tensor([[0.4899, 0.5024, 0.5114, 0.4960]])
tensor([[0.4577, 0.0039, 0.4254]]) tensor([[0.4899, 0.5024, 0.5114, 0.4960]])
tensor([[ 0.4538, -0.0037,  0.4147]]) tensor([[0.5000, 0.4926, 0.5013, 0.5060]])
tensor([[ 0.4498, -0.0113,  0.4041]]) tensor([[0.5099, 0.4825, 0.4911, 0.5157]])
tensor([[ 0.4459, -0.0189,  0.3935]]) tensor([[0.5196, 0.4722, 0.4807, 0.5253]])
tensor([[ 0.4419, -0.0265,  0.3828]]) tensor([[0.5291, 0.4618, 0.4700, 0.5347]])
tensor([[ 0.4380, -0.0341,  0.3722]]) tensor([[0.5384, 0.4512, 0.4592, 0.5438]])
tensor([[ 0.4340, -0.0417,  0.3616]]) tensor([[0.5474, 0.4404, 0.4482, 0.5527]])
tensor([[ 0.4301, -0.0493,  0.3509]]) tensor([[0.5563, 0.4294, 0.4370, 0.5614]])
tensor([[ 0.4262, -0.0569,  0.3403]]) tensor([[0.5649, 0.4182, 0.4257, 0.5699]])
tensor([[ 0.4222, -0.0645,  0.3296]]) tensor([[0.5733, 0.4069, 0.4142, 0.5782]])
tensor([[ 0.4183, -0.0720,  0.3190]]) tensor([[0.5814, 0.3954, 0.4025, 0.5862]])
tensor([[ 0.4143, -0.0796, 

In [38]:
def compute_delta_pose(pose1: Pose, pose2: Pose) -> (torch.Tensor, torch.Tensor):
    """
    Computes the delta position and delta rotation (in Euler angles) between two poses.
    
    :param pose1: The initial Pose object.
    :param pose2: The final Pose object.
    :return: A tuple containing delta position (torch.Tensor) and delta rotation (in Euler angles, torch.Tensor).
    """
    # Delta Position
    delta_position = pose2.p - pose1.p
    
    # Compute relative rotation (delta rotation)
    quat1 = pose1.q
    quat2 = pose2.q
    
    # Calculate the relative quaternion (pose1 to pose2)
    relative_quat = quaternion_multiply(quat2, quaternion_conjugate(quat1))
    
    # Convert the relative quaternion to Euler angles
    relative_rotation = rot_utils.quaternion_to_axis_angle(relative_quat)
    delta_rotation = torch.tensor(relative_rotation, device=pose1.device)
    
    return delta_position, delta_rotation

def quaternion_conjugate(quat):
    """
    Computes the conjugate of a quaternion.
    
    :param quat: The quaternion tensor [w, x, y, z].
    :return: The conjugated quaternion tensor.
    """
    conjugate = quat.clone()
    conjugate[..., 1:] *= -1  # Negate the vector part
    return conjugate

def quaternion_multiply(quat1, quat2):
    """
    Multiplies two quaternions.
    
    :param quat1: First quaternion tensor [w, x, y, z].
    :param quat2: Second quaternion tensor [w, x, y, z].
    :return: Resulting quaternion after multiplication.
    """
    w1, x1, y1, z1 = quat1[..., 0], quat1[..., 1], quat1[..., 2], quat1[..., 3]
    w2, x2, y2, z2 = quat2[..., 0], quat2[..., 1], quat2[..., 2], quat2[..., 3]
    
    w = w1*w2 - x1*x2 - y1*y2 - z1*z2
    x = w1*x2 + x1*w2 + y1*z2 - z1*y2
    y = w1*y2 + y1*w2 + z1*x2 - x1*z2
    z = w1*z2 + z1*w2 + x1*y2 - y1*x2
    
    return torch.stack((w, x, y, z), dim=-1)




In [39]:
actions = []
for i in range(len(traj)-1):
    delta_pose = compute_delta_pose(traj[i], traj[i+1])
    print(delta_pose)

(tensor([[-0.0039, -0.0076, -0.0106]]), tensor([[ 0.0004, -0.0400, -0.0002]]))
(tensor([[-0.0039, -0.0076, -0.0106]]), tensor([[ 0.0004, -0.0400, -0.0002]]))
(tensor([[-0.0039, -0.0076, -0.0106]]), tensor([[ 0.0004, -0.0400, -0.0002]]))
(tensor([[-0.0039, -0.0076, -0.0106]]), tensor([[ 0.0004, -0.0400, -0.0002]]))
(tensor([[-0.0039, -0.0076, -0.0106]]), tensor([[ 0.0004, -0.0400, -0.0002]]))
(tensor([[-0.0039, -0.0076, -0.0106]]), tensor([[ 0.0004, -0.0400, -0.0002]]))
(tensor([[-0.0039, -0.0076, -0.0106]]), tensor([[ 0.0004, -0.0400, -0.0002]]))
(tensor([[-0.0039, -0.0076, -0.0106]]), tensor([[ 0.0004, -0.0400, -0.0002]]))
(tensor([[-0.0039, -0.0076, -0.0106]]), tensor([[ 0.0004, -0.0400, -0.0002]]))
(tensor([[-0.0039, -0.0076, -0.0106]]), tensor([[ 0.0004, -0.0400, -0.0002]]))
(tensor([[-0.0039, -0.0076, -0.0106]]), tensor([[ 0.0004, -0.0400, -0.0002]]))
(tensor([[-0.0039, -0.0076, -0.0106]]), tensor([[ 0.0004, -0.0400, -0.0002]]))
(tensor([[-0.0039, -0.0076, -0.0106]]), tensor([[ 0.

  delta_rotation = torch.tensor(relative_rotation, device=pose1.device)


In [40]:
for i in range(len(traj)-1):
    delta_pose = compute_delta_pose(traj[i], traj[i+1])
    action = torch.cat([delta_pose[0], delta_pose[1]], dim=1)
    
    _ = env.step(action)
    
    print("Expect pose:", traj[i+1].p, traj[i+1].q)
    print("Actual pose:", env.agent.controller.controllers['arm'].ee_pose_at_base.p, env.agent.controller.controllers['arm'].ee_pose_at_base.q)
    
    env.render()

  delta_rotation = torch.tensor(relative_rotation, device=pose1.device)


Expect pose: tensor([[ 0.4538, -0.0037,  0.4147]]) tensor([[0.5000, 0.4926, 0.5013, 0.5060]])
Actual pose: tensor([[0.4576, 0.0037, 0.4250]]) tensor([[0.4895, 0.5028, 0.5118, 0.4957]])
Expect pose: tensor([[ 0.4498, -0.0113,  0.4041]]) tensor([[0.5099, 0.4825, 0.4911, 0.5157]])
Actual pose: tensor([[0.4573, 0.0032, 0.4243]]) tensor([[0.4889, 0.5034, 0.5124, 0.4950]])
Expect pose: tensor([[ 0.4459, -0.0189,  0.3935]]) tensor([[0.5196, 0.4722, 0.4807, 0.5253]])
Actual pose: tensor([[0.4570, 0.0026, 0.4235]]) tensor([[0.4881, 0.5041, 0.5132, 0.4942]])
Expect pose: tensor([[ 0.4419, -0.0265,  0.3828]]) tensor([[0.5291, 0.4618, 0.4700, 0.5347]])
Actual pose: tensor([[0.4567, 0.0020, 0.4226]]) tensor([[0.4872, 0.5050, 0.5140, 0.4934]])
Expect pose: tensor([[ 0.4380, -0.0341,  0.3722]]) tensor([[0.5384, 0.4512, 0.4592, 0.5438]])
Actual pose: tensor([[0.4563, 0.0012, 0.4216]]) tensor([[0.4863, 0.5059, 0.5149, 0.4924]])
Expect pose: tensor([[ 0.4340, -0.0417,  0.3616]]) tensor([[0.5474, 0.4404,

: 