Skip to content

Commit

Permalink
DatabaseSymmetric insert
Browse files Browse the repository at this point in the history
Checks whether symmetry breaking is satisfied before inserting
  • Loading branch information
johnsekar committed Mar 16, 2022
1 parent 4a29848 commit 51206e3
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 8 deletions.
25 changes: 24 additions & 1 deletion tests_rete/test_db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from wc_rules.matcher.dbase import Database, DatabaseAlias
from wc_rules.matcher.dbase import Database, DatabaseAlias, DatabaseSymmetric
from wc_rules.utils.collections import SimpleMapping
from wc_rules.graph.permutations import PermutationGroup, Permutation
from wc_rules.graph.examples import X,Y

import unittest

Expand Down Expand Up @@ -74,3 +76,24 @@ def test_database_alias(self):
self.assertEqual(aliased2,dict(zip('pqr',range(3))))
self.assertEqual(aliased3,dict(zip('ijk',range(3))))

class TestDatabaseSymmetric(unittest.TestCase):

def test_symmetric_insertion(self):
p1 = Permutation.create(list('abc'),list('abc'))
p2 = Permutation.create(list('abc'),list('acb'))
G1 = PermutationGroup.create([p1,p2])

db = DatabaseSymmetric(fields=list('abc'),symmetry_group=G1)
self.assertEqual(db.symmetry_group,G1)
self.assertEqual(len(db),0)

x1,y1,y2 = X('x1'), Y('y1'), Y('y2')

match = {'a':x1,'b':y1,'c':y2}
db.insert(match)
self.assertEqual(len(db),1)

# changing order of 'b' and 'c' should prevent insertion
match = {'a':x1,'b':y2,'c':y1}
db.insert(match)
self.assertEqual(len(db),1)
2 changes: 1 addition & 1 deletion wc_rules/graph/graph_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def partition_canonical_form(labeling,group):
if len(labeling.edges) <= 1:
return None,None

lg_nodes, lg_edges,lg_orbits = line_graph(labeling.names,labeling.edges,group.orbits())
lg_nodes, lg_edges,lg_orbits = line_graph(labeling.names,labeling.edges,group.orbits)
partition = kernighan_lin(lg_nodes.values(),lg_edges,lg_orbits)
g1, g2 = [deinduce(labeling,lg_nodes,x) for x in partition]
CL1, CL2 = [canonical_label(x) for x in [g1,g2]]
Expand Down
15 changes: 11 additions & 4 deletions wc_rules/graph/permutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from copy import deepcopy
from collections import Counter
import math
from backports.cached_property import cached_property

def print_cycles(cycles,lb=r'(',rb=r')',sep=','):
return ''.join([f"{lb}{sep.join(x)}{rb}" for x in cycles])

@dataclass(order=True,frozen=True)
class Permutation(Mapping):

### INSIGHT
Expand Down Expand Up @@ -72,16 +74,15 @@ def validate(self):
x.validate()
assert self.generators[0].is_identity(), f"Atleast one generator must be an identity permutation."

def orbits(self,simple=False):
@cached_property
def orbits(self):
orbindex = index_dict(self.generators[0].sources)
for g in self.generators:
for cyc in g.cyclic_form():
if len(cyc) > 1:
nums = [orbindex[x] for x in cyc]
orbindex = remap_values(orbindex,nums,min(nums))
orbits = invert_dict(orbindex).values()
if simple:
return print_cycles(orbits,r'{',r'}')
orbits = list(invert_dict(orbindex).values())
return tuplify(orbits)

def iter_subgroups(self):
Expand Down Expand Up @@ -135,5 +136,11 @@ def count_symmetries(self):
def restrict(self,variables):
return self.__class__.create([x.restrict(variables) for x in self.generators])

def verify_symmetry_breaking(self,match):
if self.is_trivial():
return True
elems = [[match[x].id for x in orb] for orb in self.orbits if len(orb)>1]
return all([elem==sorted(elem) for elem in elems])



17 changes: 15 additions & 2 deletions wc_rules/matcher/dbase.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pydblite import Base
from ..utils.collections import SimpleMapping
from ..utils.collections import subdict

def dict_overlap(d1,d2):
return len(set(d1.items()) & set(d2.items())) > 0
Expand Down Expand Up @@ -90,7 +91,19 @@ def filter(self,include_kwargs={},exclude_kwargs={}):


class DatabaseSymmetric(Database):
pass

def __init__(self,**kwargs):
super().__init__(**kwargs)
self.symmetry_group = kwargs.pop('symmetry_group',None)

def insert(self,record):
if self.symmetry_group.verify_symmetry_breaking(record):
super().insert(record)
return self

class DatabaseAliasSymmetric(DatabaseAlias):
pass

def __init__(self,**kwargs):
super().__init__(**kwargs)
self.symmetry_group = kwargs.pop('symmetry_group',None)

1 change: 1 addition & 0 deletions wc_rules/utils/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def subdict(d,keys,ignore=False):
keys = [x for x in keys if x in d]
return {k:d[k] for k in keys}


def triple_split(iter1,iter2):
# L1, L2 are iters
# returns (iter1-iter2), (iter1 & iter2), (iter2-iter1)
Expand Down

0 comments on commit 51206e3

Please sign in to comment.