In [1]:
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
from scipy.spatial.transform import Rotation as R 
import pickle 
import os 
import glob 
from mpl_toolkits.mplot3d import Axes3D 
import random 

In [2]:
# functions
# function to compute distance between a line segment and point 
def dist_point_to_segment(P0, PA, PB): 
    len_sq = np.linalg.norm(PB-PA)**2  
    if len_sq == 0.0: 
        return np.linalg.norm(P0-PA) 
    d = np.linalg.norm(np.linalg.cross((P0-PA),(P0-PB))) / np.linalg.norm(PB-PA) # see: https://mathworld.wolfram.com/Point-LineDistance3-Dimensional.html 
    dot = np.dot((P0-PA),(PB-PA)) 
    p = dot / len_sq # see: https://stackoverflow.com/questions/849211/shortest-distance-between-a-point-and-a-line-segment
    if p < 0: 
        dist = np.linalg.norm(P0-PA)  
    elif p > 1: 
        dist = np.linalg.norm(P0-PB) 
    else: 
        dist = d  
    return dist 

# function to classify hole point into hole geometry element 
def classify_hole_point(point, hole_parameters, tol=0.1): 
    x = point[0]
    y = point[1] 
    z = point[2] 
    hole_geom_class = [] 

    RR = hole_parameters['R'] 
    a = hole_parameters['a'] 
    b = hole_parameters['b']  
    d = hole_parameters['d']  
    z_h_top = -(d/2 - np.sqrt((d/2)**2 - (a/2)**2))
    z_h_bot = -d + (d/2 - np.sqrt((d/2)**2 - (a/2)**2));  

    # vertices 
    HV1 = np.array([+a/2, -np.sqrt(RR**2 - (a/2)**2), z_h_top])
    HV2 = np.array([-a/2, -np.sqrt(RR**2 - (a/2)**2), z_h_top])
    HV3 = np.array([+a/2, -(b+np.sqrt(RR**2 - (a/2)**2)), z_h_top]) 
    HV4 = np.array([-a/2, -(b+np.sqrt(RR**2 - (a/2)**2)), z_h_top]) 
    HV1_bot = np.array([+a/2, -np.sqrt(RR**2 - (a/2)**2), z_h_bot])
    HV2_bot = np.array([-a/2, -np.sqrt(RR**2 - (a/2)**2), z_h_bot])
    HV3_bot = np.array([+a/2, -(b+np.sqrt(RR**2 - (a/2)**2)), z_h_bot]) 
    HV4_bot = np.array([-a/2, -(b+np.sqrt(RR**2 - (a/2)**2)), z_h_bot]) 

    vertices = [HV1, HV2, HV3, HV4] 
    vertices_class = ['HV1','HV2','HV3','HV4'] 
    for i, vertex in enumerate(vertices): 
        dist = np.linalg.norm(point - vertex) 
        if dist < tol: 
            hole_geom_class.append(vertices_class[i]) 

    # curved edges 
    hole_theta_arc = np.pi - np.arcsin((a/2)/RR) 
    theta = np.arctan2(-x, y) 
    rp = np.sqrt(x**2 + y**2) 

    if x < (d/2):  
        z_curved_edge_theoretical = -(d/2) + np.sqrt((d/2)**2 - x**2) 
        if theta > 0 and theta <= hole_theta_arc and np.abs(rp - RR) < tol and np.abs(z - z_curved_edge_theoretical) < 0.1*tol:  
            hole_geom_class.append('HE1')
        elif theta < 0 and theta >= -hole_theta_arc and np.abs(rp - RR) < tol and np.abs(z - z_curved_edge_theoretical) < 0.1*tol:
            hole_geom_class.append('HE2') 

    # straight edges 
    HE3 = [HV1, HV1_bot]
    HE4 = [HV2, HV2_bot]
    HE5 = [HV3, HV3_bot]
    HE6 = [HV4, HV4_bot]

    # check straight edges 
    straight_edges = [HE3, HE4, HE5, HE6]
    straight_edges_class = ['HE3', 'HE4', 'HE5', 'HE6']  
    for i, straight_edge in enumerate(straight_edges): 
        dist = dist_point_to_segment(point, straight_edge[0], straight_edge[1])  
        if dist < tol: 
            hole_geom_class.append(straight_edges_class[i]) 

    # curved faces 
    if theta > 0 and theta <= hole_theta_arc and np.abs(rp - RR) < tol and z < z_curved_edge_theoretical and z > -d-z_curved_edge_theoretical:  
        hole_geom_class.append('HF1')
    elif theta < 0 and theta >= -hole_theta_arc and np.abs(rp - RR) < tol and z < z_curved_edge_theoretical and z > -d-z_curved_edge_theoretical:
        hole_geom_class.append('HF2')

    # if not hole_geom_class: 
    #     hole_geom_class = 'None' 

    return hole_geom_class 

# function to classify peg point into peg geometry element 
def classify_peg_point(point, peg_parameters, tol=0.1):

    x = point[0]
    y = point[1] 
    z = point[2] 
    peg_geom_class = []  

    r = peg_parameters['r'] 
    e = peg_parameters['e'] 
    lp = peg_parameters['lp']  
    hp = peg_parameters['hp']  
    
    # z_h_top = -(d/2 - np.sqrt((d/2)**2 - (a/2)**2))

    # z_h_bot = -d + (d/2 - np.sqrt((d/2)**2 - (a/2)**2));  

    # define vertices 
    PV1 = np.array([+e, -np.sqrt(r**2 - e**2),  0])
    PV2 = np.array([-e, -np.sqrt(r**2 - e**2), 0])
    PV3 = np.array([+e, -lp, 0]) 
    PV4 = np.array([-e, -lp, 0]) 

    # straight edges 
    PE3 = [PV1, PV3]
    PE4 = [PV2, PV4] 
    
    # check straight edges 
    straight_edges = [PE3, PE4] 
    straight_edges_class = ['PE3', 'PE4']   
    for i, straight_edge in enumerate(straight_edges): 
        dist = dist_point_to_segment(point, straight_edge[0], straight_edge[1])  
        if dist < tol: 
            peg_geom_class.append(straight_edges_class[i]) 
    
    # curved edges 
    peg_theta_arc = np.pi - np.arcsin(e/r) 
    theta = np.arctan2(-x, y) 
    rp = np.sqrt(x**2 + y**2) 
  
    if theta > 0 and theta <= peg_theta_arc and np.abs(rp - r) < tol and np.abs(z-0)<tol:  
        peg_geom_class.append('PE1')
    elif theta < 0 and theta >= -peg_theta_arc and np.abs(rp - r) < tol and np.abs(z-0)<tol: 
        peg_geom_class.append('PE2') 

    # curved faces 
    if theta > 0 and theta <= peg_theta_arc and np.abs(rp - r) < tol and z > 0 and z < hp + tol:  
        peg_geom_class.append('PF1')
    elif theta < 0 and theta >= -peg_theta_arc and np.abs(rp - r) < tol and z > 0 and z < hp + tol:
        peg_geom_class.append('PF2') 

    # flat faces 
    if y > -lp-tol and y < -np.sqrt(r**2 - e**2)+tol and z > -tol and z < hp+tol:  
        if np.abs(x-e)<tol: 
            peg_geom_class.append('PF3') 
        elif np.abs(x+e)<tol: 
            peg_geom_class.append('PF4') 

    # if not peg_geom_class: 
    #     peg_geom_class = 'None' 
    return peg_geom_class 

In [100]:
# define geometry parameters 
hole_parameters = {
    'R': 3.05, 
    'a': 4.5, 
    'b': 7.0 - np.sqrt(3.05**2 - (4.5/2)**2), 
    'd': 10, 
} 
peg_parameters = {
    'r': 3.0, 
    'e': 2.2, 
    'lp': 29, 
    'hp': 30, 
} 
hole_classes = ['HF1','HF2','HE1','HE2','HE3','HE4','HE5','HE6','HV1','HV2','HV3','HV4'] 
peg_classes = ['PF1','PF2','PF3','PF4','PE1','PE2','PE3','PE4'] 

# read in pkl file 
dir_pkl = "./results/data_v1/pkl"

# list all pkl files in the directory 
pkl_files = sorted(glob.glob(os.path.join(dir_pkl, "*.pkl")), key=os.path.getmtime)

# list of all contact state history 
N_timesteps = 500 
N_trials = len(pkl_files)
contact_state_hist_all = np.zeros((len(hole_classes), len(peg_classes), N_timesteps, N_trials), dtype=bool)

for i, pkl_file in enumerate(pkl_files): 
    
    # Read the pickle file
    with open(pkl_file, 'rb') as f:
        data = pickle.load(f)

    # unpack data 
    state_hist = data['state_hist'] 
    contact_num = data['contact_num'] 
    contact_geom1 = data['contact_geom1'] 
    contact_geom2 = data['contact_geom2'] 
    contact_dist = data['contact_dist'] 
    contact_pos = data['contact_pos'] 
    contact_frame = data['contact_frame'] 
    ctrl_hist = data['ctrl_hist'] 

    # initialize data structure for contact state history 
    contact_state_hist = np.zeros((len(hole_classes), len(peg_classes), len(state_hist)), dtype=bool)

    for j, contact_pos_j in enumerate(contact_pos): # iterate through each time step 
        if len(contact_pos_j) > 0: # if there is contact 
            peg_pos = state_hist[j, 1:4] * 1e3 # convert from m to mm 
            peg_quat = state_hist[j, 4:8] 
            peg_R = R.from_quat(peg_quat,scalar_first=True).as_matrix() 
            for k, contact_pos_hole_frame in enumerate(contact_pos_j): # iterate through each contact at current time step 
                contact_pos_hole_frame *= 1e3 # convert from m to mm 
                contact_pos_peg_frame = (peg_R.transpose() @ (contact_pos_hole_frame - peg_pos).reshape(3,1)).reshape(3)  
                point_hole_class = classify_hole_point(contact_pos_hole_frame, hole_parameters, tol=0.1) 
                point_peg_class = classify_peg_point(contact_pos_peg_frame, peg_parameters, tol=0.5) 

                for ii, hole_class in enumerate(hole_classes): 
                    for jj, peg_class in enumerate(peg_classes): 
                        if (hole_class in point_hole_class) and (peg_class in point_peg_class): 
                            contact_state_hist[ii,jj,j] = True 

    # save contact state history to list 
    contact_state_hist_all[:,:,:,i] = contact_state_hist 
    # contact_state_hist_list.append(contact_state_hist)

    if i == 100: 
        break 

In [101]:
print(np.sum(np.sum(contact_state_hist_all, axis=3),axis=2)) 

[[11357  5292   885   664 11332  5292   677   566]
 [ 5687  7602   323   796  5684  7598   260   707]
 [ 3759  1099   347   181  3738  1099   287   167]
 [ 1512  2800    29   380  1512  2801    27   365]
 [  370  1810    69   225   370  1809    42   158]
 [ 3897   157   225    63  3874   157   194    62]
 [    0     0    13    14     0     0    13     8]
 [    0     0    47   468     0     0     1   442]
 [  359  1624    69   214   359  1623    42   144]
 [ 3877   153   222    62  3854   153   191    61]
 [    0     0    13    12     0     0    13     6]
 [    0     0    45   327     0     0     1   327]]
