Skip to content

Commit

Permalink
bugfix and testing for rete net with pattern helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
johnsekar committed Oct 7, 2021
1 parent 01bcf05 commit da4535f
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 22 deletions.
73 changes: 73 additions & 0 deletions tests/test_rete_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,76 @@ def test_pattern_with_attr_updates_computefn(self):
self.assertEqual(len(collector.state.cache),6)


def test_pattern_with_helper(self):
net = ReteNet.default_initialization()
ss = SimulationState(matcher=net)

p1 = Pattern(parent=GraphContainer([Y('n',a=True)]))
z = Z('z',y=[Y('y')])
p2 = Pattern(
parent=GraphContainer(z.get_connected()),
helpers = {'ypos':p1},
constraints = ['ypos.contains(n=y) == True']
)

net.initialize_pattern(p2)
# check retenet structure
self.assertTrue(net.get_channel(source=p1,target=p2,type='update'))


net.initialize_collector(p1,'p1')
net.initialize_collector(p2,'p2')
node1, node2, collector1, collector2 = [net.get_node(core=x) for x in [p1,p2,'collector_p1','collector_p2']]
rete_nodes = [node1, collector1, node2, collector2]

# First push two nodes Z, Y1(true),
# p1 must be non-empty, p2 must be empty
ss.push_to_stack([
AddNode.make(Z,'z1'),
AddNode.make(Y,'y1',{'a':True}),
])
ss.simulate()

for x,y in zip(rete_nodes,[1,1,0,0]):
self.assertEqual(len(x.state.cache),y)

# now add edge y1-z1
# both p1 and p2 must be non-empty
ss.push_to_stack([
AddEdge('z1','y','y1','z'),
])
ss.simulate()

for x,y in zip(rete_nodes,[1,1,1,1]):
self.assertEqual(len(x.state.cache),y)

# now set y1(false)
# p1 and p2 must be empty,
# collectors must have 1 add and 1 remove entry each
ss.push_to_stack([
SetAttr('y1','a',False,True)
])
ss.simulate()

for x,y in zip(rete_nodes,[0,2,0,2]):
self.assertEqual(len(x.state.cache),y)

# set y1(true) and p1 and p2 must go back to 1 each
# collectors increment by 1
ss.push_to_stack([
SetAttr('y1','a',True,False)
])
ss.simulate()

for x,y in zip(rete_nodes,[1,3,1,3]):
self.assertEqual(len(x.state.cache),y)

# remove edge and p2 must go back to zero
# collector_p2 increment by 1
ss.push_to_stack([
RemoveEdge('y1','z','z1','y'),
])
ss.simulate()

for x,y in zip(rete_nodes,[1,3,0,4]):
self.assertEqual(len(x.state.cache),y)
45 changes: 26 additions & 19 deletions wc_rules/matcher/functionalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def run_match_through_constraints(match,helpers,constraints):
extended_match = ChainMap(match,helpers)
terminate_early = False
for c in constraints:
value = c.exec(match)
value = c.exec(extended_match)
if c.deps.declared_variable is not None:
extended_match[c.deps.declared_variable] = value
elif not value:
Expand All @@ -58,18 +58,22 @@ def function_node_pattern(net,node,elem):
net.remove_from_cache(node.core,entry)
for e in elems:
node.state.outgoing.append({'entry':e,'action':action})

if action == 'UpdateEntry':
# update entry is resolved to AddEntry or RemoveEntry and inserted back into incoming queue
# ReteNet.sync(node) runs until incoming is empty

match = {k:v for k,v in entry.items()}
existing_elems = net.filter_cache(node.core,entry)
match = run_match_through_constraints(match,node.data.helpers,node.data.constraints)
if match is None:
if len(existing_elems) > 0:
for e in existing_elems:
node.state.incoming.append({'entry':e,'action':'RemoveEntry'})
elif match is not None:
if len(existing_elems) == 0:
node.state.incoming.append({'entry':match,'action':'AddEntry'})

if match is None and len(existing_elems)>0:
for e in existing_elems:
node.state.incoming.append({'entry':e,'action':'RemoveEntry'})
if match is not None and len(existing_elems)==0:
node.state.incoming.append({'entry':match,'action':'AddEntry'})
return net

def function_channel_pass(net,channel,elem):
Expand Down Expand Up @@ -134,23 +138,26 @@ def function_channel_parent(net,channel,elem):
def function_channel_update(net,channel,elem):
action = elem['action']

if action in ['AddEdge','RemoveEdge','SetAttr']:
if action in ['AddEdge','RemoveEdge','SetAttr','AddEntry','RemoveEntry']:
# it asks for parent of the target
# then filters the parent cache to get candidate entries for target
# then asks target to do 'UpdateEntry' on all candidates
attr = elem['attr'] if action == 'SetAttr' else elem['attr1']

if attr == channel.data.get('attr',None):
if action in ['AddEdge','RemoveEdge']:
attr = elem['attr1']
elif action == 'SetAttr':
attr = elem['attr']
else:
attr = None

node = net.get_node(core=channel.target)
entry = channel.data.mapping.transform(elem)
parent_channel = net.get_channel(target=channel.target,type='parent')
parent = parent_channel.source
parent_mapping = parent_channel.data.mapping
filterelem = parent_mapping.reverse().transform(entry)
elems = [parent_mapping.transform(x) for x in net.filter_cache(parent,filterelem)]
for e in elems:
node.state.incoming.append({'entry':e,'action':'UpdateEntry','attr':attr})
node = net.get_node(core=channel.target)
entry = channel.data.mapping.transform(elem)
parent_channel = net.get_channel(target=channel.target,type='parent')
parent = parent_channel.source
parent_mapping = parent_channel.data.mapping
filterelem = parent_mapping.reverse().transform(entry)
elems = [parent_mapping.transform(x) for x in net.filter_cache(parent,filterelem)]
for e in elems:
node.state.incoming.append({'entry':e,'action':'UpdateEntry','attr':attr})

return net

Expand Down
7 changes: 4 additions & 3 deletions wc_rules/matcher/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,18 +138,19 @@ def initialize_pattern(net,pattern):
constraint_attr_relationships.add((x,var,attr,))


helper_channels = set([(net.get_node(core=pattern),mapping) for _,pattern,mapping in constraint_pattern_relationships])
helper_channels = set([(pname,mapping) for _,pname,mapping in constraint_pattern_relationships])
attr_channels = set([(var,attr) for _,var,attr in constraint_attr_relationships])

# print(constraint_pattern_relationships)
# print(constraint_attr_relationships)
net.add_node(type='pattern',core=pattern,symmetry_group=symmetry_group,exprgraph=graph,helpers=helpers,constraints=constraint_objects)
resolved_helpers = {h:net.get_node(core=helpers[h]).state for h in helpers}
net.add_node(type='pattern',core=pattern,symmetry_group=symmetry_group,exprgraph=graph,helpers=resolved_helpers,constraints=constraint_objects)
names = [x for x in pattern.namespace if isinstance(pattern.namespace[x],type) and issubclass(pattern.namespace[x],BaseClass)]
net.initialize_cache(pattern,names)
net.add_channel(type='parent',source=pdict.parent,target=pattern,mapping=pdict.mapping)

for pname, mapping in helper_channels:
net.add_channel(type='update',source=pname,target=pattern,mapping=mapping)
net.add_channel(type='update',source=helpers[pname],target=pattern,mapping=mapping)

for var,attr in attr_channels:
varclass = pattern.namespace[var]
Expand Down

0 comments on commit da4535f

Please sign in to comment.