In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

import mujoco
import dexhub
import mink
import numpy as np

from pathlib import Path


from tqdm import tqdm

from scipy.spatial.transform import Rotation as R

from src.common.geometry import np_matrix_to_rotation_6d
from src.common.files import get_processed_path

import zarr
from datetime import datetime

In [3]:
def fk(model, data, ctrl):
    """
    Compute forward kinematics for the robot.
    """

    data.qpos[:7] = ctrl[:7]
    data.qpos[9:16] = ctrl[8:15]

    mujoco.mj_kinematics(model, data)

    l_frame = data.body("l_robot/attachment")
    r_frame = data.body("r_robot/attachment")

    l_ee = mink.SE3.from_rotation_and_translation(
        rotation=mink.SO3(l_frame.xquat), translation=l_frame.xpos
    ).as_matrix()
    r_ee = mink.SE3.from_rotation_and_translation(
        rotation=mink.SO3(r_frame.xquat), translation=r_frame.xpos
    ).as_matrix()

    l_vel = data.body("l_robot/attachment").cvel
    r_vel = data.body("r_robot/attachment").cvel

    return l_ee, r_ee, l_vel, r_vel

In [4]:
pegname = "bimanual_insertion_peg/bimanual_insertion_peg"
holename = "bimanual_insertion_hole3/bimanual_insertion_hole3"

In [5]:
traj_paths = list(
    (Path(os.environ["DATA_DIR_RAW"]) / "raw/dexhub/sim/bimanual_insertion").glob("*.dex")
)

len(traj_paths)

36

In [6]:
# Dataset structure desired
"""
    # Define the full shapes for each dataset
    full_data_shapes = [
        # These are of length: number of timesteps
        ("robot_state", all_data["robot_state"].shape, np.float32),
        ("color_image1", all_data["color_image1"].shape, np.uint8),
        ("color_image2", all_data["color_image2"].shape, np.uint8),
        ("action/delta", all_data["action/delta"].shape, np.float32),
        ("action/pos", all_data["action/pos"].shape, np.float32),
        ("parts_poses", all_data["parts_poses"].shape, np.float32),
        ("reward", all_data["reward"].shape, np.float32),
        ("skill", all_data["skill"].shape, np.float32),
        ("augment_states", all_data["augment_states"].shape, np.float32),
        # These are of length: number of episodes
        ("episode_ends", (len(all_data["episode_ends"]),), np.uint32),
        ("task", (len(all_data["task"]),), str),
        ("success", (len(all_data["success"]),), np.uint8),
        ("pickle_file", (len(all_data["pickle_file"]),), str),
    ]

    # Initialize Zarr store with full dimensions
    z = initialize_zarr_store(output_path, full_data_shapes, chunksize=chunksize)

    # Write the data to the Zarr store
    it = tqdm(all_data)
    for name in it:
        it.set_description(f"Writing data to zarr: {name}")
        dataset = z[name]
        data = all_data[name]

        for i in trange(0, len(data), chunksize, desc="Writing chunks", leave=False):
            dataset[i : i + chunksize] = data[i : i + chunksize]

    # Update final metadata
    z.attrs["time_finished"] = datetime.now().astimezone().isoformat()
    z.attrs["noop_threshold"] = noop_threshold
    z.attrs["chunksize"] = chunksize
    z.attrs["rotation_mode"] = "rot_6d"
    z.attrs["n_episodes"] = len(z["episode_ends"])
    z.attrs["n_timesteps"] = len(z["action/delta"])
    z.attrs["mean_episode_length"] = round(
        len(z["action/delta"]) / len(z["episode_ends"])
    )
    z.attrs["calculated_pos_action_from_delta"] = True
    z.attrs["randomize_order"] = args.randomize_order
    z.attrs["random_seed"] = args.random_seed
    z.attrs["demo_source"] = args.source
    z.attrs["controller"] = args.controller
    z.attrs["domain"] = args.domain if args.domain == "real" else "sim"
    z.attrs["task"] = args.task
    z.attrs["randomness"] = args.randomness
    z.attrs["demo_outcome"] = args.demo_outcome
    z.attrs["suffix"] = args.suffix
""";

In [7]:
def quat_wxyz_to_xyzw(quat):
    return np.roll(quat, -1)

# Test out quat conversion
quat = np.array([1, 2, 3, 4])
quat_wxyz_to_xyzw(quat)

array([2, 3, 4, 1])

In [8]:
# Initialize the storage
data_dict = {
    # Per timestep
    "robot_state": np.zeros((0, 32), dtype=np.float32),
    # "color_image1": np.zeros((0, 480, 640, 3), dtype=np.uint8),
    # "color_image2": np.zeros((0, 480, 640, 3), dtype=np.uint8),
    # "action/delta": np.zeros((0, 20), dtype=np.float32),
    "action/pos": np.zeros((0, 20), dtype=np.float32),
    "parts_poses": np.zeros((0, 14), dtype=np.float32),
    # "reward": np.zeros((0, 1), dtype=np.float32),
    # Per episode
    "episode_ends": np.zeros((0,), dtype=np.uint32),
    "task": np.zeros((0,), dtype=str),
    "success": np.zeros((0,), dtype=np.uint8),
    "pickle_file": np.zeros((0,), dtype=str),
}

In [21]:
init_poses = []

for traj_path in tqdm(traj_paths):

    traj = dexhub.load(traj_path)

    model = dexhub.get_sim(traj)
    data = mujoco.MjData(model)

    data.qpos = traj.data[0].obs.mj_qpos
    data.qvel = traj.data[0].obs.mj_qvel

    mujoco.mj_forward(model, data)

    fk_model = model
    fk_data = data

    for i in range(len(traj.data)):
        # === These are the per timestep data ===

        # Compute forward kinematics
        action_qpos = traj.data[i].act.mj_ctrl
        l_ee, r_ee, _, _ = fk(fk_model, fk_data, action_qpos)

        l_pos, r_pos = l_ee[:3, 3], r_ee[:3, 3]
        l_mat, r_mat = l_ee[:3, :3], r_ee[:3, :3]

        l_6d, r_6d = np_matrix_to_rotation_6d(l_mat), np_matrix_to_rotation_6d(r_mat)

        # Concatenate the position, the 6D rotation, and gripper action
        action_pos = np.concatenate(
            [l_pos, l_6d, [action_qpos[7]], r_pos, r_6d, [action_qpos[15]]],
            axis=-1,
        )
        data_dict["action/pos"] = np.vstack([data_dict["action/pos"], action_pos], dtype=np.float32)

        data.qpos = traj.data[i].obs.mj_qpos

        init_poses.append(data.qpos[:18].tolist())

        break

        data.qvel = traj.data[i].obs.mj_qvel

        mujoco.mj_forward(model, data)
        mujoco.mj_kinematics(model, data)

        l_frame = data.body("l_robot/attachment")
        r_frame = data.body("r_robot/attachment")

        l_ee = mink.SE3.from_rotation_and_translation(
            rotation=mink.SO3(l_frame.xquat), translation=l_frame.xpos
        ).as_matrix()
        r_ee = mink.SE3.from_rotation_and_translation(
            rotation=mink.SO3(r_frame.xquat), translation=r_frame.xpos
        ).as_matrix()

        l_vel = data.body("l_robot/attachment").cvel
        r_vel = data.body("r_robot/attachment").cvel

        l_pos_state, r_pos_state = l_ee[:3, 3], r_ee[:3, 3]

        l_mat, r_mat = l_ee[:3, :3], r_ee[:3, :3]

        l_rot_6d, r_rot_6d = np_matrix_to_rotation_6d(l_mat), np_matrix_to_rotation_6d(r_mat)

        l_gripper_width = data.qpos[7] + data.qpos[8]
        r_gripper_width = data.qpos[16] + data.qpos[17]

        # Combine all states
        robot_state = np.concatenate(
            [
                l_pos_state,
                l_rot_6d,
                l_vel,
                [l_gripper_width],
                r_pos_state,
                r_rot_6d,
                r_vel,
                [r_gripper_width],
            ]
        )

        data_dict["robot_state"] = np.vstack([data_dict["robot_state"], robot_state], dtype=np.float32)

        # Get the parts poses
        peg_pos, peg_quat_xyzw = data.body(pegname).xpos, quat_wxyz_to_xyzw(data.body(pegname).xquat)
        hole_pos, hole_quat_xyzw = data.body(holename).xpos, quat_wxyz_to_xyzw(data.body(holename).xquat)

        peg_pose = np.concatenate([peg_pos, peg_quat_xyzw])
        hole_pose = np.concatenate([hole_pos, hole_quat_xyzw])

        parts_poses = np.concatenate([peg_pose, hole_pose], axis=-1)

        data_dict["parts_poses"] = np.vstack([data_dict["parts_poses"], parts_poses], dtype=np.float32)

    continue
    # === These are the per episode data ===
    data_dict["episode_ends"] = np.hstack([data_dict["episode_ends"], len(data_dict["action/pos"])])
    data_dict["task"] = np.hstack([data_dict["task"], "bimanual_insertion"])
    data_dict["success"] = np.hstack([data_dict["success"], 1])
    data_dict["pickle_file"] = np.hstack([data_dict["pickle_file"], traj_paths[0].name])


  0%|          | 0/36 [00:00<?, ?it/s]

100%|██████████| 36/36 [00:04<00:00,  7.60it/s]


In [22]:
# Save init_poses as a numpy array
init_poses = np.array(init_poses)
np.save("init_poses.npy", init_poses)

In [10]:
# Convert into a Zarr store and save
output_path = get_processed_path(
    domain="sim",
    controller="dexhub",
    task="bimanual_insertion",
    demo_outcome="success",
    demo_source="teleop",
    randomness="low",
)

print(f"Saving to {output_path}")

Saving to /data/scratch/ankile/furniture-data/processed/dexhub/sim/bimanual_insertion/teleop/low/success.zarr


In [11]:
z = zarr.open(output_path, mode="w")

for name, data in data_dict.items():
    z.create_dataset(name, data=data, dtype=data.dtype)

z.attrs["time_finished"] = datetime.now().astimezone().isoformat()
z.attrs["rotation_mode"] = "rot_6d"
z.attrs["n_episodes"] = len(z["episode_ends"])
z.attrs["n_timesteps"] = len(z["action/pos"])
z.attrs["mean_episode_length"] = round(len(z["action/pos"]) / len(z["episode_ends"]))
z.attrs["calculated_pos_action_from_delta"] = False
z.attrs["randomize_order"] = False
z.attrs["random_seed"] = None
z.attrs["demo_source"] = "teleop"
z.attrs["controller"] = "dexhub"
z.attrs["domain"] = "sim"
z.attrs["task"] = "bimanual_insertion"
z.attrs["randomness"] = "low"
z.attrs["demo_outcome"] = "success"
z.attrs["suffix"] = None

In [12]:
data_dict["action/pos"].dtype

dtype('float32')

In [13]:
# Look at the data
path = get_processed_path(
    domain="sim",
    controller="dexhub",
    task="bimanual_insertion",
    demo_outcome="success",
    demo_source="teleop",
    randomness="low",
)

path

PosixPath('/data/scratch/ankile/furniture-data/processed/dexhub/sim/bimanual_insertion/teleop/low/success.zarr')

In [14]:
z = zarr.open(path, mode="r")

z.info

0,1
Name,/
Type,zarr.hierarchy.Group
Read-only,True
Store type,zarr.storage.DirectoryStore
No. members,7
No. arrays,6
No. groups,1
Arrays,"episode_ends, parts_poses, pickle_file, robot_state, success, task"
Groups,action


In [15]:
z["robot_state"][10, :10] - z["action/pos"][10, :10]

array([-0.00956386, -0.00956407,  0.00487675, -0.02865148,  0.04399419,
        0.00840636,  0.04741105, -0.00892806, -0.05337405,  0.43563274],
      dtype=float32)

In [16]:
z["action/pos"][10]

array([-0.19278528,  0.15126579,  0.09587907, -0.81279373, -0.5769143 ,
        0.08084682, -0.5253764 ,  0.6659689 , -0.52958953,  0.03933182,
        0.2077984 ,  0.08537333,  0.13556954, -0.19469947,  0.9681319 ,
       -0.15752076,  0.96826553,  0.16404867, -0.18854676,  0.04      ],
      dtype=float32)

In [17]:
for i in range(1000):
    print(z["robot_state"][i, 0] - z["action/pos"][i, 0])

-0.0043601543
-0.0035610795
-0.004005924
-0.0031517297
-0.0035942793
-0.004630834
-0.006363511
-0.008619547
-0.007060185
-0.008117691
-0.009563863
-0.0077426583
-0.008978903
-0.01046145
-0.008459359
-0.009804711
-0.011125281
-0.008975983
-0.010225624
-0.010866448
-0.010883823
-0.008779809
-0.009740949
-0.010104373
-0.010390535
-0.008411497
-0.008613393
-0.008511946
-0.0069556534
-0.0071816593
-0.0073530525
-0.006032273
-0.007231936
-0.007574737
-0.008534446
-0.007116154
-0.0078077465
-0.008815125
-0.009187907
-0.007646024
-0.008878246
-0.009454861
-0.00789699
-0.008736044
-0.0086232275
-0.008840069
-0.0074281394
-0.007309079
-0.0075397044
-0.0063937455
-0.0063559115
-0.007194951
-0.0061293542
-0.006475717
-0.006893426
-0.008007124
-0.007086456
-0.008391827
-0.00850217
-0.0077409446
-0.007783264
-0.009297825
-0.008860767
-0.009221181
-0.009880237
-0.0096766725
-0.009778045
-0.009990193
-0.010548338
-0.010593645
-0.011840925
-0.012536399
-0.01324065
-0.013159469
-0.014436975
-0.014907323

In [18]:
for i in range(1000):
    print(z["robot_state"][i, 9:12])

[-0.03067682  0.54095155  0.44993806]
[0.04372392 0.43310976 0.54129213]
[0.09488807 0.27355695 0.6079379 ]
[0.10347182 0.19552581 0.5743725 ]
[0.10251652 0.20565672 0.52702254]
[0.18438353 0.21738699 0.49193597]
[0.28305972 0.27344784 0.4050034 ]
[0.40200487 0.3075995  0.34623873]
[0.4689043  0.28571874 0.29562125]
[0.43746594 0.3091012  0.2375193 ]
[0.47496456 0.33029664 0.29637992]
[0.44859916 0.34192166 0.32127285]
[0.43824238 0.3736095  0.2569302 ]
[0.49737677 0.44928822 0.2720001 ]
[0.5224702 0.4001646 0.2863754]
[0.44932774 0.35294905 0.29877365]
[0.42305422 0.299343   0.4361223 ]
[0.34861696 0.31929088 0.5396438 ]
[0.28551036 0.2731845  0.52161944]
[0.24037857 0.25762615 0.64393646]
[0.22463645 0.29609123 0.7119892 ]
[0.20813698 0.31844878 0.7314626 ]
[0.20707916 0.31319296 0.6301957 ]
[0.18835695 0.32348153 0.71177685]
[0.1437662  0.33559474 0.7343927 ]
[0.11549884 0.32649586 0.67047125]
[0.11342949 0.30326325 0.5955471 ]
[0.08506606 0.30798203 0.5649035 ]
[0.0606151  0.293737