In [108]:
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
import glob 
import os 
import pickle 
from scipy.spatial.transform import Rotation as R 
from simpleicp import PointCloud, SimpleICP
from scipy.spatial import KDTree 

In [2]:
# convert pose to transformation matrix 
def pose2transform(pose):
    T = np.eye(4) 
    T[:3, :3] = R.from_quat(pose[3:], scalar_first=True).as_matrix() 
    T[:3, 3] = pose[:3] 
    return T 

# convert transformation matrix to pose 
def transform2pose(T):
    pose = np.zeros(7) 
    pose[:3] = T[:3, 3] 
    pose[3:] = R.from_matrix(T[:3, :3]).as_quat(scalar_first=True) 
    return pose 

def pose6D2transform(pose):
    T = np.eye(4) 
    T[:3, :3] = R.from_euler('xyz', pose[3:], degrees=True).as_matrix() 
    T[:3, 3] = pose[:3] 
    return T 

def transform2pose6D(T):   
    pose = np.zeros(6) 
    pose[:3] = T[:3, 3] 
    pose[3:] = R.from_matrix(T[:3, :3]).as_euler('xyz', degrees=True) 
    return pose 

# transform array of poses by a given pose 
def transform_poses(poses, delta_pose): 
    T_delta = pose2transform(delta_pose) 
    transformed_poses = np.zeros_like(poses) 
    for i in range(poses.shape[0]):
        transformed_poses[i] = transform2pose(np.linalg.inv(T_delta) @ pose2transform(poses[i])) # T_P0Pi = T_P0H @ T_HPi  
    return transformed_poses 

def transform_poses6D(poses, delta_pose): 
    T_delta = pose6D2transform(delta_pose) 
    transformed_poses = np.zeros_like(poses) 
    for i in range(poses.shape[0]):
        transformed_poses[i] = transform2pose6D(np.linalg.inv(T_delta) @ pose6D2transform(poses[i])) # T_P0Pi = T_P0H @ T_HPi  
    return transformed_poses 

In [3]:
dir_results = "/media/rp/Elements/abhay_ws/mujoco_contact_graph_generation/results/cross_peg_data_perturb" 
dir_pkl = dir_results + "/pkl" 
pkl_files = sorted(glob.glob(os.path.join(dir_pkl, "*.pkl")), key=os.path.getmtime)

dir_save = dir_pkl.removesuffix("pkl") + "processed_data"
if not os.path.exists(dir_save): 
    os.makedirs(dir_save)

# list of all contact state history 
N_timesteps = 500 
N_trials_max = 10_000 
N_trials = len(pkl_files) if len(pkl_files) < N_trials_max else N_trials_max
pkl_files = pkl_files[:N_trials] 
contact_poses_list = [] 
contact_delta_poses_list = [] 
trials_pose_0_list = [] 

pose_indices = ['x', 'y', 'z', 'a', 'b', 'c']

for i, pkl_file in enumerate(pkl_files): 
    pose_boundary_list = [] 

    # Read the pickle file
    with open(pkl_file, 'rb') as f:
        data = pickle.load(f)

    # unpack data 
    state_hist = data['state_hist'] 
    contact_pos = data['contact_pos'] 

    peg_pose_0 = state_hist[0, 1:8] 
    trials_pose_0_list.append(peg_pose_0) 

    # check if there is contact within the hole area, if so, add the pose to the list 
    for j, contact_pos_j in enumerate(contact_pos): # iterate through each time step 
        if len(contact_pos_j) > 0: # if there is contact 
            for k, contact_pos_hole_frame in enumerate(contact_pos_j): # iterate through each contact at current time step 
                if contact_pos_hole_frame[2] < 0 and max(contact_pos_hole_frame[:2]) < 0.012: # if contact is below the surface and within hole area 
                    peg_pose = state_hist[j, 1:8] 
                    pose_boundary_list.append(peg_pose) 
                    continue # don't need to check pose again 

    # print progress rate every 10% of total iterations 
    if (i+1) % max(np.floor(len(pkl_files)/10),1) == 0: 
        print(f"Completion Progress: {i+1}/{len(pkl_files)}")  

    # compute delta poses 
    contact_poses = np.array(pose_boundary_list) 
    contact_delta_poses = transform_poses(np.array(pose_boundary_list), peg_pose_0) 

    # convert list to dataframe 
    contact_poses_df = pd.DataFrame(contact_poses, columns=['x', 'y', 'z', 'qw', 'qx', 'qy', 'qz']) 
    contact_delta_poses_df = pd.DataFrame(contact_delta_poses, columns=['x', 'y', 'z', 'qw', 'qx', 'qy', 'qz']) 

    # convert quaternion to euler angles 
    quaternions = contact_poses_df[['qw', 'qx', 'qy', 'qz']].values 
    euler_angles = R.from_quat(quaternions, scalar_first=True).as_euler("xyz", degrees=True) 
    contact_poses_df['a'] = euler_angles[:,2]
    contact_poses_df['b'] = euler_angles[:,1]
    contact_poses_df['c'] = euler_angles[:,0] 

    quaternions = contact_delta_poses_df[['qw', 'qx', 'qy', 'qz']].values 
    euler_angles = R.from_quat(quaternions, scalar_first=True).as_euler("xyz", degrees=True) 
    contact_delta_poses_df['a'] = euler_angles[:,2]
    contact_delta_poses_df['b'] = euler_angles[:,1]
    contact_delta_poses_df['c'] = euler_angles[:,0] 

    # convert position from meters to millimeters 
    contact_poses_df[['x', 'y', 'z']] *= 1000 
    contact_delta_poses_df[['x', 'y', 'z']] *= 1000 

    # append to list of data of all trials 
    contact_poses_list.append(contact_poses_df)
    contact_delta_poses_list.append(contact_delta_poses_df)
    peg_pose_0 = contact_poses_list[0][pose_indices].values[0]  
    trials_pose_0_list.append(peg_pose_0)

# save the dataframe 
# if not os.path.exists(dir_save): 
#     os.makedirs(dir_save)
# contact_delta_poses_df.to_csv(os.path.join(dir_save, "pose_boundary_data_10k.csv"), index=False)  

Completion Progress: 1/1


In [4]:
# read contact map 
# dir_results = "/media/rp/Elements/abhay_ws/mujoco_contact_graph_generation/results/cross_peg_data_v2" 
# dir_pkl = dir_results + "/pkl" 
# dir_map = dir_pkl.removesuffix("pkl") + "processed_data"
# df_map = pd.read_csv(os.path.join(dir_map, "pose_boundary_data.csv")) 
# map = df_map[pose_indices].values 

# map = contact_poses_list[0][pose_indices].values # test ICP algorithm by transforming trial data to create map 

# # for each trial go through contact list and perform ICP to get the pose of the peg 
# for contact_delta_poses_df in contact_delta_poses_list: 
#     observation = contact_delta_poses_df[pose_indices].values 


In [114]:
# optimization based registration 
# p_true = peg_pose_0 
# initial_pose_est = np.zeros(6) 

map = contact_delta_poses_df[pose_indices].values 
observation = transform_poses6D(map, peg_pose_0) 

p_true = transform2pose6D(np.linalg.inv(pose6D2transform(peg_pose_0)))
initial_pose_est = p_true + np.random.normal(0, 0.25, 6)
max_iter = 10_000 
alpha = 1e-6
error_threshold = 1e-1 
error_hist = np.zeros((max_iter,6)) 
flag_compute_local_grad = True 
p_perturb_sigma = np.ones(6) * 1e-7 
flag_kdtree = True 

# downsample to increase speed 
downsample_factor = 10   
observation = observation[::downsample_factor]    

for i in range(max_iter): 

    # compute transformed observations 
    if i == 0: 
        p = initial_pose_est 
    else: 
        # compute update 
        if i > 1: 
            if flag_compute_local_grad: 
                p_perturb = p + np.random.normal(0, p_perturb_sigma, 6) 
                observations_transformed_perturb = transform_poses6D(observation, p_perturb) 
                closest_points = [] 
                if flag_kdtree: 
                    tree = KDTree(map) 
                    closest_points = map[tree.query(observations_transformed_perturb)[1]] 
                else: 
                    for obs in observations_transformed_perturb: 
                        closest_point = map[np.argmin(np.linalg.norm(map - obs, axis=1))] 
                        closest_points.append(closest_point) 
                error_perturb = np.mean((observations_transformed_perturb - np.array(closest_points)), axis=0) 
                J_est = (error_perturb - error) / (p_perturb - p) 
            else: 
                J_est = (error - error_prev) / (p - p_prev) 
        else: 
            J_est = np.ones(6) * 1e-6 
        
        error_prev = error 
        p_prev = p 

        update = - alpha * J_est 
        p = p_prev + update  
    observations_transformed = transform_poses6D(observation, p) 

    # for each observation, find the closest point in the map 
    if flag_kdtree: 
        tree = KDTree(map) 
        closest_points = map[tree.query(observations_transformed)[1]] 
    else: 
        closest_points = [] 
        for obs in observations_transformed: 
            closest_point = map[np.argmin(np.linalg.norm(map - obs, axis=1))] 
            closest_points.append(closest_point) 
    
    # compute the total distance error 
    error = np.mean((observations_transformed - closest_points), axis=0)
    error_hist[i:,:] = np.mean(np.abs(observations_transformed - np.array(closest_points)), axis=0) 

    # check for convergence 
    if np.mean(np.abs(error)) < error_threshold: 
        print(f"Converged at iteration {i}: Error: {error}") 
        break 

    # print progress every 1% of total iterations 
    if (i+1) % max(np.floor(max_iter/100),1) == 0 or i == 0:  
        print(f"Iteration: {i}, Error: {error}, Mean. Abs. Error: {np.mean(np.abs(error),axis=0)}") 

# print pose estimate 
print(f"True Pose: {p_true}") 
print(f"Pose Estimate: {p}") 
print(f"Initial Error: {initial_pose_est-p_true}")
print(f"Error: {p-p_true}") 
print(f"Mean Initial Error: {np.mean(initial_pose_est-p_true)}")
print(f"Mean Error: {np.mean(p-p_true)}") 



Iteration: 0, Error: [-0.35098071 -0.04676979 -0.15071153 -0.0272922  -0.05451026 -0.04520291], Mean. Abs. Error: 0.11257790003771379
Iteration: 99, Error: [-0.351236   -0.04815228 -0.15114508 -0.02737982 -0.05471825 -0.0453088 ], Mean. Abs. Error: 0.1129900382287587
Iteration: 199, Error: [-0.35110421 -0.04836896 -0.15105765 -0.02742792 -0.05447848 -0.04543524], Mean. Abs. Error: 0.11297874218577929
Iteration: 299, Error: [-0.35118091 -0.04847425 -0.15113309 -0.02728862 -0.05449146 -0.04555532], Mean. Abs. Error: 0.11302060758690041
Iteration: 399, Error: [-0.35061363 -0.04850714 -0.15115348 -0.0274401  -0.05462918 -0.04549538], Mean. Abs. Error: 0.11297315086088396
Iteration: 499, Error: [-0.35054871 -0.04851785 -0.15123825 -0.02748335 -0.05468388 -0.04560843], Mean. Abs. Error: 0.11301341132041622
Iteration: 599, Error: [-0.35067525 -0.04847463 -0.15145116 -0.02749484 -0.05471659 -0.04579715], Mean. Abs. Error: 0.11310160224953086
Iteration: 699, Error: [-0.3508599  -0.04860225 -0.1

In [1]:
# plot error history 
%matplotlib qt 
plt.figure()
plt.plot(error_hist)
plt.plot(np.mean(error_hist, axis=1), 'k--')
plt.xlim([0, i-1])    
plt.xlabel("Iteration")
plt.ylabel("Error")
plt.legend([pose_indices,'MAE']) 


NameError: name 'plt' is not defined