In [5]:
"""
Convert from robomimic hdf5 to tensordict format
"""

import h5py
import torch
from tensordict import TensorDict
import numpy as np

save_dataset = False
base_save_path = "/scr/shared/datasets/robot_pref"
task = "lift"
type = "ph"
hdf5_type = "image"

data_path = f"/scr/matthewh6/robomimic_old/robomimic/datasets/{task}/{type}/{hdf5_type}_v15.hdf5"

with h5py.File(data_path, 'r') as f:
    data = f["data"]
    
    total_len = 0
    actions = []
    episodes = []
    images = []
    obs = []
    rewards = []
    
    for demo in sorted(data.keys(), key=lambda x: int(x.split('_')[1])):
            
        demo_data = data[demo]
        demo_len = len(demo_data["rewards"])
        
        actions.append(demo_data["actions"][:])
        episodes.append(torch.full((demo_len,), int(demo.split('_')[1])))
        images.append(demo_data["obs"]["agentview_image"][:])
        # observation consists of these
        # ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "object"]

        # print all obs keys shapes
        for key in demo_data["obs"].keys():
            print(f"{key}: {demo_data['obs'][key].shape}")

        obs.append(np.concatenate([
            demo_data["obs"]["robot0_eef_pos"][:].reshape(-1, 3),
            demo_data["obs"]["robot0_eef_quat"][:].reshape(-1, 4), 
            demo_data["obs"]["robot0_gripper_qpos"][:].reshape(-1, 2),
            demo_data["obs"]["object"][:].reshape(-1, 10) # obs varies by task
        ], axis=1))

        rewards.append(demo_data["rewards"][:])
        
        total_len += demo_len
    
    # Convert numpy arrays to tensors and concatenate all data
    tensordict = TensorDict({
        "action": torch.cat([torch.from_numpy(a).float() for a in actions]),
        "episode": torch.cat(episodes),
        "image": torch.cat([torch.from_numpy(img) for img in images]), 
        "obs": torch.cat([torch.from_numpy(o).float() for o in obs]),
        "reward": torch.cat([torch.from_numpy(r).float() for r in rewards])
    }, batch_size=torch.Size([]))
    
    print(tensordict)

    if save_dataset:
        # save the tensordict
        print(f"saving to {base_save_path}/{task}/{type}_{hdf5_type}.pt")
        torch.save(tensordict, f"{base_save_path}/{task}/{task}_{type}.pt")


agentview_image: (59, 84, 84, 3)
object: (59, 10)
robot0_eef_pos: (59, 3)
robot0_eef_quat: (59, 4)
robot0_eef_quat_site: (59, 4)
robot0_eye_in_hand_image: (59, 84, 84, 3)
robot0_gripper_qpos: (59, 2)
robot0_gripper_qvel: (59, 2)
robot0_joint_acc: (59, 7)
robot0_joint_pos: (59, 7)
robot0_joint_pos_cos: (59, 7)
robot0_joint_pos_sin: (59, 7)
robot0_joint_vel: (59, 7)
agentview_image: (58, 84, 84, 3)
object: (58, 10)
robot0_eef_pos: (58, 3)
robot0_eef_quat: (58, 4)
robot0_eef_quat_site: (58, 4)
robot0_eye_in_hand_image: (58, 84, 84, 3)
robot0_gripper_qpos: (58, 2)
robot0_gripper_qvel: (58, 2)
robot0_joint_acc: (58, 7)
robot0_joint_pos: (58, 7)
robot0_joint_pos_cos: (58, 7)
robot0_joint_pos_sin: (58, 7)
robot0_joint_vel: (58, 7)
agentview_image: (57, 84, 84, 3)
object: (57, 10)
robot0_eef_pos: (57, 3)
robot0_eef_quat: (57, 4)
robot0_eef_quat_site: (57, 4)
robot0_eye_in_hand_image: (57, 84, 84, 3)
robot0_gripper_qpos: (57, 2)
robot0_gripper_qvel: (57, 2)
robot0_joint_acc: (57, 7)
robot0_join