Skip to content

Commit

Permalink
Fixed some bugs related to optional wildcards.
Browse files Browse the repository at this point in the history
  • Loading branch information
wheerd committed Jul 21, 2017
1 parent d4ee2f3 commit 21e5e2a
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 16 deletions.
3 changes: 2 additions & 1 deletion matchpy/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,8 @@ def __eq__(self, other):
return NotImplemented
return (
other.min_count == self.min_count and other.fixed_size == self.fixed_size and
self.variable_name == other.variable_name
self.variable_name == other.variable_name and
self.optional == other.optional
)

def __hash__(self):
Expand Down
41 changes: 27 additions & 14 deletions matchpy/matching/many_to_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from ..expressions.substitution import Substitution
from ..expressions.functions import (
is_anonymous, contains_variables_from_set, create_operation_expression, preorder_iter_with_position,
rename_variables, op_iter
rename_variables, op_iter, preorder_iter
)
from ..utils import (VariableWithCount, commutative_sequence_variable_partition_iter)
from .. import functions
Expand Down Expand Up @@ -82,6 +82,9 @@
('subst', Substitution),
]) # yapf: disable


_VISITED = set()

class _MatchIter:
def __init__(self, matcher, subject, intial_associative=None):
self.matcher = matcher
Expand Down Expand Up @@ -132,6 +135,7 @@ def _internal_iter(self):
yield label, new_substitution

def _match(self, state: _State) -> Iterator[_State]:
_VISITED.add(state.number)
if len(self.subjects) == 0:
if state.number in self.matcher.finals or OPERATION_END in state.transitions:
yield state
Expand Down Expand Up @@ -315,11 +319,11 @@ def _match_regular_operation(self, transition: _Transition) -> Iterator[_State]:


class ManyToOneMatcher:
__slots__ = ('patterns', 'states', 'root', 'pattern_vars', 'constraints', 'constraint_vars', 'finals')
__slots__ = ('patterns', 'states', 'root', 'pattern_vars', 'constraints', 'constraint_vars', 'finals', 'rename')

_state_id = 0

def __init__(self, *patterns: Expression) -> None:
def __init__(self, *patterns: Expression, rename=True) -> None:
"""
Args:
*patterns: The patterns which the matcher should match.
Expand All @@ -331,6 +335,7 @@ def __init__(self, *patterns: Expression) -> None:
self.constraints = []
self.constraint_vars = {}
self.finals = set()
self.rename = rename

for pattern in patterns:
self.add(pattern)
Expand All @@ -357,7 +362,7 @@ def add(self, pattern: Pattern, label=None) -> None:
if pattern == p and label == l:
return i
# TODO: Avoid renaming in the pattern, use variable indices instead
renaming = self._collect_variable_renaming(pattern.expression)
renaming = self._collect_variable_renaming(pattern.expression) if self.rename else {}
self._internal_add(pattern, label, renaming)

def _internal_add(self, pattern: Pattern, label, renaming) -> int:
Expand Down Expand Up @@ -456,7 +461,7 @@ def _create_expression_transition(
commutative = isinstance(expression, CommutativeOperation)
matcher = None
for transition in transitions:
if transition.variable_name == variable_name and transition.label == label:
if transition.variable_name == variable_name and transition.label == label and transition.subst == subst:
transition.patterns.add(index)
if variable_name is not None:
constraints = set(
Expand Down Expand Up @@ -677,7 +682,7 @@ def _make_graph_nodes(self, graph: Digraph, finals: Optional[List[str]]) -> None
submatch_label = '<<b>Sub Matcher End</b>' if has_states else '<<b>Sub Matcher</b>'
for pattern_index, subpatterns, variables in state.matcher.patterns.values():
var_formatted = ', '.join(
'{}[{}]x{}{}'.format(self._colored_variable(n), m, c, 'W' if w else '')
'{}[{}]x{}{}{}'.format(self._colored_variable(n), m, c, 'W' if w else '', ': {}'.format(d) if d is not None else '')
for (n, c, m, d), w in variables
)
submatch_label += '<br/>\n{}: {} {}'.format(
Expand All @@ -692,6 +697,8 @@ def _make_graph_nodes(self, graph: Digraph, finals: Optional[List[str]]) -> None
graph.edge(name, 'n{}'.format(state.matcher.automaton.root.number))
else:
attrs = {'shape': ('doublecircle' if state.number in self.finals else 'circle')}
if state.number in _VISITED:
attrs['color'] = 'red'
graph.node(name, str(state.number), attrs)
if state.number in self.finals:
sp = state_patterns[state.number]
Expand Down Expand Up @@ -826,17 +833,21 @@ def add_pattern(self, operands: Iterable[Expression], constraints) -> int:
inserted_id = self.patterns[pattern_key][0]
return inserted_id

def get_match_iter(self, subject):
match_iter = _MatchIter(self.automaton, subject, self.associative)
for _ in match_iter._match(self.automaton.root):
for pattern_index in match_iter.patterns:
substitution = Substitution(match_iter.substitution)
yield pattern_index, substitution


def add_subject(self, subject: Expression) -> None:
if subject not in self.subjects:
subject_id, pattern_set = self.subjects[subject] = (len(self.subjects), set())
self.subjects[subject_id] = subject
match_iter = _MatchIter(self.automaton, subject, self.associative)
for _ in match_iter._match(self.automaton.root):
for pattern_index in match_iter.patterns:
variables = self.automaton.pattern_vars[pattern_index]
substitution = Substitution(match_iter.substitution)
self.bipartite.setdefault((subject_id, pattern_index), []).append(substitution)
pattern_set.add(pattern_index)
for pattern_index, substitution in self.get_match_iter(subject):
self.bipartite.setdefault((subject_id, pattern_index), []).append(Substitution(substitution))
pattern_set.add(pattern_index)
else:
subject_id, _ = self.subjects[subject]
return subject_id
Expand Down Expand Up @@ -893,7 +904,9 @@ def _extract_sequence_wildcards(self, operands: Iterable[Expression],
index = i
break
else:
index = self.automaton._internal_add(pattern, None, {})
vnames = set(e.variable_name for e in preorder_iter(pattern.expression) if hasattr(e, 'variable_name') and e.variable_name is not None)
renaming = {n: n for n in vnames}
index = self.automaton._internal_add(pattern, None, renaming)
if is_anonymous(pattern.expression):
self.anonymous_patterns.add(index)
pattern_set.add(index)
Expand Down
4 changes: 3 additions & 1 deletion matchpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ def commutative_sequence_variable_partition_iter(values: 'Multiset[T]', variable
for subst in generator_chain(initial, *generators):
valid = True
for var in variables:
if len(subst[var.name]) < var.minimum:
if var.default is not None and len(subst[var.name]) == 0:
subst[var.name] = var.default
elif len(subst[var.name]) < var.minimum:
valid = False
break
if valid:
Expand Down

0 comments on commit 21e5e2a

Please sign in to comment.