In [102]:
import pickle
import numpy as np
import torch
import matplotlib.pyplot as plt
import os
import rotations


In [103]:
def load_pickle(file_path):
    try:
        with open(file_path, 'rb') as file:
            data = pickle.load(file)
            return data
    except Exception as e:
        print(f"An error occurred: {e}")
        return None

In [104]:
def discretize_action_to_control_mode_E2E(action):
    """
    -1 ~ 1 maps to 0 ~ 1
    """
    # Your action discretization logic here
    # print("Action: ", action)
    action_norm = (action + 1) / 2
    # print(action_norm, action)
    if 1 / 6 > action_norm >= 0:
        # print("| Slide up on right finger")
        control_mode = 0
        friction_state = 1  # left finger high friction
        pos_idx = 0
    elif 2 / 6 > action_norm >= 1 / 6:
        # print("| Slide down on right finger")
        control_mode = 1
        friction_state = 1
        pos_idx = 1
    elif 3 / 6 > action_norm >= 2 / 6:
        # print("| Slide up on left finger")
        control_mode = 2
        friction_state = -1
        pos_idx = 1
    elif 4 / 6 > action_norm >= 3 / 6:
        # print("| Slide down on left finger")
        control_mode = 3
        friction_state = -1
        pos_idx = 0
    elif 5 / 6 > action_norm >= 4 / 6:
        # print("| Rotate clockwise")
        control_mode = 4
        friction_state = 0
        pos_idx = 0
        # print("Rotate")
    else:
        assert 1 >= action_norm >= 5 / 6, f"Check: {action_norm}"
        # print("| Rotate anticlockwise")
        control_mode = 5
        friction_state = 0
        pos_idx = 1
        # print(pos_idx)
        # print("Rotate")
    return friction_state, control_mode, pos_idx

In [105]:
def plot(action, discretize):
    
    true_action_first_elements = action[:, 0].copy()
    true_action_second_elements = action[:, 1].copy()
    if len(action[0]) == 3:
        true_action_indicator = action[:, 2].copy()
        
    if discretize is True:
        for i, sub_action in enumerate(true_action_second_elements):
            true_action_second_elements[i] = discretize_action_to_control_mode_E2E(sub_action)[1]

    # Plotting
    plt.figure(figsize=(12, 5))

    plt.subplot(1, len(action[0]), 1)  # First subplot
    indices = list(range(len(true_action_first_elements)))  # Creating a list of indices
    plt.scatter(indices, true_action_first_elements, label='True Action First Element', color='orange',
                s=10)  # Scatter plot

    plt.title('First Element of Actions')
    # plt.ylim(-1, 1)
    plt.xlabel('Iteration')
    plt.ylabel('Value')
    plt.legend()

    plt.subplot(1, len(action[0]), 2)  # Second subplot
    indices = list(range(len(true_action_second_elements)))  # Creating a list of indices
    plt.scatter(indices, true_action_second_elements, label='True Action Second Element', color='orange',
                s=5)  # Scatter plot

    plt.title('Second Element of Actions')
    plt.xlabel('Iteration')
    plt.ylabel('Value')
    plt.legend()

    if len(action[0]) == 3:
        plt.subplot(1, len(action[0]), 3)  # third subplot
        indices = list(range(len(true_action_indicator)))  # Creating a list of indices
        plt.scatter(indices, true_action_indicator, label='True Action Third Element', color='orange',
                    s=5)  # Scatter plot

        plt.title('Third Element of Actions')
        plt.xlabel('Iteration')
        plt.ylabel('Value')
        plt.legend()

    plt.tight_layout()
    plt.show()

In [106]:
# pickle_file_path = '/Users/qiyangyan/Desktop/Diffusion/Demonstration/VFF-bigSteps-testingDataset'
# pickle_file_path = '/Users/qiyangyan/Desktop/Diffusion/Demonstration/bigSteps_10000demos_testingDataset'
pickle_file_path = '/Users/qiyangyan/Desktop/Diffusion/Demonstration/bigSteps_10000demos_slide_endIndicator'
# pickle_file_path = '/Users/qiyangyan/Desktop/Diffusion/Demonstration/bigSteps_10000demos_slide_endIndicator_testingDataset'
data = load_pickle(pickle_file_path)

In [107]:
plt.figure(figsize=(12, 5))
print(len(data['actions']))
plot(data['actions'], discretize=True)

In [108]:
plot(data['actions'], discretize=False)