Skip to content

Commit

Permalink
Enabling extension to symmetry_aware
Browse files Browse the repository at this point in the history
  • Loading branch information
johnsekar committed Mar 16, 2022
1 parent ece97e0 commit 4a29848
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
7 changes: 6 additions & 1 deletion wc_rules/matcher/add_methods.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ..utils.collections import UniversalSet, SimpleMapping
from ..schema.base import BaseClass
from .dbase import Database,DatabaseAlias
from .dbase import Database, DatabaseAlias, DatabaseSymmetric, DatabaseAliasSymmetric
from .token import TokenTransformer

from collections import deque
Expand Down Expand Up @@ -71,3 +71,8 @@ def add_channel_transform(self,source,target,datamap,actionmap):
transformer = TokenTransformer(datamap,actionmap)
)
return self

class AddMethodsSymmetric(AddMethods):

DATABASE_CLASS = DatabaseSymmetric
DATABASE_ALIAS_CLASS = DatabaseAliasSymmetric
9 changes: 6 additions & 3 deletions wc_rules/matcher/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from attrdict import AttrDict

from .dbase import Database
from .add_methods import AddMethods
from .add_methods import AddMethods, AddMethodsSymmetric
from .initialize_methods import InitializationMethods
from .state import ReteNodeState
from .node_functions import NodeFunctions
Expand Down Expand Up @@ -76,6 +76,9 @@ def sync(self,node):
return self


def build_rete_net_class(bases=bases,name='ReteNet'):
ReteNet = type(name,(ReteNetBase,) + tuple(bases),{})
def build_rete_net_class(bases=bases,name='ReteNet',symmetry_aware=False):
all_bases = [ReteNetBase,] + bases
if symmetry_aware:
all_bases[all_bases.index(AddMethods)]= AddMethodsSymmetric
ReteNet = type(name,tuple(all_bases),{})
return ReteNet
6 changes: 5 additions & 1 deletion wc_rules/matcher/dbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,8 @@ def filter(self,include_kwargs={},exclude_kwargs={}):
return rotated



class DatabaseSymmetric(Database):
pass

class DatabaseAliasSymmetric(DatabaseAlias):
pass

0 comments on commit 4a29848

Please sign in to comment.