# Filtering the datalist to valid entries using the stats dataframes

In [307]:
import numpy as np
import pandas as ps
import os
from IPython.display import JSON as DJSON
import random

In [176]:
name = '/fast/jamesn8/assembly_data/assembly_data_with_transforms_all.h5'
assembly_df = ps.read_hdf(name,'assembly')
mate_df = ps.read_hdf(name,'mate')
part_df = ps.read_hdf(name,'part')

In [214]:
part_df_indexed = part_df.copy()
part_df_indexed['PartIndex'] = part_df_indexed.index
part_df_indexed['AssemblyOccurrenceID'] = [f'{ass}-{occ}' for ass,occ in zip(part_df_indexed['Assembly'], part_df_indexed['PartOccurrenceID'])]
part_df_indexed.set_index('AssemblyOccurrenceID', inplace=True)

In [168]:
stats_path = '/fast/jamesn8/assembly_data/mate_torch_norm_match/stats'

In [326]:
stats_df = ps.read_parquet(os.path.join(stats_path, 'stats_all.parquet'))
mate_stats_df = ps.read_parquet(os.path.join(stats_path, 'mate_stats_all.parquet'))

In [327]:
only_valid = stats_df[lambda df: (df['invalid_mates'] == 0) & (df['num_invalid_transformed_parts'] == 0)]

In [328]:
only_valid.shape[0]/stats_df.shape[0]

0.6864439324116743

In [331]:
mate_stats_invalid_df = mate_stats_df[lambda df: df['invalid_frame_0'] | df['invalid_frame_1']]

In [335]:
indexfile = '/fast/jamesn8/assembly_data/mate_torch_norm_match/index.txt'
with open(indexfile,'r') as f:
    datalist = [[int(val) for val in os.path.splitext(l.rstrip())[0].split('-')] for l in f.readlines()]

In [333]:
mate_df_indexed = mate_df.copy()
mate_df_indexed['MateIndex'] = mate_df_indexed.index
mate_df_indexed['MateID'] = [str(tup[0]) + '-' + '-'.join(sorted(tup[1:])) for tup in zip(mate_df_indexed['Assembly'], mate_df_indexed['Part1'], mate_df_indexed['Part2'])]
mate_df_indexed.set_index('MateID', inplace=True)

In [None]:
datalist_df = ps.DataFrame(datalist, columns=['assembly','part1','part2'])
datalist_df.join(part_df[['Assembly','PartOccurrenceID']].set_index('Assembly', append=True), on=['part1','assembly'])

In [337]:
filtered_indexfile = [] #only the mate data that actually has a mate (and is a valid assembly)
excluded_based_on_mates = []
excluded_based_on_assembly = [] #those data whose assembly is not a valid one
for i,l in enumerate(datalist):
    if i % 1000 == 0:
        print(f'processed{i}/{len(datalist)}')
    if l[0] in stats_df.index:
        occ1 = part_df.loc[l[1], 'PartOccurrenceID']
        occ2 = part_df.loc[l[2], 'PartOccurrenceID']
        key = str(l[0]) + '-' + '-'.join(sorted([occ1, occ2]))
        if key in mate_df_indexed.index:
            mates = mate_df_indexed.loc[key,'MateIndex']
            if isinstance(mates, np.int64):
                if mates in mate_stats_invalid_df.index:
                    excluded_based_on_mates.append(i)
                else:
                    filtered_indexfile.append(i)
            else:
                hasInvalid = False
                for mate in mates:
                    if mate in mate_stats_invalid_df.index:
                        excluded_based_on_mates.append(i)
                        hasInvalid = True
                        break
                if not hasInvalid:
                    filtered_indexfile.append(i)
    else:
        excluded_based_on_assembly.append(i)

processed0/429600
processed1000/429600
processed2000/429600
processed3000/429600
processed4000/429600
processed5000/429600
processed6000/429600
processed7000/429600
processed8000/429600
processed9000/429600
processed10000/429600
processed11000/429600
processed12000/429600
processed13000/429600
processed14000/429600
processed15000/429600
processed16000/429600
processed17000/429600
processed18000/429600
processed19000/429600
processed20000/429600
processed21000/429600
processed22000/429600
processed23000/429600
processed24000/429600
processed25000/429600
processed26000/429600
processed27000/429600
processed28000/429600
processed29000/429600
processed30000/429600
processed31000/429600
processed32000/429600
processed33000/429600
processed34000/429600
processed35000/429600
processed36000/429600
processed37000/429600
processed38000/429600
processed39000/429600
processed40000/429600
processed41000/429600
processed42000/429600
processed43000/429600
processed44000/429600
processed45000/429600
p

In [338]:
print(len(filtered_indexfile))
print(len(excluded_based_on_assembly))
print(len(excluded_based_on_mates))

82122
0
9420


In [339]:
excluded_set = set(excluded_based_on_assembly + excluded_based_on_mates)

In [340]:
datalist_filtered = ['-'.join([str(d) for d in data]) + '.dat\n' for i,data in enumerate(datalist) if i not in excluded_set]

In [341]:
random.shuffle(datalist_filtered)

In [342]:
print('retained',len(datalist_filtered)/len(datalist))

retained 0.978072625698324


In [343]:
indexfile_filtered = '/fast/jamesn8/assembly_data/mate_torch_norm_match/index_filtered.txt'
with open(indexfile_filtered,'w') as f:
    f.writelines(datalist_filtered)

## End datalist filtering

In [215]:
mate_df_filtered['Part1Path'] = [f'{ass}-{occ}' for ass,occ in zip(mate_df_filtered['Assembly'],mate_df_filtered['Part1'])]
mate_df_filtered['Part2Path'] = [f'{ass}-{occ}' for ass,occ in zip(mate_df_filtered['Assembly'],mate_df_filtered['Part2'])]

In [216]:
part1_ids = part_df_indexed.loc[mate_df_filtered['Part1Path'],'PartIndex']
part2_ids = part_df_indexed.loc[mate_df_filtered['Part2Path'],'PartIndex']

In [219]:
mate_df_filtered['Part1Id'] = part1_ids
mate_df_filtered['Part2Id'] = part2_ids

ValueError: cannot reindex from a duplicate axis

In [217]:
part1_ids.shape

(81582,)

In [193]:
mate_df.loc[mate_stats_valid_df[lambda df: df['mc_pair_found'].isnull()].index,'Assembly'].value_counts()

34413    648
38699    510
62935    421
33650    365
29318    365
        ... 
15324      1
15390      1
16607      1
16645      1
17417      1
Name: Assembly, Length: 1933, dtype: int64

In [194]:
stats_df[lambda df: df['proposal_time'].isnull()].shape[0]

2079

# Debugging the network

In [35]:
DJSON('/projects/grail/benjones/cadlab/dalton_lightning_logs/real_all_fn_args_amounts_sum_directedhybridgcn12_rerun.json')

<IPython.core.display.JSON object>

In [36]:
import json

In [38]:
path = '/projects/grail/benjones/cadlab/dalton_lightning_logs/'
for f in os.scandir(path):
    if f.name.endswith('.json'):
        with open(os.path.join(path, f.name), 'r') as j:
            s = json.load(j)
            if 'model_class' in s:
                print(f'found model class in {f.name}')

In [1]:
with open('/fast/jamesn8/assembly_data/indexfile.txt','r') as f:
    datalist = [l.rstrip() for l in f.readlines()]

In [4]:
import random

In [5]:
random.shuffle(datalist)

In [7]:
val_size = int(len(datalist)/5)

In [9]:
datalist_val = datalist[:val_size]
datalist_train = datalist[val_size:]

In [None]:
with open('/fast/jamesn8/assembly_data/indexfile_train.txt','w') as f:
    f.writelines([l + '\n' for l in datalist_train])

In [13]:
with open('/fast/jamesn8/assembly_data/indexfile_val.txt','w') as f:
    f.writelines([l + '\n' for l in datalist_val])

In [144]:
from automate.data.saved_dataset import SavedDataset
from torch_geometric.data import Batch
from automate.data.transforms import select_full_relations
from automate.lightning_models.simplified import SimplifiedJointModel
from pytorch_lightning.utilities.argparse import get_init_arguments_and_types
import torch
from torch_geometric.data.dataloader import DataLoader
from pytorch_lightning import Trainer

In [99]:
weights_path='/projects/grail/benjones/cadlab/dalton_lightning_logs/real_all_fn_args_amounts_sum_directedhybridgcn12/version_0/checkpoints/epoch=46-val_auc=0.666113.ckpt'
model = SimplifiedJointModel.load_from_checkpoint(weights_path, map_location=torch.device('cpu'))

RuntimeError: Error(s) in loading state_dict for SimplifiedJointModel:
	Unexpected key(s) in state_dict: "nodal.nn.f.1.weight", "nodal.nn.f.1.bias", "nodal.nn.f.1.running_mean", "nodal.nn.f.1.running_var", "nodal.nn.f.1.num_batches_tracked", "gcn.vert_edge_conv.nn.nn.f.1.weight", "gcn.vert_edge_conv.nn.nn.f.1.bias", "gcn.vert_edge_conv.nn.nn.f.1.running_mean", "gcn.vert_edge_conv.nn.nn.f.1.running_var", "gcn.vert_edge_conv.nn.nn.f.1.num_batches_tracked", "gcn.edge_loop_conv.nn.nn.f.1.weight", "gcn.edge_loop_conv.nn.nn.f.1.bias", "gcn.edge_loop_conv.nn.nn.f.1.running_mean", "gcn.edge_loop_conv.nn.nn.f.1.running_var", "gcn.edge_loop_conv.nn.nn.f.1.num_batches_tracked", "gcn.loop_face_conv.nn.nn.f.1.weight", "gcn.loop_face_conv.nn.nn.f.1.bias", "gcn.loop_face_conv.nn.nn.f.1.running_mean", "gcn.loop_face_conv.nn.nn.f.1.running_var", "gcn.loop_face_conv.nn.nn.f.1.num_batches_tracked", "gcn.face_face_convs.0.nn.nn.f.1.weight", "gcn.face_face_convs.0.nn.nn.f.1.bias", "gcn.face_face_convs.0.nn.nn.f.1.running_mean", "gcn.face_face_convs.0.nn.nn.f.1.running_var", "gcn.face_face_convs.0.nn.nn.f.1.num_batches_tracked", "gcn.face_face_convs.1.nn.nn.f.1.weight", "gcn.face_face_convs.1.nn.nn.f.1.bias", "gcn.face_face_convs.1.nn.nn.f.1.running_mean", "gcn.face_face_convs.1.nn.nn.f.1.running_var", "gcn.face_face_convs.1.nn.nn.f.1.num_batches_tracked", "gcn.face_face_convs.2.nn.nn.f.1.weight", "gcn.face_face_convs.2.nn.nn.f.1.bias", "gcn.face_face_convs.2.nn.nn.f.1.running_mean", "gcn.face_face_convs.2.nn.nn.f.1.running_var", "gcn.face_face_convs.2.nn.nn.f.1.num_batches_tracked", "gcn.face_face_convs.3.nn.nn.f.1.weight", "gcn.face_face_convs.3.nn.nn.f.1.bias", "gcn.face_face_convs.3.nn.nn.f.1.running_mean", "gcn.face_face_convs.3.nn.nn.f.1.running_var", "gcn.face_face_convs.3.nn.nn.f.1.num_batches_tracked", "gcn.face_face_convs.4.nn.nn.f.1.weight", "gcn.face_face_convs.4.nn.nn.f.1.bias", "gcn.face_face_convs.4.nn.nn.f.1.running_mean", "gcn.face_face_convs.4.nn.nn.f.1.running_var", "gcn.face_face_convs.4.nn.nn.f.1.num_batches_tracked", "gcn.face_face_convs.5.nn.nn.f.1.weight", "gcn.face_face_convs.5.nn.nn.f.1.bias", "gcn.face_face_convs.5.nn.nn.f.1.running_mean", "gcn.face_face_convs.5.nn.nn.f.1.running_var", "gcn.face_face_convs.5.nn.nn.f.1.num_batches_tracked", "gcn.face_loop_conv.nn.nn.f.1.weight", "gcn.face_loop_conv.nn.nn.f.1.bias", "gcn.face_loop_conv.nn.nn.f.1.running_mean", "gcn.face_loop_conv.nn.nn.f.1.running_var", "gcn.face_loop_conv.nn.nn.f.1.num_batches_tracked", "gcn.loop_edge_conv.nn.nn.f.1.weight", "gcn.loop_edge_conv.nn.nn.f.1.bias", "gcn.loop_edge_conv.nn.nn.f.1.running_mean", "gcn.loop_edge_conv.nn.nn.f.1.running_var", "gcn.loop_edge_conv.nn.nn.f.1.num_batches_tracked", "gcn.edge_vert_conv.nn.nn.f.1.weight", "gcn.edge_vert_conv.nn.nn.f.1.bias", "gcn.edge_vert_conv.nn.nn.f.1.running_mean", "gcn.edge_vert_conv.nn.nn.f.1.running_var", "gcn.edge_vert_conv.nn.nn.f.1.num_batches_tracked", "classifier.nn.0.f.1.weight", "classifier.nn.0.f.1.bias", "classifier.nn.0.f.1.running_mean", "classifier.nn.0.f.1.running_var", "classifier.nn.0.f.1.num_batches_tracked". 

In [313]:
indexfile = '/fast/jamesn8/assembly_data/mate_torch_norm_match/index_partial2_filtered_train.txt'

In [314]:
dataset = SavedDataset(indexfile, '/fast/jamesn8/assembly_data/mate_torch_norm_match/data')

In [325]:
dataset[57].mc_pair_labels.shape

torch.Size([31612])

In [324]:
len(dataset[57].mc_pair_labels.nonzero())

63

In [103]:
follow_batch=[
            'node_types_g1', 'node_types_g2',
            'mc_index_g1', 'mc_index_g2',
            'mc_pair_labels',
            'left_mc_individual_labels', 'right_mc_individual_labels'
        ]
dataloader = DataLoader(dataset, batch_size=4, follow_batch=follow_batch)

In [15]:
from IPython.display import clear_output

In [105]:
invalid_inds = []
invalid_inds2 = []
N = len(dataset)
for i,batch in enumerate(dataloader):
    if i % 25 == 0:
        clear_output(wait=True)
        display(f'num_processed: {i}/{N}; invalid1: {len(invalid_inds)}; invalid2: {len(invalid_inds2)}')
    try:
        preds = model(batch)
    except ValueError as e:
        invalid_inds.append(i)
    except IndexError as e:
        invalid_inds2.append(i)

'num_processed: 450/83084; invalid1: 0; invalid2: 0'

KeyboardInterrupt: 

In [128]:
model_train = model.train()

In [146]:
model_train.automatic_optimization = False
model_train.optimizers()

AttributeError: 'NoneType' object has no attribute 'lightning_optimizers'

In [147]:
for i,test_batch in enumerate(dataloader):
    if i  == 1:
        break
preds = model_train(test_batch)
target = test_batch.mc_pair_labels
loss = model_train.loss(preds, target)

In [148]:
loss

tensor(nan, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)

In [None]:
pos = target.sum()
neg = target.size(0) - pos
dev = target.device
loss = torch.nn.functional.binary_cross_entropy_with_logits(
    preds, target, pos_weight=torch.full((1,), neg / pos, device=dev)
)

In [133]:
loss.backward()

In [19]:
invalid_set = set(invalid_inds + invalid_inds2)
with open(indexfile) as f:
    datalist = f.readlines()
with open('/fast/jamesn8/assembly_data/indexfile_filtered.txt','w') as f:
    for i,l in enumerate(datalist[:10125]):
        if i not in invalid_set:
            f.write(l)

In [153]:
t = torch.full((2,), 3)

In [161]:
t.sum() < 3

tensor(False)