In [1]:
%load_ext autoreload
%autoreload 2


In [14]:
import mujoco
import dexhub
from loop_rate_limiters import RateLimiter
import mink
import numpy as np
from ipdb import set_trace as bp

from pathlib import Path

from scipy.spatial.transform import Rotation as R

from src.common.geometry import np_matrix_to_rotation_6d

In [12]:
def fk(model, data, ctrl):
    """
    Compute forward kinematics for the robot.
    """

    data.qpos[:7] = ctrl[:7]
    data.qpos[9:16] = ctrl[8:15]

    mujoco.mj_kinematics(model, data)

    l_frame = data.body("l_robot/attachment")
    r_frame = data.body("r_robot/attachment")

    l_ee = mink.SE3.from_rotation_and_translation(
        rotation=mink.SO3(l_frame.xquat), translation=l_frame.xpos
    ).as_matrix()
    r_ee = mink.SE3.from_rotation_and_translation(
        rotation=mink.SO3(r_frame.xquat), translation=r_frame.xpos
    ).as_matrix()

    l_vel = data.body("l_robot/attachment").cvel
    r_vel = data.body("r_robot/attachment").cvel

    return l_ee, r_ee, l_vel, r_vel

In [8]:
traj_paths = list(
    Path("/Users/larsankile/code/dexhub-api/my_data/place_plate").glob("*.dex")
)

len(traj_paths)

36

In [9]:
traj = dexhub.load(traj_paths[0])

model = dexhub.get_sim(traj)
data = mujoco.MjData(model)

In [None]:
# Dataset structure desired
"""
    # Define the full shapes for each dataset
    full_data_shapes = [
        # These are of length: number of timesteps
        ("robot_state", all_data["robot_state"].shape, np.float32),
        ("color_image1", all_data["color_image1"].shape, np.uint8),
        ("color_image2", all_data["color_image2"].shape, np.uint8),
        ("action/delta", all_data["action/delta"].shape, np.float32),
        ("action/pos", all_data["action/pos"].shape, np.float32),
        ("parts_poses", all_data["parts_poses"].shape, np.float32),
        ("reward", all_data["reward"].shape, np.float32),
        ("skill", all_data["skill"].shape, np.float32),
        ("augment_states", all_data["augment_states"].shape, np.float32),
        # These are of length: number of episodes
        ("episode_ends", (len(all_data["episode_ends"]),), np.uint32),
        ("task", (len(all_data["task"]),), str),
        ("success", (len(all_data["success"]),), np.uint8),
        ("pickle_file", (len(all_data["pickle_file"]),), str),
    ]

    # Initialize Zarr store with full dimensions
    z = initialize_zarr_store(output_path, full_data_shapes, chunksize=chunksize)

    # Write the data to the Zarr store
    it = tqdm(all_data)
    for name in it:
        it.set_description(f"Writing data to zarr: {name}")
        dataset = z[name]
        data = all_data[name]

        for i in trange(0, len(data), chunksize, desc="Writing chunks", leave=False):
            dataset[i : i + chunksize] = data[i : i + chunksize]

    # Update final metadata
    z.attrs["time_finished"] = datetime.now().astimezone().isoformat()
    z.attrs["noop_threshold"] = noop_threshold
    z.attrs["chunksize"] = chunksize
    z.attrs["rotation_mode"] = "rot_6d"
    z.attrs["n_episodes"] = len(z["episode_ends"])
    z.attrs["n_timesteps"] = len(z["action/delta"])
    z.attrs["mean_episode_length"] = round(
        len(z["action/delta"]) / len(z["episode_ends"])
    )
    z.attrs["calculated_pos_action_from_delta"] = True
    z.attrs["randomize_order"] = args.randomize_order
    z.attrs["random_seed"] = args.random_seed
    z.attrs["demo_source"] = args.source
    z.attrs["controller"] = args.controller
    z.attrs["domain"] = args.domain if args.domain == "real" else "sim"
    z.attrs["task"] = args.task
    z.attrs["randomness"] = args.randomness
    z.attrs["demo_outcome"] = args.demo_outcome
    z.attrs["suffix"] = args.suffix
"""

In [None]:
# Initialize the storage
data_dict = {
    "robot_state": np.zeros((1, 24), dtype=np.float32),
    "color_image1": np.zeros((1, 480, 640, 3), dtype=np.uint8),
    "color_image2": np.zeros((1, 480, 640, 3), dtype=np.uint8),
    "action/delta": np.zeros((1, 7), dtype=np.float32),
    "action/pos": np.zeros((1, 7), dtype=np.float32),
    "parts_poses": np.zeros((1, 10, 7), dtype=np.float32),
    "reward": np.zeros((1, 1), dtype=np.float32),
}

In [18]:
traj = dexhub.load(traj_paths[0])

data.qpos = traj.data[0].obs.mj_qpos
data.qvel = traj.data[0].obs.mj_qvel

pegname = "bimanual_insertion_peg/bimanual_insertion_peg"
holename = "bimanual_insertion_hole3/bimanual_insertion_hole3"

mujoco.mj_forward(model, data)

fk_model = model
fk_data = data


for i in range(len(traj.data)):

    # Compute forward kinematics
    action_qpos = traj.data[i].act.mj_ctrl
    l_ee, r_ee, _, _ = fk(fk_model, fk_data, action_qpos)

    l_pos, r_pos = l_ee[:3, 3], r_ee[:3, 3]
    l_mat, r_mat = l_ee[:3, :3], r_ee[:3, :3]

    l_6d, r_6d = np_matrix_to_rotation_6d(l_mat), np_matrix_to_rotation_6d(r_mat)

    l_gripper, r_gripper = -1 if action_qpos[7] > 0.02 else 1, (
        -1 if action_qpos[15] > 0.02 else 1
    )

    # Concatenate the position, the 6D rotation, and gripper action
    action_pos = np.concatenate(
        [l_pos, l_6d, np.array([l_gripper]), r_pos, r_6d, np.array([r_gripper])],
        axis=-1,
    )

    # Get robot state for both grippers by doing FK on the current state
    qpos = data.qpos
    qvel = data.qvel

    l_ee, r_ee = fk(fk_model, fk_data, qpos)

    l_quat_state = mink.SO3(l_mat).to_quaternion()
    r_quat_state = mink.SO3(r_mat).to_quaternion()

    l_lin_vel = data.body("left_hand").cvel[:3]
    r_lin_vel = data.body("right_hand").cvel[:3]

    l_ang_vel = data.body("left_hand").cvel[3:]
    r_ang_vel = data.body("right_hand").cvel[3:]

    l_gripper_width = data.actuator("l_gripper").length
    r_gripper_width = data.actuator("r_gripper").length

    # Combine all states
    robot_state = np.concatenate(
        [
            l_pos_state,
            l_quat_state,
            l_lin_vel,
            l_ang_vel,
            [l_gripper_width],
            r_pos_state,
            r_quat_state,
            r_lin_vel,
            r_ang_vel,
            [r_gripper_width],
        ]
    )

    # Update the data_dict
    data_dict["robot_state"] = np.vstack([data_dict["robot_state"], robot_state])
    data_dict["action/pos"] = np.vstack([data_dict["action/pos"], action_pos])

    # Compute action delta (difference from previous action)
    if i > 0:
        action_delta = action_pos - data_dict["action/pos"][-2]
        data_dict["action/delta"] = np.vstack([data_dict["action/delta"], action_delta])
    else:
        data_dict["action/delta"] = np.vstack(
            [data_dict["action/delta"], np.zeros_like(action_pos)]
        )

    if (i + 1) % 10 == 0:
        # Print relative pose matrix between the peg and the hole
        peg_mat = mink.SE3.from_rotation_and_translation(
            rotation=mink.SO3(data.body(pegname).xquat),
            translation=data.body(pegname).xpos,
        ).as_matrix()
        hole_mat = mink.SE3.from_rotation_and_translation(
            rotation=mink.SO3(data.body(holename).xquat),
            translation=data.body(holename).xpos,
        ).as_matrix()

    for _ in range(10):

        mujoco.mj_step(model, data)

        viewer.sync()
        rate.sleep()

(20,)


RuntimeError: No active exception to reraise