From 51206e38fc80d33b04132df6789e878d0dfb6598 Mon Sep 17 00:00:00 2001 From: John Sekar Date: Wed, 16 Mar 2022 05:04:10 -0400 Subject: [PATCH] DatabaseSymmetric insert Checks whether symmetry breaking is satisfied before inserting --- tests_rete/test_db.py | 25 ++++++++++++++++++++++++- wc_rules/graph/graph_partitioning.py | 2 +- wc_rules/graph/permutations.py | 15 +++++++++++---- wc_rules/matcher/dbase.py | 17 +++++++++++++++-- wc_rules/utils/collections.py | 1 + 5 files changed, 52 insertions(+), 8 deletions(-) diff --git a/tests_rete/test_db.py b/tests_rete/test_db.py index 7f0fd14..bf0135c 100644 --- a/tests_rete/test_db.py +++ b/tests_rete/test_db.py @@ -1,5 +1,7 @@ -from wc_rules.matcher.dbase import Database, DatabaseAlias +from wc_rules.matcher.dbase import Database, DatabaseAlias, DatabaseSymmetric from wc_rules.utils.collections import SimpleMapping +from wc_rules.graph.permutations import PermutationGroup, Permutation +from wc_rules.graph.examples import X,Y import unittest @@ -74,3 +76,24 @@ def test_database_alias(self): self.assertEqual(aliased2,dict(zip('pqr',range(3)))) self.assertEqual(aliased3,dict(zip('ijk',range(3)))) +class TestDatabaseSymmetric(unittest.TestCase): + + def test_symmetric_insertion(self): + p1 = Permutation.create(list('abc'),list('abc')) + p2 = Permutation.create(list('abc'),list('acb')) + G1 = PermutationGroup.create([p1,p2]) + + db = DatabaseSymmetric(fields=list('abc'),symmetry_group=G1) + self.assertEqual(db.symmetry_group,G1) + self.assertEqual(len(db),0) + + x1,y1,y2 = X('x1'), Y('y1'), Y('y2') + + match = {'a':x1,'b':y1,'c':y2} + db.insert(match) + self.assertEqual(len(db),1) + + # changing order of 'b' and 'c' should prevent insertion + match = {'a':x1,'b':y2,'c':y1} + db.insert(match) + self.assertEqual(len(db),1) \ No newline at end of file diff --git a/wc_rules/graph/graph_partitioning.py b/wc_rules/graph/graph_partitioning.py index 38c9173..67bb227 100644 --- a/wc_rules/graph/graph_partitioning.py +++ b/wc_rules/graph/graph_partitioning.py @@ -15,7 +15,7 @@ def partition_canonical_form(labeling,group): if len(labeling.edges) <= 1: return None,None - lg_nodes, lg_edges,lg_orbits = line_graph(labeling.names,labeling.edges,group.orbits()) + lg_nodes, lg_edges,lg_orbits = line_graph(labeling.names,labeling.edges,group.orbits) partition = kernighan_lin(lg_nodes.values(),lg_edges,lg_orbits) g1, g2 = [deinduce(labeling,lg_nodes,x) for x in partition] CL1, CL2 = [canonical_label(x) for x in [g1,g2]] diff --git a/wc_rules/graph/permutations.py b/wc_rules/graph/permutations.py index 3733f65..fdfd516 100644 --- a/wc_rules/graph/permutations.py +++ b/wc_rules/graph/permutations.py @@ -6,10 +6,12 @@ from copy import deepcopy from collections import Counter import math +from backports.cached_property import cached_property def print_cycles(cycles,lb=r'(',rb=r')',sep=','): return ''.join([f"{lb}{sep.join(x)}{rb}" for x in cycles]) +@dataclass(order=True,frozen=True) class Permutation(Mapping): ### INSIGHT @@ -72,16 +74,15 @@ def validate(self): x.validate() assert self.generators[0].is_identity(), f"Atleast one generator must be an identity permutation." - def orbits(self,simple=False): + @cached_property + def orbits(self): orbindex = index_dict(self.generators[0].sources) for g in self.generators: for cyc in g.cyclic_form(): if len(cyc) > 1: nums = [orbindex[x] for x in cyc] orbindex = remap_values(orbindex,nums,min(nums)) - orbits = invert_dict(orbindex).values() - if simple: - return print_cycles(orbits,r'{',r'}') + orbits = list(invert_dict(orbindex).values()) return tuplify(orbits) def iter_subgroups(self): @@ -135,5 +136,11 @@ def count_symmetries(self): def restrict(self,variables): return self.__class__.create([x.restrict(variables) for x in self.generators]) + def verify_symmetry_breaking(self,match): + if self.is_trivial(): + return True + elems = [[match[x].id for x in orb] for orb in self.orbits if len(orb)>1] + return all([elem==sorted(elem) for elem in elems]) + diff --git a/wc_rules/matcher/dbase.py b/wc_rules/matcher/dbase.py index 99e89a5..eee091b 100644 --- a/wc_rules/matcher/dbase.py +++ b/wc_rules/matcher/dbase.py @@ -1,5 +1,6 @@ from pydblite import Base from ..utils.collections import SimpleMapping +from ..utils.collections import subdict def dict_overlap(d1,d2): return len(set(d1.items()) & set(d2.items())) > 0 @@ -90,7 +91,19 @@ def filter(self,include_kwargs={},exclude_kwargs={}): class DatabaseSymmetric(Database): - pass + + def __init__(self,**kwargs): + super().__init__(**kwargs) + self.symmetry_group = kwargs.pop('symmetry_group',None) + + def insert(self,record): + if self.symmetry_group.verify_symmetry_breaking(record): + super().insert(record) + return self class DatabaseAliasSymmetric(DatabaseAlias): - pass + + def __init__(self,**kwargs): + super().__init__(**kwargs) + self.symmetry_group = kwargs.pop('symmetry_group',None) + diff --git a/wc_rules/utils/collections.py b/wc_rules/utils/collections.py index c7c4c98..5c4acb8 100644 --- a/wc_rules/utils/collections.py +++ b/wc_rules/utils/collections.py @@ -290,6 +290,7 @@ def subdict(d,keys,ignore=False): keys = [x for x in keys if x in d] return {k:d[k] for k in keys} + def triple_split(iter1,iter2): # L1, L2 are iters # returns (iter1-iter2), (iter1 & iter2), (iter2-iter1)