From 7f52b3b0812606356dc33270eaf90956414d2f0d Mon Sep 17 00:00:00 2001 From: John Sekar Date: Thu, 30 Sep 2021 02:49:23 -0400 Subject: [PATCH] building rete net.. merging behavior --- tests/test_rete_1.py | 49 ++++++++++++++++++++++- wc_rules/matcher/core.py | 4 ++ wc_rules/matcher/dbase.py | 8 +++- wc_rules/matcher/functionalize.py | 65 +++++++++++++++++++++++++------ wc_rules/utils/collections.py | 6 +++ 5 files changed, 117 insertions(+), 15 deletions(-) diff --git a/tests/test_rete_1.py b/tests/test_rete_1.py index f060fa3..f2a5fca 100644 --- a/tests/test_rete_1.py +++ b/tests/test_rete_1.py @@ -85,8 +85,8 @@ def test_single_edge_canonical_label(self): g = GraphContainer(Y('y',z=Z('z')).get_connected()) m,L,G = canonical_label(g) net.initialize_canonical_label(L,G) - net.initialize_collector(L,'XYEdge') - gnode, collector = net.get_node(core=L), net.get_node(core='collector_XYEdge') + net.initialize_collector(L,'YZEdge') + gnode, collector = net.get_node(core=L), net.get_node(core='collector_YZEdge') # command to push nodes y1 z1 @@ -121,7 +121,52 @@ def test_single_edge_canonical_label(self): # (one for add, one for remove) self.assertEqual(len(gnode.state.cache),0) self.assertEqual(len(collector.state.cache),2) + + def test_two_edges_canonical_label(self): + net = ReteNet.default_initialization() + ss = SimulationState(matcher=net) + + g = GraphContainer(Z('z',y=[Y('y1'),Y('y2')]).get_connected()) + m,L,G = canonical_label(g) + net.initialize_canonical_label(L,G) + net.initialize_collector(L,'ZYYGraph') + gnode, collector = net.get_node(core=L), net.get_node(core='collector_ZYYGraph') + + # command to push nodes y1,y2,z1, and edge y1-z1 + ss.push_to_stack([ + AddNode.make(Y,'y1'), + AddNode.make(Y,'y2'), + AddNode.make(Z,'z1'), + AddEdge('y1','z','z1','y'), + ] + ) + ss.simulate() + # both the rete node for the graph as well + # as the downstream collector + # should have zero cache entries + self.assertEqual(len(gnode.state.cache),0) + self.assertEqual(len(collector.state.cache),0) + + # add edge y2-z1 + ss.push_to_stack(AddEdge('y2','z','z1','y')) + ss.simulate() + + # rete node for graph should have two entries + # (z1-y1-y2) and (z2-y2-y1) + # and collector should have two entries as well + self.assertEqual(len(gnode.state.cache),2) + self.assertEqual(len(collector.state.cache),2) + + # remove edge y1-z1 + ss.push_to_stack(RemoveEdge('y1','z','z1','y')) + ss.simulate() + + # rete node for graph should have no entries + # but collector should have four entries + # (two for add, two for remove) + self.assertEqual(len(gnode.state.cache),0) + self.assertEqual(len(collector.state.cache),4) diff --git a/wc_rules/matcher/core.py b/wc_rules/matcher/core.py index 6f6fa0e..178bfde 100644 --- a/wc_rules/matcher/core.py +++ b/wc_rules/matcher/core.py @@ -86,6 +86,10 @@ def get_channel(self,**kwargs): def get_outgoing_channels(self,source): return [AttrDict(ch) for ch in Record.retrieve(self.channels,{'source':source})] + def get_channels(self,include_kwargs,exclude_kwargs=None): + includes = Record.retrieve_minus(self.channels,include_kwargs,exclude_kwargs) if exclude_kwargs is not None else Record.retrieve(self.channels,include_kwargs) + return [AttrDict(ch) for ch in includes] + def pprint(self,state=False): s = [] diff --git a/wc_rules/matcher/dbase.py b/wc_rules/matcher/dbase.py index 7fa3f27..24685e9 100644 --- a/wc_rules/matcher/dbase.py +++ b/wc_rules/matcher/dbase.py @@ -37,4 +37,10 @@ def insert(dbase,record): @staticmethod def remove(dbase,record): dbase.delete(dbase(**record)) - return \ No newline at end of file + return + + @staticmethod + def retrieve_minus(dbase,include_kwargs,exclude_kwargs): + exclude_ids = [x['__id__'] for x in dbase(**exclude_kwargs)] + includes = [dict(Record.itemize(x)) for x in dbase(**include_kwargs) if x['__id__'] not in exclude_ids] + return includes diff --git a/wc_rules/matcher/functionalize.py b/wc_rules/matcher/functionalize.py index 0701eaf..1c2532f 100644 --- a/wc_rules/matcher/functionalize.py +++ b/wc_rules/matcher/functionalize.py @@ -1,5 +1,6 @@ from frozendict import frozendict - +from collections import deque,Counter +from ..utils.collections import merge_lists,triple_split, subdict, merge_dicts, no_overlaps def function_node_start(net,node,elem): node.state.outgoing.append(elem) @@ -17,15 +18,15 @@ def function_node_collector(net,node,elem): def function_node_canonical_label(net,node,elem): clabel = node.core entry, action = [elem[x] for x in ['entry','action']] - if len(clabel.names)<=2: - # its a single node graph - if elem['action'] == 'AddEntry': - assert net.filter_cache(clabel,entry) == [] - net.insert_into_cache(clabel,entry) - if elem['action'] == 'RemoveEntry': - assert net.filter_cache(clabel,entry) != [] - net.remove_from_cache(clabel,entry) - node.state.outgoing.append({'entry':entry,'action':action}) + + if elem['action'] == 'AddEntry': + assert net.filter_cache(clabel,entry) == [] + net.insert_into_cache(clabel,entry) + if elem['action'] == 'RemoveEntry': + assert net.filter_cache(clabel,entry) != [] + net.remove_from_cache(clabel,entry) + node.state.outgoing.append({'entry':entry,'action':action}) + return net def function_channel_pass(net,channel,elem): @@ -52,10 +53,50 @@ def function_channel_transform_edge_token(net,channel,elem): return net def function_channel_merge(net,channel,elem): - if elem['action'] in ['AddEntry','RemoveEntry']: - pass + entry, action = channel.data.mapping.transform(elem['entry']), elem['action'] + if action not in ['AddEntry','RemoveEntry']: + return net + if action == 'AddEntry': + entries = [entry] + channels = [x for x in net.get_channels({'source':channel.source,'type':'merge'}) if x!=channel] + while channels: + channels = deque(sort_channels(channels,entries[0].keys())) + ch = channels.popleft() + entries = merge_lists([merge_from(net,e,ch.data.mapping,ch.source) for e in entries]) + if action == 'RemoveEntry': + entries = net.filter_cache(channel.target,entry) + + node = net.get_node(core=channel.target) + for e in entries: + node.state.incoming.append({'entry':e,'action':action}) return net +def sort_channels(channels,variables): + # maximize sharing, minimize number of variables that need to be extended to + def shared(ch): + return len(set(variables).intersection(ch.data.mapping.targets)) + def total(ch): + return len(ch.data.mapping.targets) + return sorted(channels,key = lambda ch: (-shared(ch),total(ch),) ) + +def merge_from(net,entry,mapping,source): + # the filter condition on the cache is: + # keys shared between entry and mapping.targets must have the same values + # the reject condition on any candidate for merging: + # keys unique to entry and mapping.targets must have unique values + merges = [] + L, M, R = triple_split(entry.keys(),mapping.targets) + Ld, Md = [subdict(entry,x) for x in [L,M]] + for x in net.filter_cache(source,mapping.reverse().transform(Md)): + Rd = subdict(mapping.transform(x),R) + if no_overlaps([Ld.values(),Rd.values()]): + merges.append(merge_dicts([Ld,Md,Rd])) + return merges + +def print_channel(channel,tab=''): + L = ['Channel:',hash(channel.source)%100, hash(channel.target)%100, channel.data.mapping] + return tab + ' '.join([str(x) for x in L]) + default_functionalization_methods = [method for name,method in globals().items() if name.startswith('function_')] \ No newline at end of file diff --git a/wc_rules/utils/collections.py b/wc_rules/utils/collections.py index 44e44a9..0bb8625 100644 --- a/wc_rules/utils/collections.py +++ b/wc_rules/utils/collections.py @@ -238,6 +238,12 @@ def subdict(d,keys,ignore=False): keys = [x for x in keys if x in d] return {k:d[k] for k in keys} +def triple_split(iter1,iter2): + # L1, L2 are iters + # returns (iter1-iter2), (iter1 & iter2), (iter2-iter1) + s1, s2 = set(iter1), set(iter2) + return [list(x) for x in [s1 - s2, s1 & s2, s2 - s1]] + def strgen(n,template='abcdefgh'): if n< len(template): return template[:n]