Skip to content

Commit

Permalink
temp commit - moving function_node_* methods to NodeBehavior classes
Browse files Browse the repository at this point in the history
  • Loading branch information
johnsekar committed Jan 16, 2022
1 parent d2b2016 commit 416e8c1
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 56 deletions.
58 changes: 48 additions & 10 deletions wc_rules/matcher/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ def sample_cache(self):
return dict(Record.itemize(random.choice(self.filter())))

class ReteNodeStateWrapper:
# wrapper is useful to manage aliases
# typically a node's wrapper maps to the node's state with
# an identity mapping
# if an alias, then the wrapper points to the parent's state
# and the mapping is used to convert in both directions

def __init__(self,target,mapping=None):
self.target = target
Expand All @@ -71,9 +76,12 @@ def set(self,target,mapping=None):

class ReteNet:

# tracer is used for capturing output of specific nodes

def __init__(self):
self.nodes = initialize_database(['type','core','data','state','wrapper','num'])
self.channels = initialize_database(['type','source','target','data','num'])
self.tracer = False
self.nodemax = 0
self.channelmax = 0

Expand All @@ -83,6 +91,10 @@ def configure(self,method,overwrite=False):
setattr(self,method.__name__,m)
return self

def configure_tracer(self,nodes=[],channels=[]):
self.tracer = AttrDict(nodes=nodes,channels=channels)
return self

def add_node(self,**kwargs):
record = {k:kwargs.pop(k) for k in ['type','core']}
record.update(dict(data=kwargs,state=ReteNodeState()))
Expand Down Expand Up @@ -141,31 +153,56 @@ def get_channels(self,include_kwargs,exclude_kwargs=None):
includes = Record.retrieve_minus(self.channels,include_kwargs,exclude_kwargs) if exclude_kwargs is not None else Record.retrieve(self.channels,include_kwargs)
return [AttrDict(ch) for ch in includes]

def pprint(self,state=False):
def pprint_node(self,num,state=False):
s = []
node = self.get_node(num=num)
desc = Record.print(node,ignore_keys=['state','num'])
s += [f"Node {num}\n{desc}"]
if state:
s1 = node['state'].pprint()
s[-1] += f'\n{SEP}state:\n{s1}'
return '\n'.join(s)

def printfn(x):
return Record.print(node,ignore_keys=['state','num'])
def trace_node(self,num):
return self.tracer and num in self.tracer.nodes

def trace_channel(self,num):
return self.tracer and num in self.tracer.channels

def trace_elem(self,elem):
if self.tracer:
print

for node in self.nodes:
s += [f"Node {node.get('num')}\n{printfn(node)}"]
if state:
s1 = node['state'].pprint()
s[-1] += f'\n{SEP}state:\n{s1}'
def pprint_channel(self,num):
channel = self.get_channel(num=num)
desc = Record.print(channel,ignore_keys=['num'])
return f"Channel {num}\n{desc}"

def pprint(self,state=False):
s = []
for node in self.nodes:
s.append(self.pprint_node(node.get('num'),state=state))
for channel in self.channels:
s += [f"Channel {channel.get('num')}\n{Record.print(channel,ignore_keys=['num'])}"]
s.append(self.pprint_channel(channel.get('num')))
return '\n'.join(s)

def sync(self,node):
log.debug(f'Syncing node {node.core}')
num = node.get('num')
trace = self.trace_node(num)
if node.state.outgoing:
if trace:
print(self.pprint_node(num,state=True))
log.debug(f'{SEP}Outgoing: {node.state.outgoing}')
elem = node.state.outgoing.popleft()
if trace:
print(f'Popping elem {elem}')
channels = self.get_outgoing_channels(node.core)
for channel in channels:
chnum = channel.get('num')
log.debug(f'{SEP}Channel: {channel}')
method = getattr(self,f'function_channel_{channel.type}')
if self.trace_channel(chnum):
print(self.pprint_channel(chnum))
method(channel,elem)
self.sync(self.get_node(core=channel.target))
self.sync(node)
Expand All @@ -175,6 +212,7 @@ def sync(self,node):
method = getattr(self,f'function_node_{node.type}')
method(node,elem)
self.sync(node)

return self

def process(self,tokens):
Expand Down
197 changes: 153 additions & 44 deletions wc_rules/matcher/functionalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,92 @@
from copy import deepcopy
from attrdict import AttrDict
from ..expressions.builtins import global_builtins
from abc import ABC, abstractmethod

def function_node_start(net,node,elem):
node.state.outgoing.append(elem)
return net
class Behavior(ABC):

# an object to indicate a callable behavior on the rete-net
callsign = ''

@abstractmethod
def __call__(self,net,node_or_channel,elem):
return net

class NodeBehavior(Behavior):
@abstractmethod
def __call__(self,net,node,elem):
return net

def function_node_end(net,node,elem):
node.state.cache.append(elem)
return
class StartNode(NodeBehavior):
callsign = 'start'

def function_node_class(net,node,elem):
if issubclass(elem['_class'],node.core):
def __call__(self,net,node,elem):
node.state.outgoing.append(elem)
return net
return net

def function_node_collector(net,node,elem):
node.state.cache.append(elem)
return net
class EndNode(NodeBehavior):
callsign = 'end'
def __call__(self,net,node,elem):
node.state.cache.append(elem)
return

def function_node_canonical_label(net,node,elem):
clabel = node.core
entry, action = [elem[x] for x in ['entry','action']]

if action == 'AddEntry':
class ClassNode(NodeBehavior):
callsign = 'class'
def __call__(self,net,node,elem):
if issubclass(elem['_class'],node.core):
node.state.outgoing.append(elem)
return net

class CollectorNode(NodeBehavior):
callsign = 'collector'
def __call__(self,net,node,elem):
node.state.cache.append(elem)
return net

class NodeBehaviorByAction(NodeBehavior):
callsign = None
actions = set()
# each behavior_{action} method must return
# either an empty list
# or a list of entries to be appended to node.state.outgoing
def __call__(self,net,node,elem):
action = elem.get('action',None)
if action in self.actions:
beh = getattr(self,f'behavior_{action}',None)
incoming,outgoing = beh(net,node,elem['entry'])
node.state.incoming.extend(incoming)
node.state.outgoing.extend(outgoing)
return net

class CanonicalLabelNode(NodeBehaviorByAction):
callsign = 'canonical_label'
actions = set(['AddEntry','RemoveEntry'])

def behavior_AddEntry(self,net,node,entry):
clabel = node.core
assert net.filter_cache(clabel,entry) == []
net.insert_into_cache(clabel,entry)
if action == 'RemoveEntry':
assert net.filter_cache(clabel,entry) != []
net.remove_from_cache(clabel,entry)
node.state.outgoing.append({'entry':entry,'action':action})
return net
return [],[{'entry':entry,action:'AddEntry'}]

def behavior_RemoveEntry(self,net,node,entry):
clabel = node.core
assert net.filter_cache(clabel,entry) == []
net.insert_into_cache(clabel,entry)
return [], [{'entry':entry,action:'RemoveEntry'}]


# def function_node_canonical_label(net,node,elem):
# clabel = node.core
# entry, action = [elem[x] for x in ['entry','action']]

# if action == 'AddEntry':
# assert net.filter_cache(clabel,entry) == []
# net.insert_into_cache(clabel,entry)
# if action == 'RemoveEntry':
# assert net.filter_cache(clabel,entry) != []
# net.remove_from_cache(clabel,entry)
# node.state.outgoing.append({'entry':entry,'action':action})
# return net

def run_match_through_constraints(match,helpers,parameters,constraints):
extended_match = ChainMap(match,helpers,parameters)
Expand All @@ -46,41 +102,82 @@ def run_match_through_constraints(match,helpers,parameters,constraints):
elif not value:
return None
return match

def function_node_pattern(net,node,elem):
if node.data.get('alias',False):
node.state.outgoing.append(elem)

class PatternNode(NodeBehavior):
callsign = 'pattern'
actions = set(['AddEntry','RemoveEntry','UpdateEntry'])

def __call__(self,net,node,elem):
if node.data.get('alias',False):
node.state.outgoing.append(elem)
else:
super().__call__(net,node,elem)
return net
entry, action= [elem[x] for x in ['entry','action']]
outgoing_entries = []
if action == 'AddEntry':

def behavior_AddEntry(self,net,node,entry):
match = {k:v for k,v in entry.items()}
match = run_match_through_constraints(match, node.data.helpers, node.data.parameters, node.data.constraints)
if match is not None:
net.insert_into_cache(node.core,match)
node.state.outgoing.append({'entry':match,'action':action})
if action == 'RemoveEntry':
return [],[{'entry':match,'action':action}]
return [],[]

def behavior_RemoveEntry(self,net,node,entry):
elems = net.filter_cache(node.core,entry)
net.remove_from_cache(node.core,entry)
for e in elems:
node.state.outgoing.append({'entry':e,'action':action})

if action == 'UpdateEntry':
return [],[{'entry':e,'action':action} for e in elems]

def behavior_UpdateEntry(self,net,node,entry):
# update entry is resolved to AddEntry or RemoveEntry and inserted back into incoming queue
# ReteNet.sync(node) runs until incoming is empty

match = {k:v for k,v in entry.items()}
existing_elems = net.filter_cache(node.core,entry)
match = run_match_through_constraints(match,node.data.helpers,node.data.parameters,node.data.constraints)
# do this!!!!!
# update_entry shd typically be from attr nodes
# as well as edge nodes
incoming = []
if match is None:
if len(existing_elems) > 0:
for e in existing_elems:
node.state.incoming.append({'entry':e,'action':'RemoveEntry'})
elif match is not None:
incoming = [{'entry':e,'action':'RemoveEntry'} for e in elems]
else:
if len(existing_elems) == 0:
node.state.incoming.append({'entry':match,'action':'AddEntry'})
incoming = [{'entry':match,'action':'AddEntry'}]
return incoming,[]

return net

# def function_node_pattern(net,node,elem):
# if node.data.get('alias',False):
# node.state.outgoing.append(elem)
# return net
# entry, action= [elem[x] for x in ['entry','action']]
# outgoing_entries = []
# if action == 'AddEntry':
# match = {k:v for k,v in entry.items()}
# match = run_match_through_constraints(match, node.data.helpers, node.data.parameters, node.data.constraints)
# if match is not None:
# net.insert_into_cache(node.core,match)
# node.state.outgoing.append({'entry':match,'action':action})
# if action == 'RemoveEntry':
# elems = net.filter_cache(node.core,entry)
# net.remove_from_cache(node.core,entry)
# for e in elems:
# node.state.outgoing.append({'entry':e,'action':action})

# if action == 'UpdateEntry':
# # update entry is resolved to AddEntry or RemoveEntry and inserted back into incoming queue
# # ReteNet.sync(node) runs until incoming is empty

# match = {k:v for k,v in entry.items()}
# existing_elems = net.filter_cache(node.core,entry)
# match = run_match_through_constraints(match,node.data.helpers,node.data.parameters,node.data.constraints)
# if match is None:
# if len(existing_elems) > 0:
# for e in existing_elems:
# node.state.incoming.append({'entry':e,'action':'RemoveEntry'})
# elif match is not None:
# if len(existing_elems) == 0:
# node.state.incoming.append({'entry':match,'action':'AddEntry'})
# return net

def function_node_rule(net,node,elem):
if elem['action']=='UpdateRule':
Expand Down Expand Up @@ -153,7 +250,6 @@ def function_channel_parent(net,channel,elem):

def function_channel_update_pattern(net,channel,elem):
action = elem['action']

if action in ['AddEdge','RemoveEdge','SetAttr','AddEntry','RemoveEntry']:
# it asks for parent of the target
# then filters the parent cache to get candidate entries for target
Expand Down Expand Up @@ -201,4 +297,17 @@ def function_channel_update_variable(net,channel,elem):



default_functionalization_methods = [method for name,method in globals().items() if name.startswith('function_')]
#default_functionalization_methods = [method for name,method in globals().items() if name.startswith('function_')]
def subclass_iter(_class):
subs = _class.__subclasses__()
for sub in subs:
yield sub
for x in subclass_iter(sub):
yield x

default_functionalization_methods = []
for sub in subclass_iter(NodeBehavior):
if sub.callsign is not None:
fnobj = sub()
fnobj.__name__ = f'function_node_{sub.callsign}'
default_functionalization_methods.append(fnobj)
4 changes: 2 additions & 2 deletions wc_rules/matcher/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def initialize_pattern(net,pattern,parameters = dict()):
helper_channels = set([(pname,mapping) for _,pname,mapping in constraint_pattern_relationships])
attr_channels = set([(var,attr) for _,var,attr in constraint_attr_relationships])

# print(constraint_pattern_relationships)
# print(constraint_attr_relationships)
#print(constraint_pattern_relationships)
#print(constraint_attr_relationships)
resolved_helpers = {h:net.get_node(core=helpers[h]).wrapper for h in helpers}
net.add_node(type='pattern',core=pattern,symmetry_group=symmetry_group,exprgraph=graph,helpers=resolved_helpers,constraints=constraint_objects,parameters=parameters)
names = [x for x in pattern.namespace if isinstance(pattern.namespace[x],type) and issubclass(pattern.namespace[x],BaseClass)]
Expand Down
1 change: 1 addition & 0 deletions wc_rules/simulator/simulator2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def default_matcher(self,methods=None):
net.configure(method)
return net


def load_model(self,model,data):
self.matcher.initialize_start()
self.matcher.initialize_end()
Expand Down

0 comments on commit 416e8c1

Please sign in to comment.