In [4]:
import numpy as np
!pip install scipy
import scipy.interpolate as interpolate
from scipy.spatial.transform import Slerp
from scipy.spatial.transform import Rotation as R



In [5]:
env = ['duck_lift', 'banana_pass1', 'hammer_strike'][-1]
# data = np.load(f'{env}.npz', allow_pickle=True,)
# entries = data.files
# print(entries)
def load_motion(motion_file):
    motion_file = np.load(motion_file, allow_pickle=True)
    reference_motion =  {k:v for k, v in motion_file.items()}
    reference_motion['s_0'] = reference_motion['s_0'][()]
    return reference_motion
data = load_motion(f'{env}.npz')
data.keys()

dict_keys(['length', 'SIM_SUBSTEPS', 'DATA_SUBSTEPS', 'object_translation', 'object_orientation', 's_0'])

In [6]:
def data_analysis(data):
    print('range of data: ', np.min(data, axis=0), np.max(data, axis=0))
    print('start and end: ', data[0], data[-1])

In [7]:
print(data['object_translation'].shape)
trans_data = data['object_translation']
data_analysis(trans_data)

(59, 3)
range of data:  [-0.13407614 -0.20917316  0.04889371] [ 0.07301792 -0.11842044  0.22849914]
start and end:  [-0.12142183 -0.20917316  0.04889371] [-0.08168487 -0.15548259  0.18643826]


In [8]:
print(data['object_orientation'].shape)
ori_data = data['object_orientation']
data_analysis(ori_data)

(59, 4)
range of data:  [ 0.55069668  0.70205983 -0.30280827 -0.21246096] [ 0.70803234  0.81144097 -0.04634558  0.04666741]
start and end:  [ 0.70803234  0.70205983 -0.06042589  0.04637796] [ 0.61398888  0.7786923  -0.11786583 -0.05256997]


In [9]:
def interpolate_data(data, random_sample=5):
    points = np.random.uniform(np.min(data), np.max(data), size=random_sample)
    points = np.concatenate([[data[0]], points])
    len_points = points.shape[0]
    f = interpolate.interp1d(np.arange(len_points), points, kind='quadratic', fill_value="extrapolate")  # ‘linear’, ‘nearest’, ‘nearest-up’, ‘zero’, ‘slinear’, ‘quadratic’, ‘cubic’, ‘previous’, or ‘next’. 

    x_new = np.arange(0, len_points-1, (len_points-1)/data.shape[0])
    y_new = f(x_new)
    y_new[-1] = points[-1]
    return y_new

def interpolate_quat(data):
    random_quat = np.random.uniform(0, 1, size=(1,4))
    random_rot = R.from_quat(random_quat)
    # what if final quat is not feasible (s.t. object intercepting the table)?

    init_rot = R.from_quat(np.hstack((data[0,-1], data[0,:3])))
    quats = R.concatenate([init_rot, random_rot])
    len_quats = 2
    f = Slerp(np.arange(len_quats), quats)

    x_new = np.arange(0, len_quats-1, (len_quats-1)/data.shape[0])
    y_new = f(x_new)
    # y_new[-1] = len_quats[-1]
    return y_new


In [13]:
new_trans_data = []
for d in trans_data.T:  # each dimension
    new_trans_data.append(interpolate_data(d))
new_trans_data = np.array(new_trans_data).T

new_ori_data = []
# for d in ori_data.T:
#     new_ori_data.append(interpolate_data(d))
# new_ori_data = np.array(new_ori_data).T
new_rot = interpolate_quat(ori_data)
for rot in new_rot:
    new_ori_data += [rot.as_quat()]
new_ori_data = np.vstack(new_ori_data)

In [14]:
import copy
save_dir = f'generated_trajs_quat_slerp/{env}'
os.makedirs(save_dir, exist_ok=True)
new_traj = copy.copy(dict(data))

new_traj['object_translation'] = new_trans_data
new_traj['object_orientation'] = new_ori_data
print(new_traj.keys())
np.savez(save_dir+f'/{env}', **new_traj)  # save a dict as npz

dict_keys(['length', 'SIM_SUBSTEPS', 'DATA_SUBSTEPS', 'object_translation', 'object_orientation', 's_0'])


In [15]:
def interpolate_trans(data, initial_point=None, random_sample=5):
    points = np.random.uniform(np.min(data), np.max(data), size=random_sample)
    if initial_point is not None: # set the same initial point as original traj
        points[0] = initial_point
    else:
        points[0] = data[0]
    len_points = points.shape[0]
    f = interpolate.interp1d(np.arange(len_points), points, kind='quadratic', fill_value="extrapolate")  # ‘linear’, ‘nearest’, ‘nearest-up’, ‘zero’, ‘slinear’, ‘quadratic’, ‘cubic’, ‘previous’, or ‘next’. 

    x_new = np.arange(0, len_points-1, (len_points-1)/data.shape[0])
    y_new = f(x_new)
    y_new[-1] = points[-1]
    return y_new

def interpolate_quat(data):
    random_quat = np.random.uniform(0, 1, size=(1,4))
    random_rot = R.from_quat(random_quat)
    # what if final quat is not feasible (s.t. object intercepting the table)?

    init_rot = R.from_quat(np.hstack((data[0,-1], data[0,:3])))
    quats = R.concatenate([init_rot, random_rot])
    len_quats = 2
    f = Slerp(np.arange(len_quats), quats)

    x_new = np.arange(0, len_quats-1, (len_quats-1)/data.shape[0])
    y_new = f(x_new)
    # y_new[-1] = len_quats[-1]
    return y_new


In [17]:
# loop over all envs
import sys
import copy
sys.path.append("..")
from tcdm.common import ENVS

for env in ENVS:
    # try:
    data = load_motion(f'{env}.npz')
    trans_data = data['object_translation']
    ori_data = data['object_orientation']

    new_trans_data = []
    for i, d in enumerate(trans_data.T):
        if data['s_0']['motion_planned']['position'].shape[0] == 36: # 30 hand + 3 pos + 3 ori of object
            initial_point = data['s_0']['motion_planned']['position'][30+i] # set the original position as initial sampled point position
        else:
            initial_point = None
        new_trans_data.append(interpolate_trans(d, initial_point))
    new_trans_data = np.array(new_trans_data).T

    new_ori_data = []
    new_rot = interpolate_quat(ori_data)
    for rot in new_rot:
        new_ori_data += [rot.as_quat()]
    new_ori_data = np.array(new_ori_data)

    print(f'{env}: {new_trans_data.shape}, {new_ori_data.shape}')

    save_dir = f'generated_trajs_quat_slerp/{env}'
    os.makedirs(save_dir, exist_ok=True)
    new_traj = copy.copy(dict(data))

    new_traj['object_translation'] = new_trans_data
    new_traj['object_orientation'] = new_ori_data
    new_traj['SIM_SUBSTEPS'] = int(data['SIM_SUBSTEPS']/3) # to adapt to current simulator
    print('substeps: ', new_traj['SIM_SUBSTEPS'])
    np.savez(save_dir+f'/{env}', **new_traj)  # save a dict as npz
    # except:
    #     print(f'{env} is not working, shape: {trans_data.shape}')


AttributeError: 'list' object has no attribute 'shape'