Skip to content

Commit

Permalink
building rete net.. merging behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
johnsekar committed Sep 30, 2021
1 parent 4fd376a commit 7f52b3b
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 15 deletions.
49 changes: 47 additions & 2 deletions tests/test_rete_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def test_single_edge_canonical_label(self):
g = GraphContainer(Y('y',z=Z('z')).get_connected())
m,L,G = canonical_label(g)
net.initialize_canonical_label(L,G)
net.initialize_collector(L,'XYEdge')
gnode, collector = net.get_node(core=L), net.get_node(core='collector_XYEdge')
net.initialize_collector(L,'YZEdge')
gnode, collector = net.get_node(core=L), net.get_node(core='collector_YZEdge')


# command to push nodes y1 z1
Expand Down Expand Up @@ -121,7 +121,52 @@ def test_single_edge_canonical_label(self):
# (one for add, one for remove)
self.assertEqual(len(gnode.state.cache),0)
self.assertEqual(len(collector.state.cache),2)

def test_two_edges_canonical_label(self):
net = ReteNet.default_initialization()
ss = SimulationState(matcher=net)

g = GraphContainer(Z('z',y=[Y('y1'),Y('y2')]).get_connected())
m,L,G = canonical_label(g)
net.initialize_canonical_label(L,G)
net.initialize_collector(L,'ZYYGraph')
gnode, collector = net.get_node(core=L), net.get_node(core='collector_ZYYGraph')

# command to push nodes y1,y2,z1, and edge y1-z1
ss.push_to_stack([
AddNode.make(Y,'y1'),
AddNode.make(Y,'y2'),
AddNode.make(Z,'z1'),
AddEdge('y1','z','z1','y'),
]
)
ss.simulate()

# both the rete node for the graph as well
# as the downstream collector
# should have zero cache entries
self.assertEqual(len(gnode.state.cache),0)
self.assertEqual(len(collector.state.cache),0)

# add edge y2-z1
ss.push_to_stack(AddEdge('y2','z','z1','y'))
ss.simulate()

# rete node for graph should have two entries
# (z1-y1-y2) and (z2-y2-y1)
# and collector should have two entries as well
self.assertEqual(len(gnode.state.cache),2)
self.assertEqual(len(collector.state.cache),2)

# remove edge y1-z1
ss.push_to_stack(RemoveEdge('y1','z','z1','y'))
ss.simulate()

# rete node for graph should have no entries
# but collector should have four entries
# (two for add, two for remove)
self.assertEqual(len(gnode.state.cache),0)
self.assertEqual(len(collector.state.cache),4)



4 changes: 4 additions & 0 deletions wc_rules/matcher/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def get_channel(self,**kwargs):
def get_outgoing_channels(self,source):
return [AttrDict(ch) for ch in Record.retrieve(self.channels,{'source':source})]

def get_channels(self,include_kwargs,exclude_kwargs=None):
includes = Record.retrieve_minus(self.channels,include_kwargs,exclude_kwargs) if exclude_kwargs is not None else Record.retrieve(self.channels,include_kwargs)
return [AttrDict(ch) for ch in includes]

def pprint(self,state=False):
s = []

Expand Down
8 changes: 7 additions & 1 deletion wc_rules/matcher/dbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,10 @@ def insert(dbase,record):
@staticmethod
def remove(dbase,record):
dbase.delete(dbase(**record))
return
return

@staticmethod
def retrieve_minus(dbase,include_kwargs,exclude_kwargs):
exclude_ids = [x['__id__'] for x in dbase(**exclude_kwargs)]
includes = [dict(Record.itemize(x)) for x in dbase(**include_kwargs) if x['__id__'] not in exclude_ids]
return includes
65 changes: 53 additions & 12 deletions wc_rules/matcher/functionalize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from frozendict import frozendict

from collections import deque,Counter
from ..utils.collections import merge_lists,triple_split, subdict, merge_dicts, no_overlaps

def function_node_start(net,node,elem):
node.state.outgoing.append(elem)
Expand All @@ -17,15 +18,15 @@ def function_node_collector(net,node,elem):
def function_node_canonical_label(net,node,elem):
clabel = node.core
entry, action = [elem[x] for x in ['entry','action']]
if len(clabel.names)<=2:
# its a single node graph
if elem['action'] == 'AddEntry':
assert net.filter_cache(clabel,entry) == []
net.insert_into_cache(clabel,entry)
if elem['action'] == 'RemoveEntry':
assert net.filter_cache(clabel,entry) != []
net.remove_from_cache(clabel,entry)
node.state.outgoing.append({'entry':entry,'action':action})

if elem['action'] == 'AddEntry':
assert net.filter_cache(clabel,entry) == []
net.insert_into_cache(clabel,entry)
if elem['action'] == 'RemoveEntry':
assert net.filter_cache(clabel,entry) != []
net.remove_from_cache(clabel,entry)
node.state.outgoing.append({'entry':entry,'action':action})

return net

def function_channel_pass(net,channel,elem):
Expand All @@ -52,10 +53,50 @@ def function_channel_transform_edge_token(net,channel,elem):
return net

def function_channel_merge(net,channel,elem):
if elem['action'] in ['AddEntry','RemoveEntry']:
pass
entry, action = channel.data.mapping.transform(elem['entry']), elem['action']
if action not in ['AddEntry','RemoveEntry']:
return net
if action == 'AddEntry':
entries = [entry]
channels = [x for x in net.get_channels({'source':channel.source,'type':'merge'}) if x!=channel]
while channels:
channels = deque(sort_channels(channels,entries[0].keys()))
ch = channels.popleft()
entries = merge_lists([merge_from(net,e,ch.data.mapping,ch.source) for e in entries])
if action == 'RemoveEntry':
entries = net.filter_cache(channel.target,entry)

node = net.get_node(core=channel.target)
for e in entries:
node.state.incoming.append({'entry':e,'action':action})
return net

def sort_channels(channels,variables):
# maximize sharing, minimize number of variables that need to be extended to
def shared(ch):
return len(set(variables).intersection(ch.data.mapping.targets))
def total(ch):
return len(ch.data.mapping.targets)
return sorted(channels,key = lambda ch: (-shared(ch),total(ch),) )

def merge_from(net,entry,mapping,source):
# the filter condition on the cache is:
# keys shared between entry and mapping.targets must have the same values
# the reject condition on any candidate for merging:
# keys unique to entry and mapping.targets must have unique values
merges = []
L, M, R = triple_split(entry.keys(),mapping.targets)
Ld, Md = [subdict(entry,x) for x in [L,M]]
for x in net.filter_cache(source,mapping.reverse().transform(Md)):
Rd = subdict(mapping.transform(x),R)
if no_overlaps([Ld.values(),Rd.values()]):
merges.append(merge_dicts([Ld,Md,Rd]))
return merges

def print_channel(channel,tab=''):
L = ['Channel:',hash(channel.source)%100, hash(channel.target)%100, channel.data.mapping]
return tab + ' '.join([str(x) for x in L])



default_functionalization_methods = [method for name,method in globals().items() if name.startswith('function_')]
6 changes: 6 additions & 0 deletions wc_rules/utils/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ def subdict(d,keys,ignore=False):
keys = [x for x in keys if x in d]
return {k:d[k] for k in keys}

def triple_split(iter1,iter2):
# L1, L2 are iters
# returns (iter1-iter2), (iter1 & iter2), (iter2-iter1)
s1, s2 = set(iter1), set(iter2)
return [list(x) for x in [s1 - s2, s1 & s2, s2 - s1]]

def strgen(n,template='abcdefgh'):
if n< len(template):
return template[:n]
Expand Down

0 comments on commit 7f52b3b

Please sign in to comment.