Skip to content

Commit

Permalink
feat: add internal class to manage and handle types
Browse files Browse the repository at this point in the history
  • Loading branch information
marcofavorito committed Jun 7, 2023
1 parent 03abf7d commit fada331
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 73 deletions.
51 changes: 4 additions & 47 deletions pddl/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,57 +11,14 @@
#

"""This module defines validation functions for PDDL data structures."""
from typing import Collection, Optional, Set, Tuple

from typing import Collection, Dict, Optional, Set, Tuple

from pddl.custom_types import name, to_names # noqa: F401
from pddl.custom_types import name, namelike, to_names # noqa: F401
from pddl.exceptions import PDDLValidationError
from pddl.helpers.base import ensure_set, find_cycle
from pddl.helpers.base import ensure_set
from pddl.logic import Constant, Predicate
from pddl.logic.terms import Term
from pddl.parser.symbols import ALL_SYMBOLS, Symbols


def _check_types_dictionary(type_dict: Dict[name, Optional[name]]) -> None:
"""
Check the consistency of the types dictionary.
1) Empty types dictionary is correct by definition:
>>> _check_types_dictionary({})
2) The `object` type cannot be a subtype:
>>> a = name("a")
>>> _check_types_dictionary({name("object"): a})
Traceback (most recent call last):
...
pddl.exceptions.PDDLValidationError: object must not have supertypes, but got 'object' is a subtype of 'a'
3) If cycles in the type hierarchy graph are present, an error is raised:
>>> a, b, c = to_names(["a", "b", "c"])
>>> _check_types_dictionary({a: b, b: c, c: a})
Traceback (most recent call last):
...
pddl.exceptions.PDDLValidationError: cycle detected in the type hierarchy: a -> b -> c
:param type_dict: the types dictionary
"""
if len(type_dict) == 0:
return

# check `object` type
object_name = name(Symbols.OBJECT.value)
if object_name in type_dict and type_dict[object_name] is not None:
object_supertype = type_dict[object_name]
raise PDDLValidationError(
f"object must not have supertypes, but got 'object' is a subtype of '{object_supertype}'"
)

# check cycles
cycle = find_cycle(type_dict) # type: ignore
if cycle is not None:
raise PDDLValidationError(
"cycle detected in the type hierarchy: " + " -> ".join(cycle)
)
from pddl.parser.symbols import ALL_SYMBOLS


def _find_inconsistencies_in_typed_terms(
Expand Down
129 changes: 106 additions & 23 deletions pddl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,23 @@
from enum import Enum
from typing import AbstractSet, Collection, Dict, Optional, Sequence, Set, cast

from pddl._validation import (
_check_constant_types,
_check_types_dictionary,
_check_types_in_has_terms_objects,
)
from pddl._validation import _check_constant_types, _check_types_in_has_terms_objects
from pddl.custom_types import name as name_type
from pddl.custom_types import namelike, to_names_types
from pddl.custom_types import namelike, to_names, to_names_types # noqa: F401
from pddl.exceptions import PDDLValidationError
from pddl.helpers.base import (
_typed_parameters,
assert_,
ensure,
ensure_sequence,
ensure_set,
find_cycle,
)
from pddl.logic.base import Formula, TrueFormula, is_literal
from pddl.logic.predicates import DerivedPredicate, Predicate
from pddl.logic.terms import Constant, Term, Variable
from pddl.parser.symbols import RequirementSymbols as RS
from pddl.parser.symbols import Symbols


class Domain:
Expand Down Expand Up @@ -68,30 +67,19 @@ def __init__(
"""
self._name = name_type(name)
self._requirements = ensure_set(requirements)
self._types = to_names_types(ensure(types, dict()))
self._types = _Types(types, self._requirements)
self._constants = ensure_set(constants)
self._predicates = ensure_set(predicates)
self._derived_predicates = ensure_set(derived_predicates)
self._actions = ensure_set(actions)

self._all_types_set = self._get_all_types()

self._check_consistency()

def _get_all_types(self) -> Set[name_type]:
"""Get all types supported by this domain."""
if self._types is None:
return set()
result = set(self._types.keys()) | set(self._types.values())
result.discard(None)
return cast(Set[name_type], result)

def _check_consistency(self) -> None:
"""Check consistency of a domain instance object."""
_check_types_dictionary(self._types)
_check_constant_types(self._constants, self._all_types_set)
_check_types_in_has_terms_objects(self._predicates, self._all_types_set)
_check_types_in_has_terms_objects(self._actions, self._all_types_set) # type: ignore
_check_constant_types(self._constants, self._types.all_types)
_check_types_in_has_terms_objects(self._predicates, self._types.all_types)
_check_types_in_has_terms_objects(self._actions, self._types.all_types) # type: ignore
self._check_types_in_derived_predicates()

def _check_types_in_derived_predicates(self) -> None:
Expand All @@ -101,7 +89,7 @@ def _check_types_in_derived_predicates(self) -> None:
if self._derived_predicates
else set()
)
_check_types_in_has_terms_objects(dp_list, self._all_types_set)
_check_types_in_has_terms_objects(dp_list, self._types.all_types)

@property
def name(self) -> str:
Expand Down Expand Up @@ -136,7 +124,7 @@ def actions(self) -> AbstractSet["Action"]:
@property
def types(self) -> Dict[name_type, Optional[name_type]]:
"""Get the type definitions, if defined. Else, raise error."""
return self._types
return self._types.raw

def __eq__(self, other):
"""Compare with another object."""
Expand Down Expand Up @@ -388,3 +376,98 @@ def __lt__(self, other):
return self.value <= other.value
else:
return super().__lt__(other)


class _Types:
"""A class for representing and managing the types available in a PDDL Domain."""

def __init__(
self,
types: Optional[Dict[namelike, Optional[namelike]]] = None,
requirements: Optional[AbstractSet[Requirements]] = None,
) -> None:
"""Initialize the Types object."""
self._types = to_names_types(ensure(types, dict()))

self._all_types = self._get_all_types()
self._check_types_dictionary(self._types, ensure_set(requirements))

@property
def raw(self) -> Dict[name_type, Optional[name_type]]:
"""Get the raw types dictionary."""
return self._types

@property
def all_types(self) -> Set[name_type]:
"""Get all available types."""
return self._all_types

def _get_all_types(self) -> Set[name_type]:
"""Get all types supported by the domain."""
if self._types is None:
return set()
result = set(self._types.keys()) | set(self._types.values())
result.discard(None)
return cast(Set[name_type], result)

@classmethod
def _check_types_dictionary(
cls,
type_dict: Dict[name_type, Optional[name_type]],
requirements: AbstractSet[Requirements],
) -> None:
"""
Check the consistency of the types dictionary.
1) Empty types dictionary is correct by definition:
>>> _Types._check_types_dictionary({}, set())
2) There are supertypes, but :typing requirement not specified
>>> a, b, c = to_names(["a", "b", "c"])
>>> _Types._check_types_dictionary({a: b, b: c}, set())
Traceback (most recent call last):
...
pddl.exceptions.PDDLValidationError: typing requirement is not specified, but types are used: 'b', 'c'
3) The `object` type cannot be a subtype:
>>> a = name_type("a")
>>> _Types._check_types_dictionary({name_type("object"): a}, {Requirements.TYPING})
Traceback (most recent call last):
...
pddl.exceptions.PDDLValidationError: object must not have supertypes, but got 'object' is a subtype of 'a'
4) If cycles in the type hierarchy graph are present, an error is raised:
>>> a, b, c = to_names(["a", "b", "c"])
>>> _Types._check_types_dictionary({a: b, b: c, c: a}, {Requirements.TYPING})
Traceback (most recent call last):
...
pddl.exceptions.PDDLValidationError: cycle detected in the type hierarchy: a -> b -> c
:param type_dict: the types dictionary
"""
if len(type_dict) == 0:
return

# check typing requirement
supertypes = {t for t in type_dict.values() if t is not None}
if len(supertypes) > 0 and Requirements.TYPING not in requirements:
raise PDDLValidationError(
"typing requirement is not specified, but types are used: '"
+ "', '".join(map(str, sorted(supertypes)))
+ "'"
)

# check `object` type
object_name = name_type(Symbols.OBJECT.value)
if object_name in type_dict and type_dict[object_name] is not None:
object_supertype = type_dict[object_name]
raise PDDLValidationError(
f"object must not have supertypes, but got 'object' is a subtype of '{object_supertype}'"
)

# check cycles
cycle = find_cycle(type_dict) # type: ignore
if cycle is not None:
raise PDDLValidationError(
"cycle detected in the type hierarchy: " + " -> ".join(cycle)
)
10 changes: 7 additions & 3 deletions tests/test_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import pytest

from pddl.core import Action, Domain
from pddl.core import Action, Domain, Requirements
from pddl.exceptions import PDDLValidationError
from pddl.logic import Constant, Variable
from pddl.logic.base import Not, TrueFormula
Expand Down Expand Up @@ -88,7 +88,11 @@ def test_cycles_in_type_defs_not_allowed() -> None:
with pytest.raises(
PDDLValidationError, match="cycle detected in the type hierarchy: A -> B -> C"
):
Domain("dummy", types={"A": "B", "B": "C", "C": "A"})
Domain(
"dummy",
requirements={Requirements.TYPING},
types={"A": "B", "B": "C", "C": "A"},
)


def test_object_must_not_be_subtype() -> None:
Expand All @@ -100,7 +104,7 @@ def test_object_must_not_be_subtype() -> None:
PDDLValidationError,
match=rf"object must not have supertypes, but got 'object' is a subtype of '{my_type}'",
):
Domain("test", types=type_set) # type: ignore
Domain("test", requirements={Requirements.TYPING}, types=type_set) # type: ignore


def test_constants_type_not_available() -> None:
Expand Down

0 comments on commit fada331

Please sign in to comment.