In [1]:
import cv2
import os
import h5py
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter


DT=0.02
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
STATE_NAMES = JOINT_NAMES + ["gripper"]
EVAL_FILE = "episode_0.hdf5"

def load_hdf5(dataset_dir, dataset_name, skip_frames=0):
    dataset_path = os.path.join(dataset_dir, dataset_name)
    if not os.path.isfile(dataset_path):
        print(f'Dataset does not exist at \n{dataset_path}\n')
        exit()

    with h5py.File(dataset_path, 'r') as root:
        compressed = root.attrs.get('compress', False)
        
        # Apply skip_frames when loading qpos, qvel, and action
        qpos = root['/observations/qpos'][()][::skip_frames + 1]
        qvel = root['/observations/qvel'][()][::skip_frames + 1]
        action = root['/action'][()][::skip_frames + 1]
        
        image_dict = dict()
        for cam_name in root[f'/observations/images/'].keys():
            image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()][::skip_frames + 1]

        if compressed:
            compress_len = root['/compress_len'][()]

    if compressed:
        for cam_id, cam_name in enumerate(image_dict.keys()):
            # un-pad and uncompress
            padded_compressed_image_list = image_dict[cam_name]
            image_list = []
            for frame_id, padded_compressed_image in enumerate(padded_compressed_image_list): # [:1000] to save memory
                compressed_image = padded_compressed_image
                image = cv2.imdecode(compressed_image, 1)
                image_list.append(image)
            image_dict[cam_name] = image_list

    return qpos, qvel, action, image_dict

def process_dataset(dataset_dir, dataset_name, ismirror):
    qpos, qvel, action, image_dict = load_hdf5(dataset_dir, dataset_name)

def visualize_joints(qpos_list, command_list, plot_path=None, ylim=None, label_overwrite=None):
    if label_overwrite:
        label1, label2 = label_overwrite
    else:
        label1, label2 = 'State', 'Command'

    qpos = np.array(qpos_list) # ts, dim
    command = np.array(command_list)
    num_ts, num_dim = qpos.shape
    h, w = 2, num_dim
    num_figs = num_dim
    fig, axs = plt.subplots(num_figs, 1, figsize=(w, h * num_figs))

    # plot joint state
    all_names = [name + '_left' for name in STATE_NAMES] + [name + '_right' for name in STATE_NAMES]
    for dim_idx in range(num_dim):
        ax = axs[dim_idx]
        ax.plot(qpos[:, dim_idx], label=label1)
        ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
        ax.legend()

    # plot arm command
    for dim_idx in range(num_dim):
        ax = axs[dim_idx]
        ax.plot(command[:, dim_idx], label=label2)
        ax.legend()

    if ylim:
        for dim_idx in range(num_dim):
            ax = axs[dim_idx]
            ax.set_ylim(ylim)

    plt.tight_layout()
    plt.savefig(plot_path)
    print(f'Saved qpos plot to: {plot_path}')
    plt.close()

In [2]:
'''
This cell works good enough, but there are some flase negatives for gripping, for example, example_0. 
Perhaps could consider using area unser a curve as a metric for gripping, and thresholding on that. 
'''

import numpy as np
import os
from scipy.signal import savgol_filter

# Assuming load_hdf5 is a function that loads your data
last_grab_release_thres = 0.012
window_size = 20  # Frame window for peak/dip filtering

def find_significant_peaks(signal, threshold, window_size):
    peaks = []
    i = 0
    while i < len(signal):
        if signal[i] > threshold:
            # Look ahead and find the maximum in the window
            window_end = min(i + window_size, len(signal))
            max_index = i + np.argmax(signal[i:window_end])
            peaks.append(max_index)
            i = window_end  # Skip the window to avoid double-counting
        else:
            i += 1
    return peaks

def find_significant_dips(signal, threshold, window_size):
    dips = []
    i = 0
    while i < len(signal):
        if signal[i] < -threshold:
            # Look ahead and find the minimum in the window
            window_end = min(i + window_size, len(signal))
            min_index = i + np.argmin(signal[i:window_end])
            dips.append(min_index)
            i = window_end  # Skip the window to avoid double-counting
        else:
            i += 1
    return dips

odd_right = []
odd_left = []

split_indices = {}
DT_f = 7 # Frames

qpos, qvel, action, image_dict = load_hdf5("", EVAL_FILE, skip_frames=2)

left_gripper_idx = STATE_NAMES.index("gripper")  # Assuming gripper is the last in the STATE_NAMES
right_gripper_idx = left_gripper_idx + len(STATE_NAMES)  # Right gripper is the corresponding right side
qpos_right = np.array(qpos[:, right_gripper_idx])  # ts, dim
qpos_left = np.array(qpos[:, left_gripper_idx])  # ts, dim

# Calculate the differential
qpos_diff_right = np.diff(qpos_right, axis=0)
qpos_diff_left = np.diff(qpos_left, axis=0)

# Apply Savitzky-Golay filter to smooth the differentials
qpos_diff_smoothed_right = savgol_filter(qpos_diff_right, window_length=5, polyorder=2, axis=0)
qpos_diff_smoothed_left = savgol_filter(qpos_diff_left, window_length=5, polyorder=2, axis=0)

right_sharp_increases_release = find_significant_peaks(qpos_diff_smoothed_right, last_grab_release_thres, window_size)
right_sharp_decreases_grab = find_significant_dips(qpos_diff_smoothed_right, last_grab_release_thres, window_size)
left_sharp_increases_release = find_significant_peaks(qpos_diff_smoothed_left, last_grab_release_thres, window_size)
left_sharp_decreases_grab = find_significant_dips(qpos_diff_smoothed_left, last_grab_release_thres, window_size)

# First Grab
first_end = right_sharp_decreases_grab[0] + DT_f

# First release after first grab:
for i in right_sharp_increases_release:
    if i > first_end:
        second_end = i + DT_f 
        break

# Left Hand First Grab after Right Hand First Release (second_end)
for i in left_sharp_decreases_grab:
    if i > second_end:
        left_first_grab = i + DT_f
        break
if left_first_grab < second_end:
    print("Error: Left Hand First Grab before Right Hand First Release")
    print("Significant Sharp Decreases Grab:", left_sharp_decreases_grab)

# Left Hand First Release after Left Hand First Grab
for i in left_sharp_increases_release:
    if i > left_first_grab:
        left_first_release = i + DT_f
        break
if left_first_release < left_first_grab:
    print("Error: Left Hand First Release before Left Hand First Grab")
    print("Significant Sharp Increases Release:", left_sharp_increases_release)

# Right Hand Last Grab after Left Hand Release
right_hand_last_grab = right_sharp_decreases_grab[-1] + DT_f
if right_hand_last_grab < left_first_release:
    print("Error: Right Hand Last Grab before Left Hand Release")
    print("Significant Sharp Increases Release:", right_sharp_decreases_grab)

# Right Hand Last Release
right_hand_last_release = right_sharp_increases_release[-1] + DT_f
if right_hand_last_release < right_hand_last_grab:
    print("Error: Right Hand Last Release before Right Hand Last Grab")
    print("Significant Sharp Decreases Grab:", right_sharp_increases_release)
    
first_seg = (0, first_end)
second_seg = (first_end, second_end)
third_seg = (second_end, left_first_grab)
fourth_seg = (left_first_grab, left_first_release)
fifth_seg = (left_first_release, right_hand_last_grab)
sixth_seg = (right_hand_last_grab, min(right_hand_last_release, qpos_diff_smoothed_right.shape[0]))
print("All Segments:", first_seg, second_seg, third_seg, fourth_seg, fifth_seg, sixth_seg)


All Segments: (0, np.int64(145)) (np.int64(145), np.int64(250)) (np.int64(250), np.int64(402)) (np.int64(402), np.int64(497)) (np.int64(497), np.int64(686)) (np.int64(686), 733)


In [3]:
# Get Qpos for each segment qpos and visualize.
qpos, qvel, action, image_dict = load_hdf5("", EVAL_FILE, skip_frames=2)

In [4]:
# Getting the qpos segments
first_qpos = qpos[first_seg[0]:first_seg[1]]
second_qpos = qpos[second_seg[0]:second_seg[1]]
third_qpos = qpos[third_seg[0]:third_seg[1]]
fourth_qpos = qpos[fourth_seg[0]:fourth_seg[1]]
fifth_qpos = qpos[fifth_seg[0]:fifth_seg[1]]
sixth_qpos = qpos[sixth_seg[0]:sixth_seg[1]]


In [25]:
# Plotting and visualizing the qpos segments
import importlib
from visionik import infer
importlib.reload(infer)
from visionik.infer import VisionIKInference

sample_per_seq = 7
model_checkpoint_path = "/home/weixun/testing/vision-ik/testing/models/truncated_shufflenet/more_distortion_less_freq/checkpoint_epoch_444_step_880000.pth"
model = VisionIKInference(model_checkpoint_path)
tuple_list = [(first_qpos, "first_seg_first_frame_out.gif"), 
              (second_qpos, "second_seg_first_frame_out.gif"),
                (third_qpos, "third_seg_first_frame_out.gif"),
                (fourth_qpos, "fourth_seg_first_frame_out.gif"),
                (fifth_qpos, "fifth_seg_first_frame_out.gif"),
                (sixth_qpos, "sixth_seg_first_frame_out.gif")
              ]

def infer_and_visualize(qpos_list, gif_path, plot_dir="./eval_plots/", ylim=(-1, 1), label_overwrite=None):
    
    if not os.path.exists(plot_dir):
        os.makedirs(plot_dir)
    seg_name = gif_path[:-20]

    # Way to sample from dataset after split:
    N = len(qpos_list)  # Total number of frames in the sequence
    t_axis = []
    # Uniformly sample {self.sample_per_seq} frames from the sequence
    for i in range(sample_per_seq-1):
        t_axis.append(int(i*(N-1)/(sample_per_seq-1)))
    t_axis.append(N-1)  # Ensure the last frame is always included

    qpos_y = qpos_list[t_axis]

    if label_overwrite:
        label1, label2 = label_overwrite
    else:
        label1, label2 = 'State', 'Inference'

    # Performing Inferenct on GIF
    infer_outputs = model.infer_gif(gif_path)
    infer_outputs = np.array(infer_outputs)

    # Plotting Preparations
    qpos = np.array(qpos_list) # ts, dim
    num_ts, num_dim = qpos.shape
    h, w = 2, num_dim
    num_figs = num_dim
    fig, axs = plt.subplots(num_figs, 1, figsize=(w, h * num_figs))

    # plot joint state
    all_names = [name + '_left' for name in STATE_NAMES] + [name + '_right' for name in STATE_NAMES]
    for dim_idx in range(num_dim):
        ax = axs[dim_idx]
        ax.plot(qpos[:, dim_idx], label=label1)
        ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
        ax.legend()

    # plot Joint state where comparisons matter
    for dim_idx in range(num_dim):
        ax = axs[dim_idx]
        ax.scatter(t_axis, qpos_y[:, dim_idx], label=label1)
        ax.legend()
    
    # Scatter Plotting Inference Points
    for dim_idx in range(num_dim):
        ax = axs[dim_idx]
        ax.scatter(t_axis, infer_outputs[:, dim_idx], label=label2)
        ax.legend()

    if ylim:
        for dim_idx in range(num_dim):
            ax = axs[dim_idx]
            ax.set_ylim(ylim)

    # Adding the super title
    fig.suptitle(f'Inference Results for {seg_name}', fontsize=16, fontweight='bold')
    plt.tight_layout(rect=[0, 0, 1, 0.98])  # Adjust the rect to fit the suptitle

    plot_path = os.path.join(plot_dir, f"infer_{seg_name}.png")
    plt.savefig(plot_path)
    print(f'Saved qpos plot to: {plot_path}')
    plt.close()

for seg, infer_gif in tuple_list:
    infer_and_visualize(seg, infer_gif)

Saved qpos plot to: ./eval_plots/infer_first_seg.png
Saved qpos plot to: ./eval_plots/infer_second_seg.png
Saved qpos plot to: ./eval_plots/infer_third_seg.png
Saved qpos plot to: ./eval_plots/infer_fourth_seg.png
Saved qpos plot to: ./eval_plots/infer_fifth_seg.png
Saved qpos plot to: ./eval_plots/infer_sixth_seg.png
