diff --git a/wc_rules/matcher/core.py b/wc_rules/matcher/core.py index 0aaf634..29c9723 100644 --- a/wc_rules/matcher/core.py +++ b/wc_rules/matcher/core.py @@ -54,6 +54,11 @@ def sample_cache(self): return dict(Record.itemize(random.choice(self.filter()))) class ReteNodeStateWrapper: + # wrapper is useful to manage aliases + # typically a node's wrapper maps to the node's state with + # an identity mapping + # if an alias, then the wrapper points to the parent's state + # and the mapping is used to convert in both directions def __init__(self,target,mapping=None): self.target = target @@ -71,9 +76,12 @@ def set(self,target,mapping=None): class ReteNet: + # tracer is used for capturing output of specific nodes + def __init__(self): self.nodes = initialize_database(['type','core','data','state','wrapper','num']) self.channels = initialize_database(['type','source','target','data','num']) + self.tracer = False self.nodemax = 0 self.channelmax = 0 @@ -83,6 +91,10 @@ def configure(self,method,overwrite=False): setattr(self,method.__name__,m) return self + def configure_tracer(self,nodes=[],channels=[]): + self.tracer = AttrDict(nodes=nodes,channels=channels) + return self + def add_node(self,**kwargs): record = {k:kwargs.pop(k) for k in ['type','core']} record.update(dict(data=kwargs,state=ReteNodeState())) @@ -141,31 +153,56 @@ def get_channels(self,include_kwargs,exclude_kwargs=None): includes = Record.retrieve_minus(self.channels,include_kwargs,exclude_kwargs) if exclude_kwargs is not None else Record.retrieve(self.channels,include_kwargs) return [AttrDict(ch) for ch in includes] - def pprint(self,state=False): + def pprint_node(self,num,state=False): s = [] + node = self.get_node(num=num) + desc = Record.print(node,ignore_keys=['state','num']) + s += [f"Node {num}\n{desc}"] + if state: + s1 = node['state'].pprint() + s[-1] += f'\n{SEP}state:\n{s1}' + return '\n'.join(s) - def printfn(x): - return Record.print(node,ignore_keys=['state','num']) + def trace_node(self,num): + return self.tracer and num in self.tracer.nodes + + def trace_channel(self,num): + return self.tracer and num in self.tracer.channels + + def trace_elem(self,elem): + if self.tracer: + print - for node in self.nodes: - s += [f"Node {node.get('num')}\n{printfn(node)}"] - if state: - s1 = node['state'].pprint() - s[-1] += f'\n{SEP}state:\n{s1}' + def pprint_channel(self,num): + channel = self.get_channel(num=num) + desc = Record.print(channel,ignore_keys=['num']) + return f"Channel {num}\n{desc}" + def pprint(self,state=False): + s = [] + for node in self.nodes: + s.append(self.pprint_node(node.get('num'),state=state)) for channel in self.channels: - s += [f"Channel {channel.get('num')}\n{Record.print(channel,ignore_keys=['num'])}"] + s.append(self.pprint_channel(channel.get('num'))) return '\n'.join(s) def sync(self,node): - log.debug(f'Syncing node {node.core}') + num = node.get('num') + trace = self.trace_node(num) if node.state.outgoing: + if trace: + print(self.pprint_node(num,state=True)) log.debug(f'{SEP}Outgoing: {node.state.outgoing}') elem = node.state.outgoing.popleft() + if trace: + print(f'Popping elem {elem}') channels = self.get_outgoing_channels(node.core) for channel in channels: + chnum = channel.get('num') log.debug(f'{SEP}Channel: {channel}') method = getattr(self,f'function_channel_{channel.type}') + if self.trace_channel(chnum): + print(self.pprint_channel(chnum)) method(channel,elem) self.sync(self.get_node(core=channel.target)) self.sync(node) @@ -175,6 +212,7 @@ def sync(self,node): method = getattr(self,f'function_node_{node.type}') method(node,elem) self.sync(node) + return self def process(self,tokens): diff --git a/wc_rules/matcher/functionalize.py b/wc_rules/matcher/functionalize.py index ce6a873..0a64d1a 100644 --- a/wc_rules/matcher/functionalize.py +++ b/wc_rules/matcher/functionalize.py @@ -5,36 +5,92 @@ from copy import deepcopy from attrdict import AttrDict from ..expressions.builtins import global_builtins +from abc import ABC, abstractmethod -def function_node_start(net,node,elem): - node.state.outgoing.append(elem) - return net +class Behavior(ABC): + + # an object to indicate a callable behavior on the rete-net + callsign = '' + + @abstractmethod + def __call__(self,net,node_or_channel,elem): + return net + +class NodeBehavior(Behavior): + @abstractmethod + def __call__(self,net,node,elem): + return net -def function_node_end(net,node,elem): - node.state.cache.append(elem) - return +class StartNode(NodeBehavior): + callsign = 'start' -def function_node_class(net,node,elem): - if issubclass(elem['_class'],node.core): + def __call__(self,net,node,elem): node.state.outgoing.append(elem) - return net + return net -def function_node_collector(net,node,elem): - node.state.cache.append(elem) - return net +class EndNode(NodeBehavior): + callsign = 'end' + def __call__(self,net,node,elem): + node.state.cache.append(elem) + return -def function_node_canonical_label(net,node,elem): - clabel = node.core - entry, action = [elem[x] for x in ['entry','action']] - - if action == 'AddEntry': +class ClassNode(NodeBehavior): + callsign = 'class' + def __call__(self,net,node,elem): + if issubclass(elem['_class'],node.core): + node.state.outgoing.append(elem) + return net + +class CollectorNode(NodeBehavior): + callsign = 'collector' + def __call__(self,net,node,elem): + node.state.cache.append(elem) + return net + +class NodeBehaviorByAction(NodeBehavior): + callsign = None + actions = set() + # each behavior_{action} method must return + # either an empty list + # or a list of entries to be appended to node.state.outgoing + def __call__(self,net,node,elem): + action = elem.get('action',None) + if action in self.actions: + beh = getattr(self,f'behavior_{action}',None) + incoming,outgoing = beh(net,node,elem['entry']) + node.state.incoming.extend(incoming) + node.state.outgoing.extend(outgoing) + return net + +class CanonicalLabelNode(NodeBehaviorByAction): + callsign = 'canonical_label' + actions = set(['AddEntry','RemoveEntry']) + + def behavior_AddEntry(self,net,node,entry): + clabel = node.core assert net.filter_cache(clabel,entry) == [] net.insert_into_cache(clabel,entry) - if action == 'RemoveEntry': - assert net.filter_cache(clabel,entry) != [] - net.remove_from_cache(clabel,entry) - node.state.outgoing.append({'entry':entry,'action':action}) - return net + return [],[{'entry':entry,action:'AddEntry'}] + + def behavior_RemoveEntry(self,net,node,entry): + clabel = node.core + assert net.filter_cache(clabel,entry) == [] + net.insert_into_cache(clabel,entry) + return [], [{'entry':entry,action:'RemoveEntry'}] + + +# def function_node_canonical_label(net,node,elem): +# clabel = node.core +# entry, action = [elem[x] for x in ['entry','action']] + +# if action == 'AddEntry': +# assert net.filter_cache(clabel,entry) == [] +# net.insert_into_cache(clabel,entry) +# if action == 'RemoveEntry': +# assert net.filter_cache(clabel,entry) != [] +# net.remove_from_cache(clabel,entry) +# node.state.outgoing.append({'entry':entry,'action':action}) +# return net def run_match_through_constraints(match,helpers,parameters,constraints): extended_match = ChainMap(match,helpers,parameters) @@ -46,41 +102,82 @@ def run_match_through_constraints(match,helpers,parameters,constraints): elif not value: return None return match - -def function_node_pattern(net,node,elem): - if node.data.get('alias',False): - node.state.outgoing.append(elem) + +class PatternNode(NodeBehavior): + callsign = 'pattern' + actions = set(['AddEntry','RemoveEntry','UpdateEntry']) + + def __call__(self,net,node,elem): + if node.data.get('alias',False): + node.state.outgoing.append(elem) + else: + super().__call__(net,node,elem) return net - entry, action= [elem[x] for x in ['entry','action']] - outgoing_entries = [] - if action == 'AddEntry': + + def behavior_AddEntry(self,net,node,entry): match = {k:v for k,v in entry.items()} match = run_match_through_constraints(match, node.data.helpers, node.data.parameters, node.data.constraints) if match is not None: net.insert_into_cache(node.core,match) - node.state.outgoing.append({'entry':match,'action':action}) - if action == 'RemoveEntry': + return [],[{'entry':match,'action':action}] + return [],[] + + def behavior_RemoveEntry(self,net,node,entry): elems = net.filter_cache(node.core,entry) net.remove_from_cache(node.core,entry) - for e in elems: - node.state.outgoing.append({'entry':e,'action':action}) - - if action == 'UpdateEntry': + return [],[{'entry':e,'action':action} for e in elems] + + def behavior_UpdateEntry(self,net,node,entry): # update entry is resolved to AddEntry or RemoveEntry and inserted back into incoming queue # ReteNet.sync(node) runs until incoming is empty - match = {k:v for k,v in entry.items()} existing_elems = net.filter_cache(node.core,entry) match = run_match_through_constraints(match,node.data.helpers,node.data.parameters,node.data.constraints) + # do this!!!!! + # update_entry shd typically be from attr nodes + # as well as edge nodes + incoming = [] if match is None: - if len(existing_elems) > 0: - for e in existing_elems: - node.state.incoming.append({'entry':e,'action':'RemoveEntry'}) - elif match is not None: + incoming = [{'entry':e,'action':'RemoveEntry'} for e in elems] + else: if len(existing_elems) == 0: - node.state.incoming.append({'entry':match,'action':'AddEntry'}) + incoming = [{'entry':match,'action':'AddEntry'}] + return incoming,[] - return net + +# def function_node_pattern(net,node,elem): +# if node.data.get('alias',False): +# node.state.outgoing.append(elem) +# return net +# entry, action= [elem[x] for x in ['entry','action']] +# outgoing_entries = [] +# if action == 'AddEntry': +# match = {k:v for k,v in entry.items()} +# match = run_match_through_constraints(match, node.data.helpers, node.data.parameters, node.data.constraints) +# if match is not None: +# net.insert_into_cache(node.core,match) +# node.state.outgoing.append({'entry':match,'action':action}) +# if action == 'RemoveEntry': +# elems = net.filter_cache(node.core,entry) +# net.remove_from_cache(node.core,entry) +# for e in elems: +# node.state.outgoing.append({'entry':e,'action':action}) + +# if action == 'UpdateEntry': +# # update entry is resolved to AddEntry or RemoveEntry and inserted back into incoming queue +# # ReteNet.sync(node) runs until incoming is empty + +# match = {k:v for k,v in entry.items()} +# existing_elems = net.filter_cache(node.core,entry) +# match = run_match_through_constraints(match,node.data.helpers,node.data.parameters,node.data.constraints) +# if match is None: +# if len(existing_elems) > 0: +# for e in existing_elems: +# node.state.incoming.append({'entry':e,'action':'RemoveEntry'}) +# elif match is not None: +# if len(existing_elems) == 0: +# node.state.incoming.append({'entry':match,'action':'AddEntry'}) +# return net def function_node_rule(net,node,elem): if elem['action']=='UpdateRule': @@ -153,7 +250,6 @@ def function_channel_parent(net,channel,elem): def function_channel_update_pattern(net,channel,elem): action = elem['action'] - if action in ['AddEdge','RemoveEdge','SetAttr','AddEntry','RemoveEntry']: # it asks for parent of the target # then filters the parent cache to get candidate entries for target @@ -201,4 +297,17 @@ def function_channel_update_variable(net,channel,elem): -default_functionalization_methods = [method for name,method in globals().items() if name.startswith('function_')] \ No newline at end of file +#default_functionalization_methods = [method for name,method in globals().items() if name.startswith('function_')] +def subclass_iter(_class): + subs = _class.__subclasses__() + for sub in subs: + yield sub + for x in subclass_iter(sub): + yield x + +default_functionalization_methods = [] +for sub in subclass_iter(NodeBehavior): + if sub.callsign is not None: + fnobj = sub() + fnobj.__name__ = f'function_node_{sub.callsign}' + default_functionalization_methods.append(fnobj) diff --git a/wc_rules/matcher/initialize.py b/wc_rules/matcher/initialize.py index 7c830a7..f6a89a8 100644 --- a/wc_rules/matcher/initialize.py +++ b/wc_rules/matcher/initialize.py @@ -151,8 +151,8 @@ def initialize_pattern(net,pattern,parameters = dict()): helper_channels = set([(pname,mapping) for _,pname,mapping in constraint_pattern_relationships]) attr_channels = set([(var,attr) for _,var,attr in constraint_attr_relationships]) - # print(constraint_pattern_relationships) - # print(constraint_attr_relationships) + #print(constraint_pattern_relationships) + #print(constraint_attr_relationships) resolved_helpers = {h:net.get_node(core=helpers[h]).wrapper for h in helpers} net.add_node(type='pattern',core=pattern,symmetry_group=symmetry_group,exprgraph=graph,helpers=resolved_helpers,constraints=constraint_objects,parameters=parameters) names = [x for x in pattern.namespace if isinstance(pattern.namespace[x],type) and issubclass(pattern.namespace[x],BaseClass)] diff --git a/wc_rules/simulator/simulator2.py b/wc_rules/simulator/simulator2.py index d537743..803b3a5 100644 --- a/wc_rules/simulator/simulator2.py +++ b/wc_rules/simulator/simulator2.py @@ -38,6 +38,7 @@ def default_matcher(self,methods=None): net.configure(method) return net + def load_model(self,model,data): self.matcher.initialize_start() self.matcher.initialize_end()