Skip to content

Commit

Permalink
AddRemove model to test creating and deleting nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
johnsekar committed May 9, 2022
1 parent 0a8359d commit 5d0a635
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 62 deletions.
79 changes: 49 additions & 30 deletions examples/simple_addremove/model.py
@@ -1,6 +1,6 @@
from wc_rules.schema.entity import Entity
from wc_rules.graph.collections import GraphContainer, GraphFactory
from wc_rules.modeling.pattern import Pattern,Observable
from wc_rules.modeling.pattern import Pattern,SimpleObservable
from wc_rules.modeling.rule import InstanceRateRule
from wc_rules.modeling.model import RuleBasedModel

Expand All @@ -13,34 +13,53 @@ class Y(Entity):
class Z(Entity):
pass

gx,gy,gz = [GraphFactory([C(c)]) for C,c in [(X,'x'),(Y,'y'),(Z,'z')]]
px,py,pz = [Pattern(x) for x in [gx,gy,gz]]

xrule = InstanceRateRule(
name = 'x_addition_rule',
factories = {'px':gx},
actions = ['px.build()'],
rate_prefix = 'kx',
parameters = ['kx']
)

yrule = InstanceRateRule(
name = 'y_addition_rule',
factories = {'py':gy},
actions = ['py.build()'],
rate_prefix = 'ky',
parameters = ['ky']
)

xyzrule = InstanceRateRule(
name = 'xyz_rule',
reactants = {'px':px,'py':py},
factories = {'pz':gz},
actions = ['px.remove()','py.remove()','pz.build()'],
rate_prefix = 'ky',
parameters = ['ky']
)

model = RuleBasedModel('addremove_model',rules=[xrule,yrule,xyzrule ])
class SimpleAddRemoveModel(RuleBasedModel):

data = {'k1':1,'k2':1,'k3':1,'k4':1}

def __init__(self,name):
gx,gy,gz = [GraphFactory([C(c)]) for C,c in [(X,'x'),(Y,'y'),(Z,'z')]]
px,py,pz = [Pattern(x) for x in [gx,gy,gz]]

r1 = InstanceRateRule(
name = 'adding_x',
factories = {'px':gx},
actions = ['add(px)'],
rate_prefix = 'k1',
parameters = ['k1']
)

r2 = InstanceRateRule(
name = 'adding_y',
factories = {'py':gy},
actions = ['add(py)'],
rate_prefix = 'k2',
parameters = ['k2']
)

r3 = InstanceRateRule(
name = 'transforming_xy_to_z',
reactants = {'px':px,'py':py},
factories = {'pz':gz},
actions = ['remove(px)','remove(py)','add(pz)'],
rate_prefix = 'k3',
parameters = ['k3']
)

r4 = InstanceRateRule(
name = 'removing_z',
reactants = {'pz':pz,},
actions = ['remove(pz)'],
rate_prefix = 'k4',
parameters = ['k4']
)

obs1 = SimpleObservable(name='x',target=pz)
obs2 = SimpleObservable(name='y',target=pz)
obs3 = SimpleObservable(name='z',target=pz)

super().__init__(name,rules=[r1,r2,r3,r4],observables=[obs1,obs2,obs3])

# model = RuleBasedModel('addremove_model',rules=[r1,r2,r3,r4])
model = SimpleAddRemoveModel('addremove_model')

29 changes: 29 additions & 0 deletions tests/simulator/test_sim_arch.py
Expand Up @@ -19,6 +19,8 @@
from simple_binding.model import X,Y
from simple_flip.model import model as flip_model
from simple_flip.model import Point
from simple_addremove.model import model as addremove_model


def get_lengths(elems):
return list(map(len,elems))
Expand Down Expand Up @@ -251,6 +253,33 @@ def test_fire(self):
self.assertEqual([sim.cache[x].v for x in ['x1','x2']],[False,True])
self.assertEqual(sim.get_updated_variables(),['flip_model.flipping_rule.propensity'])

class TestAddRemove(unittest.TestCase):

def setUp(self):
self.model = addremove_model
self.data = {'addremove_model':{'k1':1,'k2':1,'k3':1,'k4':1}}
self.sim = SimulationEngine(model=self.model,parameters=self.data)

def test_fire(self):
sim = self.sim
r1 = 'addremove_model.adding_x'
r2 = 'addremove_model.adding_y'
r3 = 'addremove_model.transforming_xy_to_z'
r4 = 'addremove_model.removing_z'

self.assertEqual(len(sim.cache),0)
sim.fire(r1)
self.assertEqual(len(sim.cache),1)
sim.fire(r2)
self.assertEqual(len(sim.cache),2)
sim.fire(r3)
self.assertEqual(len(sim.cache),1)
sim.fire(r4)
self.assertEqual(len(sim.cache),0)





class TestScheduler(unittest.TestCase):

Expand Down
36 changes: 30 additions & 6 deletions wc_rules/expressions/executable.py
@@ -1,12 +1,13 @@
from .parse import process_expression_string, serialize
from .dependency import DependencyCollector
from ..utils.collections import subdict
from ..schema.actions import RollbackAction, TerminateAction, PrimaryAction, CompositeAction, SimulatorAction
from ..schema.actions import RollbackAction, TerminateAction, PrimaryAction, CompositeAction, SimulatorAction, action_builtins, CollectReferences
from .builtins import ordered_builtins, global_builtins
from .exprgraph import dfs_make
from collections import ChainMap, defaultdict
from sortedcontainers import SortedSet
from frozendict import frozendict
from collections.abc import Sequence

import inspect

Expand Down Expand Up @@ -50,12 +51,11 @@ def initialize(cls,s):
keywords = list(deps.variables)

# this step figures what builtins to use, picks them from the global_builtins list
#builtins = subdict(cls.builtins, ['__builtins__'] + list(deps.builtins))
#builtins = subdict(global_builtins,['__builtins__'] + list(deps.builtins))
builtins = subdict(cls.builtins, ['__builtins__'] + list(deps.builtins))
code2 = 'lambda {vars}: {code}'.format(vars=','.join(keywords),code=code)
try:
fn = eval(code2,global_builtins)
x = cls(keywords=keywords,builtins=global_builtins,fn=fn,code=code,deps=deps)
fn = eval(code2,builtins)
x = cls(keywords=keywords,builtins=builtins,fn=fn,code=code,deps=deps)
except:
x = None
except:
Expand Down Expand Up @@ -151,7 +151,7 @@ def terminate(expr):
# is equivalent to an action method call
class ActionCaller(ExecutableExpression):
start = 'function_call'
builtins = ChainMap(global_builtins,dict(rollback=rollback,terminate=terminate))
builtins = ChainMap(global_builtins,dict(rollback=rollback,terminate=terminate),action_builtins)
allowed_forms = ['<actioncall> ( <boolexpr> )', '<pattern>.<var>.<actioncall> (<params>)', '<pattern>.<actioncall> (<params>)']
allowed_returns = None

Expand Down Expand Up @@ -222,3 +222,27 @@ def exec(self,match,*dicts):
elif not c.exec(match,*dicts):
return None
return match


class ActionManager:
def __init__(self,action_execs,factories):
self.execs = action_execs

for e in self.execs:
for fnametuple in e.deps.function_calls:
if fnametuple == ('add',):
assert len(e.deps.variables)==1
var = list(e.deps.variables)[0]
assert var in factories
setattr(e,'build_variable',var)

def exec(self,match,*dicts):
for c in self.execs:
actions = c.exec(match,*dicts)
if hasattr(c,'build_variable'):
assert isinstance(actions[-1],CollectReferences)
actions[-1].variable = c.build_variable
if isinstance(actions,Sequence):
yield from actions
else:
yield actions
26 changes: 1 addition & 25 deletions wc_rules/modeling/rule.py
Expand Up @@ -2,34 +2,10 @@
from ..utils.validate import *
from ..graph.collections import GraphContainer, GraphFactory
from .pattern import Pattern
from ..expressions.executable import ActionCaller,Constraint, Computation, RateLaw, initialize_from_string
from ..expressions.executable import ActionCaller,Constraint, Computation, RateLaw, initialize_from_string, ActionManager
from collections import Counter,ChainMap
from collections.abc import Sequence
from ..utils.collections import sort_by_value
from ..schema.actions import CollectReferences

class ActionManager:
def __init__(self,action_execs,factories):
self.execs = action_execs

for e in self.execs:
for fnametuple in e.deps.function_calls:
if fnametuple[-1] == 'build':
assert len(fnametuple)==2
assert fnametuple[0] in factories
setattr(e,'build_variable',fnametuple[0])

def exec(self,match,*dicts):
for c in self.execs:
if hasattr(c,'build_variable'):
actions, idmap = c.exec(match,*dicts)
actions.append(CollectReferences(variable=c.build_variable,data=idmap))
else:
actions = c.exec(match,*dicts)
if isinstance(actions,Sequence):
yield from actions
else:
yield actions

class Rule:

Expand Down
14 changes: 13 additions & 1 deletion wc_rules/schema/actions.py
Expand Up @@ -374,4 +374,16 @@ class CollectReferences:

def execute(self,match,cache):
match[self.variable] = {k:cache[v] for k,v in self.data.items()}
return self
return self



# Builtins for actions
def add(graphfactory):
idmap = graphfactory.build_random_idmap()
return list(graphfactory.generate_actions(idmap)) + [CollectReferences(data=idmap)]

def remove(match):
return [Remove(source=elem) for elem in match.values()]

action_builtins = {'add':add,'remove':remove}

0 comments on commit 5d0a635

Please sign in to comment.