In [1]:
import onshape.brepio as brepio
import meshplot as mp
import numpy as np
from visualize import inspect, occ_to_mesh, emptyplot, add_axis
import os
from utils import adjacency_list_from_brepdata, homogenize, connected_components
from pspart import Part
import pandas as ps
import intervaltree
import pspart
from intervaltree import IntervalTree, Interval
import numpy.linalg as LA
import ipywidgets as widgets

In [2]:
df_name = '/fast/jamesn8/assembly_data/assembly_data_with_transforms.h5'
assembly_df = ps.read_hdf(df_name,'assembly')
part_df = ps.read_hdf(df_name,'part')

In [3]:
mate_df = ps.read_hdf(df_name,'mate')

In [4]:
has_geometry = part_df.groupby('Assembly')['HasGeometry'].agg(all)
assembly_df['HasAllGeometry'] = has_geometry

In [5]:
basename = 'mate_statistics_heuristics'
stride=1000
dfs = []
for i in range(17):
    loc=i*stride
    lastloc = (i+1)*stride-1
    fname = f'{basename}_chunk_{loc}-{lastloc}.parquet'
    if not os.path.isfile(fname):
        break
    dfs.append(ps.read_parquet(fname))
mate_statistics_df = ps.concat(dfs, axis=0)
del dfs

In [6]:
mate_statistics_df['Assembly'] = mate_df['Assembly']

In [7]:
mcs_coincide = mate_statistics_df.groupby('Assembly').agg({'originsCoincide':all, 'mc1Exists':all, 'mc0Exists':all})

In [8]:
assembly_df_filtered = assembly_df.loc[mcs_coincide[mcs_coincide['originsCoincide'] & mcs_coincide['mc0Exists'] & mcs_coincide['mc1Exists']].index]
assembly_df_filtered = assembly_df_filtered[assembly_df_filtered['HasAllGeometry'] & (assembly_df_filtered['ConnectedComponents']==1) & (assembly_df_filtered['RigidPieces'] > 1)]

In [9]:
datapath = '/projects/grail/benjones/cadlab'

In [10]:
loader = brepio.Loader(datapath)

In [11]:
def mate_proposals(parts, epsilon_rel=0.001):
    """
    Find list of (part1, part2, mc1, mc2) of probable mate locations given a list of (transform, pspart.Part)
    `epsilon_fac`: fraction of maximum part dimension to use as epsilon for finding neighboring mate connectors
    """
    maxdim = max([(part.bounding_box()[1]-part.bounding_box()[0]).max() for _, part in parts])
    mc_locations = []
    interval2part = []
    part2offset = dict()
    total_mcs = 0
    for i,tf_part in enumerate(parts):
        tf, part = tf_part
        for mc in part.all_mate_connectors:
            cs = mc.get_coordinate_system()
            origin = tf[:3,:3] @ cs[:3,3] + tf[:3,3]
            z_axis = tf[:3,:3] @ cs[:3,2]
            mc_locations.append(np.concatenate([origin/maxdim, z_axis], axis=0))
        new_total = total_mcs + len(part.all_mate_connectors)
        interval2part.append((total_mcs, new_total, i))
        part2offset[i] = total_mcs
        total_mcs = new_total
    nnhash = pspart.NNHash(mc_locations, 6, epsilon_rel)
    tree = IntervalTree([Interval(l, u, d) for l, u, d in interval2part])
    
    proposals = set()
    for i,loc in enumerate(mc_locations):
        nearest = list(nnhash.get_nearest_points(loc))
        part_index = next(iter(tree[i])).data
        for j in nearest:
            other_part_index = next(iter(tree[j])).data
            if other_part_index != part_index:
                pi1, pi2 = part_index, other_part_index
                mci1, mci2 = i - part2offset[part_index], j - part2offset[other_part_index]
                if pi1 > pi2:
                    pi1, pi2 = pi2, pi1
                    mci1, mci2 = mci2, mci1
                proposals.add((pi1, pi2, parts[pi1][1].all_mate_connectors[mci1], parts[pi2][1].all_mate_connectors[mci2]))
    return proposals

In [12]:
from automate.lightning_models.simplified import SimplifiedJointModel
from extension.ml import Predictor

In [13]:
predictor_location = Predictor('/projects/grail/benjones/cadlab/dalton_lightning_logs/real_all_fn_args_amounts_sum_directedhybridgcn12/version_0/checkpoints/epoch=46-val_auc=0.666113.ckpt')
predictor_type = Predictor('/projects/grail/benjones/cadlab/dalton_lightning_logs/real_all_fn_args_amounts_sum_directedhybridgcn12_type/version_0/checkpoints/epoch=5-val_auc=0.948979.ckpt')

  stream(template_mgs % msg_args)


In [14]:
cached_results_loc = dict()
cached_results_type = dict()

In [19]:
def infer_assembly(geo, mates, results_loc):

    find_mc = dict() #mate descriptors in global embedding -> (partOcc, mateconnector)
    topo_offset = 0
    for k in geo:
        tf, part = geo[k]
        N = part.num_topologies
        for mc in part.all_mate_connectors:
            key = (mc.orientation_inference.topology_ref + topo_offset, mc.location_inference.topology_ref + topo_offset, mc.location_inference.inference_type.value)
            find_mc[key] = (k, mc)
        topo_offset += N

    best_mates = dict() #part pair -> mc pair
    for i in range(results_loc[1].shape[1]):
        occ1, mc1 = find_mc[tuple(results_loc[1][:3,i].numpy())]
        occ2, mc2 = find_mc[tuple(results_loc[1][3:,i].numpy())]
        if occ1 > occ2:
            occ1, occ2 = occ2, occ1
            mc1, mc2 = mc2, mc1
        key = (occ1, occ2)
        if key in best_mates:
            continue
        else:
            best_mates[key] = (mc1, mc2, results_loc[0][i].item())
            
    return best_mates
    
def inference_statistics(geo, mates, best_mates, best_types, prob_threshold):
    gt_mates = dict()
    for mate in mates:
        key = tuple(sorted((mate.matedEntities[0][0], mate.matedEntities[1][0])))
        if key not in gt_mates:
            gt_mates[key] = []
        gt_mates[key].append(mate)
    
    gt_types = []
    extra_mates = 0
    missing_mates = 0
    matched_mates = 0
    missing_duplicates = 0
    mate_distances = dict()
    misclassified = 0
    for j,pair in enumerate(best_mates):
        mate = best_mates[pair]
        gt_type = ['NONE']
        if mate[2] >= prob_threshold:
            if pair in gt_mates:
                gt_type = [mate.type for mate in gt_mates[pair]]
                tf = geo[pair[0]][0]
                origin = tf[:3,:3] @ mate[0].get_coordinate_system()[:3,3] + tf[:3,3]
                missing_duplicates += len(gt_mates[pair]) - 1
                mate_types_i = [mate_types.index(gt_mate.type) for gt_mate in gt_mates[pair]]
                matches = [mate_type_i == best_types[j] for mate_type_i in mate_types_i]
                matched = any(matches)
                if matched:
                    matched_mates += 1
                else:
                    misclassified += 1
                if matched:
                    mate_index = matches.index(True)
                    gt_origin = tf[:3,:3] @ gt_mates[pair][mate_index].matedEntities[0][1][0] + tf[:3,3]
                    mate_distances[pair] = LA.norm(origin-gt_origin)
                else:
                    min_dist = np.inf
                    for i in range(len(gt_mates[pair])):
                        gt_origin = tf[:3,:3] @ gt_mates[pair][i].matedEntities[0][1][0] + tf[:3,3]
                        dist = LA.norm(origin - gt_origin)
                        if dist < min_dist:
                            min_dist = dist
                            mate_index = i
                    mate_distances[pair] = min_dist
            else:
                extra_mates += 1
        gt_types.append(gt_type)
    for pair in gt_mates:
        if pair not in best_mates or best_mates[pair][2] < prob_threshold:
            missing_mates += 1
    assert(len(gt_types) == len(best_types))
    best_mates_keys = list(best_mates.keys())
    print('inferred types:',[(mate_types[i[0]],i[1]) for j,i in enumerate(zip(best_types, gt_types)) if best_mates[best_mates_keys[j]][2] >= prob_threshold])
    print(f'unmated parts: {missing_mates}\nextra mated parts: {extra_mates}\nmisclassified: {misclassified}\nmissing duplicate mates: {missing_duplicates}\ncorrectly classified mates: {matched_mates}\nmate distances: {[mate_distances[k] for k in mate_distances]}')
    

In [20]:
mate_epsilon_rel = 0.001
mate_types = ['PIN_SLOT', 'BALL', 'PARALLEL', 'SLIDER', 'REVOLUTE', 'CYLINDRICAL', 'PLANAR', 'FASTENED']
@mp.interact(sample=[(f'{assembly_df_filtered.loc[ind]["AssemblyPath"][:10]}; {assembly_df_filtered.loc[ind]["RigidPieces"]} moving parts',assembly_df_filtered.loc[ind]['AssemblyPath']) for ind in assembly_df_filtered.index[:100]])
def display_sample(sample):
    print(sample)
    try:
        geo, mates = loader.load_flattened(sample + '.json', skipInvalid=True)
    except FileNotFoundError as e:
        print(f'File not found: {e}')
        return
    mate_counts = dict()
    for mate in mates:
        if mate.type in mate_counts:
            mate_counts[mate.type] += 1
        else:
            mate_counts[mate.type] = 1
    adj = homogenize(adjacency_list_from_brepdata(geo, mates))
    num_connected = connected_components(adj)
    num_rigid = connected_components(adj, connectionType='fasten')
    if num_connected > 1:
        print('warning:',num_connected,'connected components')
    print('rigid pieces:',num_rigid)
    print('total parts:',len(geo))
    print(f'mates: {len(mates)}: ',mate_counts)

    #choices = [(f'mate {i} ({mates[i].type}) ({mates[i].matedEntities[0][0]}, {mates[i].matedEntities[1][0]})',i) for i in range(len(mates)) if len(mates[i].matedEntities) == 2]
    choices = [(f'mate {i} ({mates[i].type}) ({mates[i].name})',i) for i in range(len(mates)) if len(mates[i].matedEntities) == 2]
    choices.append(('fullAssembly', -1))
    choices.append(('inferAssembly', -2))
    p = emptyplot()
    badOccs = [k for k in geo if geo[k][1] is None or geo[k][1].V.shape[0] == 0]
    if len(badOccs) > 0:
        print(f'warning: {len(badOccs)} invalid parts!')
    #for o in badOccs:
    #    geo.pop(o)
    @mp.interact(mate=choices, wireframe=False, show_parts=True)
    def ff(mate, wireframe, show_parts):
        if mate == -2:
            if sample in cached_results_loc:
                results_loc = cached_results_loc[sample]
            else:
                proposals = mate_proposals([geo[k] for k in geo], epsilon_rel = mate_epsilon_rel)
                if len(proposals) > 0:
                    results_loc = predictor_location.predict_assembly([geo[k][1] for k in geo], proposals)
                    cached_results_loc[sample] = results_loc
                else:
                    print('no proposal MCs')
                    return
            
            best_mates = infer_assembly(geo, mates, results_loc)
            
            if sample in cached_results_type:
                results_type = cached_results_type[sample]
            else:
                occ2index = dict()
                for i,occ in enumerate(geo):
                    occ2index[occ] = i
                results_type = predictor_type.predict_assembly_types([geo[k][1] for k in geo], [(occ2index[pair[0]], occ2index[pair[1]], best_mates[pair][0], best_mates[pair][1]) for pair in best_mates])
                cached_results_type[sample] = results_type
            
            best_types = results_type.argmax(axis=1).numpy() #corresponds to best_mates
            
            @mp.interact(prob_threshold=widgets.BoundedFloatText(
                value=.5,
                min=0,
                max=1.0,
                step=0.1,
                description='Probability threshold:',
                disabled=False
            ))
            def show_result(prob_threshold):
                inference_statistics(geo, mates, best_mates, best_types,prob_threshold)
                inferred_mates = [brepio.Mate(mcs=best_mates[pair][:2], occIds=[pair[0], pair[1]], mateType=mate_types[best_types[i]], name=f'Mate {i}') for i,pair in enumerate(best_mates)]
                inspect(geo, inferred_mates, p=p, wireframe=wireframe, show_parts=show_parts)
        elif mate == -1:
            print('displaying full assembly')
            inspect(geo, mates, p=p, wireframe=wireframe, show_parts=show_parts)
            #print('num mates:',len(mates))
        elif len(mates[mate].matedEntities) == 2:
            me = mates[mate].matedEntities
            print('mated parts:',me[0][0],me[1][0])
            if me[0][0] in badOccs or me[1][0] in badOccs:
                print('invalid parts in mate')
                return
            occs = [geo[me[i][0]] for i in range(2)]
            maxdim = max([max(geo[i[0]][1].V.max(0)-geo[i[0]][1].V.min(0)) for i in me if geo[i[0]][1].V.shape[0] > 0])

            meshes = [occ_to_mesh(occ) for occ in occs]
            if wireframe:
                p.reset()
                p.add_edges(meshes[0][0], meshes[0][1], shading={'line_color': 'red'})
                p.add_edges(meshes[1][0], meshes[1][1], shading={'line_color': 'blue'})
            else:
                mp.plot(meshes[0][0], meshes[0][1],c=np.array([1, 0, 0]), plot=p)
                p.add_mesh(meshes[1][0], meshes[1][1],c=np.array([0, 0, 1]))

            for i in range(2):
                tf = occs[i][0]
                #print(f'matedCS origin {i}: {me[i][1][0]}')
                newaxes = tf[:3, :3] @ me[i][1][1]
                neworigin = tf[:3,:3] @ me[i][1][0] + tf[:3,3]
                #print(f'transform {i}: {tf}')
                print(f'origin {i}: {neworigin}')
                add_axis(p, neworigin, newaxes[:,0], newaxes[:,1], newaxes[:,2], scale=maxdim/2)
        else:
            print(f'nonstandard mate with {len(me)} entities')
    p

interactive(children=(Dropdown(description='sample', options=(('58aace5054; 40 moving parts', '58aace5054540c1…