In [None]:
import mediapy as media
import mujoco
import mujoco.viewer
import numpy as np
import tqdm
from dm_control import composer

from aloha_sim import task_suite
from aloha_sim.aloha import Aloha
from aloha_sim.motion_planner import Planner
from aloha_sim.tasks.base.aloha2_task import (
    SIM_GRIPPER_QPOS_CLOSE,
    SIM_GRIPPER_QPOS_OPEN,
    qpos_to_ctrl,
)

In [None]:
task_name = "MarkerRemoveLid"

In [None]:
_DT = 0.02
_IMAGE_SIZE = (480, 848)
_ALOHA_CAMERAS = {
    "overhead_cam": _IMAGE_SIZE,
    "worms_eye_cam": _IMAGE_SIZE,
    "wrist_cam_left": _IMAGE_SIZE,
    "wrist_cam_right": _IMAGE_SIZE,
}

In [None]:
task_class, kwargs = task_suite.TASK_FACTORIES[task_name]
task = task_class(
    cameras=_ALOHA_CAMERAS,
    control_timestep=_DT,
    update_interval=1,
    image_observation_delay_secs=0.0,
    **kwargs,
)
env = composer.Environment(
    task=task,
    time_limit=float("inf"),  # No explicit time limit from the environment
    random_state=np.random.RandomState(0),  # For reproducibility
    recompile_mjcf_every_episode=False,
    strip_singleton_obs_buffer_dim=True,
    delayed_observation_padding=composer.ObservationPadding.INITIAL_VALUE,
)
time_step = env.reset()

In [None]:
# random_qpos = env.random_state.uniform(joint_limits[:, 0], joint_limits[:, 1])

In [None]:
media.show_image(env.physics.render())

In [None]:
aloha = Aloha(
    env.physics.model.ptr,
    disable_collisions={
        ("marker//unnamed_body_0", "cap//unnamed_body_0"),
        ("table", "marker//unnamed_body_0"),
        ("left\\left_finger_link", "left\\right_finger_link"),
        ("right\\left_finger_link", "right\\right_finger_link"),
        ("table", "cap//unnamed_body_0"),
    },
)

In [None]:
from collections import defaultdict

frames = defaultdict(list)

env.reset()

seq = "XYZ"
ee_quat = np.zeros(4)
mujoco.mju_euler2Quat(ee_quat, [0, 1.57, 3.14], seq)

marker_geom = env.physics.data.geom("marker//unnamed_geom_0")
marker_quat = np.zeros(4)
mujoco.mju_euler2Quat(
    marker_quat,
    [
        0,
        0.0,
        -np.arccos(np.dot([1.0, 0.0, 0.0], marker_geom.xmat.reshape((3, 3))[:, 2])),
    ],
    "XYZ",
)

quat = np.zeros(4)
mujoco.mju_mulQuat(quat, marker_quat, ee_quat)

rotation = np.zeros(9)
mujoco.mju_quat2Mat(rotation, quat)

marker = env.physics.data.body("marker/")

pose = np.eye(4)
pose[:3, :3] = rotation.reshape((3, 3))  # marker.xmat.reshape((3, 3))
pose[:3, 3] = marker.xpos

approach_offset = np.eye(4)
approach_offset[2, 3] = 0.1
# pose[2, 3] += 0.1

object_center = np.eye(4)
object_center[2, 3] = 0.05


def step_env(env, trajectory):
    for index in tqdm.tqdm(range(len(trajectory))):
        time_step = env.step(qpos_to_ctrl(trajectory[index]))
        for camera_name in _ALOHA_CAMERAS:
            frames[camera_name].append(time_step.observation[camera_name])


# Pre-approach
trajectory = aloha.plan_to_pose(
    env.physics.data.ptr,
    approach_offset @ pose @ object_center,
    Planner.OMPL,
)
assert trajectory is not None
step_env(env, trajectory)

retreat_offset = np.eye(4)
retreat_offset[2, 3] = 0.0

# Approach motion
# disable_collisions.update({('right\\left_finger_link', 'marker//unnamed_body_0')})
# pose[2, 3] -= 0.09
trajectory = aloha.plan_to_pose(
    env.physics.data.ptr,
    retreat_offset @ pose @ object_center,
    Planner.CARTESIAN,
)
assert trajectory is not None
step_env(env, trajectory)

# Close gripper

aloha.disable_collisions(
    {
        ("right\\left_finger_link", "marker//unnamed_body_0"),
        ("right\\right_finger_link", "marker//unnamed_body_0"),
        ("table", "right\\right_finger_link"),
        ("table", "right\\left_finger_link"),
    },
)

step_env(env, [aloha.close_gripper(env.physics.data.ptr)] * 10)

# Place: Retreat
trajectory = aloha.plan_to_pose(
    env.physics.data.ptr,
    approach_offset @ pose @ object_center,
    Planner.CARTESIAN,
)
assert trajectory is not None
step_env(env, trajectory)

aloha.enable_collisions(
    {("table", "right\\right_finger_link"), ("table", "right\\left_finger_link")},
)

RESET_QPOS = np.asarray(
    [
        0.0,
        -0.96,
        1.16,
        0.0,
        -0.3,
        0.0,
        SIM_GRIPPER_QPOS_OPEN,
        SIM_GRIPPER_QPOS_OPEN,
        0.0,
        -0.96,
        1.16,
        0.0,
        -0.3,
        0.0,
        SIM_GRIPPER_QPOS_CLOSE,
        SIM_GRIPPER_QPOS_CLOSE,
    ],
)
trajectory = aloha.plan_to_qpos(env.physics.data.ptr, RESET_QPOS)
assert trajectory is not None
step_env(env, trajectory)

In [None]:
media.show_videos(frames, fps=1.0 / 0.01)