diff --git a/ConfigSpace/forbidden.py b/ConfigSpace/forbidden.py index ce7dadd2..0fab0aa2 100644 --- a/ConfigSpace/forbidden.py +++ b/ConfigSpace/forbidden.py @@ -32,12 +32,13 @@ import io from functools import reduce -from ConfigSpace.hyperparameters import Hyperparameter +from ConfigSpace.hyperparameters import Hyperparameter +from typing import List, Dict, Any, Union class AbstractForbiddenComponent(object): __metaclass__ = ABCMeta - + hyperparameter = None # type: Hyperparameter @abstractmethod def __init__(self): pass @@ -47,19 +48,19 @@ def __repr__(self): pass # http://stackoverflow.com/a/25176504/4636294 - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: """Override the default Equals behavior""" if isinstance(other, self.__class__): return self.__dict__ == other.__dict__ return NotImplemented - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: """Define a non-equality test""" if isinstance(other, self.__class__): return not self.__eq__(other) return NotImplemented - def __hash__(self): + def __hash__(self) -> int: """Override the default hash behavior (that returns the id or the object)""" return hash(tuple(sorted(self.__dict__.items()))) @@ -68,17 +69,17 @@ def get_descendant_literal_clauses(self): pass @abstractmethod - def is_forbidden(self, instantiated_hyperparameters): + def is_forbidden(self, instantiated_hyperparameters, strict): pass class AbstractForbiddenClause(AbstractForbiddenComponent): - def get_descendant_literal_clauses(self): + def get_descendant_literal_clauses(self) -> List[AbstractForbiddenComponent]: return [self] class SingleValueForbiddenClause(AbstractForbiddenClause): - def __init__(self, hyperparameter, value): + def __init__(self, hyperparameter: Hyperparameter, value: Any) -> None: super(SingleValueForbiddenClause, self).__init__() if not isinstance(hyperparameter, Hyperparameter): raise TypeError("'%s' is not of type %s." % @@ -90,7 +91,7 @@ def __init__(self, hyperparameter, value): "'%s'" % (hyperparameter, str(value))) self.value = value - def is_forbidden(self, instantiated_hyperparameters, strict=True): + def is_forbidden(self, instantiated_hyperparameters: Dict[str, Union[None, str, float, int]], strict: bool=True) -> bool: value = instantiated_hyperparameters.get(self.hyperparameter.name) if value is None: @@ -110,7 +111,7 @@ def _is_forbidden(self, target_instantiated_hyperparameter): class MultipleValueForbiddenClause(AbstractForbiddenClause): - def __init__(self, hyperparameter, values): + def __init__(self, hyperparameter: Hyperparameter, values: Any) -> None: super(MultipleValueForbiddenClause, self).__init__() if not isinstance(hyperparameter, Hyperparameter): raise TypeError("Argument 'hyperparameter' is not of type %s." % @@ -123,7 +124,7 @@ def __init__(self, hyperparameter, values): "'%s'" % (hyperparameter, str(value))) self.values = values - def is_forbidden(self, instantiated_hyperparameters, strict=True): + def is_forbidden(self, instantiated_hyperparameters: Dict[str, Union[None, str, float, int]], strict: bool=True) -> bool: value = instantiated_hyperparameters.get(self.hyperparameter.name) if value is None: @@ -143,31 +144,31 @@ def _is_forbidden(self, target_instantiated_hyperparameter): class ForbiddenEqualsClause(SingleValueForbiddenClause): - def __repr__(self): + def __repr__(self) -> str: return "Forbidden: %s == %s" % (self.hyperparameter.name, repr(self.value)) - def _is_forbidden(self, value): + def _is_forbidden(self, value: Any) -> bool: return value == self.value class ForbiddenInClause(MultipleValueForbiddenClause): - def __init__(self, hyperparameter, values): + def __init__(self, hyperparameter: Dict[str, Union[None, str, float, int]], values: Any) -> None: super(ForbiddenInClause, self).__init__(hyperparameter, values) self.values = set(self.values) - def __repr__(self): + def __repr__(self) -> str: return "Forbidden: %s in %s" % ( self.hyperparameter.name, "{" + ", ".join((repr(value) for value in sorted(self.values))) + "}") - def _is_forbidden(self, value): + def _is_forbidden(self, value: Any) -> bool: return value in self.values class AbstractForbiddenConjunction(AbstractForbiddenComponent): - def __init__(self, *args): + def __init__(self, *args: AbstractForbiddenComponent) -> None: super(AbstractForbiddenConjunction, self).__init__() # Test the classes for idx, component in enumerate(args): @@ -183,7 +184,8 @@ def __init__(self, *args): def __repr__(self): pass - def get_descendant_literal_clauses(self): + # todo:recheck is return type should be AbstractForbiddenComponent or AbstractForbiddenConjunction or Hyperparameter + def get_descendant_literal_clauses(self) -> List[AbstractForbiddenComponent]: children = [] for component in self.components: if isinstance(component, AbstractForbiddenConjunction): @@ -192,7 +194,7 @@ def get_descendant_literal_clauses(self): children.append(component) return children - def is_forbidden(self, instantiated_hyperparameters, strict=True): + def is_forbidden(self, instantiated_hyperparameters: Dict[str, Union[None, str, float, int]], strict: bool=True) -> bool: ihp_names = list(instantiated_hyperparameters.keys()) dlcs = self.get_descendant_literal_clauses() @@ -222,7 +224,7 @@ def _is_forbidden(self, evaluations): class ForbiddenAndConjunction(AbstractForbiddenConjunction): - def __repr__(self): + def __repr__(self) -> str: retval = io.StringIO() retval.write("(") for idx, component in enumerate(self.components): @@ -232,5 +234,5 @@ def __repr__(self): retval.write(")") return retval.getvalue() - def _is_forbidden(self, evaluations): + def _is_forbidden(self, evaluations: List[bool]) -> bool: return reduce(operator.and_, evaluations)