Skip to content

Commit

Permalink
Pattern uses set() to keep its nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
johnsekar committed Oct 1, 2018
1 parent d7bb700 commit 6ea11e9
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 22 deletions.
13 changes: 8 additions & 5 deletions tests/test_pattern.py
Expand Up @@ -37,7 +37,9 @@ def test_pattern_init(self):

p2 = p1.duplicate(preserve_ids=True)
self.assertTrue(len(p2)==3)
self.assertEqual(sorted(p1._nodes.keys()), sorted(p2._nodes.keys()))
p1_ids = sorted([x.id for x in p1])
p2_ids = sorted([x.id for x in p2])
self.assertEqual(p1_ids,p2_ids)

def test_generate_queries(self):
a = A(id='a')
Expand All @@ -46,19 +48,20 @@ def test_generate_queries(self):
a.add_sites(y1,y2)

p = Pattern('p').add_node(a,recurse=True)
d = p.as_dict()
qdict = p.generate_queries()

self.assertEqual(sorted(list(qdict.keys())), ['attr','rel','type'])

for idx in qdict['type']:
node = p._nodes[idx]
node = d[idx]
tuplist = qdict['type'][idx]
for tup in tuplist:
_class = tup[1]
self.assertTrue(isinstance(node,_class))

for idx in qdict['attr']:
node = p._nodes[idx]
node = d[idx]
tuplist = qdict['attr'][idx]
for tup in tuplist:
attr = tup[1]
Expand All @@ -72,7 +75,7 @@ def test_generate_queries(self):
attr2 = tup[3]
idx2 = tup[4]

node1 = p._nodes[idx1]
node2 = p._nodes[idx2]
node1 = d[idx1]
node2 = d[idx2]
self.assertTrue(node1 in utils.listify(getattr(node2,attr2)))
self.assertTrue(node2 in utils.listify(getattr(node1,attr1)))
4 changes: 2 additions & 2 deletions tests/test_simulation_state.py
Expand Up @@ -25,7 +25,7 @@ def test_simulation_state(self):

ss =SimulationState().load_new_species(list(sf.generate(100,preserve_ids=True)))

c = sum(s._nodes['x'].ph1 and s._nodes['x'].ph2 for x,s in ss._species.items())
c = sum(s.as_dict()['x'].ph1 and s.as_dict()['x'].ph2 for x,s in ss._species.items())
self.assertTrue(c < 50)
c = sum(s._nodes['x'].ph1 or s._nodes['x'].ph2 for x,s in ss._species.items())
c = sum(s.as_dict()['x'].ph1 or s.as_dict()['x'].ph2 for x,s in ss._species.items())
self.assertTrue(c > 50)
12 changes: 6 additions & 6 deletions tests/test_species.py
Expand Up @@ -35,16 +35,16 @@ def test_species_factory(self):
factory1 = SpeciesFactory().add_species(s1)
sp_list =list(factory1.generate(100,preserve_ids=True))

c = Counter(x._nodes['a'].label for x in sp_list)
c = Counter(x.as_dict()['a'].label for x in sp_list)
self.assertEqual(c['A'],100)

c = Counter(x._nodes['x'].label for x in sp_list)
c = Counter(x.as_dict()['x'].label for x in sp_list)
self.assertEqual(c['X'],100)

c = Counter(x._nodes['x'].ph1 for x in sp_list)
c = Counter(x.as_dict()['x'].ph1 for x in sp_list)
self.assertEqual(c[True],100)

c = Counter(x._nodes['x'].ph2 for x in sp_list)
c = Counter(x.as_dict()['x'].ph2 for x in sp_list)
self.assertEqual(c[True],100)
del sp_list

Expand All @@ -53,8 +53,8 @@ def test_species_factory(self):
factory2.add_species(s2,weight=2)
sp_list =list(factory2.generate(100,preserve_ids=True))

c = Counter(x._nodes['x'].ph1 for x in sp_list)
c = Counter(x.as_dict()['x'].ph1 for x in sp_list)
self.assertEqual(c[True],100)

c = Counter(x._nodes['x'].ph2 for x in sp_list)
c = Counter(x.as_dict()['x'].ph2 for x in sp_list)
self.assertTrue(c[True] < 50)
19 changes: 18 additions & 1 deletion wc_rules/base.py
Expand Up @@ -92,16 +92,33 @@ def get_nonempty_related_attributes(self):
if getattr(self,attr,None) is not None: final_attrs.append(attr)
return final_attrs

def duplicate(self,id=None):
def duplicate(self,id=None,preserve_id=False):
''' duplicates node up to scalar attributes '''
new_node = self.__class__()
for attr in self.get_nonempty_scalar_attributes():
if attr=='id': continue
setattr(new_node,attr,getattr(self,attr))
if id:
new_node.set_id(id)
elif preserve_id:
# use cautiously
new_node.set_id(self.id)
return new_node

def generate_attr_contents(self,nonempty_only=True):
''' return a dict of attrname:[list of objects] '''
v = dict()
if nonempty_only:
attrs = self.get_nonempty_related_attributes()
else:
attrs = self.get_related_attributes()
for attr in attrs:
v[attr] = utils.listify(getattr(self,attr))
return v

def generate_appendability_dict(self):
return {x:self.attribute_properties[x]['append'] for x in self.get_related_attributes()}

def set_id(self, id):
""" Sets id attribute.
Expand Down
58 changes: 50 additions & 8 deletions wc_rules/pattern.py
@@ -1,4 +1,4 @@
from wc_rules.indexer import Index_By_ID
from wc_rules.indexer import Index_By_ID, HashableDict
from wc_rules.utils import listify,generate_id
from operator import eq
import random
Expand All @@ -7,18 +7,28 @@
class Pattern(object):
def __init__(self,idx,nodelist=None,recurse=True):
self.id = idx
self._nodes = Index_By_ID()
self._nodes = set()
if nodelist is not None:
for node in nodelist:
self.add_node(node,recurse)

def __contains__(self,node):
return node in self._nodes

def __iter__(self):
return iter(self._nodes)

def as_dict(self):
return { x.id:x for x in self }

def __getitem__(self,key):
d = self.as_dict()
return d[key]

def add_node(self,node,recurse=True):
if node in self:
return self
self._nodes.append(node)
self._nodes.add(node)
if recurse:
for attr in node.get_nonempty_related_attributes():
nodelist = listify(getattr(node,attr))
Expand All @@ -27,12 +37,12 @@ def add_node(self,node,recurse=True):
return self

def __str__(self):
s = pprint.pformat(self) + '\n' + pprint.pformat(self._nodes)
s = pprint.pformat(self) + '\n' + pprint.pformat(sorted(self._nodes,key=lambda x: (x.label,x.id)))
return s

def __len__(self): return len(self._nodes)

def duplicate(self,idx=None,preserve_ids=False):
def duplicate2(self,idx=None,preserve_ids=False):
nodemap = {}
if idx is None:
idx = generate_id()
Expand Down Expand Up @@ -70,10 +80,39 @@ def duplicate(self,idx=None,preserve_ids=False):
new_pattern.add_node(x,recurse=False)
return new_pattern

def duplicate(self,idx=None,preserve_ids=False):
if idx is None:
idx = generate_id()
new_pattern = self.__class__(idx)
nodemap = dict()
for node in self:
# this duplicates upto scalar attributes
new_node = node.duplicate(preserve_id=preserve_ids)
nodemap[node.id] = new_node
new_pattern.add_node(new_node,recurse=False)
encountered = set()
for node in self:
attrcontents = node.generate_attr_contents()
appendable = node.generate_appendability_dict()
for attr in attrcontents:
objs = set(attrcontents[attr]) - encountered
if len(objs) == 0: continue
new_objs = [nodemap[x.id] for x in objs]
new_node = nodemap[node.id]
if appendable[attr]:
new_attr = getattr(new_node,attr)
new_attr.extend(new_objs)
else:
setattr(new_node,attr,new_objs.pop())
encountered.add(node)
return new_pattern


def generate_queries_TYPE(self):
''' Generates tuples ('type',_class) '''
type_queries = {}
for idx,node in self._nodes.items():
for node in self:
idx = node.id
type_queries[idx] = []
list_of_classes = node.__class__.__mro__
for _class in reversed(list_of_classes):
Expand All @@ -86,7 +125,8 @@ def generate_queries_TYPE(self):
def generate_queries_ATTR(self):
''' Generates tuples ('attr',attrname,operator,value) '''
attr_queries = {}
for idx,node in self._nodes.items():
for node in self:
idx = node.id
attr_queries[idx] = []
for attr in sorted(node.get_nonempty_scalar_attributes()):
if attr=='id': continue
Expand All @@ -98,7 +138,8 @@ def generate_queries_REL(self):
''' Generate tuples ('rel',idx1,attrname,related_attrname,idx2) '''
rel_queries = []
already_encountered = []
for idx,node in self._nodes.items():
for node in self:
idx = node.id
for attr in node.get_nonempty_related_attributes():
nodelist = listify(getattr(node,attr))
for node2 in nodelist:
Expand All @@ -124,5 +165,6 @@ def generate_queries(self):
def main():
pass


if __name__ == '__main__':
main()

0 comments on commit 6ea11e9

Please sign in to comment.