In [5]:
import cv2
import os
import h5py
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter


DT=0.02
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
STATE_NAMES = JOINT_NAMES + ["gripper"]


def load_hdf5(dataset_dir, dataset_name, skip_frames=0):
    dataset_path = os.path.join(dataset_dir, dataset_name)
    if not os.path.isfile(dataset_path):
        print(f'Dataset does not exist at \n{dataset_path}\n')
        exit()

    with h5py.File(dataset_path, 'r') as root:
        compressed = root.attrs.get('compress', False)
        
        # Apply skip_frames when loading qpos, qvel, and action
        qpos = root['/observations/qpos'][()][::skip_frames + 1]
        qvel = root['/observations/qvel'][()][::skip_frames + 1]
        action = root['/action'][()][::skip_frames + 1]
        
        image_dict = dict()
        for cam_name in root[f'/observations/images/'].keys():
            image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()][::skip_frames + 1]

        if compressed:
            compress_len = root['/compress_len'][()]

    if compressed:
        for cam_id, cam_name in enumerate(image_dict.keys()):
            # un-pad and uncompress
            padded_compressed_image_list = image_dict[cam_name]
            image_list = []
            for frame_id, padded_compressed_image in enumerate(padded_compressed_image_list): # [:1000] to save memory
                compressed_image = padded_compressed_image
                image = cv2.imdecode(compressed_image, 1)
                image_list.append(image)
            image_dict[cam_name] = image_list

    return qpos, qvel, action, image_dict


In [6]:
'''
dataset_directory/
├── images/
│   ├── image1.png
│   ├── image2.png
│   ├── image3.png
│   └── ... (more image files)
└── labels.csv

labels.csv
image1.png, label_1, label_2, ..., label_14
image2.png, label_1, label_2, ..., label_14
'''
import os

dataset_dst = "./datasets/kitting_vision_ik/"
# Create dataset directory if it does not exist
if not os.path.exists(dataset_dst):
    os.makedirs(dataset_dst)


In [17]:
from PIL import Image
from torchvision import transforms
import os

# Set up your paths and transformation
dataset_dir = "./datasets/kitting/"
image_dir = os.path.join(dataset_dst, 'images')
if not os.path.exists(image_dir):
    os.makedirs(image_dir)
cam_name = 'cam_low'
transform = transforms.Resize((48, 64))
skip_frames = 1

# Define the source points (clicked points on the original video frames)
src_points = np.array([
    [114, 101],  # Left top
    [516, 111],  # Right top
    [9, 477],    # Left bottom
    [613, 476]   # Right bottom
], dtype='float32')

# Define the destination points (you can change these as needed)
dst_width = 640  # Width of the output video
dst_height = 480  # Height of the output video

dst_points = np.array([
    [0, 0],              # Left top
    [dst_width - 1, 0],  # Right top
    [0, dst_height - 1], # Left bottom
    [dst_width - 1, dst_height - 1]  # Right bottom
], dtype='float32')

# Compute the homography matrix
homography_matrix, _ = cv2.findHomography(src_points, dst_points)

count = 0

for dataset_name in os.listdir(dataset_dir):
    if not dataset_name.endswith('.hdf5'):
        continue

    print(f'Processing {dataset_name}...')

    # Load the dataset
    qpos, qvel, action, image_dict = load_hdf5(dataset_dir, dataset_name, skip_frames=0)

    # Save the images and labels
    images = image_dict[cam_name]
    for i, image in enumerate(images):
        # Skip for every skip_frames
        if i % (skip_frames + 1) != 0:
            continue

        # Convert the NumPy array to a PIL Image and apply the transformation
                # Apply the homography transformation
        image = cv2.warpPerspective(image, homography_matrix, (dst_width, dst_height))
        image = Image.fromarray(image)
        image = transform(image)

        # Save the transformed image
        image_name = f'{dataset_name[:-5]}_{i}.png'
        image_path = os.path.join(image_dir, image_name)
        image.save(image_path)

        # Save the labels in a csv file
        label_path = os.path.join(dataset_dst, 'labels.csv')
        with open(label_path, 'a') as f:
            f.write(f'{image_name},')
            for j in range(len(qpos[i])):
                f.write(f'{qpos[i][j]},')
            f.write('\n')

    # if count == 2:
    #     break
    # count += 1

Processing episode_14.hdf5...
Processing episode_25.hdf5...
Processing episode_18.hdf5...
Processing episode_13.hdf5...
Processing episode_5.hdf5...
Processing episode_30.hdf5...
Processing episode_8.hdf5...
Processing episode_48.hdf5...
Processing episode_9.hdf5...
Processing episode_22.hdf5...
Processing episode_10.hdf5...
Processing episode_27.hdf5...
Processing episode_20.hdf5...
Processing episode_43.hdf5...
Processing episode_40.hdf5...
Processing episode_31.hdf5...
Processing episode_44.hdf5...
Processing episode_11.hdf5...
Processing episode_23.hdf5...
Processing episode_16.hdf5...
Processing episode_42.hdf5...
Processing episode_0.hdf5...
Processing episode_29.hdf5...
Processing episode_32.hdf5...
Processing episode_21.hdf5...
Processing episode_24.hdf5...
Processing episode_2.hdf5...
Processing episode_28.hdf5...
Processing episode_34.hdf5...
Processing episode_15.hdf5...
Processing episode_37.hdf5...
Processing episode_46.hdf5...
Processing episode_19.hdf5...
Processing epis