In [2]:
import numpy as np
import numpy.linalg as LA
import meshplot as mp
from IPython.display import JSON as DJSON
from IPython.display import clear_output
from pspart import Part
from pspart import NNHash
import os
import pandas as ps
from mate_proposals import mate_proposals, homogenize_frame
from scipy.spatial.transform import Rotation as R
import meshplot as mp
import onshape.brepio as brepio
import time
from automate.data.data import UniformMateData
import torch
import random

In [3]:
datapath = '/projects/grail/benjones/cadlab'

In [4]:
name = '/fast/jamesn8/assembly_data/assembly_data_with_transforms_all.h5'
assembly_df = ps.read_hdf(name,'assembly')
mate_df = ps.read_hdf(name,'mate')
part_df = ps.read_hdf(name,'part')
mate_df['MateIndex'] = mate_df.index
part_df['PartIndex'] = part_df.index

In [5]:
mate_df.set_index('Assembly', inplace=True)
part_df.set_index('Assembly', inplace=True)    

In [6]:
with open('fully_connected_moving_no_multimates.txt','r') as f:
    set_E_indices = [int(l.rstrip()) for l in f.readlines()]

In [7]:
mate_types = [
            'PIN_SLOT',
            'BALL',
            'PARALLEL',
            'SLIDER',
            'REVOLUTE',
            'CYLINDRICAL',
            'PLANAR',
            'FASTENED'
        ]

In [24]:
start_index = 1243

outpath = '/fast/jamesn8/assembly_data/mate_torch_data'
def LOG(st):
    with open(logfile,'a') as logf:
        logf.write(st + '\n')
statspath = '/fast/jamesn8/assembly_data/mate_torch_stats'
logfile = os.path.join(statspath, 'log.txt')

epsilon_rel = 0.001
max_groups = 10
max_mcs = 10000
max_mc_pairs = 20000
stride = 200
last_mate_ckpt = 0
last_ckpt = 0

all_stats = []
mate_stats = []
processed_indices = []
mate_indices = []
run_start_time = time.time()
for num_processed,ind in enumerate(set_E_indices[start_index:]):
    stats = dict()
    curr_mate_stats = []
    #clear_output(wait=True)
    display(f'num_processed: {num_processed}/{len(set_E_indices)}')

    #1. spatially hash all MCFs (cache hash maps for each part for re-use with individual mates)
    #2. for all mates, ensure that each MCF is represented (keep track of closest/equivalent MCFs, log percentage of assemblies for which this holds)
    #3. get proposals, edit appropriate ones to true based on equivalence class computed per mated pair of parts (taking outer product of equivalent MCs on left and right)

    LOG(f'{num_processed}/{len(set_E_indices)}: processing {assembly_df.loc[ind,"AssemblyPath"]} at {time.time()-run_start_time}')
    
    part_subset = part_df.loc[ind]
    mate_subset = mate_df.loc[ind]
    if mate_subset.ndim == 1:
        mate_subset = ps.DataFrame([mate_subset], columns=mate_subset.keys())
    
    parts = []
    part_paths = []
    transforms = []
    mcf_hashes = []
    mco_hashes = []
    mc_frames_all = []
    occ_to_index = dict()
    

    #debug
    #all_points = []
    
    for j in range(part_subset.shape[0]):
        path = os.path.join(datapath, 'data/models', *[part_subset.iloc[j][k] for k in ['did','mv','eid','config']], f'{part_subset.iloc[j]["PartId"]}.xt')
        assert(os.path.isfile(path))
        part = Part(path)
        part_paths.append(path)
        parts.append(part)
        tf = part_subset.iloc[j]['Transform']
        transforms.append(tf)
        occ_to_index[part_subset.iloc[j]['PartOccurrenceID']] = j
    
    
    allpoints_transformed = [(tf[:3,:3] @ part.V.T + tf[:3,3,np.newaxis]).T for tf, part in zip(transforms, parts) if parts.V.shape[0] > 0]
    if len(allpoints_transformed) == 0:
        LOG('skipping due to no geometry')
        continue
    minPt = np.array([points.min(axis=0) for points in allpoints_transformed]).min(axis=0)
    maxPt = np.array([points.max(axis=0) for points in allpoints_transformed]).max(axis=0)
    
    median = (minPt + maxPt)/2
    dims = maxPt - minPt
    maxdim = max(dims)
    #maxdim = max([(part.bounding_box()[1]-part.bounding_box()[0]).max() for part in parts])
    threshold = maxdim * epsilon_rel
    
    total_mcs = sum([len(part.all_mate_connectors) for part in parts])
    stats['total_mates'] = mate_subset.shape[0]
    stats['total_parts'] = len(parts)
    stats['maxdim'] = maxdim
    stats['total_mcs'] = total_mcs

    for j in range(len(parts)):
        part = parts[j]
        tf = transforms[j]
        mc_frames = []
        mc_origins = []
        for mc in part.all_mate_connectors:
            cs = mc.get_coordinate_system()
            frame = tf[:3,:3] @ cs[:3,:3]
            frame_homogenized = homogenize_frame(frame, z_flip_only=True)
            origin = tf[:3,:3] @ cs[:3,3] + tf[:3,3]
            #all_points.append(origin)
            rot = R.from_matrix(frame_homogenized).as_quat()
            mc_origins.append(origin)
            mc_frames.append(np.concatenate([origin/maxdim, rot]))
        mc_frames_all.append(mc_frames)
        frame_hash = NNHash(mc_frames, 7, epsilon_rel)
        origin_hash = NNHash(mc_origins, 3, threshold)
        #frame_hash = NNHash([mc_frame[:3] for mc_frame in mc_frames], 3, threshold)
        mcf_hashes.append(frame_hash)
        mco_hashes.append(origin_hash)

    stats['invalid_frames'] = 0
    stats['invalid_mates'] = 0
    stats['invalid_coincident_origins'] = 0
    stats['invalid_permuted_z'] = 0
    
    mate_matches = [] #list of (left MC IDs, right MC Ids) based on the type of mate
    part_pair_to_mate = dict()

    #all_points = np.array(all_points)
    #p = mp.plot(all_points)
    mate_invalids = []
    for j in range(mate_subset.shape[0]):
        matches = [set(), set()]
        m_stats = dict()
        part_indices = []
        mate_invalid = False
        for i in range(2):
            occId = mate_subset.iloc[j][f'Part{i+1}']
            partIndex = occ_to_index[occId]
            part_indices.append(partIndex)
            assert(part_subset.iloc[partIndex]['PartOccurrenceID'] == occId)
            origin_local = mate_subset.iloc[j][f'Origin{i+1}']
            frame_local = mate_subset.iloc[j][f'Axes{i+1}']
            tf = transforms[partIndex]
            origin = tf[:3,:3] @ origin_local + tf[:3,3]            
            frame = tf[:3,:3] @ frame_local
            frame_homogenized = homogenize_frame(frame, z_flip_only=True)
            rot = R.from_matrix(frame_homogenized).as_quat()
            mc_frame = np.concatenate([origin/maxdim, rot])
            neighbors = mcf_hashes[partIndex].get_nearest_points(mc_frame)

            for n in neighbors:
                matches[i].add(n)
            b_invalid = len(neighbors) == 0
            b_num_matches = len(neighbors)
            b_invalid_coincident_origins = False
            b_invalid_permuted_z = False
            if b_invalid:
                stats['invalid_frames'] += 1
                if not mate_invalid:
                    stats['invalid_mates'] += 1
                mate_invalid = True
                origin_neighbors = mco_hashes[partIndex].get_nearest_points(origin)
                if len(origin_neighbors) > 0:
                    b_invalid_coincident_origins = True
                    stats['invalid_coincident_origins'] += 1
                    n = next(iter(origin_neighbors))
                    c_frame = R.from_quat(mc_frames_all[partIndex][n][3:]).as_matrix()
                    c_frame_homogenized = homogenize_frame(c_frame, z_flip_only=False)
                    mate_frame_homogenized = homogenize_frame(frame, z_flip_only=False)
                    dist = LA.norm(c_frame_homogenized - mate_frame_homogenized)

                    if dist < threshold:
                        b_invalid_permuted_z = True
                        stats['invalid_permuted_z'] += 1
            else:
                mateType = mate_subset.iloc[j]['Type']
                for k in range(len(mc_frames_all[partIndex])):
                    c_origin_quat = mc_frames_all[partIndex][k]
                    c_origin = c_origin_quat[:3]
                    c_frame = R.from_quat(c_origin_quat[3:]).as_matrix()
                    axisdist = LA.norm(c_frame[:,2] - frame_homogenized[:,2])
                    if axisdist < epsilon_rel:
                        if mateType == 'CYLINDRICAL' or mateType == 'SLIDER':
                            c_origin_proj = c_origin @ c_frame[:,:2]
                            origin_proj = (origin @ c_frame[:,:2])/maxdim
                            projdist = LA.norm(c_origin_proj - origin_proj)
                            if projdist < epsilon_rel:
                                matches[i].add(k)
                        elif mateType == 'PLANAR' or mateType == 'PARALLEL':
                            c_origin_proj = c_origin.dot(c_frame[:,2])
                            origin_proj = origin.dot(c_frame[:,2])/maxdim
                            projdist = abs(c_origin_proj - origin_proj)
                            if projdist < epsilon_rel:
                                matches[i].add(k)
                    
            
            m_stats[f'invalid_frame_{i}'] = b_invalid
            m_stats[f'invalid_frame_{i}_coincident_origins'] = b_invalid_coincident_origins
            m_stats[f'invalid_frame_{i}_permuted_z'] = b_invalid_permuted_z
            m_stats[f'matches_frame_{i}'] = b_num_matches
            m_stats[f'extra_matches_frame_{i}'] = len(matches[i]) - b_num_matches
        m_stats['type'] = mate_subset.iloc[j]['Type']
        m_stats['truncated_mc_pairs'] = False
        curr_mate_stats.append(m_stats)
        mate_indices.append(mate_subset.iloc[j]['MateIndex'])
        mate_invalids.append(mate_invalid)
        mate_matches.append(matches)
        part_indices = tuple(sorted(part_indices))
        part_pair_to_mate[part_indices] = j#mate_subset.iloc[j]['Type']
    
    if total_mcs <= max_mcs:
        stats['false_part_pairs'] = 0
        stats['missed_part_pairs'] = 0
        stats['missed_mc_pairs'] = 0
        #find assembly-level normalization matrix
        p_normalized = np.identity(4, dtype=float)
        p_normalized[:3,3] = -median
        p_normalized[3,3] = maxdim #todo: figure out if this is double the factor
        
        #find match proposals
        start = time.time()
        proposals = mate_proposals(list(zip(transforms, parts)), epsilon_rel=epsilon_rel, max_groups=max_groups)
        end = time.time()
        stats['num_proposals'] = len(proposals)
        stats['proposal_time'] = end-start
        
        #initialize pairs based on proposals
        part_proposals = dict()
        for proposal in proposals:
            part_pair = proposal[:2]
            if part_pair not in part_proposals:
                mc_pair_dict = dict()
                part_proposals[part_pair] = mc_pair_dict
            else:
                mc_pair_dict = part_proposals[part_pair]
            mc_pair_dict[proposal[2:]] = -1 #mate type

        #populate pairs with labels
        #print('populating pairs with labels')
        part_pair_found=False
        mc_pair_found=False
        for j in range(mate_subset.shape[0]):
            if not mate_invalids[j]:
                mate_type = mate_subset.iloc[j]['Type']
                partIds = [occ_to_index[mate_subset.iloc[j][f'Part{i+1}']] for i in range(2)]
                matches = mate_matches[j]

                if partIds[0] > partIds[1]:
                    partIds.reverse()
                    matches = matches.copy()
                    matches.reverse()
                partIds = tuple(partIds)

                if partIds in part_proposals:
                    part_pair_found=True
                    mc_pair_dict = part_proposals[partIds]
                    for index1 in matches[0]:
                        for index2 in matches[1]:
                            mc_pair = index1, index2
                            if mc_pair in mc_pair_dict:
                                mc_pair_found=True
                                mc_pair_dict[mc_pair] = mate_types.index(mate_type)
            if not part_pair_found:
                stats['missed_part_pairs'] += 1
            if not mc_pair_found:
                stats['missed_mc_pairs'] += 1
            curr_mate_stats[j]['part_pair_found'] =  part_pair_found           
            curr_mate_stats[j]['mc_pair_found'] =  mc_pair_found     
        
        #create data object for each part pair
        #print('creating data object')
        for part_pair in part_proposals:
            mateIndex = -1
            if part_pair in part_pair_to_mate:
                mateIndex = part_pair_to_mate[part_pair]
                mateType = mate_subset.iloc[mateIndex]['Type']
            else:
                mateType='FASTENED'
                stats['false_part_pairs'] += 1
            mc_pairs = part_proposals[part_pair]

            if len(mc_pairs) > max_mc_pairs:
                curr_mate_stats[mateIndex]['truncated_mc_pairs'] = True
                mc_pairs_final=[]
                mc_pairs_false=[]
                for pair in mc_pairs:
                    if mc_pairs[pair] >= 0:
                        mc_pairs_final.append(pair)
                    else:
                        mc_pairs_false.append(pair)
                N_true = len(mc_pairs_final)
                N_remainder = max_mc_pairs - N_true
                random.shuffle(mc_pairs_false)
                for pair in mc_pairs_false[:N_remainder]:
                    mc_pairs_final.append(pair)
            else:
                mc_pairs_final = mc_pairs
            
            part1 = parts[part_pair[0]]
            part2 = parts[part_pair[1]]
            or1, loc1, inf1 = part1.get_onshape_def_from_mc(part1.all_mate_connectors[0])
            or2, loc2, inf2 = part2.get_onshape_def_from_mc(part2.all_mate_connectors[0])
            data = UniformMateData(
                part_paths[part_pair[0]],
                or1,
                loc1,
                inf1,
                p_normalized,
                part_paths[part_pair[1]],
                or2,
                loc2,
                inf2,
                p_normalized,
                mateType
            )

            data.mc_pairs = torch.empty((6, len(mc_pairs_final)), dtype=torch.int)
            data.mc_pair_labels = torch.zeros(len(mc_pairs_final), dtype=torch.int)
            all_mcs = [parts[part_pair[lr]].all_mate_connectors for lr in range(2)]
            for k,p in enumerate(mc_pairs_final):
                type_index = mc_pairs[p]
                if type_index >= 0:
                    data.mc_pair_labels[k] = 1
                mcs = [all_mcs[lr][p[lr]] for lr in range(2)]
                col = torch.tensor([mcs[0].orientation_inference.topology_ref, mcs[0].location_inference.topology_ref, mcs[0].location_inference.inference_type.value,
                      mcs[1].orientation_inference.topology_ref, mcs[1].location_inference.topology_ref, mcs[1].location_inference.inference_type.value], dtype=torch.int)
                data.mc_pairs[:,k] = col
            #dataname = f'{assembly_df.loc[ind,"AssemblyPath"]}-{part_subset.iloc[part_pair[0]]["PartOccurrenceID"].replace("/","_")}-{part_subset.iloc[part_pair[1]]["PartOccurrenceID"].replace("/","_")}.dat'
            dataname = f'{ind}-{part_subset.iloc[part_pair[0]]["PartIndex"]}-{part_subset.iloc[part_pair[1]]["PartIndex"]}.dat'
            torch.save(data, os.path.join(outpath, dataname))
            del data
            
    for stat in curr_mate_stats:
        mate_stats.append(stat)
    all_stats.append(stats)
    processed_indices.append(ind)
    
    if (num_processed+1) % stride == 0:
        
        stat_df_mini = ps.DataFrame(all_stats[last_ckpt:], index=processed_indices[last_ckpt:])
        mate_stat_df_mini = ps.DataFrame(mate_stats[last_mate_ckpt:], index=mate_indices[last_mate_ckpt:])
        stat_df_mini.to_parquet(os.path.join(statspath, f'stats_{num_processed}.parquet'))
        mate_stat_df_mini.to_parquet(os.path.join(statspath, f'mate_stats_{num_processed}.parquet'))
        print(stat_df_mini.shape)
        last_mate_ckpt = len(mate_indices)
        last_ckpt = len(processed_indices)
    

'num_processed: 0/20845'

'num_processed: 1/20845'

'num_processed: 2/20845'

OSError: [Errno 36] File name too long: '/fast/jamesn8/assembly_data/mate_torch_data/de0155b7acbdde81620c0779_c23e613e49df9b61c54991d3_11b50d2755721e251405d166-MTDtQzeSq1A+4F2fU:Mpzcw2zh_N_wHy6iY:Ms58ZDAtXaLf7of4T:M_VZfLPJ0nskgK_x4:Mm0_dLYAgyo6uch9_-MTDtQzeSq1A+4F2fU:Mpzcw2zh_N_wHy6iY:Ms58ZDAtXaLf7of4T:M_VZfLPJ0nskgK_x4:M0qR3x_224bQEh9rf.dat'

In [16]:
stat_df_mini['missed_mc_pairs'].value_counts()

0.0    154
1.0     14
2.0      4
4.0      3
3.0      1
Name: missed_mc_pairs, dtype: int64

In [22]:
processed_indices[-1]

7472

In [23]:
set_E_indices[1241]

7472

In [9]:
stats_df = ps.DataFrame(all_stats, index=processed_indices)
mate_stats_df = ps.DataFrame(mate_stats, index=mate_indices)

ValueError: Shape of passed values is (917, 13), indices imply (918, 13)

In [10]:
stats_df.shape

(94, 13)

In [66]:
len(part_proposals[(23,25)])

3264

In [69]:
torch.empty((3, 4), dtype=torch.int)



tensor([[          0,           0,         113,           0],
        [-1221734976,       21869,  -609547136,       32600],
        [        248,         248,           0,           0]],
       dtype=torch.int32)