Skip to content

Commit

Permalink
Added support for custom operand iteration.
Browse files Browse the repository at this point in the history
This allows dicts to be used with MatchPy.
  • Loading branch information
wheerd committed Jul 13, 2017
1 parent 36a66a6 commit a1e92d2
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 30 deletions.
2 changes: 2 additions & 0 deletions matchpy/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ def __copy__(self) -> 'Operation':
Operation.register(tuple)
Operation.register(set)
Operation.register(frozenset)
Operation.register(dict)


class AssociativeOperation(metaclass=ABCMeta):
Expand All @@ -622,6 +623,7 @@ def __subclasshook__(cls, C):

CommutativeOperation.register(set)
CommutativeOperation.register(frozenset)
CommutativeOperation.register(dict)


class Atom(Expression): # pylint: disable=abstract-method
Expand Down
36 changes: 27 additions & 9 deletions matchpy/expressions/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def is_constant(expression):
"""Check if the given expression is constant, i.e. it does not contain Wildcards."""
if isinstance(expression, Operation):
return all(is_constant(o) for o in expression)
return all(is_constant(o) for o in op_iter(expression))
return not isinstance(expression, Wildcard)


Expand All @@ -26,7 +26,7 @@ def is_syntactic(expression):
if isinstance(expression, (AssociativeOperation, CommutativeOperation)):
return False
if isinstance(expression, Operation):
return all(is_syntactic(o) for o in expression)
return all(is_syntactic(o) for o in op_iter(expression))
if isinstance(expression, Wildcard):
return expression.fixed_size
return True
Expand Down Expand Up @@ -57,7 +57,7 @@ def preorder_iter(expression):
"""Iterate over the expression in preorder."""
yield expression
if isinstance(expression, Operation):
for operand in expression:
for operand in op_iter(expression):
yield from preorder_iter(operand)


Expand All @@ -68,7 +68,7 @@ def preorder_iter_with_position(expression):
"""
yield expression, ()
if isinstance(expression, Operation):
for i, operand in enumerate(expression):
for i, operand in enumerate(op_iter(expression)):
for child, pos in preorder_iter_with_position(operand):
yield child, (i, ) + pos

Expand All @@ -78,7 +78,7 @@ def is_anonymous(expression):
if hasattr(expression, 'variable_name') and expression.variable_name:
return False
if isinstance(expression, Operation):
return all(is_anonymous(o) for o in expression)
return all(is_anonymous(o) for o in op_iter(expression))
return True


Expand All @@ -87,7 +87,7 @@ def contains_variables_from_set(expression, variables):
if hasattr(expression, 'variable_name') and expression.variable_name in variables:
return True
if isinstance(expression, Operation):
return any(contains_variables_from_set(o, variables) for o in expression)
return any(contains_variables_from_set(o, variables) for o in op_iter(expression))
return False


Expand All @@ -108,9 +108,9 @@ def rename_variables(expression: Expression, renaming: Dict[str, str]) -> Expres
if hasattr(expression, 'variable_name'):
variable_name = renaming.get(expression.variable_name, expression.variable_name)
return create_operation_expression(
expression, [rename_variables(o, renaming) for o in expression], variable_name=variable_name
expression, [rename_variables(o, renaming) for o in op_iter(expression)], variable_name=variable_name
)
operands = [rename_variables(o, renaming) for o in expression]
operands = [rename_variables(o, renaming) for o in op_iter(expression)]
return create_operation_expression(expression, operands)
elif isinstance(expression, Expression):
expression = expression.__copy__()
Expand All @@ -119,6 +119,8 @@ def rename_variables(expression: Expression, renaming: Dict[str, str]) -> Expres


def simple_operation_factory(op, args, variable_name):
if variable_name not in (True, False, None):
raise NotImplementedError('Expressions of type {} cannot have a variable name.'.format(type(op)))
return type(op)(args)


Expand All @@ -127,13 +129,20 @@ def simple_operation_factory(op, args, variable_name):
tuple: simple_operation_factory,
set: simple_operation_factory,
frozenset: simple_operation_factory,
# TODO: Add support for dicts
dict: simple_operation_factory,
}

_operation_iterators = {
dict: lambda d: d.items(),
}


def register_operation_factory(operation, factory):
_operation_factories[operation] = factory

def register_operation_iterator(operation, iterator):
_operation_iterators[operation] = iterator


def create_operation_expression(old_operation, new_operands, variable_name=True):
operation = type(old_operation)
Expand All @@ -145,3 +154,12 @@ def create_operation_expression(old_operation, new_operands, variable_name=True)
if variable_name is False:
return operation(*new_operands)
return operation(*new_operands, variable_name=variable_name)


def op_iter(operation):
op_type = type(operation)
for parent in op_type.__mro__:
if parent in _operation_iterators:
iterator = _operation_iterators[parent]
return iterator(operation)
return iter(operation)
8 changes: 4 additions & 4 deletions matchpy/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Expression, Operation, Pattern, Wildcard, SymbolWildcard, AssociativeOperation, CommutativeOperation
)
from .expressions.substitution import Substitution
from .expressions.functions import preorder_iter_with_position, create_operation_expression
from .expressions.functions import preorder_iter_with_position, create_operation_expression, op_iter
from .matching.one_to_one import match

__all__ = ['substitute', 'replace', 'replace_all', 'replace_many', 'is_match', 'ReplacementRule']
Expand Down Expand Up @@ -77,7 +77,7 @@ def _substitute(expression: Expression, substitution: Substitution) -> Tuple[Rep
elif isinstance(expression, Operation):
any_replaced = False
new_operands = []
for operand in expression:
for operand in op_iter(expression):
result, replaced = _substitute(operand, substitution)
if replaced:
any_replaced = True
Expand Down Expand Up @@ -126,7 +126,7 @@ def replace(expression: Expression, position: Sequence[int], replacement: Replac
if position[0] >= len(expression):
raise IndexError("Position {!r} out of range for expression {!s}".format(position, expression))
pos = position[0]
operands = list(expression)
operands = list(op_iter(expression))
subexpr = replace(operands[pos], position[1:], replacement)
if isinstance(subexpr, Sequence):
new_operands = tuple(operands[:pos]) + tuple(subexpr) + tuple(operands[pos + 1:])
Expand Down Expand Up @@ -189,7 +189,7 @@ def replace_many(expression: Expression, replacements: Sequence[Tuple[Sequence[i
return replace(expression, replacements[0][0], replacements[0][1])
if not isinstance(expression, Operation):
raise IndexError("Invalid replacements {!r} for expression {!s}".format(replacements, expression))
operands = list(expression)
operands = list(op_iter(expression))
new_operands = []
last_index = 0
for index, group in itertools.groupby(replacements, lambda r: r[0][0]):
Expand Down
2 changes: 1 addition & 1 deletion matchpy/matching/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,4 @@ def __str__(self):
for name, count in self.fixed_variables.items():
parts.extend([name] * count)

return '{}({})'.format(self.operation.name, ', '.join(parts))
return '{}({})'.format(getattr(self.operation, 'name', self.operation.__name__), ', '.join(parts))
12 changes: 6 additions & 6 deletions matchpy/matching/many_to_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from ..expressions.substitution import Substitution
from ..expressions.functions import (
is_anonymous, contains_variables_from_set, create_operation_expression, preorder_iter_with_position,
rename_variables
rename_variables, op_iter
)
from ..utils import (VariableWithCount, commutative_sequence_variable_partition_iter)
from .. import functions
Expand Down Expand Up @@ -251,7 +251,7 @@ def _match_commutative_operation(self, state: _State) -> Iterator[_State]:
matcher = state.matcher
substitution = self.substitution
matcher.add_subject(None)
for operand in subject:
for operand in op_iter(subject):
matcher.add_subject(operand)
for matched_pattern, new_substitution in matcher.match(subject, substitution):
restore_constraints = set()
Expand All @@ -278,7 +278,7 @@ def _match_commutative_operation(self, state: _State) -> Iterator[_State]:
def _match_regular_operation(self, transition: _Transition) -> Iterator[_State]:
subject = self.subjects.popleft()
after_subjects = self.subjects
operand_subjects = self.subjects = deque(subject)
operand_subjects = self.subjects = deque(op_iter(subject))
new_associative = transition.label if issubclass(transition.label, AssociativeOperation) else None
self.associative.append(new_associative)
for new_state in self._check_transition(transition, subject, False):
Expand Down Expand Up @@ -365,7 +365,7 @@ def _internal_add(self, pattern: Pattern, label, renaming) -> int:
if patterns_stack[-1]:
subpattern = patterns_stack[-1].popleft()
if isinstance(subpattern, Operation) and not isinstance(subpattern, CommutativeOperation):
patterns_stack.append(deque(subpattern))
patterns_stack.append(deque(op_iter(subpattern)))
variable_name = getattr(subpattern, 'variable_name', None)
state = self._create_expression_transition(state, subpattern, variable_name, pattern_index)
if isinstance(subpattern, CommutativeOperation):
Expand Down Expand Up @@ -813,7 +813,7 @@ def match(self, subjects: Sequence[Expression], substitution: Substitution) -> I
subject_ids.add(subject_id)
for _ in range(self.max_optional_count):
pattern_ids.update(subject_pattern_ids)
for subject in subjects:
for subject in op_iter(subjects):
subject_id, subject_pattern_ids = self.subjects[subject]
subject_ids.add(subject_id)
pattern_ids.update(subject_pattern_ids)
Expand Down Expand Up @@ -845,7 +845,7 @@ def _extract_sequence_wildcards(self, operands: Iterable[Expression],
pattern_set = Multiset()
pattern_vars = dict()
opt_count = 0
for operand in operands:
for operand in op_iter(operands):
if isinstance(operand, Wildcard) and operand.optional is not None:
opt_count += 1
if not self._is_sequence_wildcard(operand):
Expand Down
14 changes: 8 additions & 6 deletions matchpy/matching/one_to_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)
from ..expressions.constraints import Constraint
from ..expressions.substitution import Substitution
from ..expressions.functions import is_constant, preorder_iter_with_position, match_head, create_operation_expression
from ..expressions.functions import is_constant, preorder_iter_with_position, match_head, create_operation_expression, op_iter
from ..utils import (
VariableWithCount, commutative_sequence_variable_partition_iter, fixed_integer_vector_iter, weak_composition_iter,
generator_chain, optional_iter
Expand Down Expand Up @@ -155,7 +155,7 @@ def _count_seq_vars(expressions, operation):
remaining = len(expressions)
sequence_var_count = 0
optional_count = 0
for operand in operation:
for operand in op_iter(operation):
if isinstance(operand, Wildcard):
if not operand.fixed_size or isinstance(operation, AssociativeOperation):
sequence_var_count += 1
Expand Down Expand Up @@ -183,7 +183,7 @@ def _build_full_partition(optional_parts, sequence_var_partition: Sequence[int],
var_index = 0
opt_index = 0
result = []
for operand in operation:
for operand in op_iter(operation):
wrap_associative = False
if isinstance(operand, Wildcard):
count = operand.min_count if operand.optional is None else 0
Expand Down Expand Up @@ -221,7 +221,7 @@ def _non_commutative_match(subjects, operation, subst, constraints, matcher):
continue
for part in weak_composition_iter(new_remaining, sequence_var_count):
partition = _build_full_partition(optional, part, subjects, operation)
factories = [_match_factory(e, o, constraints, matcher) for e, o in zip(partition, operation)]
factories = [_match_factory(e, o, constraints, matcher) for e, o in zip(partition, op_iter(operation))]

for new_subst in generator_chain(subst, *factories):
yield new_subst
Expand All @@ -235,7 +235,9 @@ def _match_operation(expressions, operation, subst, matcher, constraints):
if not isinstance(operation, CommutativeOperation):
yield from _non_commutative_match(expressions, operation, subst, constraints, matcher)
else:
parts = CommutativePatternsParts(type(operation), *operation)
parts = CommutativePatternsParts(type(operation), *op_iter(operation))
print(expressions)
print(parts)
yield from _match_commutative_operation(expressions, parts, subst, constraints, matcher)


Expand All @@ -246,7 +248,7 @@ def _match_commutative_operation(
constraints,
matcher
) -> Iterator[Substitution]:
subjects = Multiset(subject_operands) # type: Multiset
subjects = Multiset(op_iter(subject_operands)) # type: Multiset
if not pattern.constant <= subjects:
return
subjects -= pattern.constant
Expand Down
9 changes: 5 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# -*- coding: utf-8 -*-
import pytest

from matchpy.expressions.expressions import Operation, Wildcard
from matchpy.expressions.expressions import Wildcard, CommutativeOperation
from matchpy.matching.one_to_one import match as match_one_to_one
from matchpy.matching.many_to_one import ManyToOneMatcher
from matchpy.matching.syntactic import DiscriminationNet
from matchpy.expressions.functions import preorder_iter


def pytest_generate_tests(metafunc):
Expand All @@ -16,10 +17,10 @@ def pytest_generate_tests(metafunc):

def match_many_to_one(expression, pattern):
try:
commutative, _ = next(
p for p in pattern.expression.preorder_iter(lambda e: isinstance(e, Operation) and e.commutative)
commutative = next(
p for p in preorder_iter(pattern.expression) if isinstance(p, CommutativeOperation)
)
next(wc for wc in commutative.preorder_iter(lambda e: isinstance(e, Wildcard) and e.min_count > 1))
next(wc for wc in preorder_iter(commutative) if isinstance(wc, Wildcard) and wc.min_count > 1)
except StopIteration:
pass
else:
Expand Down

0 comments on commit a1e92d2

Please sign in to comment.