Skip to content

Commit

Permalink
Forgot to check in expressions/functions.py
Browse files Browse the repository at this point in the history
Fixed some bugs.
  • Loading branch information
wheerd committed May 18, 2017
1 parent 1d41c02 commit 7561764
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 6 deletions.
5 changes: 2 additions & 3 deletions matchpy/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@

__all__ = [
'Expression', 'Arity', 'Atom', 'Symbol', 'Wildcard', 'Operation', 'SymbolWildcard', 'Pattern', 'make_dot_variable',
'make_plus_variable', 'make_star_variable', 'make_symbol_variable'
'make_plus_variable', 'make_star_variable', 'make_symbol_variable', 'AssociativeOperation', 'CommutativeOperation'
]

ExprPredicate = Optional[Callable[['Expression'], bool]]
Expand Down Expand Up @@ -609,8 +609,7 @@ def __subclasshook__(cls, C):
class Atom(Expression): # pylint: disable=abstract-method
"""Base for all atomic expressions."""

def __iter__(self):
raise NotImplementedError()
__iter__ = None


class Symbol(Atom):
Expand Down
109 changes: 109 additions & 0 deletions matchpy/expressions/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from .expressions import Operation, Wildcard, AssociativeOperation, CommutativeOperation, SymbolWildcard, Pattern

__all__ = [
'is_constant', 'is_syntactic', 'get_head', 'match_head', 'preorder_iter', 'preorder_iter_with_position',
'is_anonymous', 'contains_variables_from_set', 'register_operation_factory', 'create_operation_expression'
]

def is_constant(expression):
"""Check if the given expression is constant, i.e. it does not contain Wildcards."""
if isinstance(expression, Operation):
return all(is_constant(o) for o in expression)
return not isinstance(expression, Wildcard)


def is_syntactic(expression):
"""
Check if the given expression is syntactic, i.e. it does not contain sequence wildcards or
associative/commutative operations.
"""
if isinstance(expression, (AssociativeOperation, CommutativeOperation)):
return False
if isinstance(expression, Operation):
return all(is_syntactic(o) for o in expression)
if isinstance(expression, Wildcard):
return expression.fixed_size
return True


def get_head(expression):
"""Returns the given expression's head."""
if isinstance(expression, Wildcard):
if isinstance(expression, SymbolWildcard):
return expression.symbol_type
return None
return type(expression)


def match_head(subject, pattern):
"""Checks if the head of subject matches the pattern's head."""
if isinstance(pattern, Pattern):
pattern = pattern.expression
pattern_head = get_head(pattern)
if pattern_head is None:
return True
subject_head = get_head(subject)
assert subject_head is not None
return issubclass(subject_head, pattern_head)

def preorder_iter(expression):
"""Iterate over the expression in preorder."""
yield expression
if isinstance(expression, Operation):
for operand in expression:
yield from preorder_iter(operand)

def preorder_iter_with_position(expression):
"""Iterate over the expression in preorder.
Also yields the position of each subexpression.
"""
yield expression, ()
if isinstance(expression, Operation):
for i, operand in enumerate(expression):
for child, pos in preorder_iter_with_position(operand):
yield child, (i, ) + pos

def is_anonymous(expression):
"""Returns True iff the expression does not contain any variables."""
if hasattr(expression, 'variable_name') and expression.variable_name:
return False
if isinstance(expression, Operation):
return all(is_anonymous(o) for o in expression)
return True

def contains_variables_from_set(expression, variables):
"""Returns True iff the expression contains any of the variables from the given set."""
if hasattr(expression, 'variable_name') and expression.variable_name in variables:
return True
if isinstance(expression, Operation):
return any(contains_variables_from_set(o, variables) for o in expression)
return False


def simple_operation_factory(op, args, variable_name):
return type(op)(args)


_operation_factories = {
list: simple_operation_factory,
tuple: simple_operation_factory,
set: simple_operation_factory,
frozenset: simple_operation_factory,
# TODO: Add support for dicts
}

def register_operation_factory(operation, factory):
_operation_factories[operation] = factory


def create_operation_expression(old_operation, new_operands, variable_name=True):
operation = type(old_operation)
for parent in operation.__mro__:
if parent in _operation_factories:
return _operation_factories[parent](old_operation, new_operands, variable_name)
if variable_name is True:
variable_name = getattr(old_operation, 'variable_name', None)
if variable_name is False:
return operation(*new_operands)
return operation(*new_operands, variable_name=variable_name)
2 changes: 1 addition & 1 deletion matchpy/matching/many_to_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def _create_expression_transition(
) -> _State:
label, head = self._get_label_and_head(expression)
transitions = state.transitions.setdefault(head, [])
commutative = getattr(expression, 'commutative', False)
commutative = isinstance(expression, CommutativeOperation)
matcher = None
for transition in transitions:
if transition.variable_name == variable_name and transition.label == label:
Expand Down
6 changes: 4 additions & 2 deletions matchpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import ast
import os
import tokenize
import copy
from types import LambdaType
# pylint: disable=unused-import
from typing import (Callable, Dict, Iterator, List, NamedTuple, Optional, Sequence, Tuple, TypeVar, cast, Union, Any)
Expand Down Expand Up @@ -133,9 +134,10 @@ def _factory(subst):
solutions = list(solve_linear_diop(total, *var_counts))
_linear_diop_solution_cache[cache_key] = solutions
for solution in solutions:
new_subst = copy.copy(subst)
for var, count in zip(variables, solution):
subst[var.name][value] = count
yield subst
new_subst[var.name][value] = count
yield new_subst

return _factory

Expand Down

0 comments on commit 7561764

Please sign in to comment.