In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

import mujoco
import dexhub
from dart_physics import mink
import numpy as np

from pathlib import Path


from tqdm import tqdm

from scipy.spatial.transform import Rotation as R

from src.common.geometry import np_matrix_to_rotation_6d
from src.common.files import get_processed_path

import zarr
from datetime import datetime

In [3]:
def fk(model, data, ctrl):
    """
    Compute forward kinematics for the robot.
    """

    data.qpos[:7] = ctrl[:7]
    data.qpos[9:16] = ctrl[8:15]

    mujoco.mj_kinematics(model, data)

    l_frame = data.body("l_robot/attachment")
    r_frame = data.body("r_robot/attachment")

    l_ee = mink.SE3.from_rotation_and_translation(
        rotation=mink.SO3(l_frame.xquat), translation=l_frame.xpos
    ).as_matrix()
    r_ee = mink.SE3.from_rotation_and_translation(
        rotation=mink.SO3(r_frame.xquat), translation=r_frame.xpos
    ).as_matrix()

    l_vel = data.body("l_robot/attachment").cvel
    r_vel = data.body("r_robot/attachment").cvel

    return l_ee, r_ee, l_vel, r_vel

In [4]:
pegname = "bimanual_insertion_peg/bimanual_insertion_peg"
holename = "bimanual_insertion_hole3/bimanual_insertion_hole3"

In [None]:
traj_paths = list(
    (Path(os.environ["DATA_DIR_RAW"]) / "raw/dexhub/sim/bimanual_insertion").glob(
        "*.dex"
    )
)

len(traj_paths)

In [None]:
# Dataset structure desired
"""
    # Define the full shapes for each dataset
    full_data_shapes = [
        # These are of length: number of timesteps
        ("robot_state", all_data["robot_state"].shape, np.float32),
        ("color_image1", all_data["color_image1"].shape, np.uint8),
        ("color_image2", all_data["color_image2"].shape, np.uint8),
        ("action/delta", all_data["action/delta"].shape, np.float32),
        ("action/pos", all_data["action/pos"].shape, np.float32),
        ("parts_poses", all_data["parts_poses"].shape, np.float32),
        ("reward", all_data["reward"].shape, np.float32),
        ("skill", all_data["skill"].shape, np.float32),
        ("augment_states", all_data["augment_states"].shape, np.float32),
        # These are of length: number of episodes
        ("episode_ends", (len(all_data["episode_ends"]),), np.uint32),
        ("task", (len(all_data["task"]),), str),
        ("success", (len(all_data["success"]),), np.uint8),
        ("pickle_file", (len(all_data["pickle_file"]),), str),
    ]

    # Initialize Zarr store with full dimensions
    z = initialize_zarr_store(output_path, full_data_shapes, chunksize=chunksize)

    # Write the data to the Zarr store
    it = tqdm(all_data)
    for name in it:
        it.set_description(f"Writing data to zarr: {name}")
        dataset = z[name]
        data = all_data[name]

        for i in trange(0, len(data), chunksize, desc="Writing chunks", leave=False):
            dataset[i : i + chunksize] = data[i : i + chunksize]

    # Update final metadata
    z.attrs["time_finished"] = datetime.now().astimezone().isoformat()
    z.attrs["noop_threshold"] = noop_threshold
    z.attrs["chunksize"] = chunksize
    z.attrs["rotation_mode"] = "rot_6d"
    z.attrs["n_episodes"] = len(z["episode_ends"])
    z.attrs["n_timesteps"] = len(z["action/delta"])
    z.attrs["mean_episode_length"] = round(
        len(z["action/delta"]) / len(z["episode_ends"])
    )
    z.attrs["calculated_pos_action_from_delta"] = True
    z.attrs["randomize_order"] = args.randomize_order
    z.attrs["random_seed"] = args.random_seed
    z.attrs["demo_source"] = args.source
    z.attrs["controller"] = args.controller
    z.attrs["domain"] = args.domain if args.domain == "real" else "sim"
    z.attrs["task"] = args.task
    z.attrs["randomness"] = args.randomness
    z.attrs["demo_outcome"] = args.demo_outcome
    z.attrs["suffix"] = args.suffix
"""

In [None]:
def quat_wxyz_to_xyzw(quat):
    return np.roll(quat, -1)


# Test out quat conversion
quat = np.array([1, 2, 3, 4])
quat_wxyz_to_xyzw(quat)

In [8]:
# Initialize the storage
data_dict = {
    # Per timestep
    "robot_state": np.zeros((0, 32), dtype=np.float32),
    # "color_image1": np.zeros((0, 480, 640, 3), dtype=np.uint8),
    # "color_image2": np.zeros((0, 480, 640, 3), dtype=np.uint8),
    # "action/delta": np.zeros((0, 20), dtype=np.float32),
    "action/pos": np.zeros((0, 20), dtype=np.float32),
    "parts_poses": np.zeros((0, 14), dtype=np.float32),
    # "reward": np.zeros((0, 1), dtype=np.float32),
    # Per episode
    "episode_ends": np.zeros((0,), dtype=np.uint32),
    "task": np.zeros((0,), dtype=str),
    "success": np.zeros((0,), dtype=np.uint8),
    "pickle_file": np.zeros((0,), dtype=str),
}

In [None]:
init_poses = []

for traj_path in tqdm(traj_paths):

    traj = dexhub.load(traj_path)

    model = dexhub.get_sim(traj)
    data = mujoco.MjData(model)

    data.qpos = traj.data[0].obs.mj_qpos
    data.qvel = traj.data[0].obs.mj_qvel

    mujoco.mj_forward(model, data)

    fk_model = model
    fk_data = data
    
    init_poses.append(data.qpos[:18].tolist())

    for i in range(len(traj.data)):
        # === These are the per timestep data ===

        # Compute forward kinematics
        action_qpos = traj.data[i].act.mj_ctrl
        l_ee, r_ee, _, _ = fk(fk_model, fk_data, action_qpos)

        l_pos, r_pos = l_ee[:3, 3], r_ee[:3, 3]
        l_mat, r_mat = l_ee[:3, :3], r_ee[:3, :3]

        l_6d, r_6d = np_matrix_to_rotation_6d(l_mat), np_matrix_to_rotation_6d(r_mat)

        # Concatenate the position, the 6D rotation, and gripper action
        action_pos = np.concatenate(
            [l_pos, l_6d, [action_qpos[7]], r_pos, r_6d, [action_qpos[15]]],
            axis=-1,
        )
        data_dict["action/pos"] = np.vstack(
            [data_dict["action/pos"], action_pos], dtype=np.float32
        )

        data.qpos = traj.data[i].obs.mj_qpos


        data.qvel = traj.data[i].obs.mj_qvel

        mujoco.mj_forward(model, data)
        mujoco.mj_kinematics(model, data)

        l_frame = data.body("l_robot/attachment")
        r_frame = data.body("r_robot/attachment")

        l_ee = mink.SE3.from_rotation_and_translation(
            rotation=mink.SO3(l_frame.xquat), translation=l_frame.xpos
        ).as_matrix()
        r_ee = mink.SE3.from_rotation_and_translation(
            rotation=mink.SO3(r_frame.xquat), translation=r_frame.xpos
        ).as_matrix()

        l_vel = data.body("l_robot/attachment").cvel
        r_vel = data.body("r_robot/attachment").cvel

        l_pos_state, r_pos_state = l_ee[:3, 3], r_ee[:3, 3]

        l_mat, r_mat = l_ee[:3, :3], r_ee[:3, :3]

        l_rot_6d, r_rot_6d = np_matrix_to_rotation_6d(l_mat), np_matrix_to_rotation_6d(
            r_mat
        )

        l_gripper_width = data.qpos[7] + data.qpos[8]
        r_gripper_width = data.qpos[16] + data.qpos[17]

        # Combine all states
        robot_state = np.concatenate(
            [
                l_pos_state,
                l_rot_6d,
                l_vel,
                [l_gripper_width],
                r_pos_state,
                r_rot_6d,
                r_vel,
                [r_gripper_width],
            ]
        )

        data_dict["robot_state"] = np.vstack(
            [data_dict["robot_state"], robot_state], dtype=np.float32
        )

        # Get the parts poses
        peg_pos, peg_quat_xyzw = data.body(pegname).xpos, quat_wxyz_to_xyzw(
            data.body(pegname).xquat
        )
        hole_pos, hole_quat_xyzw = data.body(holename).xpos, quat_wxyz_to_xyzw(
            data.body(holename).xquat
        )

        peg_pose = np.concatenate([peg_pos, peg_quat_xyzw])
        hole_pose = np.concatenate([hole_pos, hole_quat_xyzw])

        parts_poses = np.concatenate([peg_pose, hole_pose], axis=-1)

        data_dict["parts_poses"] = np.vstack(
            [data_dict["parts_poses"], parts_poses], dtype=np.float32
        )

    # Check that the length of the different data is the same
    assert len(data_dict["action/pos"]) == len(data_dict["robot_state"])
    assert len(data_dict["action/pos"]) == len(data_dict["parts_poses"])

    # === These are the per episode data ===
    data_dict["episode_ends"] = np.hstack(
        [data_dict["episode_ends"], len(data_dict["action/pos"])]
    )
    data_dict["task"] = np.hstack([data_dict["task"], "bimanual_insertion"])
    data_dict["success"] = np.hstack([data_dict["success"], 1])
    data_dict["pickle_file"] = np.hstack([data_dict["pickle_file"], traj_paths[0].name])

In [None]:
# Convert into a Zarr store and save
output_path = get_processed_path(
    domain="sim",
    controller="dexhub",
    task="bimanual_insertion",
    demo_outcome="success",
    demo_source="teleop",
    randomness="low",
)

print(f"Saving to {output_path}")

In [11]:
z = zarr.open(output_path, mode="w")

for name, data in data_dict.items():
    z.create_dataset(name, data=data, dtype=data.dtype)

z.attrs["time_finished"] = datetime.now().astimezone().isoformat()
z.attrs["rotation_mode"] = "rot_6d"
z.attrs["n_episodes"] = len(z["episode_ends"])
z.attrs["n_timesteps"] = len(z["action/pos"])
z.attrs["mean_episode_length"] = round(len(z["action/pos"]) / len(z["episode_ends"]))
z.attrs["calculated_pos_action_from_delta"] = False
z.attrs["randomize_order"] = False
z.attrs["random_seed"] = None
z.attrs["demo_source"] = "teleop"
z.attrs["controller"] = "dexhub"
z.attrs["domain"] = "sim"
z.attrs["task"] = "bimanual_insertion"
z.attrs["randomness"] = "low"
z.attrs["demo_outcome"] = "success"
z.attrs["suffix"] = None

In [12]:
# Save init_poses as a numpy array
init_poses = np.array(init_poses)

init_poses_path = output_path.parent / "init_poses.npy"

np.save(init_poses_path, init_poses)

In [None]:
data_dict["action/pos"].dtype

In [None]:
# Look at the data
path = get_processed_path(
    domain="sim",
    controller="dexhub",
    task="bimanual_insertion",
    demo_outcome="success",
    demo_source="teleop",
    randomness="low",
)

path

In [None]:
z = zarr.open(path, mode="r")

z.info

In [None]:
z["robot_state"].shape, z["action/pos"].shape, z["parts_poses"].shape, z["episode_ends"][-1]