From 1092fae39308d826f3e506ec291b4bf4e4f2dc1d Mon Sep 17 00:00:00 2001 From: MarcoFavorito Date: Wed, 7 Jun 2023 16:36:28 +0200 Subject: [PATCH] feat: add internal class to manage and handle types --- pddl/_validation.py | 51 ++--------------- pddl/core.py | 129 +++++++++++++++++++++++++++++++++++-------- tests/test_domain.py | 6 +- 3 files changed, 113 insertions(+), 73 deletions(-) diff --git a/pddl/_validation.py b/pddl/_validation.py index 8dfda49..dacf770 100644 --- a/pddl/_validation.py +++ b/pddl/_validation.py @@ -11,57 +11,14 @@ # """This module defines validation functions for PDDL data structures.""" +from typing import Collection, Optional, Set, Tuple -from typing import Collection, Dict, Optional, Set, Tuple - -from pddl.custom_types import name, to_names # noqa: F401 +from pddl.custom_types import name, namelike, to_names # noqa: F401 from pddl.exceptions import PDDLValidationError -from pddl.helpers.base import ensure_set, find_cycle +from pddl.helpers.base import ensure_set from pddl.logic import Constant, Predicate from pddl.logic.terms import Term -from pddl.parser.symbols import ALL_SYMBOLS, Symbols - - -def _check_types_dictionary(type_dict: Dict[name, Optional[name]]) -> None: - """ - Check the consistency of the types dictionary. - - 1) Empty types dictionary is correct by definition: - >>> _check_types_dictionary({}) - - 2) The `object` type cannot be a subtype: - >>> a = name("a") - >>> _check_types_dictionary({name("object"): a}) - Traceback (most recent call last): - ... - pddl.exceptions.PDDLValidationError: object must not have supertypes, but got 'object' is a subtype of 'a' - - 3) If cycles in the type hierarchy graph are present, an error is raised: - >>> a, b, c = to_names(["a", "b", "c"]) - >>> _check_types_dictionary({a: b, b: c, c: a}) - Traceback (most recent call last): - ... - pddl.exceptions.PDDLValidationError: cycle detected in the type hierarchy: a -> b -> c - - :param type_dict: the types dictionary - """ - if len(type_dict) == 0: - return - - # check `object` type - object_name = name(Symbols.OBJECT.value) - if object_name in type_dict and type_dict[object_name] is not None: - object_supertype = type_dict[object_name] - raise PDDLValidationError( - f"object must not have supertypes, but got 'object' is a subtype of '{object_supertype}'" - ) - - # check cycles - cycle = find_cycle(type_dict) # type: ignore - if cycle is not None: - raise PDDLValidationError( - "cycle detected in the type hierarchy: " + " -> ".join(cycle) - ) +from pddl.parser.symbols import ALL_SYMBOLS def _find_inconsistencies_in_typed_terms( diff --git a/pddl/core.py b/pddl/core.py index 9d0e8cf..347b5db 100644 --- a/pddl/core.py +++ b/pddl/core.py @@ -19,24 +19,23 @@ from enum import Enum from typing import AbstractSet, Collection, Dict, Optional, Sequence, Set, cast -from pddl._validation import ( - _check_constant_types, - _check_types_dictionary, - _check_types_in_has_terms_objects, -) +from pddl._validation import _check_constant_types, _check_types_in_has_terms_objects from pddl.custom_types import name as name_type -from pddl.custom_types import namelike, to_names_types +from pddl.custom_types import namelike, to_names, to_names_types # noqa: F401 +from pddl.exceptions import PDDLValidationError from pddl.helpers.base import ( _typed_parameters, assert_, ensure, ensure_sequence, ensure_set, + find_cycle, ) from pddl.logic.base import Formula, TrueFormula, is_literal from pddl.logic.predicates import DerivedPredicate, Predicate from pddl.logic.terms import Constant, Term, Variable from pddl.parser.symbols import RequirementSymbols as RS +from pddl.parser.symbols import Symbols class Domain: @@ -68,30 +67,19 @@ def __init__( """ self._name = name_type(name) self._requirements = ensure_set(requirements) - self._types = to_names_types(ensure(types, dict())) + self._types = _Types(types, self._requirements) self._constants = ensure_set(constants) self._predicates = ensure_set(predicates) self._derived_predicates = ensure_set(derived_predicates) self._actions = ensure_set(actions) - self._all_types_set = self._get_all_types() - self._check_consistency() - def _get_all_types(self) -> Set[name_type]: - """Get all types supported by this domain.""" - if self._types is None: - return set() - result = set(self._types.keys()) | set(self._types.values()) - result.discard(None) - return cast(Set[name_type], result) - def _check_consistency(self) -> None: """Check consistency of a domain instance object.""" - _check_types_dictionary(self._types) - _check_constant_types(self._constants, self._all_types_set) - _check_types_in_has_terms_objects(self._predicates, self._all_types_set) - _check_types_in_has_terms_objects(self._actions, self._all_types_set) # type: ignore + _check_constant_types(self._constants, self._types.all_types) + _check_types_in_has_terms_objects(self._predicates, self._types.all_types) + _check_types_in_has_terms_objects(self._actions, self._types.all_types) # type: ignore self._check_types_in_derived_predicates() def _check_types_in_derived_predicates(self) -> None: @@ -101,7 +89,7 @@ def _check_types_in_derived_predicates(self) -> None: if self._derived_predicates else set() ) - _check_types_in_has_terms_objects(dp_list, self._all_types_set) + _check_types_in_has_terms_objects(dp_list, self._types.all_types) @property def name(self) -> str: @@ -136,7 +124,7 @@ def actions(self) -> AbstractSet["Action"]: @property def types(self) -> Dict[name_type, Optional[name_type]]: """Get the type definitions, if defined. Else, raise error.""" - return self._types + return self._types.raw def __eq__(self, other): """Compare with another object.""" @@ -388,3 +376,98 @@ def __lt__(self, other): return self.value <= other.value else: return super().__lt__(other) + + +class _Types: + """A class for representing and managing the types available in a PDDL Domain.""" + + def __init__( + self, + types: Optional[Dict[namelike, Optional[namelike]]] = None, + requirements: Optional[AbstractSet[Requirements]] = None, + ) -> None: + """Initialize the Types object.""" + self._types = to_names_types(ensure(types, dict())) + + self._all_types = self._get_all_types() + self._check_types_dictionary(self._types, ensure_set(requirements)) + + @property + def raw(self) -> Dict[name_type, Optional[name_type]]: + """Get the raw types dictionary.""" + return self._types + + @property + def all_types(self) -> Set[name_type]: + """Get all available types.""" + return self._all_types + + def _get_all_types(self) -> Set[name_type]: + """Get all types supported by the domain.""" + if self._types is None: + return set() + result = set(self._types.keys()) | set(self._types.values()) + result.discard(None) + return cast(Set[name_type], result) + + @classmethod + def _check_types_dictionary( + cls, + type_dict: Dict[name_type, Optional[name_type]], + requirements: AbstractSet[Requirements], + ) -> None: + """ + Check the consistency of the types dictionary. + + 1) Empty types dictionary is correct by definition: + >>> _Types._check_types_dictionary({}, set()) + + 2) There are supertypes, but :typing requirement not specified + >>> a, b, c = to_names(["a", "b", "c"]) + >>> _Types._check_types_dictionary({a: b, b: c}, set()) + Traceback (most recent call last): + ... + pddl.exceptions.PDDLValidationError: typing requirement is not specified, but types are used: 'b', 'c' + + 3) The `object` type cannot be a subtype: + >>> a = name_type("a") + >>> _Types._check_types_dictionary({name_type("object"): a}, {Requirements.TYPING}) + Traceback (most recent call last): + ... + pddl.exceptions.PDDLValidationError: object must not have supertypes, but got 'object' is a subtype of 'a' + + 4) If cycles in the type hierarchy graph are present, an error is raised: + >>> a, b, c = to_names(["a", "b", "c"]) + >>> _Types._check_types_dictionary({a: b, b: c, c: a}, {Requirements.TYPING}) + Traceback (most recent call last): + ... + pddl.exceptions.PDDLValidationError: cycle detected in the type hierarchy: a -> b -> c + + :param type_dict: the types dictionary + """ + if len(type_dict) == 0: + return + + # check typing requirement + supertypes = {t for t in type_dict.values() if t is not None} + if len(supertypes) > 0 and Requirements.TYPING not in requirements: + raise PDDLValidationError( + "typing requirement is not specified, but types are used: '" + + "', '".join(map(str, sorted(supertypes))) + + "'" + ) + + # check `object` type + object_name = name_type(Symbols.OBJECT.value) + if object_name in type_dict and type_dict[object_name] is not None: + object_supertype = type_dict[object_name] + raise PDDLValidationError( + f"object must not have supertypes, but got 'object' is a subtype of '{object_supertype}'" + ) + + # check cycles + cycle = find_cycle(type_dict) # type: ignore + if cycle is not None: + raise PDDLValidationError( + "cycle detected in the type hierarchy: " + " -> ".join(cycle) + ) diff --git a/tests/test_domain.py b/tests/test_domain.py index d8d35e9..70cbdb6 100644 --- a/tests/test_domain.py +++ b/tests/test_domain.py @@ -17,7 +17,7 @@ import pytest -from pddl.core import Action, Domain +from pddl.core import Action, Domain, Requirements from pddl.exceptions import PDDLValidationError from pddl.logic import Constant, Variable from pddl.logic.base import Not, TrueFormula @@ -88,7 +88,7 @@ def test_cycles_in_type_defs_not_allowed() -> None: with pytest.raises( PDDLValidationError, match="cycle detected in the type hierarchy: A -> B -> C" ): - Domain("dummy", types={"A": "B", "B": "C", "C": "A"}) + Domain("dummy", requirements={Requirements.TYPING}, types={"A": "B", "B": "C", "C": "A"}) def test_object_must_not_be_subtype() -> None: @@ -100,7 +100,7 @@ def test_object_must_not_be_subtype() -> None: PDDLValidationError, match=rf"object must not have supertypes, but got 'object' is a subtype of '{my_type}'", ): - Domain("test", types=type_set) # type: ignore + Domain("test", requirements={Requirements.TYPING}, types=type_set) # type: ignore def test_constants_type_not_available() -> None: