From 5d0a635b82a042c4b5d68a85501760cf4ed41cd6 Mon Sep 17 00:00:00 2001 From: John Sekar Date: Mon, 9 May 2022 14:30:03 -0400 Subject: [PATCH] AddRemove model to test creating and deleting nodes --- examples/simple_addremove/model.py | 79 ++++++++++++++++++------------ tests/simulator/test_sim_arch.py | 29 +++++++++++ wc_rules/expressions/executable.py | 36 +++++++++++--- wc_rules/modeling/rule.py | 26 +--------- wc_rules/schema/actions.py | 14 +++++- 5 files changed, 122 insertions(+), 62 deletions(-) diff --git a/examples/simple_addremove/model.py b/examples/simple_addremove/model.py index edf913d..b589732 100644 --- a/examples/simple_addremove/model.py +++ b/examples/simple_addremove/model.py @@ -1,6 +1,6 @@ from wc_rules.schema.entity import Entity from wc_rules.graph.collections import GraphContainer, GraphFactory -from wc_rules.modeling.pattern import Pattern,Observable +from wc_rules.modeling.pattern import Pattern,SimpleObservable from wc_rules.modeling.rule import InstanceRateRule from wc_rules.modeling.model import RuleBasedModel @@ -13,34 +13,53 @@ class Y(Entity): class Z(Entity): pass -gx,gy,gz = [GraphFactory([C(c)]) for C,c in [(X,'x'),(Y,'y'),(Z,'z')]] -px,py,pz = [Pattern(x) for x in [gx,gy,gz]] - -xrule = InstanceRateRule( - name = 'x_addition_rule', - factories = {'px':gx}, - actions = ['px.build()'], - rate_prefix = 'kx', - parameters = ['kx'] - ) - -yrule = InstanceRateRule( - name = 'y_addition_rule', - factories = {'py':gy}, - actions = ['py.build()'], - rate_prefix = 'ky', - parameters = ['ky'] - ) - -xyzrule = InstanceRateRule( - name = 'xyz_rule', - reactants = {'px':px,'py':py}, - factories = {'pz':gz}, - actions = ['px.remove()','py.remove()','pz.build()'], - rate_prefix = 'ky', - parameters = ['ky'] - ) - -model = RuleBasedModel('addremove_model',rules=[xrule,yrule,xyzrule ]) +class SimpleAddRemoveModel(RuleBasedModel): + data = {'k1':1,'k2':1,'k3':1,'k4':1} + + def __init__(self,name): + gx,gy,gz = [GraphFactory([C(c)]) for C,c in [(X,'x'),(Y,'y'),(Z,'z')]] + px,py,pz = [Pattern(x) for x in [gx,gy,gz]] + + r1 = InstanceRateRule( + name = 'adding_x', + factories = {'px':gx}, + actions = ['add(px)'], + rate_prefix = 'k1', + parameters = ['k1'] + ) + + r2 = InstanceRateRule( + name = 'adding_y', + factories = {'py':gy}, + actions = ['add(py)'], + rate_prefix = 'k2', + parameters = ['k2'] + ) + + r3 = InstanceRateRule( + name = 'transforming_xy_to_z', + reactants = {'px':px,'py':py}, + factories = {'pz':gz}, + actions = ['remove(px)','remove(py)','add(pz)'], + rate_prefix = 'k3', + parameters = ['k3'] + ) + + r4 = InstanceRateRule( + name = 'removing_z', + reactants = {'pz':pz,}, + actions = ['remove(pz)'], + rate_prefix = 'k4', + parameters = ['k4'] + ) + + obs1 = SimpleObservable(name='x',target=pz) + obs2 = SimpleObservable(name='y',target=pz) + obs3 = SimpleObservable(name='z',target=pz) + + super().__init__(name,rules=[r1,r2,r3,r4],observables=[obs1,obs2,obs3]) + +# model = RuleBasedModel('addremove_model',rules=[r1,r2,r3,r4]) +model = SimpleAddRemoveModel('addremove_model') diff --git a/tests/simulator/test_sim_arch.py b/tests/simulator/test_sim_arch.py index 10c9682..32fc995 100644 --- a/tests/simulator/test_sim_arch.py +++ b/tests/simulator/test_sim_arch.py @@ -19,6 +19,8 @@ from simple_binding.model import X,Y from simple_flip.model import model as flip_model from simple_flip.model import Point +from simple_addremove.model import model as addremove_model + def get_lengths(elems): return list(map(len,elems)) @@ -251,6 +253,33 @@ def test_fire(self): self.assertEqual([sim.cache[x].v for x in ['x1','x2']],[False,True]) self.assertEqual(sim.get_updated_variables(),['flip_model.flipping_rule.propensity']) +class TestAddRemove(unittest.TestCase): + + def setUp(self): + self.model = addremove_model + self.data = {'addremove_model':{'k1':1,'k2':1,'k3':1,'k4':1}} + self.sim = SimulationEngine(model=self.model,parameters=self.data) + + def test_fire(self): + sim = self.sim + r1 = 'addremove_model.adding_x' + r2 = 'addremove_model.adding_y' + r3 = 'addremove_model.transforming_xy_to_z' + r4 = 'addremove_model.removing_z' + + self.assertEqual(len(sim.cache),0) + sim.fire(r1) + self.assertEqual(len(sim.cache),1) + sim.fire(r2) + self.assertEqual(len(sim.cache),2) + sim.fire(r3) + self.assertEqual(len(sim.cache),1) + sim.fire(r4) + self.assertEqual(len(sim.cache),0) + + + + class TestScheduler(unittest.TestCase): diff --git a/wc_rules/expressions/executable.py b/wc_rules/expressions/executable.py index 6247b4f..d7274f2 100644 --- a/wc_rules/expressions/executable.py +++ b/wc_rules/expressions/executable.py @@ -1,12 +1,13 @@ from .parse import process_expression_string, serialize from .dependency import DependencyCollector from ..utils.collections import subdict -from ..schema.actions import RollbackAction, TerminateAction, PrimaryAction, CompositeAction, SimulatorAction +from ..schema.actions import RollbackAction, TerminateAction, PrimaryAction, CompositeAction, SimulatorAction, action_builtins, CollectReferences from .builtins import ordered_builtins, global_builtins from .exprgraph import dfs_make from collections import ChainMap, defaultdict from sortedcontainers import SortedSet from frozendict import frozendict +from collections.abc import Sequence import inspect @@ -50,12 +51,11 @@ def initialize(cls,s): keywords = list(deps.variables) # this step figures what builtins to use, picks them from the global_builtins list - #builtins = subdict(cls.builtins, ['__builtins__'] + list(deps.builtins)) - #builtins = subdict(global_builtins,['__builtins__'] + list(deps.builtins)) + builtins = subdict(cls.builtins, ['__builtins__'] + list(deps.builtins)) code2 = 'lambda {vars}: {code}'.format(vars=','.join(keywords),code=code) try: - fn = eval(code2,global_builtins) - x = cls(keywords=keywords,builtins=global_builtins,fn=fn,code=code,deps=deps) + fn = eval(code2,builtins) + x = cls(keywords=keywords,builtins=builtins,fn=fn,code=code,deps=deps) except: x = None except: @@ -151,7 +151,7 @@ def terminate(expr): # is equivalent to an action method call class ActionCaller(ExecutableExpression): start = 'function_call' - builtins = ChainMap(global_builtins,dict(rollback=rollback,terminate=terminate)) + builtins = ChainMap(global_builtins,dict(rollback=rollback,terminate=terminate),action_builtins) allowed_forms = [' ( )', '.. ()', '. ()'] allowed_returns = None @@ -222,3 +222,27 @@ def exec(self,match,*dicts): elif not c.exec(match,*dicts): return None return match + + +class ActionManager: + def __init__(self,action_execs,factories): + self.execs = action_execs + + for e in self.execs: + for fnametuple in e.deps.function_calls: + if fnametuple == ('add',): + assert len(e.deps.variables)==1 + var = list(e.deps.variables)[0] + assert var in factories + setattr(e,'build_variable',var) + + def exec(self,match,*dicts): + for c in self.execs: + actions = c.exec(match,*dicts) + if hasattr(c,'build_variable'): + assert isinstance(actions[-1],CollectReferences) + actions[-1].variable = c.build_variable + if isinstance(actions,Sequence): + yield from actions + else: + yield actions diff --git a/wc_rules/modeling/rule.py b/wc_rules/modeling/rule.py index 2f13d04..97a4f2f 100644 --- a/wc_rules/modeling/rule.py +++ b/wc_rules/modeling/rule.py @@ -2,34 +2,10 @@ from ..utils.validate import * from ..graph.collections import GraphContainer, GraphFactory from .pattern import Pattern -from ..expressions.executable import ActionCaller,Constraint, Computation, RateLaw, initialize_from_string +from ..expressions.executable import ActionCaller,Constraint, Computation, RateLaw, initialize_from_string, ActionManager from collections import Counter,ChainMap -from collections.abc import Sequence from ..utils.collections import sort_by_value from ..schema.actions import CollectReferences - -class ActionManager: - def __init__(self,action_execs,factories): - self.execs = action_execs - - for e in self.execs: - for fnametuple in e.deps.function_calls: - if fnametuple[-1] == 'build': - assert len(fnametuple)==2 - assert fnametuple[0] in factories - setattr(e,'build_variable',fnametuple[0]) - - def exec(self,match,*dicts): - for c in self.execs: - if hasattr(c,'build_variable'): - actions, idmap = c.exec(match,*dicts) - actions.append(CollectReferences(variable=c.build_variable,data=idmap)) - else: - actions = c.exec(match,*dicts) - if isinstance(actions,Sequence): - yield from actions - else: - yield actions class Rule: diff --git a/wc_rules/schema/actions.py b/wc_rules/schema/actions.py index 403f99f..c627758 100644 --- a/wc_rules/schema/actions.py +++ b/wc_rules/schema/actions.py @@ -374,4 +374,16 @@ class CollectReferences: def execute(self,match,cache): match[self.variable] = {k:cache[v] for k,v in self.data.items()} - return self \ No newline at end of file + return self + + + +# Builtins for actions +def add(graphfactory): + idmap = graphfactory.build_random_idmap() + return list(graphfactory.generate_actions(idmap)) + [CollectReferences(data=idmap)] + +def remove(match): + return [Remove(source=elem) for elem in match.values()] + +action_builtins = {'add':add,'remove':remove}