From 450aa305b7029232fa04d4924be72cbb1178bde7 Mon Sep 17 00:00:00 2001 From: John Sekar Date: Wed, 9 Mar 2022 17:34:45 -0500 Subject: [PATCH] DatabaseAlias with filter method --- tests_rete/test_db.py | 21 ++++++++++++++++----- wc_rules/matcher/dbase.py | 30 +++++++++++++++++++++++++----- wc_rules/utils/collections.py | 5 +---- 3 files changed, 42 insertions(+), 14 deletions(-) diff --git a/tests_rete/test_db.py b/tests_rete/test_db.py index c8847b9..7f0fd14 100644 --- a/tests_rete/test_db.py +++ b/tests_rete/test_db.py @@ -34,7 +34,7 @@ def test_db(self): class TestAlias(unittest.TestCase): def test_mapping(self): - + # xyz->abc * pqr->xyz = pqr->xyz m1 = SimpleMapping(zip('xyz','abc')) self.assertEqual(m1,dict(zip('xyz','abc'))) self.assertEqual(m1.reverse,dict(zip('abc','xyz'))) @@ -42,15 +42,15 @@ def test_mapping(self): m2 = SimpleMapping(zip('pqr','xyz')) # testing multiply self.assertEqual(m1*m2,dict(zip('pqr','abc'))) - v = dict(zip('abc',range(3))) - self.assertEqual(m1.premultiply(v),dict(zip('xyz',range(3)))) - self.assertEqual(SimpleMapping(m1*m2).premultiply(v),dict(zip('pqr',range(3)))) + v = dict(zip('abc',range(3))) + self.assertEqual(SimpleMapping(v)*m1,dict(zip('xyz',range(3)))) + self.assertEqual(SimpleMapping(v)*SimpleMapping(m1*m2),dict(zip('pqr',range(3)))) def test_database_alias(self): db = Database(list('abc')) - db.insert(dict(a=1,b=2,c=3)) + db.insert(dict(zip('abc',range(3)))) alias1 = DatabaseAlias(db, dict(zip('xyz','abc'))) alias2 = DatabaseAlias(alias1,dict(zip('pqr','xyz'))) @@ -63,3 +63,14 @@ def test_database_alias(self): self.assertEqual(alias1.mapping, dict(zip('xyz','abc'))) self.assertEqual(alias2.mapping, dict(zip('pqr','abc'))) self.assertEqual(alias3.mapping, dict(zip('ijk','abc'))) + + # checking filtering + record = db.filter(dict(a=0))[0] + self.assertEqual(record,dict(zip('abc',range(3)))) + aliased1 = alias1.filter(dict(x=0))[0] + aliased2 = alias2.filter(dict(p=0))[0] + aliased3 = alias3.filter(dict(i=0))[0] + self.assertEqual(aliased1,dict(zip('xyz',range(3)))) + self.assertEqual(aliased2,dict(zip('pqr',range(3)))) + self.assertEqual(aliased3,dict(zip('ijk',range(3)))) + diff --git a/wc_rules/matcher/dbase.py b/wc_rules/matcher/dbase.py index e79d8c5..b5c9e0b 100644 --- a/wc_rules/matcher/dbase.py +++ b/wc_rules/matcher/dbase.py @@ -45,17 +45,37 @@ def __len__(self): class DatabaseAlias: - def __init__(self,target,mapping): - # NOTE: mapping has keys=CURRENT variables, values=variables of cache it is aliasing + + ''' + Example: Parent database with fields a b c + Child Alias with fields x y z + Mapping stored in child: x->a, y->b, z->c + Sending data downstream (foward) + {a:1,b:2,c:3}*{x:a,y:b,z:3} = {x:1,y:2,z:3} + Sending data request upstream (reverse) + {x:1,y:2,z:3}*{x:a,y:b,z:3}^-1 = {a:1,b:2,c:3} + ''' + if isinstance(target,DatabaseAlias): target, mapping = target.target, target.mapping*mapping - #assert isinstance(target,Database) and set(mapping.values())==target.fields - self.target = target self.mapping = SimpleMapping(mapping) - + + def forward_transform(self,match): + return SimpleMapping(match)*self.mapping + + def reverse_transform(self,match): + return SimpleMapping(match)*self.mapping.reverse + + def filter(self,include_kwargs={},exclude_kwargs={}): + includes = self.reverse_transform(include_kwargs) + excludes = self.reverse_transform(exclude_kwargs) + records = self.target.filter(includes,excludes) + rotated = [self.forward_transform(x) for x in records] + return rotated + \ No newline at end of file diff --git a/wc_rules/utils/collections.py b/wc_rules/utils/collections.py index 91557dc..b0931d7 100644 --- a/wc_rules/utils/collections.py +++ b/wc_rules/utils/collections.py @@ -18,16 +18,13 @@ def __init__(self,*args,**kwargs): def __mul__(self,other): # convention self * other == f o g # takes g's keys and f's values + # {x:a,y:b,z:c}*{p:x,q:y,r:z} = {p:a,q:b,r:c} return {k:self[v] for k,v in other.items() if v in self} @cached_property def reverse(self): return SimpleMapping({v:k for k,v in self.items()}) - def premultiply(self,other): - return {k:other[v] for k,v in self.items() if v in other} - - @dataclass(order=True,frozen=True)