# Filtering the datalist to valid entries using the stats dataframes

In [1]:
import numpy as np
import pandas as ps
import os
from IPython.display import JSON as DJSON
import random

In [2]:
name = '/fast/jamesn8/assembly_data/assembly_data_with_transforms_all.h5'
assembly_df = ps.read_hdf(name,'assembly')
mate_df = ps.read_hdf(name,'mate')
part_df = ps.read_hdf(name,'part')

In [3]:
part_df_indexed = part_df.copy()
part_df_indexed['PartIndex'] = part_df_indexed.index
part_df_indexed['AssemblyOccurrenceID'] = [f'{ass}-{occ}' for ass,occ in zip(part_df_indexed['Assembly'], part_df_indexed['PartOccurrenceID'])]
part_df_indexed.set_index('AssemblyOccurrenceID', inplace=True)

In [23]:
base_path = '/fast/jamesn8/assembly_data/mate_torch_norm_match'

In [5]:
stats_path = os.path.join(base_path,'stats')

In [6]:
stats_df = ps.read_parquet(os.path.join(stats_path, 'stats_all.parquet'))
mate_stats_df = ps.read_parquet(os.path.join(stats_path, 'mate_stats_all.parquet'))

In [7]:
only_valid = stats_df[lambda df: (df['invalid_mates'] == 0) & (df['num_invalid_transformed_parts'] == 0)]

In [8]:
only_valid.shape[0]/stats_df.shape[0]

0.6864439324116743

In [9]:
mate_stats_invalid_df = mate_stats_df[lambda df: df['invalid_frame_0'] | df['invalid_frame_1']]

In [10]:
indexfile = os.path.join(base_path,'index.txt')
with open(indexfile,'r') as f:
    datalist = [[int(val) for val in os.path.splitext(l.rstrip())[0].split('-')] for l in f.readlines()]

In [11]:
def sort_columns(df,names):
    cols = tuple(zip(*[sorted(me) for me in zip(*[df[name] for name in names])]))
    df.drop(names,inplace=True, axis=1)
    for name,col in zip(names,cols):
        df[name] = col

In [12]:
mate_df_indexed = mate_df.copy()
mate_df_indexed['MateIndex'] = mate_df_indexed.index
sort_columns(mate_df_indexed, ['Part1','Part2'])
mate_df_indexed.set_index(['Assembly','Part1','Part2'],inplace=True)
#mate_df_indexed['MateID'] = [str(tup[0]) + '-' + '-'.join(sorted(tup[1:])) for tup in zip(mate_df_indexed['Assembly'], mate_df_indexed['Part1'], mate_df_indexed['Part2'])]
#mate_df_indexed.set_index('MateID', inplace=True)

In [13]:
datalist_df = ps.DataFrame(datalist, columns=['assembly','part1','part2'])

In [14]:
minipartdf = part_df['PartOccurrenceID']
datalist_with_occ = datalist_df.join(minipartdf, on='part1').rename({'PartOccurrenceID':'Part1OccurrenceID'},axis=1)
datalist_with_occ = datalist_with_occ.join(minipartdf, on='part2').rename({'PartOccurrenceID':'Part2OccurrenceID'},axis=1)

In [15]:
sort_columns(datalist_with_occ,['Part1OccurrenceID','Part2OccurrenceID'])

In [16]:
datalist_with_mates = datalist_with_occ.join(mate_df_indexed['MateIndex'], on=['assembly','Part1OccurrenceID','Part2OccurrenceID'])

In [17]:
excluded_based_on_mates = datalist_with_mates.loc[lambda df: [mi in mate_stats_invalid_df.index for mi in df['MateIndex']]].index

In [18]:
excluded_based_on_mates

Int64Index([    10,     36,    129,    221,    244,    375,    382,    426,
               565,    599,
            ...
            429036, 429059, 429093, 429121, 429125, 429206, 429342, 429456,
            429468, 429506],
           dtype='int64', length=9423)

### Todo: exclude assemblies with any invalid mates? (just filter to only_valid)

In [76]:
excluded_set = excluded_based_on_mates

In [77]:
datalist_filtered = ['-'.join([str(d) for d in data]) + '.dat\n' for i,data in enumerate(datalist) if i not in excluded_set]

In [75]:
print('num elements excluded based on invalid mates',len(excluded_based_on_mates))

num elements excluded based on invalid mates 9423


In [None]:
excluded_set = set(excluded_based_on_assembly + excluded_based_on_mates)

In [None]:
datalist_filtered = ['-'.join([str(d) for d in data]) + '.dat\n' for i,data in enumerate(datalist) if i not in excluded_set]

In [78]:
random.shuffle(datalist_filtered)

In [79]:
print('retained',len(datalist_filtered)/len(datalist))

retained 0.978072625698324


In [80]:
indexfile_filtered = os.path.join(base_path, 'index_filtered2.txt')
with open(indexfile_filtered,'w') as f:
    f.writelines(datalist_filtered)

In [86]:
assemblies = list(stats_df.index)

In [87]:
random.shuffle(assemblies)

In [93]:
split = len(assemblies)//5
assemblies_val = set(assemblies[:split])
assemblies_train = set(assemblies[split:])

In [94]:
datalist_train = [l for l in datalist_filtered if int(l.split('-')[0]) in assemblies_train]
datalist_val = [l for l in datalist_filtered if int(l.split('-')[0]) in assemblies_val]

In [98]:
print('final split ratio of mate data:',len(datalist_val)/(len(datalist_val)+len(datalist_train)))

final split ratio of mate data: 0.20652577466799943


In [105]:
with open(os.path.join(base_path,'assemblies_train.txt'),'w') as f:
    f.writelines([str(l) +'\n' for l in assemblies_train])

In [106]:
with open(os.path.join(base_path,'assemblies_val.txt'),'w') as f:
    f.writelines([str(l) +'\n' for l in assemblies_val])

In [24]:
indexfile_train = os.path.join(base_path, 'index_filtered2_train.txt')
indexfile_val = os.path.join(base_path, 'index_filtered2_val.txt')

In [102]:
with open(indexfile_train,'w') as f:
    f.writelines(datalist_train)
with open(indexfile_val,'w') as f:
    f.writelines(datalist_val)

# Debugging the network

In [55]:
stats_df['total_mates'].sum()

151182

In [67]:
mate_stats_df[~mate_stats_df['invalid_frame_0'] & ~mate_stats_df['invalid_frame_1']]

Unnamed: 0_level_0,invalid_frame_0,invalid_frame_0_coincident_origins,invalid_frame_0_permuted_z,matches_frame_0,extra_matches_frame_0,invalid_frame_1,invalid_frame_1_coincident_origins,invalid_frame_1_permuted_z,matches_frame_1,extra_matches_frame_1,type,truncated_mc_pairs,part_pair_found,mc_pair_found
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
376,False,False,False,7,0,False,False,False,7,0,FASTENED,False,,
377,False,False,False,6,0,False,False,False,6,0,FASTENED,False,,
378,False,False,False,3,0,False,False,False,4,0,REVOLUTE,False,,
379,False,False,False,6,0,False,False,False,1,0,FASTENED,False,,
380,False,False,False,1,0,False,False,False,1,0,FASTENED,False,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1882555,False,False,False,3,0,False,False,False,3,0,FASTENED,False,1.0,1.0
1882556,False,False,False,3,0,False,False,False,3,0,FASTENED,False,1.0,1.0
1882557,False,False,False,3,0,False,False,False,3,0,FASTENED,False,1.0,1.0
1882558,False,False,False,3,0,False,False,False,3,0,FASTENED,False,1.0,1.0


In [58]:
mate_stats_df.keys()

Index(['invalid_frame_0', 'invalid_frame_0_coincident_origins',
       'invalid_frame_0_permuted_z', 'matches_frame_0',
       'extra_matches_frame_0', 'invalid_frame_1',
       'invalid_frame_1_coincident_origins', 'invalid_frame_1_permuted_z',
       'matches_frame_1', 'extra_matches_frame_1', 'type',
       'truncated_mc_pairs', 'part_pair_found', 'mc_pair_found'],
      dtype='object')

In [35]:
DJSON('/projects/grail/benjones/cadlab/dalton_lightning_logs/real_all_fn_args_amounts_sum_directedhybridgcn12_rerun.json')

<IPython.core.display.JSON object>

In [36]:
import json

In [38]:
path = '/projects/grail/benjones/cadlab/dalton_lightning_logs/'
for f in os.scandir(path):
    if f.name.endswith('.json'):
        with open(os.path.join(path, f.name), 'r') as j:
            s = json.load(j)
            if 'model_class' in s:
                print(f'found model class in {f.name}')

In [1]:
with open('/fast/jamesn8/assembly_data/indexfile.txt','r') as f:
    datalist = [l.rstrip() for l in f.readlines()]

In [4]:
import random

In [5]:
random.shuffle(datalist)

In [7]:
val_size = int(len(datalist)/5)

In [9]:
datalist_val = datalist[:val_size]
datalist_train = datalist[val_size:]

In [None]:
with open('/fast/jamesn8/assembly_data/indexfile_train.txt','w') as f:
    f.writelines([l + '\n' for l in datalist_train])

In [13]:
with open('/fast/jamesn8/assembly_data/indexfile_val.txt','w') as f:
    f.writelines([l + '\n' for l in datalist_val])

In [21]:
from automate.data.saved_dataset import SavedDataset
from torch_geometric.data import Batch
from automate.data.transforms import select_full_relations
from automate.lightning_models.simplified import SimplifiedJointModel
from pytorch_lightning.utilities.argparse import get_init_arguments_and_types
import torch
from torch_geometric.data.dataloader import DataLoader
from pytorch_lightning import Trainer

In [99]:
weights_path='/projects/grail/benjones/cadlab/dalton_lightning_logs/real_all_fn_args_amounts_sum_directedhybridgcn12/version_0/checkpoints/epoch=46-val_auc=0.666113.ckpt'
model = SimplifiedJointModel.load_from_checkpoint(weights_path, map_location=torch.device('cpu'))

RuntimeError: Error(s) in loading state_dict for SimplifiedJointModel:
	Unexpected key(s) in state_dict: "nodal.nn.f.1.weight", "nodal.nn.f.1.bias", "nodal.nn.f.1.running_mean", "nodal.nn.f.1.running_var", "nodal.nn.f.1.num_batches_tracked", "gcn.vert_edge_conv.nn.nn.f.1.weight", "gcn.vert_edge_conv.nn.nn.f.1.bias", "gcn.vert_edge_conv.nn.nn.f.1.running_mean", "gcn.vert_edge_conv.nn.nn.f.1.running_var", "gcn.vert_edge_conv.nn.nn.f.1.num_batches_tracked", "gcn.edge_loop_conv.nn.nn.f.1.weight", "gcn.edge_loop_conv.nn.nn.f.1.bias", "gcn.edge_loop_conv.nn.nn.f.1.running_mean", "gcn.edge_loop_conv.nn.nn.f.1.running_var", "gcn.edge_loop_conv.nn.nn.f.1.num_batches_tracked", "gcn.loop_face_conv.nn.nn.f.1.weight", "gcn.loop_face_conv.nn.nn.f.1.bias", "gcn.loop_face_conv.nn.nn.f.1.running_mean", "gcn.loop_face_conv.nn.nn.f.1.running_var", "gcn.loop_face_conv.nn.nn.f.1.num_batches_tracked", "gcn.face_face_convs.0.nn.nn.f.1.weight", "gcn.face_face_convs.0.nn.nn.f.1.bias", "gcn.face_face_convs.0.nn.nn.f.1.running_mean", "gcn.face_face_convs.0.nn.nn.f.1.running_var", "gcn.face_face_convs.0.nn.nn.f.1.num_batches_tracked", "gcn.face_face_convs.1.nn.nn.f.1.weight", "gcn.face_face_convs.1.nn.nn.f.1.bias", "gcn.face_face_convs.1.nn.nn.f.1.running_mean", "gcn.face_face_convs.1.nn.nn.f.1.running_var", "gcn.face_face_convs.1.nn.nn.f.1.num_batches_tracked", "gcn.face_face_convs.2.nn.nn.f.1.weight", "gcn.face_face_convs.2.nn.nn.f.1.bias", "gcn.face_face_convs.2.nn.nn.f.1.running_mean", "gcn.face_face_convs.2.nn.nn.f.1.running_var", "gcn.face_face_convs.2.nn.nn.f.1.num_batches_tracked", "gcn.face_face_convs.3.nn.nn.f.1.weight", "gcn.face_face_convs.3.nn.nn.f.1.bias", "gcn.face_face_convs.3.nn.nn.f.1.running_mean", "gcn.face_face_convs.3.nn.nn.f.1.running_var", "gcn.face_face_convs.3.nn.nn.f.1.num_batches_tracked", "gcn.face_face_convs.4.nn.nn.f.1.weight", "gcn.face_face_convs.4.nn.nn.f.1.bias", "gcn.face_face_convs.4.nn.nn.f.1.running_mean", "gcn.face_face_convs.4.nn.nn.f.1.running_var", "gcn.face_face_convs.4.nn.nn.f.1.num_batches_tracked", "gcn.face_face_convs.5.nn.nn.f.1.weight", "gcn.face_face_convs.5.nn.nn.f.1.bias", "gcn.face_face_convs.5.nn.nn.f.1.running_mean", "gcn.face_face_convs.5.nn.nn.f.1.running_var", "gcn.face_face_convs.5.nn.nn.f.1.num_batches_tracked", "gcn.face_loop_conv.nn.nn.f.1.weight", "gcn.face_loop_conv.nn.nn.f.1.bias", "gcn.face_loop_conv.nn.nn.f.1.running_mean", "gcn.face_loop_conv.nn.nn.f.1.running_var", "gcn.face_loop_conv.nn.nn.f.1.num_batches_tracked", "gcn.loop_edge_conv.nn.nn.f.1.weight", "gcn.loop_edge_conv.nn.nn.f.1.bias", "gcn.loop_edge_conv.nn.nn.f.1.running_mean", "gcn.loop_edge_conv.nn.nn.f.1.running_var", "gcn.loop_edge_conv.nn.nn.f.1.num_batches_tracked", "gcn.edge_vert_conv.nn.nn.f.1.weight", "gcn.edge_vert_conv.nn.nn.f.1.bias", "gcn.edge_vert_conv.nn.nn.f.1.running_mean", "gcn.edge_vert_conv.nn.nn.f.1.running_var", "gcn.edge_vert_conv.nn.nn.f.1.num_batches_tracked", "classifier.nn.0.f.1.weight", "classifier.nn.0.f.1.bias", "classifier.nn.0.f.1.running_mean", "classifier.nn.0.f.1.running_var", "classifier.nn.0.f.1.num_batches_tracked". 

In [313]:
#indexfile = '/fast/jamesn8/assembly_data/mate_torch_norm_match/index_partial2_filtered_train.txt'

In [25]:
dataset = SavedDataset(indexfile_train, os.path.join(base_path, 'data'))

In [45]:
mask_train = [len(torch.load(f).mc_pair_labels.nonzero()) > 0 for f in dataset.index]

In [46]:
dataset_val = SavedDataset(indexfile_val, os.path.join(base_path, 'data'))

In [47]:
mask_val = [len(torch.load(f).mc_pair_labels.nonzero()) > 0 for f in dataset_val.index]

In [60]:
sum(mask_train)

64780

In [51]:
sum(mask_train)/len(dataset.index)

0.1942999742053137

In [34]:
follow_batch=[
            'node_types_g1', 'node_types_g2',
            'mc_index_g1', 'mc_index_g2',
            'mc_pair_labels',
            'left_mc_individual_labels', 'right_mc_individual_labels'
        ]
dataloader = DataLoader(dataset, batch_size=1, follow_batch=follow_batch)

In [15]:
from IPython.display import clear_output

In [105]:
invalid_inds = []
invalid_inds2 = []
N = len(dataset)
for i,batch in enumerate(dataloader):
    if i % 25 == 0:
        clear_output(wait=True)
        display(f'num_processed: {i}/{N}; invalid1: {len(invalid_inds)}; invalid2: {len(invalid_inds2)}')
    try:
        preds = model(batch)
    except ValueError as e:
        invalid_inds.append(i)
    except IndexError as e:
        invalid_inds2.append(i)

'num_processed: 450/83084; invalid1: 0; invalid2: 0'

KeyboardInterrupt: 