Skip to content

Commit

Permalink
building rete net... Add/Remove/Update Entry for pattern nodes
Browse files Browse the repository at this point in the history
TODO: mechanism of handling requests to helpers
  • Loading branch information
johnsekar committed Oct 6, 2021
1 parent 0126b52 commit 95bfc18
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 20 deletions.
2 changes: 1 addition & 1 deletion tests/test_rete_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_single_node_canonical_label(self):
# gnode collects true matches X('')
# collector collects all tokens coming out of gnode
self.assertEqual(len(gnode.state.cache),1)
self.assertEqual(gnode.state.cache[0]['a'],'x1')
self.assertEqual(gnode.state.cache[0]['a'].id,'x1')
self.assertEqual(len(collector.state.cache),1)

# Command to remove an instance X('x1')
Expand Down
4 changes: 2 additions & 2 deletions tests/test_rete_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_pattern_alias(self):
self.assertEqual(len(x.state.cache),math.factorial(n))

# check if filtering through alias works
fil = net.filter_cache(p,{'z1':'z1'})
fil = net.filter_cache(p,{'z1':ss.resolve('z1')})
self.assertEqual(len(fil),math.factorial(n))

# # # remove z1-y1 edge
Expand Down Expand Up @@ -113,5 +113,5 @@ def test_double_alias_pattern(self):


self.assertEqual(len(collector.state.cache),math.factorial(n))
self.assertEqual(len(net.filter_cache(q,{'z1':'z1'})), math.factorial(n))
self.assertEqual(len(net.filter_cache(q,{'z1':ss.resolve('z1')})), math.factorial(n))

3 changes: 1 addition & 2 deletions wc_rules/matcher/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ def filter_cache(self,core,elem):
if node.data.get('alias',False):
ch = self.get_channel(target=core,type='alias')
elem = ch.data.mapping.reverse().transform(elem)
return self.filter_cache(core=ch.source,elem=elem)

return self.filter_cache(core=ch.source,elem=elem)
return Record.retrieve(node.state.cache,elem)

def insert_into_cache(self,core,elem):
Expand Down
73 changes: 68 additions & 5 deletions wc_rules/matcher/functionalize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from frozendict import frozendict
from collections import deque,Counter
from collections import deque, Counter, ChainMap
#from ..utils.collections import merge_lists,triple_split, subdict, merge_dicts, no_overlaps, tuplify_dict
from ..utils.collections import merge_dicts_strictly, is_one_to_one
from copy import deepcopy

def function_node_start(net,node,elem):
node.state.outgoing.append(elem)
Expand All @@ -20,20 +21,55 @@ def function_node_canonical_label(net,node,elem):
clabel = node.core
entry, action = [elem[x] for x in ['entry','action']]

if elem['action'] == 'AddEntry':
if action == 'AddEntry':
assert net.filter_cache(clabel,entry) == []
net.insert_into_cache(clabel,entry)
if elem['action'] == 'RemoveEntry':
if action == 'RemoveEntry':
assert net.filter_cache(clabel,entry) != []
net.remove_from_cache(clabel,entry)
node.state.outgoing.append({'entry':entry,'action':action})
return net

def run_match_through_constraints(match,helpers,constraints):
extended_match = ChainMap(match,helpers)
terminate_early = False
for c in constraints:
value = c.exec(match)
if c.deps.declared_variable is not None:
extended_match[c.deps.declared_variable] = value
elif not value:
return None
return match

def function_node_pattern(net,node,elem):
if node.data.get('alias',False):
node.state.outgoing.append(elem)
else:
assert False, "Not Yet!"
return net
entry, action= [elem[x] for x in ['entry','action']]
outgoing_entries = []
if action == 'AddEntry':
match = {k:v for k,v in entry.items()}
match = run_match_through_constraints(match, node.data.helpers, node.data.constraints)
if match is not None:
net.insert_into_cache(node.core,match)
node.state.outgoing.append({'entry':match,'action':action})
if action == 'RemoveEntry':
elems = net.filter_cache(node.core,entry)
net.remove_from_cache(node.core,entry)
for e in elems:
node.state.outgoing.append({'entry':e,'action':action})
if action == 'UpdateEntry':
# update entry is resolved to AddEntry or RemoveEntry and inserted back into incoming queue
# ReteNet.sync(node) runs until incoming is empty
match = {k:v for k,v in entry.items()}
existing_elems = net.filter_cache(node.core,entry)
match = run_match_through_constraints(match,node.data.helpers,node.data.constraints)
if match is None and len(existing_elems)>0:
for e in existing_elems:
node.state.incoming.append({'entry':e,'action':'RemoveEntry'})
if match is not None and len(existing_elems)==0:
for e in existing_elems:
node.state.incoming.append({'entry':e,'action':'AddEntry'})
return net

def function_channel_pass(net,channel,elem):
Expand Down Expand Up @@ -87,4 +123,31 @@ def function_channel_alias(net,channel,elem):
node.state.incoming.append(outd)
return net

def function_channel_parent(net,channel,elem):
action = elem['action']
if action in ['AddEntry', 'RemoveEntry']:
node = net.get_node(core=channel.target)
entry = channel.data.mapping.transform(elem['entry'])
node.state.incoming.append({'entry':entry,'action':action})
return net

def function_channel_update(net,channel,elem):
action = elem['action']
if action == 'SetAttr':
# it asks for parent of the target
# then filters the parent cache to get candidate entries for target
# then asks target to do 'UpdateEntry' on all candidates
attr = elem['attr']
node = net.get_node(core=channel.target)
entry = channel.data.mapping.transform(elem)
parent_channel = net.get_channel(target=channel.target,type='parent')
parent = parent_channel.source
parent_mapping = parent_channel.data.mapping
filterelem = parent_mapping.reverse().transform(entry)
elems = [parent_mapping.transform(x) for x in net.filter_cache(parent,filterelem)]
for e in elems:
node.state.incoming.append({'entry':e,'action':'UpdateEntry','attr':attr})

return net

default_functionalization_methods = [method for name,method in globals().items() if name.startswith('function_')]
8 changes: 5 additions & 3 deletions wc_rules/matcher/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,20 +138,22 @@ def initialize_pattern(net,pattern):
constraint_attr_relationships.add((x,var,attr,))


helper_channels = set([(pattern,mapping) for _,pattern,mapping in constraint_pattern_relationships])
helper_channels = set([(net.get_node(core=pattern),mapping) for _,pattern,mapping in constraint_pattern_relationships])
attr_channels = set([(var,attr) for _,var,attr in constraint_attr_relationships])


# print(constraint_pattern_relationships)
# print(constraint_attr_relationships)
net.add_node(type='pattern',core=pattern,symmetry_group=symmetry_group,exprgraph=graph,helpers=helpers)
net.add_node(type='pattern',core=pattern,symmetry_group=symmetry_group,exprgraph=graph,helpers=helpers,constraints=constraint_objects)
names = [x for x in pattern.namespace if isinstance(pattern.namespace[x],type) and issubclass(pattern.namespace[x],BaseClass)]
net.initialize_cache(pattern,names)
net.add_channel(type='parent',source=pdict.parent,target=pattern,mapping=pdict.mapping)

for pname, mapping in helper_channels:
net.add_channel(type='update',source=pname,target=pattern,mapping=mapping)
for var,attr in attr_channels:
varclass = pattern.namespace[var]
mapping = Mapping.create(['a'],[var])
mapping = Mapping.create(['idx'],[var])
net.initialize_class(varclass)
net.add_channel(type='update',source=varclass,target=pattern,mapping=mapping)

Expand Down
7 changes: 7 additions & 0 deletions wc_rules/modeling/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ def namespace(self):
d[v] = "Assigned Variable"
return d

@property
def cache_variables(self):
if isinstance(self.parent,GraphContainer):
return self.parent.names + self.assigned_variables
return self.parent.cache_variables + self.assigned_variables


def asdict(self):
return dict(**self.namespace,constraints=self.constraints)

Expand Down
4 changes: 2 additions & 2 deletions wc_rules/schema/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def safely_remove_edge(self,attr,target):
return self

def __repr__(self):
return 'Object of {0} with id \'{1}\''.format(self.__class__,self.id)

#return 'Object of {0} with id \'{1}\''.format(self.__class__,self.id)
return f'<{self.__class__.__name__}: {self.id}>'

def pprint(self):
s = [f'<{self.__class__.__name__}: {self.id}>']
Expand Down
17 changes: 12 additions & 5 deletions wc_rules/simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,16 @@ def simulate(self):
action = self.action_stack.popleft()
if hasattr(action,'expand'):
self.push_to_stack(action.expand())
elif action.__class__.__name__ == 'RemoveNode':
self.rollback_stack.appendleft(action)
matcher_tokens = self.compile_to_matcher_tokens(action)
action.execute(self)
self.matcher.process(matcher_tokens)
else:
self.rollback_stack.appendleft(action)
action.execute(self)
self.matcher.process(self.compile_to_matcher_tokens(action))
matcher_tokens = self.compile_to_matcher_tokens(action)
self.matcher.process(matcher_tokens)
return self

def rollback(self):
Expand All @@ -62,17 +68,18 @@ def rollback(self):
def compile_to_matcher_tokens(self,action):
action_name = action.__class__.__name__
#d = {'AddNode':'add','RemoveNode':'remove','AddEdge':'add','RemoveEdge':'remove'}
# NOTE: WE"RE ATTACHING ACTUAL NODES HERE, NOT IDS, FIX action.idx,idx1,idx2 later
if action_name in ['AddNode','RemoveNode']:
return [make_node_token(action._class, action.idx, action_name)]
return [make_node_token(action._class, self.resolve(action.idx), action_name)]
if action_name in ['SetAttr']:
_class = self.resolve(action.idx).__class__
return [make_attr_token(_class, action.idx, action.attr, action.value, action_name)]
return [make_attr_token(_class, self.resolve(action.idx), action.attr, action.value, action_name)]
if action_name in ['AddEdge','RemoveEdge']:
i1,a1,i2,a2 = [getattr(action,x) for x in ['source_idx','source_attr','target_idx','target_attr']]
c1,c2 = [self.resolve(x).__class__ for x in [i1,i2]]
return [
make_edge_token(c1,i1,a1,i2,a2,action_name),
make_edge_token(c2,i2,a2,i1,a1,action_name)
make_edge_token(c1,self.resolve(i1),a1,self.resolve(i2),a2,action_name),
make_edge_token(c2,self.resolve(i2),a2,self.resolve(i1),a1,action_name)
]
return []

Expand Down

0 comments on commit 95bfc18

Please sign in to comment.