In [None]:
"""
Convert from robomimic hdf5 to tensordict format
"""

import h5py
import torch
from tensordict import TensorDict
import numpy as np

save_dataset = True
base_save_path = "/scr2/shared/pref/datasets/robomimic"
task = "lift"
type = "mh"
hdf5_type = "image_dense"

data_path = f"/scr/matthewh6/robomimic/robomimic/datasets/{task}/{type}/{hdf5_type}_v15.hdf5"

with h5py.File(data_path, 'r') as f:
    data = f["data"]
    
    total_len = 0
    actions = []
    episodes = []
    images = []
    obs = []
    rewards = []
    
    for demo in sorted(data.keys(), key=lambda x: int(x.split('_')[1])):
            
        demo_data = data[demo]
        demo_len = len(demo_data["rewards"])
        
        actions.append(demo_data["actions"][:])
        episodes.append(torch.full((demo_len,), int(demo.split('_')[1])))
        images.append(demo_data["obs"]["agentview_image"][:])
        # observation consists of these
        # ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "object"]

        # print all obs keys shapes
        # for key in demo_data["obs"].keys():
        #     print(f"{key}: {demo_data['obs'][key].shape}")

        obs.append(np.concatenate([
            demo_data["obs"]["robot0_eef_pos"][:].reshape(-1, 3),
            demo_data["obs"]["robot0_eef_quat"][:].reshape(-1, 4), 
            demo_data["obs"]["robot0_gripper_qpos"][:].reshape(-1, 2),
            demo_data["obs"]["object"][:].reshape(-1, 14) # obs varies by task
        ], axis=1))

        rewards.append(demo_data["rewards"][:])
        
        total_len += demo_len
    
    # Convert numpy arrays to tensors and concatenate all data
    tensordict = TensorDict({
        "action": torch.cat([torch.from_numpy(a).float() for a in actions]),
        "episode": torch.cat(episodes),
        "image": torch.cat([torch.from_numpy(img) for img in images]), 
        "obs": torch.cat([torch.from_numpy(o).float() for o in obs]),
        "reward": torch.cat([torch.from_numpy(r).float() for r in rewards])
    }, batch_size=torch.Size([]))
    
    print(tensordict)

    if save_dataset:
        # save the tensordict
        print(f"saving to {base_save_path}/{task}/{type}_{hdf5_type}.pt")
        torch.save(tensordict, f"{base_save_path}/{task}/{task}_{type}_{hdf5_type}.pt")


agentview_image: (48, 84, 84, 3)
object: (48, 10)
robot0_eef_pos: (48, 3)
robot0_eef_quat: (48, 4)
robot0_eef_quat_site: (48, 4)
robot0_eye_in_hand_image: (48, 84, 84, 3)
robot0_gripper_qpos: (48, 2)
robot0_gripper_qvel: (48, 2)
robot0_joint_acc: (48, 7)
robot0_joint_pos: (48, 7)
robot0_joint_pos_cos: (48, 7)
robot0_joint_pos_sin: (48, 7)
robot0_joint_vel: (48, 7)
agentview_image: (55, 84, 84, 3)
object: (55, 10)
robot0_eef_pos: (55, 3)
robot0_eef_quat: (55, 4)
robot0_eef_quat_site: (55, 4)
robot0_eye_in_hand_image: (55, 84, 84, 3)
robot0_gripper_qpos: (55, 2)
robot0_gripper_qvel: (55, 2)
robot0_joint_acc: (55, 7)
robot0_joint_pos: (55, 7)
robot0_joint_pos_cos: (55, 7)
robot0_joint_pos_sin: (55, 7)
robot0_joint_vel: (55, 7)
agentview_image: (46, 84, 84, 3)
object: (46, 10)
robot0_eef_pos: (46, 3)
robot0_eef_quat: (46, 4)
robot0_eef_quat_site: (46, 4)
robot0_eye_in_hand_image: (46, 84, 84, 3)
robot0_gripper_qpos: (46, 2)
robot0_gripper_qvel: (46, 2)
robot0_joint_acc: (46, 7)
robot0_join