In [None]:
import cv2
import os
import h5py
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter


DT=0.02
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
STATE_NAMES = JOINT_NAMES + ["gripper"]


def load_hdf5(dataset_dir, dataset_name, skip_frames=0):
    dataset_path = os.path.join(dataset_dir, dataset_name)
    if not os.path.isfile(dataset_path):
        print(f'Dataset does not exist at \n{dataset_path}\n')
        exit()

    with h5py.File(dataset_path, 'r') as root:
        compressed = root.attrs.get('compress', False)
        
        # Apply skip_frames when loading qpos, qvel, and action
        qpos = root['/observations/qpos'][()][::skip_frames + 1]
        qvel = root['/observations/qvel'][()][::skip_frames + 1]
        action = root['/action'][()][::skip_frames + 1]
        
        image_dict = dict()
        for cam_name in root[f'/observations/images/'].keys():
            image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()][::skip_frames + 1]

        if compressed:
            compress_len = root['/compress_len'][()]

    if compressed:
        for cam_id, cam_name in enumerate(image_dict.keys()):
            # un-pad and uncompress
            padded_compressed_image_list = image_dict[cam_name]
            image_list = []
            for frame_id, padded_compressed_image in enumerate(padded_compressed_image_list): # [:1000] to save memory
                compressed_image = padded_compressed_image
                image = cv2.imdecode(compressed_image, 1)
                image_list.append(image)
            image_dict[cam_name] = image_list

    return qpos, qvel, action, image_dict
