Skip to content

Commit

Permalink
building rete net.. single-edge canonical labels
Browse files Browse the repository at this point in the history
  • Loading branch information
johnsekar committed Sep 29, 2021
1 parent 3a005d7 commit 665082d
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 14 deletions.
48 changes: 45 additions & 3 deletions tests/test_rete_1.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from wc_rules.schema.entity import Entity
from wc_rules.schema.attributes import BooleanAttribute
from wc_rules.schema.actions import AddNode,RemoveNode
from wc_rules.schema.attributes import BooleanAttribute, ManyToOneAttribute
from wc_rules.schema.actions import AddNode,RemoveNode,AddEdge,RemoveEdge
from wc_rules.modeling.pattern import GraphContainer, Pattern
from wc_rules.simulator.simulator import SimulationState
from wc_rules.matcher.core import ReteNet
Expand All @@ -12,7 +12,7 @@ class X(Entity):
pass

class Y(X):
pass
z = ManyToOneAttribute('Z',related_name='y')

class Z(X):
pass
Expand Down Expand Up @@ -78,8 +78,50 @@ def test_single_node_branching_classes(self):
cache_sizes = [len(g.state.cache) for g in gnodes]
self.assertEqual(cache_sizes,[0,0,0])

def test_single_edge_canonical_label(self):
net = ReteNet.default_initialization()
ss = SimulationState(matcher=net)

g = GraphContainer(Y('y',z=Z('z')).get_connected())
m,L,G = canonical_label(g)
net.initialize_canonical_label(L,G)
net.initialize_collector(L,'XYEdge')
gnode, collector = net.get_node(core=L), net.get_node(core='collector_XYEdge')


# command to push nodes y1 z1
ss.push_to_stack([
AddNode.make(Y,'y1'),
AddNode.make(Z,'z1'),
]
)
ss.simulate()

# both the rete node for the edge as well
# as the downstream collector
# should have zero cache entries
self.assertEqual(len(gnode.state.cache),0)
self.assertEqual(len(collector.state.cache),0)

# add edge y1-z1
ss.push_to_stack(AddEdge('y1','z','z1','y'))
ss.simulate()

# rete node for edge and collector should have
# one entry each
self.assertEqual(len(gnode.state.cache),1)
self.assertEqual(len(collector.state.cache),1)

# remove edge y1-z1
ss.push_to_stack(RemoveEdge('y1','z','z1','y'))
ss.simulate()

# rete node for edge should have no entries
# but collector should have two entries
# (one for add, one for remove)
self.assertEqual(len(gnode.state.cache),0)
self.assertEqual(len(collector.state.cache),2)




10 changes: 5 additions & 5 deletions wc_rules/matcher/actions.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from frozendict import frozendict
#from frozendict import frozendict

def make_node_token(_class,idx,action):
return frozendict(dict(_class=_class,idx=idx,action=action))
return dict(_class=_class,idx=idx,action=action)

def make_edge_token(_class1,idx1,attr1,idx2,attr2,action):
return frozendict(dict(_class=_class1,idx1=idx1,idx2=idx2,attr2=attr2,action=action))
return dict(_class=_class1,idx1=idx1,idx2=idx2,attr2=attr2,action=action)

def make_attr_token(_class,idx,attr,value,action):
return frozendict(dict(_class=_class,idx=idx,attr=attr,value=value,action=action))
return dict(_class=_class,idx=idx,attr=attr,value=value,action=action)

def make_cache_token(variables,values,action):
return frozendict(variables=variables,values=values,action=action)
return dict(variables=variables,values=values,action=action)


3 changes: 2 additions & 1 deletion wc_rules/matcher/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def __init__(self):
self.outgoing = deque()

def pprint(self,nsep=2):
d = AttrDict({'cache':[x for x in self.cache],'incoming':self.incoming,'outgoing':self.outgoing})
d = dict(incoming=self.incoming,outgoing=self.outgoing)
d['cache'] = [x for x in self.cache] if self.cache is not None else None
return Record.print(d,nsep=2)


Expand Down
18 changes: 15 additions & 3 deletions wc_rules/matcher/functionalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def function_node_collector(net,node,elem):

def function_node_canonical_label(net,node,elem):
clabel = node.core
entry, action, channel = [elem[x] for x in ['entry','action','channel']]
if len(clabel.names)==1:
entry, action = [elem[x] for x in ['entry','action']]
if len(clabel.names)<=2:
# its a single node graph
if elem['action'] == 'AddEntry':
assert net.filter_cache(clabel,entry) == []
Expand All @@ -36,9 +36,21 @@ def function_channel_pass(net,channel,elem):
def function_channel_transform_node_token(net,channel,elem):
if elem['action'] in ['AddNode','RemoveNode']:
action = {'AddNode':'AddEntry','RemoveNode':'RemoveEntry'}[elem['action']]
d = {'entry':{'a':elem['idx']},'action':action,'channel':channel}
entry = channel.data.mapping.transform(elem)
d = {'entry':entry,'action':action}
node = net.get_node(core=channel.target)
node.state.incoming.append(d)
return net

def function_channel_transform_edge_token(net,channel,elem):
if elem['action'] in ['AddEdge','RemoveEdge']:
action = {'AddEdge':'AddEntry','RemoveEdge':'RemoveEntry'}[elem['action']]
entry = channel.data.mapping.transform(elem)
d = {'entry':entry,'action':action}
node = net.get_node(core=channel.target)
node.state.incoming.append(d)
return net



default_functionalization_methods = [method for name,method in globals().items() if name.startswith('function_')]
7 changes: 5 additions & 2 deletions wc_rules/matcher/initialize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ..schema.base import BaseClass
from ..utils.random import generate_id
from ..utils.collections import Mapping
from collections import deque
# nodes must be a dict with keys
# 'type','core',
Expand All @@ -26,12 +27,14 @@ def initialize_collector(net,source,label):
return net

def initialize_canonical_label(net,clabel,symmetry_group):
if len(clabel.names)==1 and net.get_node(core=clabel) is None:
if len(clabel.names)<=2 and net.get_node(core=clabel) is None:
# it is a singleton graph
net.initialize_class(clabel.classes[0])
net.add_node(type='canonical_label',core=clabel,symmetry_group=symmetry_group)
net.initialize_cache(clabel,clabel.names)
net.add_channel(type='transform_node_token',source=clabel.classes[0],target=clabel,mapping={'idx':'a'})
chtype = {1:'node',2:'edge'}[len(clabel.names)]
mapping = {1:Mapping(['idx'],['a']),2:Mapping(['idx1','idx2'],['a','b'])}[len(clabel.names)]
net.add_channel(type=f'transform_{chtype}_token',source=clabel.classes[0],target=clabel,mapping=mapping)
return net


Expand Down
3 changes: 3 additions & 0 deletions wc_rules/utils/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def __mul__(self,other):
return other.__class__([self.get(x) for x in other])
return self.get(other)

def transform(self,d):
return {self._dict[k]:d[k] for k in d if k in self.sources}

def sort(self,order=None):
if order is None:
order = sorted(self.sources)
Expand Down

0 comments on commit 665082d

Please sign in to comment.