From 3661f64e946130208bedc59311731afd349878b9 Mon Sep 17 00:00:00 2001 From: MarcoFavorito Date: Fri, 9 Jun 2023 16:20:49 +0200 Subject: [PATCH] feat: check terms consistency wrt type tag i.e. terms with the same name should have the same type tags. --- pddl/helpers/base.py | 2 +- pddl/logic/predicates.py | 32 +++++++++++++++++++++++------ pddl/logic/terms.py | 15 ++++++++++---- pddl/parser/typed_list_parser.py | 11 +++------- tests/test_logic/test_predicates.py | 28 +++++++++++++++++++++++++ tests/test_logic/test_terms.py | 2 +- 6 files changed, 70 insertions(+), 20 deletions(-) create mode 100644 tests/test_logic/test_predicates.py diff --git a/pddl/helpers/base.py b/pddl/helpers/base.py index dcc30dc..ad5b89b 100644 --- a/pddl/helpers/base.py +++ b/pddl/helpers/base.py @@ -72,7 +72,7 @@ def ensure_set(arg: Optional[Collection], immutable: bool = True) -> AbstractSet return op(arg) if arg is not None else op() -def check_no_duplicates(arg: Optional[Sequence[str]]) -> Optional[Collection]: +def check_no_duplicates(arg: Optional[Collection]) -> Optional[Collection]: """Check that the argument is a set.""" if arg is None: return None diff --git a/pddl/logic/predicates.py b/pddl/logic/predicates.py index f044c94..615bfc8 100644 --- a/pddl/logic/predicates.py +++ b/pddl/logic/predicates.py @@ -12,25 +12,45 @@ """This class implements PDDL predicates.""" import functools -from typing import Sequence +from typing import Dict, Sequence, Set -from pddl.custom_types import namelike, parse_name -from pddl.helpers.base import assert_ +from pddl.custom_types import name, namelike, parse_name +from pddl.helpers.base import assert_, check from pddl.helpers.cache_hash import cache_hash from pddl.logic.base import Atomic, Formula -from pddl.logic.terms import Term +from pddl.logic.terms import Term, _print_tag_set from pddl.parser.symbols import Symbols +def _check_terms_consistency(terms: Sequence[Term]): + """ + Check that the term sequence have consistent type tags. + + In particular, terms with the same name must have the same type tags. + """ + seen: Dict[name, Set[name]] = {} + for term in terms: + if term.name not in seen: + seen[term.name] = set(term.type_tags) + else: + check( + seen[term.name] == set(term.type_tags), + f"Term {term} has inconsistent type tags: " + f"previous type tags {_print_tag_set(seen[term.name])}, new type tags {_print_tag_set(term.type_tags)}", + exception_cls=ValueError, + ) + + @cache_hash @functools.total_ordering class Predicate(Atomic): """A class for a Predicate in PDDL.""" - def __init__(self, name: namelike, *terms: Term): + def __init__(self, predicate_name: namelike, *terms: Term): """Initialize the predicate.""" - self._name = parse_name(name) + self._name = parse_name(predicate_name) self._terms = tuple(terms) + _check_terms_consistency(self._terms) @property def name(self) -> str: diff --git a/pddl/logic/terms.py b/pddl/logic/terms.py index 22e578f..4fb37f1 100644 --- a/pddl/logic/terms.py +++ b/pddl/logic/terms.py @@ -20,26 +20,33 @@ from pddl.helpers.cache_hash import cache_hash +def _print_tag_set(type_tags: AbstractSet[name_type]) -> str: + """Print a tag set.""" + if len(type_tags) == 0: + return "[]" + return repr(sorted(map(str, type_tags))) + + @cache_hash @functools.total_ordering class Term: """A term in a formula.""" def __init__( - self, name: namelike, type_tags: Optional[Collection[namelike]] = None + self, term_name: namelike, type_tags: Optional[Collection[namelike]] = None ): """ Initialize a term. - :param name: the name for the term. + :param term_name: the name for the term. :param type_tags: the type tags associated to this term. """ assert_(type(self) is not Term, "Term is an abstract class") - self._name = parse_name(name) + self._name = parse_name(term_name) self._type_tags = frozenset(to_type(ensure_set(check_no_duplicates(type_tags)))) # type: ignore @property - def name(self) -> str: + def name(self) -> name_type: """Get the name.""" return self._name diff --git a/pddl/parser/typed_list_parser.py b/pddl/parser/typed_list_parser.py index 1bfc657..2f259ec 100644 --- a/pddl/parser/typed_list_parser.py +++ b/pddl/parser/typed_list_parser.py @@ -16,6 +16,7 @@ from pddl.custom_types import name, parse_name, parse_type from pddl.helpers.base import check, safe_index +from pddl.logic.terms import _print_tag_set from pddl.parser.symbols import Symbols @@ -203,7 +204,7 @@ def _check_item_types(self, item_name: name, type_tags: Set[name]) -> None: if previous_type_tags != type_tags: raise ValueError( f"invalid types for item '{item_name}': previous known tags were " - f"{self._print_tag_set(previous_type_tags)}, got {self._print_tag_set(type_tags)}" + f"{_print_tag_set(previous_type_tags)}, got {_print_tag_set(type_tags)}" ) def _add_item(self, item_name: name, type_tags: Set[name]) -> None: @@ -219,11 +220,5 @@ def _raise_multiple_types_error( """Raise an error if the item has multiple types.""" raise ValueError( f"typed list names should not have more than one type, got '{item_name}' with " - f"types {self._print_tag_set(type_tags)}" + f"types {_print_tag_set(type_tags)}" ) - - def _print_tag_set(self, type_tags: Set[name]) -> str: - """Print a tag set.""" - if len(type_tags) == 0: - return "[]" - return repr(sorted(map(str, type_tags))) diff --git a/tests/test_logic/test_predicates.py b/tests/test_logic/test_predicates.py new file mode 100644 index 0000000..4bdb6f6 --- /dev/null +++ b/tests/test_logic/test_predicates.py @@ -0,0 +1,28 @@ +# +# Copyright 2021-2023 WhiteMech +# +# ------------------------------ +# +# This file is part of pddl. +# +# Use of this source code is governed by an MIT-style +# license that can be found in the LICENSE file or at +# https://opensource.org/licenses/MIT. +# + +"""Test pddl.logic.predicates module.""" +import pytest + +from pddl.logic import Predicate +from pddl.logic.terms import Variable + + +def test_inconsistent_predicate_terms() -> None: + """Test that terms of a predicate must have consistent typing.""" + with pytest.raises( + ValueError, + match=r"Term \?a has inconsistent type tags: previous type tags \['t1', 't2'\], new type tags " + r"\['t3', 't4'\]", + ): + a1, a2 = Variable("a", ["t1", "t2"]), Variable("a", ["t3", "t4"]) + Predicate("p", a1, a2) diff --git a/tests/test_logic/test_terms.py b/tests/test_logic/test_terms.py index 7016332..0208240 100644 --- a/tests/test_logic/test_terms.py +++ b/tests/test_logic/test_terms.py @@ -10,7 +10,7 @@ # https://opensource.org/licenses/MIT. # -"""Test pddl.logic module.""" +"""Test pddl.logic.terms module.""" import pytest from pddl.logic.terms import Variable