Skip to content

Commit

Permalink
Matcher tests: Model with simple binding rule
Browse files Browse the repository at this point in the history
  • Loading branch information
johnsekar committed Apr 12, 2022
1 parent 48944cd commit 91335f5
Show file tree
Hide file tree
Showing 10 changed files with 223 additions and 17 deletions.
94 changes: 94 additions & 0 deletions tests/tests_matcher/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from wc_rules.modeling.pattern import GraphContainer, Pattern
from wc_rules.modeling.rule import InstanceRateRule
from wc_rules.modeling.model import RuleBasedModel,AggregateModel
from wc_rules.matcher.core import build_rete_net_class
from wc_rules.graph.examples import X,Y
from wc_rules.matcher.token import make_node_token, make_attr_token

import unittest

ReteNet = build_rete_net_class()

class BindingRuleModel(RuleBasedModel):
defaults = {'k':1}

def __init__(self,name):
px = Pattern(GraphContainer([X('x')]))
py = Pattern(GraphContainer([Y('y')]),constraints = ['len(y.x)==0'])
rule = InstanceRateRule('binding_rule',
reactants = {'rX':px, 'rY':py},
actions = ['rX.x.add_y(rY.y)'],
rate_prefix = 'k',
parameters = ['k']
)
super().__init__(name=name,rules=[rule])

class ModelInitialization(unittest.TestCase):

def test_binding_rule(self):

model = AggregateModel('model',models=[BindingRuleModel('binding_model')])
model_rule_names = [n for n,r in model.iter_rules()]
self.assertEqual(model_rule_names,['binding_model.binding_rule'])
model_param_names = [p for p,v in model.iter_parameters()]
self.assertEqual(model_param_names,['binding_model.k'])
model.verify(data={'binding_model':{'k':1}})

rn = ReteNet().initialize_start().initialize_end()
rn.initialize_model(model)
self.assertTrue(rn.get_node(core='binding_model.binding_rule.propensity') is not None)
self.assertTrue(rn.get_node(core='binding_model.k') is not None)
self.assertTrue(rn.get_channel(
source = 'binding_model.k',
target = 'binding_model.binding_rule.propensity',
type = 'variable_update'
) is not None)
reactants = model.models[0].rules[0].reactants.values()
for pattern in reactants:
self.assertTrue(rn.get_channel(
source = pattern,
target = 'binding_model.binding_rule.propensity',
type = 'variable_update'
) is not None)
self.assertTrue(rn.get_channel(
source = 'binding_model.binding_rule.propensity',
target = 'end',
type = 'variable_update'
) is not None)

class ModelBehavior(unittest.TestCase):

def test_binding_rule(self):

model = AggregateModel('model',models=[BindingRuleModel('binding_model')])
rn = ReteNet().initialize_start().initialize_end()
rn.initialize_model(model)
start,end = [rn.get_node(type=x) for x in ['start','end',]]
prop = rn.get_node(core='binding_model.binding_rule.propensity')

x1,y1 = X('x1'),Y('y1')
self.assertEqual(prop.state.cache.value,0)
self.assertEqual(len(end.state.cache),0)

tokens = [
make_node_token(X,x1,'AddNode'),
make_node_token(Y,y1,'AddNode'),
]
start.state.incoming.extend(tokens)
rn.sync(start)
self.assertEqual(prop.state.cache.value,1)
self.assertEqual(len(end.state.cache),1)

x1.y.add(y1)
tokens = [
make_attr_token(X,x1,'y','SetAttr'),
make_attr_token(Y,y1,'x','SetAttr')
]

start.state.incoming.extend(tokens)
rn.sync(start)
self.assertEqual(prop.state.cache.value,0)
self.assertEqual(len(end.state.cache),2)



1 change: 0 additions & 1 deletion wc_rules/expressions/executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,4 +202,3 @@ def exec(self,match,*dicts):
return None
return match


30 changes: 27 additions & 3 deletions wc_rules/matcher/add_methods.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
from ..utils.collections import UniversalSet, SimpleMapping
from ..schema.base import BaseClass
from .dbase import Database, DatabaseAlias, DatabaseSymmetric, DatabaseAliasSymmetric
from .dbase import Database, DatabaseAlias, DatabaseSingleValue, DatabaseSymmetric, DatabaseAliasSymmetric
from .token import TokenTransformer

from collections import deque

class AddMethods:

DATABASE_CLASS = Database
DATABASE_SINGLE_VALUE = DatabaseSingleValue
DATABASE_ALIAS_CLASS = DatabaseAlias

def add_node_start(self):
self.add_node(type='start',core=BaseClass)
return self

def add_node_end(self):
self.add_node(type='end',core='end',cache=deque())
return self

def add_node_class(self,_class):
self.add_node(type='class',core=_class)
return self
Expand Down Expand Up @@ -48,6 +53,17 @@ def add_node_pattern(self,pattern,cache,subtype='default',executables=[],caches=
caches = caches
)

def add_node_variable(self,name,default_value=None,executable=None,parameters={},caches={},subtype=None):
self.add_node(
type='variable',
core=name,
cache=self.DATABASE_SINGLE_VALUE(value=default_value),
executable = executable,
parameters = parameters,
caches=caches,
subtype=subtype
)

def add_data_property(self,core,variable,value):
node = self.get_node(core=core)
node.data[variable]=value
Expand All @@ -62,9 +78,8 @@ def generate_cache_reference(self,target,mapping,symmetry_group=None):
)
return cache_ref

def generate_cache(self,fields,**kwargs):
def generate_cache(self,fields,**kwargs):
return self.DATABASE_CLASS(fields=fields,**kwargs)


def update_node_data(self,core,update_dict):
data = self.get_node(core=core).data
Expand Down Expand Up @@ -92,6 +107,15 @@ def add_channel_transform(self,source,target,datamap,actionmap,filter_data=lambd
)
return self

def add_channel_variable_update(self,source,target,variable):
self.add_channel(
type = 'variable_update',
source = source,
target = target,
variable = variable
)
return self


class AddMethodsSymmetric(AddMethods):

Expand Down
14 changes: 13 additions & 1 deletion wc_rules/matcher/channel_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# see tokens.py for the different types of tokens
from .token import VarToken

class ChannelFunctions:

def function_channel_pass(self,channel,token):
Expand All @@ -7,6 +9,16 @@ def function_channel_pass(self,channel,token):
return self

def function_channel_transform(self,channel,token):
if token.action not in channel.data.allowed_token_actions:
return self
if not channel.data.filter_data(token.data):
return self
newtoken = channel.data.transformer.transform(token,channel=channel.num)
self.function_channel_pass(channel,newtoken)
return self
return self

def function_channel_variable_update(self,channel,token):
token = VarToken(variable=channel.data.variable)
self.function_channel_pass(channel,token)
return

2 changes: 1 addition & 1 deletion wc_rules/matcher/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def sync_outgoing(self,node):
token = node.state.outgoing.popleft()
channels = self.get_channels(source=node.core)
for channel in channels:
if token.action in channel.data.allowed_token_actions and channel.data.filter_data(token.data):
#if token.action in channel.data.allowed_token_actions and channel.data.filter_data(token.data):
method = getattr(self,f'function_channel_{channel.type}')
method(channel,token)
self.sync(self.get_node(core=channel.target))
Expand Down
23 changes: 15 additions & 8 deletions wc_rules/matcher/dbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,6 @@ def dict_overlap(d1,d2):
def clean_record(r):
return {k:v for k,v in r.items() if k not in ['__id__','__version__']}

class SingleValueDatabase:
def __init__(self,field,**kwargs):
self.field = field
self.value = None



class Database:

def __init__(self,fields,**kwargs):
Expand Down Expand Up @@ -43,7 +36,7 @@ def update(self,kwargs={},update_kwargs={}):
self._db.update(record,**update_kwargs)
return self

def filter_one(self,kwargs):
def filter_one(self,kwargs={}):
records = self.filter(kwargs)
if len(records)==1:
return records[0]
Expand All @@ -52,6 +45,17 @@ def filter_one(self,kwargs):
def __len__(self):
return len(self._db)

def count(self):
return len(self)

class DatabaseSingleValue:

def __init__(self,value=None):
self.value = value

def update(self,value):
self.value = value

class DatabaseAlias:

def __init__(self,target,mapping,**kwargs):
Expand Down Expand Up @@ -90,6 +94,9 @@ def filter(self,kwargs={}):
rotated = [self.forward_transform(x) for x in records]
return rotated

def count(self):
return len(self.target)


class DatabaseSymmetric(Database):

Expand Down
41 changes: 40 additions & 1 deletion wc_rules/matcher/initialize_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def initialize_start(self):
self.add_node_start()
return self

def initialize_end(self):
self.add_node_end()
return self

def initialize_class(self,_class):
if self.node_exists(core=_class):
return self
Expand Down Expand Up @@ -155,6 +159,41 @@ def initialize_pattern_constraints(self,pattern,parent,mapping):
self.add_channel_transform(source=_class,target=pattern,datamap=datamap,actionmap=actionmap,filter_data=filter_data)

# to do: HELPERS
return self
return self

def initialize_rule(self,name,rule):
model_name = '.'.join(name.split('.')[:-1])
parameters, caches = dict(),dict()
for param in rule.parameters:
parameters[param] = f'{model_name}.{param}'
assert self.get_node(core=f'{model_name}.{param}') is not None
for pname,pattern in rule.reactants.items():
self.initialize_pattern(pattern)
cache_ref = self.get_node(core=pattern).state.cache
caches[pname] = cache_ref
rate_law_executable = rule.get_rate_law_executable()
node_name = f'{name}.propensity'
self.add_node_variable(
name = node_name,
default_value = 0,
executable = rate_law_executable,
parameters = parameters,
caches = caches,
subtype = 'recompute'
)
for k,v in parameters.items():
self.add_channel_variable_update(source=v,target=node_name,variable=k)
for pname, pattern in rule.reactants.items():
self.add_channel_variable_update(source=pattern,target=node_name,variable=pname)

self.add_channel_variable_update(source=node_name,target='end',variable=node_name)


def initialize_model(self,model):
for name,value in model.iter_parameters():
self.add_node_variable(name,value,subtype='fixed')
for name,rule in model.iter_rules():
self.initialize_rule(name,rule)
return self


18 changes: 17 additions & 1 deletion wc_rules/matcher/node_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# see tokens.py for the different types of tokens
from ..utils.collections import subdict
from .token import CacheToken
from .token import CacheToken, VarToken

def no_common_values(d1,d2):
return len(set(d1.values()) & set(d2.values())) == 0
Expand All @@ -20,6 +20,10 @@ def function_node_receiver(self,node,token):
node.state.cache.append(token)
return self

def function_node_end(self,node,token):
self.function_node_receiver(node,token)
return self

def function_node_canonical_label(self,node,token):
if len(node.core.names)==1:
self.function_node_canonical_label_single_node(node,token)
Expand Down Expand Up @@ -102,4 +106,16 @@ def function_node_pattern(self,node,token):
self.function_node_alias(node,token)
else:
self.function_node_constraints(node,token)
return self

def function_node_variable(self,node,token):
if node.data.subtype == 'recompute':
params = {k:self.get_node(core=c).state.cache.value for k,c in node.data.parameters.items()}
value = node.data.executable.exec(params,node.data.caches)
if value != node.state.cache.value:
node.state.cache.update(value)
token = VarToken(variable=node.core)
node.state.outgoing.append(token)
else:
assert False, "Not supported yet!"
return self
9 changes: 8 additions & 1 deletion wc_rules/matcher/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,20 @@ class CacheToken:
data: Dict
channel: int=-1

@dataclass(eq=True)
class VarToken:
variable: str

class TokenTransformer:
def __init__(self,datamap,actionmap):
self.datamap = datamap
self.actionmap = actionmap

def transform(self,token,channel=-1):
data = {self.datamap[k]:v for k,v in token.data.items()}
if self.datamap is not None:
data = {self.datamap[k]:v for k,v in token.data.items()}
else:
data = token.data
action = self.actionmap[token.action]
return CacheToken(data=data,action=action,channel=channel)

Expand Down
8 changes: 8 additions & 0 deletions wc_rules/modeling/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def iter_observables(self,prefix=''):
for observable in self.observables:
yield add_prefix(prefix,observable.name), observable

def iter_parameters(self,prefix=''):
for param,value in self.collect_parameters().items():
yield add_prefix(prefix,param), value




Expand Down Expand Up @@ -147,4 +151,8 @@ def iter_observables(self,prefix=''):
for n,m in self.iter_models(prefix):
yield from m.iter_observables()

def iter_parameters(self,prefix=''):
for n,m in self.iter_models(prefix):
yield from m.iter_parameters(prefix=n)


0 comments on commit 91335f5

Please sign in to comment.