Skip to content

Commit

Permalink
DatabaseAlias
Browse files Browse the repository at this point in the history
Have to code and test filtering
  • Loading branch information
johnsekar committed Mar 9, 2022
1 parent fb90820 commit 8d63388
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 2 deletions.
65 changes: 65 additions & 0 deletions tests_rete/test_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from wc_rules.matcher.dbase import Database, DatabaseAlias
from wc_rules.utils.collections import SimpleMapping

import unittest

class TestDatabase(unittest.TestCase):

def test_db(self):

db = Database(list('abc'))
self.assertEqual(db.fields,list('abc'))
self.assertEqual(len(db.filter(dict(a=1))),0)

record = dict(a=1,b=2,c=3)
db.insert(record)
self.assertEqual(len(db),1)
self.assertEqual(len(db.filter(dict(a=1))),1)
self.assertEqual(db.filter_one(dict(a=1)),dict(a=1,b=2,c=3))

record = dict(a=1,b=3,c=4)
db.insert(record)
self.assertEqual(len(db),2)
self.assertEqual(len(db.filter(dict(a=1))),2)
self.assertEqual(len(db.filter(dict(b=2))),1)

records = db.delete(dict(c=3))
self.assertEqual(len(records),1)
self.assertEqual(len(db),1)

records = db.delete(dict(c=4))
self.assertEqual(len(records),1)
self.assertEqual(len(db),0)

class TestAlias(unittest.TestCase):

def test_mapping(self):

m1 = SimpleMapping(zip('xyz','abc'))
self.assertEqual(m1,dict(zip('xyz','abc')))
self.assertEqual(m1.reverse,dict(zip('abc','xyz')))

m2 = SimpleMapping(zip('pqr','xyz'))
# testing multiply
self.assertEqual(m1*m2,dict(zip('pqr','abc')))
v = dict(zip('abc',range(3)))

self.assertEqual(m1.premultiply(v),dict(zip('xyz',range(3))))
self.assertEqual(SimpleMapping(m1*m2).premultiply(v),dict(zip('pqr',range(3))))


def test_database_alias(self):
db = Database(list('abc'))
db.insert(dict(a=1,b=2,c=3))

alias1 = DatabaseAlias(db, dict(zip('xyz','abc')))
alias2 = DatabaseAlias(alias1,dict(zip('pqr','xyz')))
alias3 = DatabaseAlias(alias2,dict(zip('ijk','pqr')))

self.assertEqual(alias1.target, db)
self.assertEqual(alias2.target, db)
self.assertEqual(alias3.target, db)

self.assertEqual(alias1.mapping, dict(zip('xyz','abc')))
self.assertEqual(alias2.mapping, dict(zip('pqr','abc')))
self.assertEqual(alias3.mapping, dict(zip('ijk','abc')))
20 changes: 19 additions & 1 deletion wc_rules/matcher/dbase.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pydblite import Base
from ..utils.collections import SimpleMapping

def dict_overlap(d1,d2):
return len(set(d1.items()) & set(d2.items())) > 0
Expand Down Expand Up @@ -40,4 +41,21 @@ def filter_one(self,include_kwargs):
return None

def __len__(self):
return len(self._db)
return len(self._db)

class DatabaseAlias:


def __init__(self,target,mapping):

# NOTE: mapping has keys=CURRENT variables, values=variables of cache it is aliasing
if isinstance(target,DatabaseAlias):
target, mapping = target.target, target.mapping*mapping

#assert isinstance(target,Database) and set(mapping.values())==target.fields

self.target = target
self.mapping = SimpleMapping(mapping)



32 changes: 31 additions & 1 deletion wc_rules/utils/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,27 @@
from dataclasses import dataclass, field
from typing import Tuple, Dict
from backports.cached_property import cached_property
from collections import defaultdict
from collections import defaultdict, UserDict

class SimpleMapping(UserDict):

def __init__(self,*args,**kwargs):
super().__init__(*args,**kwargs)

def __mul__(self,other):
# convention self * other == f o g
# takes g's keys and f's values
return {k:self[v] for k,v in other.items() if v in self}

@cached_property
def reverse(self):
return SimpleMapping({v:k for k,v in self.items()})

def premultiply(self,other):
return {k:other[v] for k,v in self.items() if v in other}




@dataclass(order=True,frozen=True)
class Mapping:
Expand Down Expand Up @@ -233,6 +253,16 @@ def remap_values(d,oldvalues,newvalue):
return d

###### Methods ######
def compose_mapping(m1,m2):
# mappings are simple dict
# downgrade Mapping above????
# m1: ad,be,cf
# m2: dg,eh,fi
# produces: ag,bh,ci
# equivalent to m2 o m1
return {k:m2[v] for k in m1 if v in m2}


def get_values(d,keys):
return [d[k] for k in keys]

Expand Down

0 comments on commit 8d63388

Please sign in to comment.