In [34]:
import pandas as pd
import json
import matplotlib.pyplot as plt
import os

from itertools import chain
from pathlib import *
from math import ceil
from sqlalchemy import create_engine

from helper import *
from tree_weight import *
from time import time


from typing import List
%matplotlib inline

In [59]:
# %%time
image_folder = 'graphs/'
tree = 'benchmark_models/jobshop/trees/jobshop_ft06.sqlite'
info_df = to_df(tree, 'info').set_index('NodeID')
nodes_df = to_df(tree, 'nodes').set_index('NodeID')

if 'DomainNodeWeight' in nodes_df.columns and 'DFSOrdering' in nodes_df.columns:
    test_df = nodes_df.drop(columns=['NodeWeight']).rename(columns={'DomainNodeWeight': 'NodeWeight'})
    dfs_ordering = nodes_df[~nodes_df['Status'].isin({3})].sort_values('DFSOrdering').index.to_list()
else:
    test_df = pd.DataFrame.copy(nodes_df)
    assign_weights(nodes_df, uniform_scheme)
    assign_weights(test_df, domain_weight_scheme, info_df=info_df)
    dfs_ordering = make_dfs_ordering(nodes_df)
    
uniform_cumsum = get_cum_weights(nodes_df, dfs_ordering)
domain_cumsum = get_cum_weights(test_df, dfs_ordering)

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(5,5), squeeze=True)

ax.set_title(tree.split('/')[1] + '_' + tree.split('/')[-1].strip('.sqlite'))
ax.plot(pd.Series(range(len(dfs_ordering))) / (len(dfs_ordering) - 1), label='Ground truth')
ax.plot(uniform_cumsum, label='Uniform scheme')
ax.plot(domain_cumsum, label='Domain scheme')
ax.tick_params(axis='x', rotation=90)
ax.legend()

if not os.path.exists(image_folder + ax.title.get_text()):
    fig.savefig(image_folder + ax.title.get_text())

    # write to sqlite file so we do not waste time recomputing

    engine = create_engine('sqlite:///' + tree)
    nodes_df['DomainNodeWeight'] = test_df['NodeWeight']
    nodes_df['DFSOrdering'] = -1
    nodes_df.loc[dfs_ordering, 'DFSOrdering'] = range(len(dfs_ordering))
    nodes_df.loc[:, 'DFSOrdering'] = nodes_df['DFSOrdering'].astype(int)
    write_df = nodes_df.reset_index().reindex(columns=['NodeID', 'ParentID', 'Alternative',
                                                  'NKids', 'Status', 'Label',
                                                  'NodeWeight', 'DomainNodeWeight', 'DFSOrdering'])
    write_df.to_sql('Nodes', engine, if_exists='replace', index=False)
    
plt.show()

AssertionError: 

In [71]:
info_df['ParentID'] = nodes_df['ParentID']
info_df['DomainSize'] = info_df['Info'].apply(get_domain_size)
info_df['DomainSize'] / info_df.groupby(['ParentID'])['DomainSize'].sum()

-1              NaN
 0       5.0724e+13
 1      1.61361e+09
 2          2391.86
 3          24475.5
           ...     
 776    2.61575e+10
 777            NaN
 778            NaN
 779            NaN
 780            NaN
Name: DomainSize, Length: 709, dtype: object

In [10]:
valid_df = pd.DataFrame.copy(nodes_df[~nodes_df['Status'].isin({3})])
valid_df.loc[0, 'NodeWeight'] = 1 # root node has weight 1

for node_id in range(valid_df.shape[0]):
    if valid_df[(valid_df['ParentID'] == node_id)].shape[0] == 0:
        continue
    
    kids = valid_df[valid_df['ParentID'] == node_id]
    domains = info_df.loc[kids.index, 'Info'].apply(get_domain_size)

    weights = (domains / domains.sum()).to_list()

    if abs(1 - sum(weights)) > 1e-6:
        print(weights)              

KeyError: "None of [Int64Index([2157, 2158], dtype='int64', name='NodeID')] are in the [index]"

In [None]:
invalid_status = [3]
# if no_good_no_domain:
#     invalid_status.append(1)
time_1 = 0
time_2 = 0

for node_id in range(nodes_df.shape[0]):

    kids = nodes_df[(nodes_df['ParentID'] == node_id)]
    non_restart_kids = kids[~kids['Status'].isin(invalid_status)]

    if non_restart_kids.shape[0] == 0:
#         print("Dead end node")
        pass
    else:
        start1 = time()
        weights1 = nodes_df[nodes_df['ParentID'] == node_id]['Status'].apply(lambda x: 1 / non_restart_kids.shape[0] if x not in invalid_status else 0)
        time_1 += time() - start1
        start2 = time()
        weights2 = (1 / non_restart_kids.shape[0]) * (~kids['Status'].isin(invalid_status))
        time_2 += time() - start2
        assert all(weights1 == weights2)
        assert abs(sum(weights1) - 1) < EPSILON
        assert abs(sum(weights2) - 1) < EPSILON

In [None]:
print("Time taken for apply method: ", time_1)
print("Time taken for series method: ", time_2)

nodes_df['
- total time: 20min
- make_dfs_ordering: 3m30s
- load info_df: 365ms
- load nodes_df: 404ms
- copy nodes_df: 7.96ms
- assign weights to nodes_df: 7min30s
- assign weights to test_df: 12min26s
- get_cum_weights for nodes_df: 156ms
- get_cum_weights for test_df: 125ms
