In [64]:
import numpy as np
import pickle
import matplotlib.pyplot as plt

In [65]:
def load_pickle(file_path):
    with open(file_path, 'rb') as f:
        return pickle.load(f)

In [66]:
def discretize_action_to_control_mode_E2E(action):
    """
    -1 ~ 1 maps to 0 ~ 1
    """
    # Your action discretization logic here
    # print("Action: ", action)
    action_norm = (action + 1) / 2
    # print(action_norm, action)
    if 1 / 6 > action_norm >= 0:
        # print("| Slide up on right finger")
        control_mode = 0
        friction_state = 1  # left finger high friction
        pos_idx = 0
    elif 2 / 6 > action_norm >= 1 / 6:
        # print("| Slide down on right finger")
        control_mode = 1
        friction_state = 1
        pos_idx = 1
    elif 3 / 6 > action_norm >= 2 / 6:
        # print("| Slide up on left finger")
        control_mode = 2
        friction_state = -1
        pos_idx = 1
    elif 4 / 6 > action_norm >= 3 / 6:
        # print("| Slide down on left finger")
        control_mode = 3
        friction_state = -1
        pos_idx = 0
    elif 5 / 6 > action_norm >= 4 / 6:
        # print("| Rotate clockwise")
        control_mode = 4
        friction_state = 0
        pos_idx = 0
        # print("Rotate")
    else:
        assert 1 >= action_norm >= 5 / 6, f"Check: {action_norm}"
        # print("| Rotate anticlockwise")
        control_mode = 5
        friction_state = 0
        pos_idx = 1
        # print(pos_idx)
        # print("Rotate")
    return friction_state, control_mode, pos_idx

In [67]:
file_paths = [
    "/Users/qiyangyan/Desktop/TrainingFiles/Real4/Real4/demonstration/cubecylinder_3k5_demos_seed=0",
    "/Users/qiyangyan/Desktop/TrainingFiles/Real4/Real4/demonstration/cubecylinder_3k5_demos_seed=1",
    "/Users/qiyangyan/Desktop/TrainingFiles/Real4/Real4/demonstration/cubecylinder_3k5_demos_seed=2"
#     # "/Users/qiyangyan/Desktop/Diffusion/Demonstration/VFF-bigSteps-10000demos"
]

# file_paths = [
#     "/Users/qiyangyan/Desktop/TrainingFiles/Real4/Real4/demonstration/three_cylinder_3k5_demos_seed=0",
#     "/Users/qiyangyan/Desktop/TrainingFiles/Real4/Real4/demonstration/three_cylinder_3k5_demos_seed=1",
#     "/Users/qiyangyan/Desktop/TrainingFiles/Real4/Real4/demonstration/three_cylinder_3k5_demos_seed=2"
# ]

# Merge Data

In [68]:
merged_data = {}
terminal_counts = 0
last_goal = np.zeros(9)
for file_path in file_paths:
    terminal_counts = 0
    data = load_pickle(file_path)
    if merged_data == {}:
        merged_data = {key: [] for key in data.keys()}
    terminal_counts = sum(1 for item in data['terminals'] if item == 1)
    print(terminal_counts)
    
    for key in data.keys():
        merged_data[key].extend(data[key])  # Assuming each pickle file contains a list, you can modify this based on your data structure

3628
3508
3510


In [69]:
for key in merged_data.keys():
    merged_data[key] = np.array(merged_data[key])
    print(key, np.shape(merged_data[key]))

observations (82521, 24)
next_observations (93138, 24)
desired_goals (82521, 11)
sampled_desired_goals (82521, 11)
actions (82521, 2)
rewards (0,)
terminals (82521,)


# Preprocess the data

In [70]:
training_episode = 10000
data = merged_data.copy()
print(sum(1 for item in data['terminals'] if item == 1))

10646


## Get the episode number

In [71]:
num_episode = sum(1 for item in data['terminals'] if item == 1)
episode_info = {
        "episode_start_idx": [0],
        "episode_end_idx": [],
        "remove_episode": np.zeros(num_episode),
    }

for i, action in enumerate(data['actions']):
    control_mode = discretize_action_to_control_mode_E2E(data['actions'][i][1])[1]
    if data['terminals'][i] == 1:
        episode_info['episode_end_idx'].append(i)
        
        if i < len(data['terminals']) - 1:
            episode_info['episode_start_idx'].append(i+1)
    
    episode_idx = len(episode_info['episode_start_idx']) - 1
    if control_mode == 4 or control_mode == 5:
        if data['terminals'][i] == 1:
            pass
        else:
            # print(i, control_mode, action[1], episode_idx)
            episode_info['remove_episode'][episode_idx] = 1
    
print(len(episode_info['episode_start_idx']))
print(len(episode_info['episode_end_idx']))

num_remove = sum(1 for item in episode_info['remove_episode'] if item == 1)
print(num_remove)


10646
10646
4


## Remove the episode

In [72]:
for i in reversed(range(len(episode_info['remove_episode']))):
    if episode_info['remove_episode'][i] == 1:
        start_idx = episode_info['episode_start_idx'][i]
        end_idx = episode_info['episode_end_idx'][i]
        print("remove episode", i, start_idx, end_idx)
        for key in data.keys():
            if len(np.shape(data[key])) == 2:
                data[key] = np.concatenate((data[key][:start_idx, :], data[key][end_idx + 1:, :]), axis=0)
            else:
                data[key] = np.concatenate((data[key][:start_idx], data[key][end_idx + 1:]), axis=0)        
        print(len(data['terminals']))

remove episode 8248 63964 63975
82509
remove episode 7180 55744 55755
82497
remove episode 3653 28449 28460
82485
remove episode 1534 12007 12018
82473


## Use goal to reset terminals

In [73]:
print(sum(1 for item in data['terminals'] if item == 1))
for i in range(len(data['terminals'])):
    if i < len(data['terminals']) - 1:
        if data['desired_goals'][i][0] != data['desired_goals'][i+1][0]:
            data['terminals'][i] = 1
        else:
            data['terminals'][i] = 0
print(sum(1 for item in data['terminals'] if item == 1))

10642
10613


## Remove object velocity from observation

In [74]:
new_observation = []
for i, obs in enumerate(data['observations']):
    new_observation.append(np.concatenate((obs[:15], obs[21:]), axis=0))

data['observations'] = np.array(new_observation)

## Verify

In [75]:
for i, action in enumerate(data['actions']):
    control_mode = discretize_action_to_control_mode_E2E(data['actions'][i][1])[1]
    if control_mode == 4 or control_mode == 5:
        if data['terminals'][i] == 1:
            pass
        else:
            print(i, control_mode, action[1])
print("It should print nothing")

for key in data.keys():
    print(key, np.shape(data[key]))

print("Data: ")
print(file_paths)

It should print nothing
observations (82473, 18)
next_observations (93090, 24)
desired_goals (82473, 11)
sampled_desired_goals (82473, 11)
actions (82473, 2)
rewards (0,)
terminals (82473,)
Data: 
['/Users/qiyangyan/Desktop/TrainingFiles/Real4/Real4/demonstration/cubecylinder_3k5_demos_seed=0', '/Users/qiyangyan/Desktop/TrainingFiles/Real4/Real4/demonstration/cubecylinder_3k5_demos_seed=1', '/Users/qiyangyan/Desktop/TrainingFiles/Real4/Real4/demonstration/cubecylinder_3k5_demos_seed=2']


# Split training and testing data

In [76]:
terminal_counts = 0
data_train = {key: [] for key in data.keys()}
data_test = {key: [] for key in data.keys()}

for i, item in enumerate(data['terminals']):
    if item == 1:
        terminal_counts += 1
    
    if terminal_counts == training_episode:
        for key in data.keys():
            if len(np.shape(data[key])) == 2:
                data_train[key] = data[key][:i+1, :]
                data_test[key] = data[key][i+1:, :]
            else:
                data_train[key] = data[key][:i+1, ]
                data_test[key] = data[key][i+1:, ]
            print(key, np.shape(data_train[key]))
        break

terminal_counts = 0
for i, item in enumerate(data_train['terminals']):
    if item == 1:
        terminal_counts += 1
print(terminal_counts)

terminal_counts = 0
for i, item in enumerate(data_test['terminals']):
    if item == 1:
        terminal_counts += 1
print(terminal_counts)

# file_path = f'/Users/qiyangyan/Desktop/Diffusion/Demo_random/train_10k_cube_cylinder_noObjVel.pkl'
# with open(file_path, 'wb') as file:
#     pickle.dump(data_train, file)
# 
# file_path = '/Users/qiyangyan/Desktop/Diffusion/Demo_random/test_10k_cube_cylinder_noObjVel.pkl'
# with open(file_path, 'wb') as file:
#     pickle.dump(data_test, file)
# 
# print(f"Data saved to {file_path}")

observations (77768, 18)
next_observations (77768, 24)
desired_goals (77768, 11)
sampled_desired_goals (77768, 11)
actions (77768, 2)
rewards (0,)
terminals (77768,)
10000
613
Data saved to /Users/qiyangyan/Desktop/Diffusion/Demo_random/test_10k_cube_cylinder_noObjVel.pkl
