In [21]:
import pandas as pd
import json
import matplotlib.pyplot as plt
import os

from itertools import chain
from pathlib import *
from math import ceil
from sqlalchemy import create_engine

from helper import *
from tree_weight import *
from time import time
from main import make_graph_from_tree

from typing import List
%matplotlib inline

In [22]:
image_folder = 'graphs/'
tree = 'benchmark_models/league/trees/model15-4-3.sqlite'
info_df = to_df(tree, 'info').set_index('NodeID')
nodes_df = to_df(tree, 'nodes').set_index('NodeID')
# # nodes_df['Label'].str.split(' ', expand=True)[1].unique()
# a = nodes_df[nodes_df['Label'] != '']['Label'].str.split(' ', expand=True)[0]
# # a
# (fig, ax), nodes_df, cum_sums = \
#     make_graph_from_tree(image_folder, tree, 
#     schemes=[
#         uniform_scheme, 
#         domain_scheme, 
#         searchSpace_scheme
#     ],
#     write_to_sqlite=False)

# if 'NodeWeight' in nodes_df.columns:
#     nodes_df = nodes_df.rename(columns={'NodeWeight': 'UniformNodeWeight'})
#     engine = create_engine('sqlite:///' + tree)
#     write_df = nodes_df.reset_index()
#     write_df.to_sql('Nodes', engine, if_exists='replace', index=False)

In [52]:
# find the variable with no intersection
# no goods domains are unreliable (one or more variable should have a null domain)
# but the split variable's domain should still be reliable (?)
# some split labels are not appearing in the domains and vice versa due to the domains outputing
# only certain names in flatzinc, and some variables are never split on but filled by 
# propogation

def pairwise_diff(lst):
    return len(set(lst)) == len(lst)

def is_unequal_split(nodes_df, children_idx):    
    return set(nodes_df.loc[children_idx, 'Label']\
                   .str.split(' ', expand=True)[1]\
                   .unique())\
        == set(['=', '!='])

def has_no_split_children(nodes_df, par_idx):
    return nodes_df[nodes_df['ParentID'] == par_idx]['Status'].isin({2}).sum() == 0


def find_split_variable(par_idx, nodes_df, info_df, mappings: dict={}):
    # mappings between label_name -> info_name
    children_idx = nodes_df[nodes_df['ParentID'] == par_idx].index
    is_skipped = nodes_df[nodes_df['ParentID'] == par_idx]['Status'] == 3
    cands = []
    
    if nodes_df.loc[par_idx, 'Status'] != 2 or len(children_idx) == 0 or is_skipped.sum() == len(children_idx):
        return []
    else:
        label_var = nodes_df.loc[children_idx[0], 'Label'].split(' ')[0]
        par_domain = parse_info_string(info_df.loc[par_idx, 'Info'])
        if label_var in mappings:
            return [mappings[label_var]]
    
    if is_skipped.sum() > 0:
        # we have only one node to rely on to get the split variable
        child_idx = children_idx[~is_skipped][0]
        child_domain = parse_info_string(info_df.loc[child_idx, 'Info'])
        split_val = int(nodes_df.loc[child_idx, 'Label'].split('=')[1])
        
        for variable in par_domain:
            rule_1 = variable in child_domain
            if '!' in nodes_df.loc[child_idx, 'Label']:
                # case 1: label != split_val
                rule_2 = split_val not in child_domain[variable] and \
                         child_domain[variable].union({split_val}) == par_domain[variable]
            else:
                # case 2: label = split_val
                rule_2 = {split_val} == child_domain[variable] and split_val in par_domain[variable]
            if rule_1 and rule_2 and variable not in mappings:
                cands.append(variable)
                
    else:

        split_vals = nodes_df.loc[children_idx, 'Label'].str.split('=', expand=True)[1].astype(int).unique().flatten()
        children_domain = [
            parse_info_string(info_df.loc[child_id, 'Info']) for child_id in children_idx
        ]

        # children domain may include domains of no-goods which are unreliable
        # for now we ignore this thorny problem

        if not is_unequal_split(nodes_df, children_idx):
            assert len(split_vals) == len(children_idx)
            for variable in par_domain:
                # each child should have a label = split_value
                rule_1 = all([children_domain[i][variable] == {split_vals[i]} for i in range(len(children_domain))])
                rule_2 = set.union(*[children_domain[i][variable] for i in range(len(children_domain))])\
                            == par_domain[variable]
                rule_3 = pairwise_diff(split_vals)

                if rule_1 and rule_2 and rule_3:
                    cands.append(variable)      

        else:
            assert len(children_domain) == 2
            assert len(split_vals) == 1
            split_val = split_vals[0]
            for variable in par_domain:
                child_1, child_2 = children_domain

                # case 1
                rule_1 = ({split_val} == child_1[variable]) and (split_val not in child_2[variable])
                rule_2 = child_2[variable].union({split_val}) == par_domain[variable]
                # case 2
                rule_3 = {split_val} == child_2[variable] and split_val not in child_1[variable]
                rule_4 = child_1[variable].union({split_val}) == par_domain[variable]

                if (rule_1 and rule_2) or (rule_3 and rule_4):
                    cands.append(variable)

    # filter by not set in par
    cands = [name for name in cands if not len(par_domain[name]) == 1]
                    
    if len(cands) == 1:
        mappings[label_var] = cands[0]
        mappings[cands[0]] = label_var
    return cands

In [53]:
mappings = {}
all_split_vars = nodes_df['Label'].str.split(' ', expand=True)[0].unique()
all_split_vars = set(all_split_vars) - set([''])

for node_idx in range(nodes_df.shape[0]):
#     if nodes_df.loc[node_idx, 'Status'] != 2:
#         print('Not a parent node')
#     else:
    cands = find_split_variable(node_idx, nodes_df, info_df, mappings)
    if len(cands) > 2:
        import pdb; pdb.set_trace()

    if len(mappings)// 2 == len(all_split_vars):
        break

In [54]:
mappings

{'assign_to[1]': 'X_INTRODUCED_3_',
 'X_INTRODUCED_3_': 'assign_to[1]',
 'assign_to[2]': 'X_INTRODUCED_4_',
 'X_INTRODUCED_4_': 'assign_to[2]',
 'assign_to[3]': 'X_INTRODUCED_5_',
 'X_INTRODUCED_5_': 'assign_to[3]',
 'assign_to[4]': 'X_INTRODUCED_6_',
 'X_INTRODUCED_6_': 'assign_to[4]',
 'assign_to[5]': 'X_INTRODUCED_7_',
 'X_INTRODUCED_7_': 'assign_to[5]',
 'assign_to[6]': 'X_INTRODUCED_8_',
 'X_INTRODUCED_8_': 'assign_to[6]',
 'assign_to[7]': 'X_INTRODUCED_9_',
 'X_INTRODUCED_9_': 'assign_to[7]',
 'assign_to[8]': 'X_INTRODUCED_10_',
 'X_INTRODUCED_10_': 'assign_to[8]',
 'assign_to[9]': 'X_INTRODUCED_11_',
 'X_INTRODUCED_11_': 'assign_to[9]',
 'assign_to[12]': 'X_INTRODUCED_14_',
 'X_INTRODUCED_14_': 'assign_to[12]',
 'assign_to[13]': 'X_INTRODUCED_15_',
 'X_INTRODUCED_15_': 'assign_to[13]',
 'assign_to[14]': 'X_INTRODUCED_16_',
 'X_INTRODUCED_16_': 'assign_to[14]',
 'assign_to[15]': 'X_INTRODUCED_17_',
 'X_INTRODUCED_17_': 'assign_to[15]'}

In [55]:
all_split_vars

{'assign_to[10]',
 'assign_to[12]',
 'assign_to[13]',
 'assign_to[14]',
 'assign_to[15]',
 'assign_to[1]',
 'assign_to[2]',
 'assign_to[3]',
 'assign_to[4]',
 'assign_to[5]',
 'assign_to[6]',
 'assign_to[7]',
 'assign_to[8]',
 'assign_to[9]',
 'rank_diff[1]',
 'rank_diff[2]'}

In [40]:
x = '''{
	"domains": "var 49992..89992: obj;
var int: X_INTRODUCED_3_ = 3;
var int: X_INTRODUCED_4_ = 3;
var int: X_INTRODUCED_5_ = 3;
var int: X_INTRODUCED_6_ = 3;
var int: X_INTRODUCED_7_ = 3;
var int: X_INTRODUCED_8_ = 2;
var int: X_INTRODUCED_9_ = 2;
var int: X_INTRODUCED_10_ = 2;
var int: X_INTRODUCED_11_ = 2;
var int: X_INTRODUCED_12_ = 2;
var int: X_INTRODUCED_13_ = 1;
var int: X_INTRODUCED_14_ = 1;
var int: X_INTRODUCED_15_ = 1;
var int: X_INTRODUCED_16_ = 1;
var int: X_INTRODUCED_17_ = 1;
var int: X_INTRODUCED_39_ = 2;
var int: X_INTRODUCED_40_ = 3;
var int: X_INTRODUCED_41_ = 3;
var int: X_INTRODUCED_18_ = 4;
var int: X_INTRODUCED_19_ = 4;
var int: X_INTRODUCED_20_ = 4;
var 1..3: X_INTRODUCED_21_;
var 1..3: X_INTRODUCED_22_;
var int: X_INTRODUCED_23_ = 1;
"
}'''
info_df[info_df['Info'] == x].index

Int64Index([19], dtype='int64', name='NodeID')

In [None]:
nodes_df.loc[[1, 2], 'Label'].str.split('=').e

In [None]:
nodes_df[nodes_df['ParentID'] == node_idx]['Label'].str.split(' ', expand=True)

In [None]:
info_df['Info'].apply(parse_info_string).apply(lambda x: len(x)).nunique()

In [None]:
# NOTES

# some children's domains are larger than parent's domain, but these are dead ends that are not updated to be null

In [None]:
nodes_df.reset_index().reindex(columns=['NodeID', 'ParentID', 'Alternative', 'NKids', 'Status', 'Label'] + nodes_df.columns)

In [None]:
engine = create_engine('sqlite:///' + tree)
write_df = nodes_df.reset_index().drop(columns='NodeWeight')
write_df.to_sql('Nodes', engine, if_exists='replace', index=False)

In [None]:
print("Time taken for apply method: ", time_1)
print("Time taken for series method: ", time_2)

nodes_df['
- total time: 20min
- make_dfs_ordering: 3m30s
- load info_df: 365ms
- load nodes_df: 404ms
- copy nodes_df: 7.96ms
- assign weights to nodes_df: 7min30s
- assign weights to test_df: 12min26s
- get_cum_weights for nodes_df: 156ms
- get_cum_weights for test_df: 125ms
