Skip to content

Commit

Permalink
building rete net.. start, class and collector nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
johnsekar committed Sep 28, 2021
1 parent 1227911 commit ce8ea57
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 10 deletions.
45 changes: 44 additions & 1 deletion tests/test_action_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from wc_rules.schema.attributes import *
from wc_rules.schema.actions import *
from wc_rules.simulator.simulator import SimulationState
from wc_rules.matcher.core import ReteNet
import unittest


Expand Down Expand Up @@ -211,4 +212,46 @@ def test_expand_behavior(self):
self.assertEqual( RemoveEdgeAttr(card,'owner').expand(), [RemoveEdge('card','owner','john','card')])
self.assertEqual( RemoveEdgeAttr(card,'signatories').expand(), [RemoveEdge('card','signatories','jack','cards') ])
self.assertEqual( RemoveAllEdges(card).expand(), [RemoveEdgeAttr(card,'owner'), RemoveEdgeAttr(card,'signatories') ])
self.assertEqual( Remove(card).expand(), [RemoveAllEdges(card), RemoveNode(ReportCard,'card',attrs)])
self.assertEqual( Remove(card).expand(), [RemoveAllEdges(card), RemoveNode(ReportCard,'card',attrs)])


def test_rete_simple(self):
actions = [
AddNode(Animal,'doggy',{'sound':'ruff'}),
AddNode(Animal,'kitty',{'sound':'woof'}),
AddNode(Person,'john',{}),
SetAttr('kitty','sound','meow','woof'),
AddEdge('doggy','owner','john','pets'),
AddEdge('kitty','owner','john','pets'),
AddEdge('doggy','friends','kitty','friends'),
RemoveEdge('john','pets','doggy','owner'),
RemoveEdge('john','pets','kitty','owner'),
RemoveEdge('kitty','friends','doggy','friends'),
RemoveNode(Person,'john',{}),
RemoveNode(Animal,'doggy',{'sound':'ruff'}),
RemoveNode(Animal,'kitty',{'sound':'meow'}),
]

ss = ReteNet.default_initialization()
net = ReteNet.default_initialization()
net.initialize_class(Animal)
net.initialize_class(Person)
net.initialize_collector(Animal,'Animal')
net.initialize_collector(Person,'Person')

ss = SimulationState(matcher=net)
ss.push_to_stack(actions)
ss.simulate()

collector_Animal = ss.matcher.get_node(core='collector_Animal')
collector_Person = ss.matcher.get_node(core='collector_Person')

# 2 AddNodes, 2 RemoveNodes, 1 SetAttr
# 2x1 AddEdge + 2x1 RemoveEdge with John
# 1x2 AddEdge + 1x2 RemoveEdge with each other
self.assertEqual(len(collector_Animal.state.cache),13)

# 1 AddNode, 1 RemoveNode
# 1 AddEdge + 1 RemoveEdge with kitty
# 1 AddEdge + 1 RemoveEdge with doggy
self.assertEqual(len(collector_Person.state.cache),6)
15 changes: 15 additions & 0 deletions wc_rules/matcher/actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from frozendict import frozendict

def make_node_token(_class,idx,action):
return frozendict(dict(_class=_class,idx=idx,action=action))

def make_edge_token(_class1,idx1,attr1,idx2,attr2,action):
return frozendict(dict(_class=_class1,idx1=idx1,idx2=idx2,attr2=attr2,action=action))

def make_attr_token(_class,idx,attr,value,action):
return frozendict(dict(_class=_class,idx=idx,attr=attr,value=value,action=action))

def make_cache_token(variables,values,action):
return frozendict(variables=variables,values=values,action=action)


41 changes: 38 additions & 3 deletions wc_rules/matcher/core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@

from collections import deque
from .configuration import ReteNetConfiguration
from .dbase import initialize_database, Record
from .dbase import initialize_database, Record, SEP
from .actions import *
from attrdict import AttrDict
import logging
import os

log = logging.getLogger(__name__)
FORMAT = '%(message)s'
logging.basicConfig(level=os.environ.get("LOGLEVEL","NOTSET"), format=FORMAT)

log.propagate = False



class ReteNodeState:

Expand All @@ -11,6 +22,10 @@ def __init__(self):
self.incoming = deque()
self.outgoing = deque()

def pprint(self,nsep=2):
d = AttrDict({'cache':self.cache,'incoming':self.incoming,'outgoing':self.outgoing})
return Record.print(d,nsep=2)


class ReteNet:

Expand Down Expand Up @@ -45,27 +60,47 @@ def get_channel(self,**kwargs):
def get_outgoing_channels(self,source):
return [AttrDict(ch) for ch in Record.retrieve(self.channels,{'source':source})]

def pprint(self):
def pprint(self,state=False):
s = []

def printfn(x):
return Record.print(node,ignore_keys=['state'])

for node in self.nodes:
s += [f'Node\n{Record.print(node)}']
s += [f'Node\n{printfn(node)}']
if state:
s1 = node['state'].pprint()
if s1:
s[-1] += f'\n{SEP}state:\n{s1}'

for channel in self.channels:
s += [f'Channel\n{Record.print(channel)}']
return '\n'.join(s)

def sync(self,node):
log.debug(f'Syncing node {node.core}')
if node.state.outgoing:
log.debug(f'{SEP}Outgoing: {node.state.outgoing}')
elem = node.state.outgoing.popleft()
channels = self.get_outgoing_channels(node.core)
for channel in channels:
log.debug(f'{SEP}Channel: {channel}')
method = getattr(self,f'function_channel_{channel.type}')
method(channel,elem)
self.sync(self.get_node(core=channel.target))
self.sync(node)
if node.state.incoming:
log.debug(f'{SEP}Incoming: {node.state.incoming}')
elem = node.state.incoming.popleft()
method = getattr(self,f'function_node_{node.type}')
method(node,elem)
self.sync(node)
return self

def process(self,tokens):
start = self.get_node(type='start')
for token in tokens:
start.state.incoming.append(token)
self.sync(start)
return

10 changes: 6 additions & 4 deletions wc_rules/matcher/dbase.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pydblite import Base

SEP = ' '

def initialize_database(fields):
db = Base(':memory:')
Expand All @@ -9,8 +10,8 @@ def initialize_database(fields):
class Record:

@staticmethod
def itemize(r):
return ((k,v) for k,v in r.items() if not k.startswith('__') and v)
def itemize(r,ignore_keys=[]):
return ((k,v) for k,v in r.items() if k not in ignore_keys and not k.startswith('__') and v)

@staticmethod
def retrieve(dbase,kwargs):
Expand All @@ -24,8 +25,9 @@ def retrieve_exactly(dbase,kwargs):
return record

@staticmethod
def print(record,sep=' '):
return '\n'.join([f'{sep}{k}: {v}' for k,v in Record.itemize(record)])
def print(record,nsep=1,ignore_keys=[]):
sep = SEP*nsep
return '\n'.join([f'{sep}{k}: {v}' for k,v in Record.itemize(record,ignore_keys)])

@staticmethod
def insert(dbase,record):
Expand Down
5 changes: 5 additions & 0 deletions wc_rules/matcher/functionalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@ def function_node_start(net,node,elem):
return net

def function_node_class(net,node,elem):
if issubclass(elem['_class'],node.core):
node.state.outgoing.append(elem)
return net

def function_node_collector(net,node,elem):
node.state.cache.append(elem)
return net

def function_channel_pass(net,channel,elem):
target = net.get_node(core=channel.target)
target.state.incoming.append(elem)
return net

default_functionalization_methods = [method for name,method in globals().items() if name.startswith('function_')]
2 changes: 2 additions & 0 deletions wc_rules/matcher/initialize/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ...schema.base import BaseClass
from ...utils.random import generate_id
from collections import deque
# nodes must be a dict with keys
# 'type','core',
# 'state' gets automatically initialized
Expand All @@ -20,6 +21,7 @@ def initialize_class(net,_class):
def initialize_collector(net,source,label):
idx = f'collector_{label}'
net.add_node(type='collector',core=idx)
net.get_node(type='collector',core=idx).state.cache = deque()
net.add_channel(type='pass',source=source,target=idx)
return net

Expand Down
27 changes: 25 additions & 2 deletions wc_rules/simulator/simulator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from collections import deque
from ..matcher.core import ReteNet
from ..matcher.actions import make_node_token, make_edge_token, make_attr_token

class SimulationState:
def __init__(self,nodes=[]):
def __init__(self,nodes=[],**kwargs):
self.state = {x.id:x for x in nodes}
# for both stacks, use LIFO semantics using appendleft and popleft
self.action_stack = deque()
self.rollback_stack = deque()
self.matcher = kwargs.get('matcher',ReteNet.default_initialization())

def resolve(self,idx):
return self.state[idx]
Expand Down Expand Up @@ -45,8 +48,9 @@ def simulate(self):
if hasattr(action,'expand'):
self.push_to_stack(action.expand())
else:
self.rollback_stack.appendleft(action.rollback(self))
self.rollback_stack.appendleft(action)
action.execute(self)
self.matcher.process(self.compile_to_matcher_tokens(action))
return self

def rollback(self):
Expand All @@ -55,6 +59,25 @@ def rollback(self):
action.execute(self)
return self

def compile_to_matcher_tokens(self,action):
action_name = action.__class__.__name__
#d = {'AddNode':'add','RemoveNode':'remove','AddEdge':'add','RemoveEdge':'remove'}
if action_name in ['AddNode','RemoveNode']:
return [make_node_token(action._class, action.idx, action_name)]
if action_name in ['SetAttr']:
_class = self.resolve(action.idx).__class__
return [make_attr_token(_class, action.idx, action.attr, action.value, action_name)]
if action_name in ['AddEdge','RemoveEdge']:
i1,a1,i2,a2 = [getattr(action,x) for x in ['source_idx','source_attr','target_idx','target_attr']]
c1,c2 = [self.resolve(x).__class__ for x in [i1,i2]]
return [
make_edge_token(c1,i1,a1,i2,a2,action_name),
make_edge_token(c2,i2,a2,i1,a1,action_name)
]
return []






0 comments on commit ce8ea57

Please sign in to comment.