Skip to content

Commit

Permalink
building rete net.. expressions (TODO: update channels)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnsekar committed Oct 5, 2021
1 parent 661bba4 commit 9192bb9
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 23 deletions.
7 changes: 6 additions & 1 deletion wc_rules/expressions/executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def exec(self,*dicts):
def build_exprgraph(self):
tree, deps = process_expression_string(self.code,start=self.__class__.start)
return dfs_make(tree)


def full_string(self):
return self.code

class Constraint(ExecutableExpression):
start = 'boolean_expression'
Expand All @@ -111,6 +113,9 @@ def build_exprgraph(self):
tree, deps = process_expression_string(code,start=self.__class__.start)
return dfs_make(tree)

def full_string(self):
return f'{self.deps.declared_variable} = {self.code}'


class RateLaw(ExecutableExpression):
start = 'expression'
Expand Down
5 changes: 5 additions & 0 deletions wc_rules/expressions/exprgraph_serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .exprgraph_utils import dfs_iter, dfs_visit, dfs_print


def serialize(graph):
print(dfs_print(graph))
1 change: 1 addition & 0 deletions wc_rules/expressions/exprgraph_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ..schema.base import BaseClass
from ..schema.attributes import StringAttribute
from lark import Token

class ExprBase(BaseClass):
data = StringAttribute()
Expand Down
2 changes: 0 additions & 2 deletions wc_rules/expressions/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ def category_pair(cat):
kwarg = lambda x,y: '='.join(y)
kwargs = args


def function_call(self,args):
d,s = dict(args), ''
if 'function_name' in d and 'args' not in d:
Expand All @@ -303,4 +302,3 @@ def serialize(tree):
s = Serializer().transform(tree=tree)
return s


3 changes: 3 additions & 0 deletions wc_rules/graph/permutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,5 +132,8 @@ def count_symmetries(self):
stabilizer = self.stabilizer(variable)
return norbits*stabilizer.count_symmetries()

def restrict(self,variables):
return self.__class__.create([x.restrict(variables) for x in self.generators])



95 changes: 79 additions & 16 deletions wc_rules/matcher/initialize.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from ..schema.base import BaseClass
from ..utils.random import generate_id
from ..utils.collections import Mapping
from ..utils.collections import Mapping, quoted, unzip
from ..expressions.exprgraph import PatternReference
from ..graph.graph_partitioning import partition_canonical_form
from ..graph.canonical_labeling import canonical_label
from ..graph.collections import GraphContainer
from ..modeling.pattern import Pattern
from collections import deque,Counter
from collections import deque,Counter, defaultdict
from attrdict import AttrDict
# nodes must be a dict with keys
# 'type','core',
# 'state' gets automatically initialized
Expand Down Expand Up @@ -63,24 +65,85 @@ def initialize_pattern(net,pattern):
if net.get_node(core=pattern) is not None:
return net

if len(pattern.constraints) == 0:
parent = pattern.parent
# handle helpers
helpers = dict()
if len(pattern.helpers) > 0:
for var,p in pattern.helpers.items():
net.initialize_pattern(p)
helpers[var] = p

# handle parent
parent,constraints = pattern.parent, pattern.constraints
if isinstance(parent,GraphContainer):
parent,attrs = parent.strip_attrs()
constraints = [f'{var}.{attr} == {quoted(value)}' for var in attrs for attr,value in attrs[var].items()] + constraints
m, L, G = canonical_label(parent)
net.initialize_canonical_label(L,G)
pdict = AttrDict(parent=L,mapping=m,symmetry_group=G)
elif isinstance(parent,Pattern):
net.initialize_pattern(parent)
m = Mapping.create(parent.variables)
G = net.get_node(core=parent).data.symmetry_group
pdict = AttrDict(parent = parent, mapping=m, symmetry_group=G)
else:
assert False, f'Some trouble initializing parent from {parent}'

graph = parent.duplicate() if isinstance(parent,GraphContainer) else net.get_node(core=parent).data.exprgraph
constraint_objects = []
if len(constraints) == 0:
# is an alias for its parent
if isinstance(parent,GraphContainer):
m, L, G = canonical_label(parent)
net.initialize_canonical_label(L,G)
net.add_node(type='pattern',core=pattern,symmetry_group=G,parent=L,mapping=m,alias=True)
net.add_channel(type='alias',source=L,target=pattern,mapping=m)
# is an alias for another pattern
if isinstance(parent,Pattern):
net.initialize_pattern(parent)
m = Mapping.create(parent.variables)
G = net.get_node(core=parent).data.symmetry_group
net.add_node(type='pattern',core=pattern,symmetry_group=G,parent=parent,mapping=m,alias=True)
net.add_channel(type='alias',source=parent,target=pattern,mapping=m)
net.add_node(type='pattern',core=pattern,symmetry_group=pdict.symmetry_group,exprgraph = graph,alias=True)
net.add_channel(type='alias',source=pdict.parent,target=pattern,mapping=pdict.mapping)

if len(constraints) > 0:

for var,p in helpers.items():
graph.add(PatternReference(id=var,pattern_id=id(p)))
for c in constraints:
x = pattern.make_executable_constraint(c)
constraint_objects.append(x)
exprgraph = GraphContainer(x.build_exprgraph().get_connected())
for _,node in exprgraph.iter_nodes():
if getattr(node,'variable',None) is not None:
reference = node.variable_reference()
assert reference in graph._dict, f'{reference} not found'
node.attach_source(graph[reference])
constraint_objects.append(x)

graph = graph + exprgraph
m,L,G = canonical_label(graph)
symmetry_group = G.duplicate(m).restrict(pattern.variables)

# collect update channels
constraint_pattern_relationships = set()
for x in constraint_objects:
for fname_tuple in x.deps.function_calls:
pname = fname_tuple[0]
if pname in helpers:
kwpairs = x.deps.function_calls[fname_tuple]['kwpairs']
kwpairs = [(x,y) for x,y in kwpairs if x in helpers[pname].variables and y in pattern.variables]
m = Mapping.create(*unzip(kwpairs)) if kwpairs else None
constraint_pattern_relationships.add((x,pname,m,))
constraint_attr_relationships = set()
for x in constraint_objects:
for var, attrs in x.deps.attribute_calls.items():
for attr in attrs:
if var in pattern.variables:
constraint_attr_relationships.add((x,var,attr,))

print(constraint_pattern_relationships)
print(constraint_attr_relationships)
net.add_node(type='pattern',core=pattern,symmetry_group=symmetry_group,exprgraph=graph)
net.add_channel(type='parent',source=pdict.parent,target=pattern,mapping=pdict.mapping)

# for x,pname,m in constraint_pattern_relationships:
# net.add_channel(type='update',source=pname,target=pattern,mapping=m)
# for x,var,attr in constraint_attr_relationships:
# net.add_channel(type='update',source=<get_var_class>,target=pattern,mapping=Mapping.create(['a'],[var]))

return net


def print_merge_form(names,m1,m2):
s1 = '(' + ','.join(names) + ')'
s2 = '(' + ','.join(m1.sources) + ')' + '->' + '(' + ','.join(m1.targets) + ')'
Expand Down
8 changes: 6 additions & 2 deletions wc_rules/modeling/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,17 @@ def validate_constraints(self,constraints):
kwdeps[v] = x.keywords
executables.append(x)
validate_acyclic(kwdeps)

return newvars

def make_executable_constraints(self):
# mark for downgrade
return [initialize_from_string(s,(Constraint,Computation)) for s in self.constraints]

def compute_symmetry_group(self,source_symmetry_group):
return None
def make_executable_constraint(self,s):
return initialize_from_string(s,(Constraint,Computation))



class SynthesisPattern:

Expand Down
2 changes: 1 addition & 1 deletion wc_rules/simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def push_to_stack(self,action):
# assume list has to be executed left to right
self.action_stack = deque(action) + self.action_stack
else:
self.action_stack.append(action)
self.action_stack.appendleft(action)
return self

def simulate(self):
Expand Down
8 changes: 7 additions & 1 deletion wc_rules/utils/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,4 +322,10 @@ def rotate_until(dq,conditionfn):
dq.rotate(-1)
nrots += 1
assert nrots < len(dq)
return dq
return dq

def quoted(x):
return f'\"{x}\"' if isinstance(x,str) else x

def unzip(zipped):
return list(zip(*zipped))

0 comments on commit 9192bb9

Please sign in to comment.