diff --git a/matchpy/expressions/expressions.py b/matchpy/expressions/expressions.py index 63b2ce2..d91d722 100644 --- a/matchpy/expressions/expressions.py +++ b/matchpy/expressions/expressions.py @@ -600,6 +600,7 @@ def __copy__(self) -> 'Operation': Operation.register(tuple) Operation.register(set) Operation.register(frozenset) +Operation.register(dict) class AssociativeOperation(metaclass=ABCMeta): @@ -622,6 +623,7 @@ def __subclasshook__(cls, C): CommutativeOperation.register(set) CommutativeOperation.register(frozenset) +CommutativeOperation.register(dict) class Atom(Expression): # pylint: disable=abstract-method diff --git a/matchpy/expressions/functions.py b/matchpy/expressions/functions.py index 25a6392..2749dae 100644 --- a/matchpy/expressions/functions.py +++ b/matchpy/expressions/functions.py @@ -14,7 +14,7 @@ def is_constant(expression): """Check if the given expression is constant, i.e. it does not contain Wildcards.""" if isinstance(expression, Operation): - return all(is_constant(o) for o in expression) + return all(is_constant(o) for o in op_iter(expression)) return not isinstance(expression, Wildcard) @@ -26,7 +26,7 @@ def is_syntactic(expression): if isinstance(expression, (AssociativeOperation, CommutativeOperation)): return False if isinstance(expression, Operation): - return all(is_syntactic(o) for o in expression) + return all(is_syntactic(o) for o in op_iter(expression)) if isinstance(expression, Wildcard): return expression.fixed_size return True @@ -57,7 +57,7 @@ def preorder_iter(expression): """Iterate over the expression in preorder.""" yield expression if isinstance(expression, Operation): - for operand in expression: + for operand in op_iter(expression): yield from preorder_iter(operand) @@ -68,7 +68,7 @@ def preorder_iter_with_position(expression): """ yield expression, () if isinstance(expression, Operation): - for i, operand in enumerate(expression): + for i, operand in enumerate(op_iter(expression)): for child, pos in preorder_iter_with_position(operand): yield child, (i, ) + pos @@ -78,7 +78,7 @@ def is_anonymous(expression): if hasattr(expression, 'variable_name') and expression.variable_name: return False if isinstance(expression, Operation): - return all(is_anonymous(o) for o in expression) + return all(is_anonymous(o) for o in op_iter(expression)) return True @@ -87,7 +87,7 @@ def contains_variables_from_set(expression, variables): if hasattr(expression, 'variable_name') and expression.variable_name in variables: return True if isinstance(expression, Operation): - return any(contains_variables_from_set(o, variables) for o in expression) + return any(contains_variables_from_set(o, variables) for o in op_iter(expression)) return False @@ -108,9 +108,9 @@ def rename_variables(expression: Expression, renaming: Dict[str, str]) -> Expres if hasattr(expression, 'variable_name'): variable_name = renaming.get(expression.variable_name, expression.variable_name) return create_operation_expression( - expression, [rename_variables(o, renaming) for o in expression], variable_name=variable_name + expression, [rename_variables(o, renaming) for o in op_iter(expression)], variable_name=variable_name ) - operands = [rename_variables(o, renaming) for o in expression] + operands = [rename_variables(o, renaming) for o in op_iter(expression)] return create_operation_expression(expression, operands) elif isinstance(expression, Expression): expression = expression.__copy__() @@ -119,6 +119,8 @@ def rename_variables(expression: Expression, renaming: Dict[str, str]) -> Expres def simple_operation_factory(op, args, variable_name): + if variable_name not in (True, False, None): + raise NotImplementedError('Expressions of type {} cannot have a variable name.'.format(type(op))) return type(op)(args) @@ -127,13 +129,20 @@ def simple_operation_factory(op, args, variable_name): tuple: simple_operation_factory, set: simple_operation_factory, frozenset: simple_operation_factory, - # TODO: Add support for dicts + dict: simple_operation_factory, +} + +_operation_iterators = { + dict: lambda d: d.items(), } def register_operation_factory(operation, factory): _operation_factories[operation] = factory +def register_operation_iterator(operation, iterator): + _operation_iterators[operation] = iterator + def create_operation_expression(old_operation, new_operands, variable_name=True): operation = type(old_operation) @@ -145,3 +154,12 @@ def create_operation_expression(old_operation, new_operands, variable_name=True) if variable_name is False: return operation(*new_operands) return operation(*new_operands, variable_name=variable_name) + + +def op_iter(operation): + op_type = type(operation) + for parent in op_type.__mro__: + if parent in _operation_iterators: + iterator = _operation_iterators[parent] + return iterator(operation) + return iter(operation) diff --git a/matchpy/functions.py b/matchpy/functions.py index 3572383..728079a 100644 --- a/matchpy/functions.py +++ b/matchpy/functions.py @@ -19,7 +19,7 @@ Expression, Operation, Pattern, Wildcard, SymbolWildcard, AssociativeOperation, CommutativeOperation ) from .expressions.substitution import Substitution -from .expressions.functions import preorder_iter_with_position, create_operation_expression +from .expressions.functions import preorder_iter_with_position, create_operation_expression, op_iter from .matching.one_to_one import match __all__ = ['substitute', 'replace', 'replace_all', 'replace_many', 'is_match', 'ReplacementRule'] @@ -77,7 +77,7 @@ def _substitute(expression: Expression, substitution: Substitution) -> Tuple[Rep elif isinstance(expression, Operation): any_replaced = False new_operands = [] - for operand in expression: + for operand in op_iter(expression): result, replaced = _substitute(operand, substitution) if replaced: any_replaced = True @@ -126,7 +126,7 @@ def replace(expression: Expression, position: Sequence[int], replacement: Replac if position[0] >= len(expression): raise IndexError("Position {!r} out of range for expression {!s}".format(position, expression)) pos = position[0] - operands = list(expression) + operands = list(op_iter(expression)) subexpr = replace(operands[pos], position[1:], replacement) if isinstance(subexpr, Sequence): new_operands = tuple(operands[:pos]) + tuple(subexpr) + tuple(operands[pos + 1:]) @@ -189,7 +189,7 @@ def replace_many(expression: Expression, replacements: Sequence[Tuple[Sequence[i return replace(expression, replacements[0][0], replacements[0][1]) if not isinstance(expression, Operation): raise IndexError("Invalid replacements {!r} for expression {!s}".format(replacements, expression)) - operands = list(expression) + operands = list(op_iter(expression)) new_operands = [] last_index = 0 for index, group in itertools.groupby(replacements, lambda r: r[0][0]): diff --git a/matchpy/matching/_common.py b/matchpy/matching/_common.py index ff7861b..b45fbf5 100644 --- a/matchpy/matching/_common.py +++ b/matchpy/matching/_common.py @@ -160,4 +160,4 @@ def __str__(self): for name, count in self.fixed_variables.items(): parts.extend([name] * count) - return '{}({})'.format(self.operation.name, ', '.join(parts)) + return '{}({})'.format(getattr(self.operation, 'name', self.operation.__name__), ', '.join(parts)) diff --git a/matchpy/matching/many_to_one.py b/matchpy/matching/many_to_one.py index 036d53f..4e1451a 100644 --- a/matchpy/matching/many_to_one.py +++ b/matchpy/matching/many_to_one.py @@ -50,7 +50,7 @@ from ..expressions.substitution import Substitution from ..expressions.functions import ( is_anonymous, contains_variables_from_set, create_operation_expression, preorder_iter_with_position, - rename_variables + rename_variables, op_iter ) from ..utils import (VariableWithCount, commutative_sequence_variable_partition_iter) from .. import functions @@ -251,7 +251,7 @@ def _match_commutative_operation(self, state: _State) -> Iterator[_State]: matcher = state.matcher substitution = self.substitution matcher.add_subject(None) - for operand in subject: + for operand in op_iter(subject): matcher.add_subject(operand) for matched_pattern, new_substitution in matcher.match(subject, substitution): restore_constraints = set() @@ -278,7 +278,7 @@ def _match_commutative_operation(self, state: _State) -> Iterator[_State]: def _match_regular_operation(self, transition: _Transition) -> Iterator[_State]: subject = self.subjects.popleft() after_subjects = self.subjects - operand_subjects = self.subjects = deque(subject) + operand_subjects = self.subjects = deque(op_iter(subject)) new_associative = transition.label if issubclass(transition.label, AssociativeOperation) else None self.associative.append(new_associative) for new_state in self._check_transition(transition, subject, False): @@ -365,7 +365,7 @@ def _internal_add(self, pattern: Pattern, label, renaming) -> int: if patterns_stack[-1]: subpattern = patterns_stack[-1].popleft() if isinstance(subpattern, Operation) and not isinstance(subpattern, CommutativeOperation): - patterns_stack.append(deque(subpattern)) + patterns_stack.append(deque(op_iter(subpattern))) variable_name = getattr(subpattern, 'variable_name', None) state = self._create_expression_transition(state, subpattern, variable_name, pattern_index) if isinstance(subpattern, CommutativeOperation): @@ -813,7 +813,7 @@ def match(self, subjects: Sequence[Expression], substitution: Substitution) -> I subject_ids.add(subject_id) for _ in range(self.max_optional_count): pattern_ids.update(subject_pattern_ids) - for subject in subjects: + for subject in op_iter(subjects): subject_id, subject_pattern_ids = self.subjects[subject] subject_ids.add(subject_id) pattern_ids.update(subject_pattern_ids) @@ -845,7 +845,7 @@ def _extract_sequence_wildcards(self, operands: Iterable[Expression], pattern_set = Multiset() pattern_vars = dict() opt_count = 0 - for operand in operands: + for operand in op_iter(operands): if isinstance(operand, Wildcard) and operand.optional is not None: opt_count += 1 if not self._is_sequence_wildcard(operand): diff --git a/matchpy/matching/one_to_one.py b/matchpy/matching/one_to_one.py index 34eb40e..e70a65c 100644 --- a/matchpy/matching/one_to_one.py +++ b/matchpy/matching/one_to_one.py @@ -8,7 +8,7 @@ ) from ..expressions.constraints import Constraint from ..expressions.substitution import Substitution -from ..expressions.functions import is_constant, preorder_iter_with_position, match_head, create_operation_expression +from ..expressions.functions import is_constant, preorder_iter_with_position, match_head, create_operation_expression, op_iter from ..utils import ( VariableWithCount, commutative_sequence_variable_partition_iter, fixed_integer_vector_iter, weak_composition_iter, generator_chain, optional_iter @@ -155,7 +155,7 @@ def _count_seq_vars(expressions, operation): remaining = len(expressions) sequence_var_count = 0 optional_count = 0 - for operand in operation: + for operand in op_iter(operation): if isinstance(operand, Wildcard): if not operand.fixed_size or isinstance(operation, AssociativeOperation): sequence_var_count += 1 @@ -183,7 +183,7 @@ def _build_full_partition(optional_parts, sequence_var_partition: Sequence[int], var_index = 0 opt_index = 0 result = [] - for operand in operation: + for operand in op_iter(operation): wrap_associative = False if isinstance(operand, Wildcard): count = operand.min_count if operand.optional is None else 0 @@ -221,7 +221,7 @@ def _non_commutative_match(subjects, operation, subst, constraints, matcher): continue for part in weak_composition_iter(new_remaining, sequence_var_count): partition = _build_full_partition(optional, part, subjects, operation) - factories = [_match_factory(e, o, constraints, matcher) for e, o in zip(partition, operation)] + factories = [_match_factory(e, o, constraints, matcher) for e, o in zip(partition, op_iter(operation))] for new_subst in generator_chain(subst, *factories): yield new_subst @@ -235,7 +235,9 @@ def _match_operation(expressions, operation, subst, matcher, constraints): if not isinstance(operation, CommutativeOperation): yield from _non_commutative_match(expressions, operation, subst, constraints, matcher) else: - parts = CommutativePatternsParts(type(operation), *operation) + parts = CommutativePatternsParts(type(operation), *op_iter(operation)) + print(expressions) + print(parts) yield from _match_commutative_operation(expressions, parts, subst, constraints, matcher) @@ -246,7 +248,7 @@ def _match_commutative_operation( constraints, matcher ) -> Iterator[Substitution]: - subjects = Multiset(subject_operands) # type: Multiset + subjects = Multiset(op_iter(subject_operands)) # type: Multiset if not pattern.constant <= subjects: return subjects -= pattern.constant diff --git a/tests/conftest.py b/tests/conftest.py index 893f1c3..256437b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,11 @@ # -*- coding: utf-8 -*- import pytest -from matchpy.expressions.expressions import Operation, Wildcard +from matchpy.expressions.expressions import Wildcard, CommutativeOperation from matchpy.matching.one_to_one import match as match_one_to_one from matchpy.matching.many_to_one import ManyToOneMatcher from matchpy.matching.syntactic import DiscriminationNet +from matchpy.expressions.functions import preorder_iter def pytest_generate_tests(metafunc): @@ -16,10 +17,10 @@ def pytest_generate_tests(metafunc): def match_many_to_one(expression, pattern): try: - commutative, _ = next( - p for p in pattern.expression.preorder_iter(lambda e: isinstance(e, Operation) and e.commutative) + commutative = next( + p for p in preorder_iter(pattern.expression) if isinstance(p, CommutativeOperation) ) - next(wc for wc in commutative.preorder_iter(lambda e: isinstance(e, Wildcard) and e.min_count > 1)) + next(wc for wc in preorder_iter(commutative) if isinstance(wc, Wildcard) and wc.min_count > 1) except StopIteration: pass else: