From 8d63388d7bcd39926feaa02a84ea684ac3b0f3f3 Mon Sep 17 00:00:00 2001 From: John Sekar Date: Wed, 9 Mar 2022 02:16:51 -0500 Subject: [PATCH] DatabaseAlias Have to code and test filtering --- tests_rete/test_db.py | 65 +++++++++++++++++++++++++++++++++++ wc_rules/matcher/dbase.py | 20 ++++++++++- wc_rules/utils/collections.py | 32 ++++++++++++++++- 3 files changed, 115 insertions(+), 2 deletions(-) create mode 100644 tests_rete/test_db.py diff --git a/tests_rete/test_db.py b/tests_rete/test_db.py new file mode 100644 index 0000000..c8847b9 --- /dev/null +++ b/tests_rete/test_db.py @@ -0,0 +1,65 @@ +from wc_rules.matcher.dbase import Database, DatabaseAlias +from wc_rules.utils.collections import SimpleMapping + +import unittest + +class TestDatabase(unittest.TestCase): + + def test_db(self): + + db = Database(list('abc')) + self.assertEqual(db.fields,list('abc')) + self.assertEqual(len(db.filter(dict(a=1))),0) + + record = dict(a=1,b=2,c=3) + db.insert(record) + self.assertEqual(len(db),1) + self.assertEqual(len(db.filter(dict(a=1))),1) + self.assertEqual(db.filter_one(dict(a=1)),dict(a=1,b=2,c=3)) + + record = dict(a=1,b=3,c=4) + db.insert(record) + self.assertEqual(len(db),2) + self.assertEqual(len(db.filter(dict(a=1))),2) + self.assertEqual(len(db.filter(dict(b=2))),1) + + records = db.delete(dict(c=3)) + self.assertEqual(len(records),1) + self.assertEqual(len(db),1) + + records = db.delete(dict(c=4)) + self.assertEqual(len(records),1) + self.assertEqual(len(db),0) + +class TestAlias(unittest.TestCase): + + def test_mapping(self): + + m1 = SimpleMapping(zip('xyz','abc')) + self.assertEqual(m1,dict(zip('xyz','abc'))) + self.assertEqual(m1.reverse,dict(zip('abc','xyz'))) + + m2 = SimpleMapping(zip('pqr','xyz')) + # testing multiply + self.assertEqual(m1*m2,dict(zip('pqr','abc'))) + v = dict(zip('abc',range(3))) + + self.assertEqual(m1.premultiply(v),dict(zip('xyz',range(3)))) + self.assertEqual(SimpleMapping(m1*m2).premultiply(v),dict(zip('pqr',range(3)))) + + + def test_database_alias(self): + db = Database(list('abc')) + db.insert(dict(a=1,b=2,c=3)) + + alias1 = DatabaseAlias(db, dict(zip('xyz','abc'))) + alias2 = DatabaseAlias(alias1,dict(zip('pqr','xyz'))) + alias3 = DatabaseAlias(alias2,dict(zip('ijk','pqr'))) + + self.assertEqual(alias1.target, db) + self.assertEqual(alias2.target, db) + self.assertEqual(alias3.target, db) + + self.assertEqual(alias1.mapping, dict(zip('xyz','abc'))) + self.assertEqual(alias2.mapping, dict(zip('pqr','abc'))) + self.assertEqual(alias3.mapping, dict(zip('ijk','abc'))) diff --git a/wc_rules/matcher/dbase.py b/wc_rules/matcher/dbase.py index d3e0f98..e79d8c5 100644 --- a/wc_rules/matcher/dbase.py +++ b/wc_rules/matcher/dbase.py @@ -1,4 +1,5 @@ from pydblite import Base +from ..utils.collections import SimpleMapping def dict_overlap(d1,d2): return len(set(d1.items()) & set(d2.items())) > 0 @@ -40,4 +41,21 @@ def filter_one(self,include_kwargs): return None def __len__(self): - return len(self._db) \ No newline at end of file + return len(self._db) + +class DatabaseAlias: + + + def __init__(self,target,mapping): + + # NOTE: mapping has keys=CURRENT variables, values=variables of cache it is aliasing + if isinstance(target,DatabaseAlias): + target, mapping = target.target, target.mapping*mapping + + #assert isinstance(target,Database) and set(mapping.values())==target.fields + + self.target = target + self.mapping = SimpleMapping(mapping) + + + \ No newline at end of file diff --git a/wc_rules/utils/collections.py b/wc_rules/utils/collections.py index f58bce5..91557dc 100644 --- a/wc_rules/utils/collections.py +++ b/wc_rules/utils/collections.py @@ -8,7 +8,27 @@ from dataclasses import dataclass, field from typing import Tuple, Dict from backports.cached_property import cached_property -from collections import defaultdict +from collections import defaultdict, UserDict + +class SimpleMapping(UserDict): + + def __init__(self,*args,**kwargs): + super().__init__(*args,**kwargs) + + def __mul__(self,other): + # convention self * other == f o g + # takes g's keys and f's values + return {k:self[v] for k,v in other.items() if v in self} + + @cached_property + def reverse(self): + return SimpleMapping({v:k for k,v in self.items()}) + + def premultiply(self,other): + return {k:other[v] for k,v in self.items() if v in other} + + + @dataclass(order=True,frozen=True) class Mapping: @@ -233,6 +253,16 @@ def remap_values(d,oldvalues,newvalue): return d ###### Methods ###### +def compose_mapping(m1,m2): + # mappings are simple dict + # downgrade Mapping above???? + # m1: ad,be,cf + # m2: dg,eh,fi + # produces: ag,bh,ci + # equivalent to m2 o m1 + return {k:m2[v] for k in m1 if v in m2} + + def get_values(d,keys): return [d[k] for k in keys]