Skip to content

Commit

Permalink
Observables initialization on rete net and behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
johnsekar committed Apr 24, 2022
1 parent cf58c83 commit 01faa31
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 29 deletions.
75 changes: 72 additions & 3 deletions tests/tests_matcher/test_pattern.py
@@ -1,6 +1,6 @@
from wc_rules.schema.entity import Entity
from wc_rules.schema.attributes import BooleanAttribute, IntegerAttribute, StringAttribute, computation
from wc_rules.modeling.pattern import Pattern, GraphContainer
from wc_rules.modeling.pattern import Pattern, GraphContainer, Observable
from wc_rules.matcher.core import build_rete_net_class
from wc_rules.matcher.token import *
from wc_rules.graph.examples import X,Y
Expand Down Expand Up @@ -133,8 +133,6 @@ def test_pattern_assigned_variables(self):
def test_pattern_helpers(self):
ReteNet = build_rete_net_class()
net = ReteNet().initialize_start()


p = Pattern(GraphContainer([BigAttributeClass('elem',x=True,y=True)]))
q = Pattern(GraphContainer([BigAttributeClass('elem',n1=0,n2=0)]))

Expand All @@ -153,6 +151,27 @@ def test_pattern_helpers(self):
self.assertEqual(pr.data.transformer.datamap,{'elem':'elem'})
self.assertEqual(qr.data.transformer.datamap,{'elem':'elem'})

def test_observable(self):
ReteNet = build_rete_net_class()
net = ReteNet().initialize_start().initialize_end()
p = Pattern(GraphContainer([BigAttributeClass('elem',x=True,y=True)]))
q = Observable(name='obsTrue',helpers={'p':p},expression='p.count()')

net.initialize_observable('obsTrue',q)
q_rn = net.get_node(core='obsTrue')
self.assertTrue(q_rn is not None)
self.assertTrue(net.get_node(core=p) is not None)
ch = net.get_channel(source=p,target='obsTrue')
self.assertTrue(ch is not None)
self.assertEqual(ch.type,'variable_update')
self.assertEqual(q_rn.state.cache.value,0)
ch = net.get_channel(source='obsTrue',target='end')
self.assertTrue(ch is not None)





class TestRetePatternBehavior(unittest.TestCase):
# test insert/remove/filter

Expand Down Expand Up @@ -412,3 +431,53 @@ def test_pattern_helpers(self):
# we went from p:[elem2], q:[elem1], r:[]
# to p:[elem1,elem2], q:[elem1,elem2], r:[elem1,elem2]
# to p:[elem1], q:[elem2], r:[]

def test_observable(self):
ReteNet = build_rete_net_class()
net = ReteNet().initialize_start().initialize_end()
start = net.get_node(type='start')
p = Pattern(GraphContainer([BigAttributeClass('elem',x=True,y=True)]))
q = Observable(name='obsTrue',helpers={'p':p},expression='p.count()')

net.initialize_observable('obsTrue',q)
q_rn = net.get_node(core='obsTrue')
self.assertEqual(q_rn.state.cache.value,0)

elem1 = BigAttributeClass('elem1',x=True,y=True,n1=0,n2=0,s1='',s2='')
token = make_node_token(BigAttributeClass,elem1,'AddNode')
start.state.incoming.append(token)
net.sync(start)
self.assertEqual(q_rn.state.cache.value,1)

elem2 = BigAttributeClass('elem2',x=True,y=False,n1=1,n2=0,s1='',s2='')
token = make_node_token(BigAttributeClass,elem2,'AddNode')
start.state.incoming.append(token)
net.sync(start)
self.assertEqual(q_rn.state.cache.value,1)

elem2.y = True
token = make_attr_token(BigAttributeClass,elem2,'y','SetAttr')
start.state.incoming.append(token)
net.sync(start)
self.assertEqual(q_rn.state.cache.value,2)

elem1.x = False
token = make_attr_token(BigAttributeClass,elem1,'x','SetAttr')
start.state.incoming.append(token)
net.sync(start)
self.assertEqual(q_rn.state.cache.value,1)

elem2.x = False
token = make_attr_token(BigAttributeClass,elem2,'x','SetAttr')
start.state.incoming.append(token)
net.sync(start)
self.assertEqual(q_rn.state.cache.value,0)

end = net.get_node(type='end')
self.assertEqual(list(end.state.cache),['obsTrue'])

# first add elem1 which matches q
# then elem2 which does not match q
# then modify elem2 so it does match q
# then modify elem1 so it does not match q
# then modify elem2 so it does not match q
7 changes: 7 additions & 0 deletions wc_rules/expressions/executable.py
Expand Up @@ -128,6 +128,13 @@ class RateLaw(ExecutableExpression):
allowed_forms = ['<expr>']
allowed_returns = (int,float,)


class ObservableExpression(ExecutableExpression):
start = 'expression'
builtins = global_builtins
allowed_forms = ['<expr>']
allowed_returns = None

######## Simulator methods #########
def rollback(expr):
assert isinstance(expr,bool), "Rollback condition must evaluate to a boolean."
Expand Down
21 changes: 20 additions & 1 deletion wc_rules/matcher/initialize_methods.py
Expand Up @@ -145,7 +145,9 @@ def initialize_pattern(self,pattern):
caches[var] = self.get_node(core=pat).state.cache

manager = pattern.make_executable_expression_manager()
self.add_node_pattern(pattern=pattern,subtype='default',executables=manager,caches=caches)

if requires_parent:
self.add_node_pattern(pattern=pattern,subtype='default',executables=manager,caches=caches)

for variable, attr in manager.get_attribute_calls():
assert issubclass(pattern.namespace[variable],BaseClass)
Expand Down Expand Up @@ -200,4 +202,21 @@ def initialize_rules(self,rules,parameters):
self.add_node_variable(name,value,subtype='fixed')
for name,rule in rules.items():
self.initialize_rule(name,rule)
return self

def initialize_observable(self,name,observable):
caches = dict()
for pname,pattern in observable.helpers.items():
self.initialize_pattern(pattern)
caches[pname]= self.get_node(core=pattern).state.cache
self.add_node_variable(
name = name,
default_value = observable.default,
executable = observable.make_executable(),
caches = caches,
subtype = 'recompute'
)
for pname,pattern in observable.helpers.items():
self.add_channel_variable_update(source=pattern,target=name,variable=pname)
self.add_channel_variable_update(source=name,target='end',variable=name)
return self
4 changes: 3 additions & 1 deletion wc_rules/modeling/model.py
@@ -1,5 +1,5 @@
from .rule import Rule
from .observable import Observable
from .pattern import Observable
from ..utils.validate import *
from ..utils.collections import DictLike,merge_lists
from collections.abc import Sequence
Expand All @@ -22,6 +22,8 @@ def __init__(self,name,rules,observables=[]):
validate_set(rule_names,'Rule names in a model')
self.rules = rules
validate_list(observables,Observable,'Observable')
obs_names = [x.name for x in observables]
validate_set(obs_names,'Rule names in a model')
self.observables = observables
self._dict = {x.name:x for x in self.rules}

Expand Down
6 changes: 0 additions & 6 deletions wc_rules/modeling/observable.py

This file was deleted.

53 changes: 35 additions & 18 deletions wc_rules/modeling/pattern.py
Expand Up @@ -2,8 +2,8 @@
from ..utils.validate import *
from ..graph.collections import GraphContainer
from ..utils.collections import split_string
from ..expressions.executable import Constraint, Computation, initialize_from_string, ExecutableExpressionManager
from attrdict import AttrDict
from ..expressions.executable import Constraint, Computation, initialize_from_string, ExecutableExpressionManager, ObservableExpression



def make_attr_constraints(attrs):
Expand Down Expand Up @@ -105,19 +105,36 @@ def make_executable_expression_manager(self):
manager = ExecutableExpressionManager(execs,self.namespace)
return manager

def get_initialization_code(self):
code = AttrDict({
'parent': None,
'constraints': bool(self.constraints),
'helpers': bool(self.helpers)
})
if isinstance(self.parent,GraphContainer) and len(self.parent)>0:
code['parent'] = 'graph'
elif isinstance(self.parent,Pattern):
code['parent'] = 'pattern'

if not any([code.helpers,code.constraints]):
code['subtype'] = 'alias'
else:
code['subtype'] = 'default'
return code

class Observable:

def __init__(self,name,helpers,expression,default=0):
self.name = name
self.default = default

self.validate_helpers(helpers)
self.helpers = helpers

self.validate_expression(expression)
self.expression = expression

def validate_helpers(self,helpers):
validate_class(helpers,dict,'Helpers')
validate_keywords(helpers.keys(),'Helper')
validate_dict(helpers,Pattern,'Helper')

def validate_expression(self,expression):
validate_class(expression,str,'Expression')
namespace = list(self.helpers.keys())
x = initialize_from_string(expression,(ObservableExpression,))
validate_contains(namespace,x.keywords,'Variable')

def make_executable(self):
return initialize_from_string(self.expression,(ObservableExpression,))

class SimpleObservable:

def __init__(self,name,target,default=0):
helpers = {'target':target}
expression = 'target.count()'
super().__init__(name,helpers,expression,default)

0 comments on commit 01faa31

Please sign in to comment.