Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 85 additions & 1 deletion examples/visualization/reach_target_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@
simulation_app = app_launcher.app


from isaaclab.utils.math import (
compute_pose_error,
euler_xyz_from_quat,
subtract_frame_transforms,
)

import autosim_examples # noqa: F401
from autosim import make_pipeline
from autosim.utils.debug_util import visualize_reach_target_poses
Expand Down Expand Up @@ -112,6 +118,74 @@ def _apply_live_poses(*, poses_path: str, pipeline) -> None:
]


def _target_object_names(pipeline) -> list[str]:
return list(pipeline._env_extra_info.object_reach_target_poses.keys())


def _snapshot_object_poses_w(*, env, object_names: list[str]) -> dict[str, list[float]]:
poses_w: dict[str, list[float]] = {}
for obj_name in object_names:
if obj_name not in env.scene.keys():
print(f"[reach_target_pose] Skip missing scene object: {obj_name}")
continue
pose_w = env.scene[obj_name].data.root_pose_w[0]
poses_w[obj_name] = [float(v) for v in pose_w.detach().cpu().tolist()]
return poses_w


def _report_pose_drift(*, poses_before: dict[str, list[float]], poses_after: dict[str, list[float]]) -> None:
"""Print world-frame and object-frame relative pose change for each target object."""
print("[reach_target_pose] Object pose drift after 20 zero-action steps:")
for obj_name in poses_before:
if obj_name not in poses_after:
print(f" - {obj_name}: missing pose after steps")
continue

pose_before = poses_before[obj_name]
pose_after = poses_after[obj_name]
pos_b = torch.tensor(pose_before[:3]).view(1, 3)
quat_b = torch.tensor(pose_before[3:]).view(1, 4)
pos_a = torch.tensor(pose_after[:3]).view(1, 3)
quat_a = torch.tensor(pose_after[3:]).view(1, 4)

world_pos_delta = (pos_a - pos_b).squeeze(0)
world_pos_norm = float(torch.linalg.norm(world_pos_delta))
world_pos_err, world_rot_err = compute_pose_error(pos_b, quat_b, pos_a, quat_a)
world_rot_deg = float(torch.rad2deg(torch.linalg.norm(world_rot_err)).item())

rel_pos, rel_quat = subtract_frame_transforms(pos_b, quat_b, pos_a, quat_a)
rel_pos_norm = float(torch.linalg.norm(rel_pos).item())
_, rel_rot_axis_angle = compute_pose_error(
torch.zeros_like(pos_b),
torch.tensor([1.0, 0.0, 0.0, 0.0]).view(1, 4),
rel_pos,
rel_quat,
)
rel_rot_deg = float(torch.rad2deg(torch.linalg.norm(rel_rot_axis_angle)).item())
rel_roll, rel_pitch, rel_yaw = euler_xyz_from_quat(rel_quat)
rel_roll_deg = float(torch.rad2deg(rel_roll).item())
rel_pitch_deg = float(torch.rad2deg(rel_pitch).item())
rel_yaw_deg = float(torch.rad2deg(rel_yaw).item())

print(f" - {obj_name}:")
print(f" pose_before_w: {pose_before}")
print(f" pose_after_w: {pose_after}")
print(
" world_delta: "
f"pos=[{world_pos_delta[0]:.6f}, {world_pos_delta[1]:.6f}, {world_pos_delta[2]:.6f}] "
f"(norm={world_pos_norm:.6f} m), rot={world_rot_deg:.4f} deg"
)
print(
" relative_delta (in pose_before / object frame): "
f"pos=[{rel_pos[0, 0]:.6f}, {rel_pos[0, 1]:.6f}, {rel_pos[0, 2]:.6f}] "
f"(norm={rel_pos_norm:.6f} m), rot={rel_rot_deg:.4f} deg"
)
print(
" relative_rot_object_frame (XYZ Euler, deg): "
f"x(roll)={rel_roll_deg:.4f}, y(pitch)={rel_pitch_deg:.4f}, z(yaw)={rel_yaw_deg:.4f}"
)


def _export_env_extra_poses_to_json(*, out_path: str, pipeline) -> None:
"""Export current env_extra_info reach targets to JSON."""
env_extra_info = pipeline._env_extra_info
Expand Down Expand Up @@ -144,11 +218,21 @@ def main():
_apply_live_poses(poses_path=debug_path, pipeline=pipeline)
except Exception as e:
print(f"[reach_target_pose] Failed to apply exported debug poses: {e}")
visualize_reach_target_poses(pipeline._env_extra_info, pipeline._env)

target_objects = _target_object_names(pipeline)
poses_before_step = _snapshot_object_poses_w(env=pipeline._env, object_names=target_objects)

last_mtime = os.path.getmtime(debug_path)
last_poll_t = 0.0

for _ in range(20):
pipeline._env.step(torch.zeros(pipeline._env.action_space.shape, device=pipeline._env.device))

poses_after_step = _snapshot_object_poses_w(env=pipeline._env, object_names=target_objects)
_report_pose_drift(poses_before=poses_before_step, poses_after=poses_after_step)

visualize_reach_target_poses(pipeline._env_extra_info, pipeline._env)

while simulation_app.is_running():
pipeline._env.sim.render()

Expand Down
1 change: 1 addition & 0 deletions source/autosim/autosim/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def initialize(self) -> None:

# load the environment and extra information
self._env: ManagerBasedEnv = self.load_env()
self._env.reset()
self._env_extra_info: EnvExtraInfo = self.get_env_extra_info()
self._env_id = 0

Expand Down
Loading