diff --git a/ConfigSpace/conditions.py b/ConfigSpace/conditions.py index 7c592aeb..dee7bfdd 100644 --- a/ConfigSpace/conditions.py +++ b/ConfigSpace/conditions.py @@ -28,6 +28,7 @@ from abc import ABCMeta, abstractmethod from itertools import combinations +from typing import Any, List, Union import operator import io @@ -40,43 +41,43 @@ class ConditionComponent(object): __metaclass__ = ABCMeta @abstractmethod - def __init__(self): + def __init__(self) -> None: pass @abstractmethod - def __repr__(self): + def __repr__(self) -> str: pass @abstractmethod - def get_children(self): + def get_children(self) -> List['ConditionComponent']: pass @abstractmethod - def get_parents(self): + def get_parents(self) -> List['ConditionComponent']: pass @abstractmethod - def get_descendant_literal_conditions(self): + def get_descendant_literal_conditions(self) ->List['AbstractCondition']: pass @abstractmethod - def evaluate(self, instantiated_parent_hyperparameter): + def evaluate(self, instantiated_parent_hyperparameter: Hyperparameter) -> bool: pass # http://stackoverflow.com/a/25176504/4636294 - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: """Override the default Equals behavior""" if isinstance(other, self.__class__): return self.__dict__ == other.__dict__ return NotImplemented - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: """Define a non-equality test""" if isinstance(other, self.__class__): return not self.__eq__(other) return NotImplemented - def __hash__(self): + def __hash__(self) -> int: """Override the default hash behavior (that returns the id or the object)""" return hash(tuple(sorted(self.__dict__.items()))) @@ -85,7 +86,7 @@ class AbstractCondition(ConditionComponent): # TODO create a condition evaluator! @abstractmethod - def __init__(self, child, parent): + def __init__(self, child: Hyperparameter, parent: Hyperparameter) -> None: if not isinstance(child, Hyperparameter): raise ValueError("Argument 'child' is not an instance of " "HPOlibConfigSpace.hyperparameter.Hyperparameter.") @@ -98,26 +99,27 @@ def __init__(self, child, parent): self.child = child self.parent = parent - def get_children(self): + def get_children(self) -> List[Hyperparameter]: return [self.child] - def get_parents(self): + def get_parents(self) -> List[Hyperparameter]: return [self.parent] - def get_descendant_literal_conditions(self): + def get_descendant_literal_conditions(self) -> List['AbstractCondition']: return [self] - def evaluate(self, instantiated_parent_hyperparameter): + def evaluate(self, instantiated_parent_hyperparameter: Hyperparameter) -> bool: hp_name = self.parent.name return self._evaluate(instantiated_parent_hyperparameter[hp_name]) @abstractmethod - def _evaluate(self, instantiated_parent_hyperparameter): + def _evaluate(self, instantiated_parent_hyperparameter: Union[str, int, float]) -> bool: pass class AbstractConjunction(ConditionComponent): - def __init__(self, *args): + def __init__(self, *args: AbstractCondition) -> None: + super(AbstractConjunction, self).__init__() self.components = args # Test the classes @@ -125,7 +127,7 @@ def __init__(self, *args): if not isinstance(component, ConditionComponent): raise TypeError("Argument #%d is not an instance of %s, " "but %s" % ( - idx, ConditionComponent, type(component))) + idx, ConditionComponent, type(component))) # Test that all conjunctions and conditions have the same child! children = self.get_children() @@ -134,8 +136,8 @@ def __init__(self, *args): raise ValueError("All Conjunctions and Conditions must have " "the same child.") - def get_descendant_literal_conditions(self): - children = [] + def get_descendant_literal_conditions(self) -> List[AbstractCondition]: + children = [] # type: List[AbstractCondition] for component in self.components: if isinstance(component, AbstractConjunction): children.extend(component.get_descendant_literal_conditions()) @@ -143,19 +145,19 @@ def get_descendant_literal_conditions(self): children.append(component) return children - def get_children(self): - children = [] + def get_children(self) -> List[ConditionComponent]: + children = [] # type: List[ConditionComponent] for component in self.components: children.extend(component.get_children()) return children - def get_parents(self): - parents = [] + def get_parents(self) -> List[ConditionComponent]: + parents = [] # type: List[ConditionComponent] for component in self.components: parents.extend(component.get_parents()) return parents - def evaluate(self, instantiated_hyperparameters): + def evaluate(self, instantiated_hyperparameters: Hyperparameter) -> bool: # Then, check if all parents were passed conditions = self.get_descendant_literal_conditions() for condition in conditions: @@ -175,12 +177,12 @@ def evaluate(self, instantiated_hyperparameters): return self._evaluate(evaluations) @abstractmethod - def _evaluate(self, evaluations): + def _evaluate(self, evaluations: List[bool]) -> bool: pass class EqualsCondition(AbstractCondition): - def __init__(self, child, parent, value): + def __init__(self, child: Hyperparameter, parent: Hyperparameter, value: Union[str, float, int]) -> None: super(EqualsCondition, self).__init__(child, parent) if not parent.is_legal(value): raise ValueError("Hyperparameter '%s' is " @@ -189,16 +191,23 @@ def __init__(self, child, parent, value): (child.name, value, parent.name)) self.value = value - def __repr__(self): + def __repr__(self) -> str: return "%s | %s == %s" % (self.child.name, self.parent.name, repr(self.value)) - def _evaluate(self, value): - return value == self.value + def _evaluate(self, value: Union[str, float, int]) -> bool: + if not self.parent.is_legal(value): + return False + + cmp = self.parent.compare(value, self.value) + if cmp == 0: + return True + else: + return False class NotEqualsCondition(AbstractCondition): - def __init__(self, child, parent, value): + def __init__(self, child: Hyperparameter, parent: Hyperparameter, value: Union[str, float, int]) -> None: super(NotEqualsCondition, self).__init__(child, parent) if not parent.is_legal(value): raise ValueError("Hyperparameter '%s' is " @@ -207,63 +216,74 @@ def __init__(self, child, parent, value): (child.name, value, parent.name)) self.value = value - def __repr__(self): + def __repr__(self) -> str: return "%s | %s != %s" % (self.child.name, self.parent.name, repr(self.value)) - def _evaluate(self, value): - return value != self.value - + def _evaluate(self, value: Union[str, float, int]) -> bool: + if not self.parent.is_legal(value): + return False + + cmp = self.parent.compare(value, self.value) + if cmp != 0: + return True + else: + return False + + class LessThanCondition(AbstractCondition): - def __init__(self, child, parent, value): - if not isinstance(parent, (NumericalHyperparameter, - OrdinalHyperparameter)): - raise ValueError("Parent hyperparameter in a < condition must " - "be a subclass of NumericalHyperparameter or " - "OrdinalHyperparameter, but is %s" % type(parent)) + def __init__(self, child: Hyperparameter, parent: Hyperparameter, value: Union[str, float, int]) -> None: super(LessThanCondition, self).__init__(child, parent) + self.parent.allow_greater_less_comparison() if not parent.is_legal(value): raise ValueError("Hyperparameter '%s' is " "conditional on the illegal value '%s' of " "its parent hyperparameter '%s'" % (child.name, value, parent.name)) self.value = value - - def __repr__(self): + + def __repr__(self) -> str: return "%s | %s < %s" % (self.child.name, self.parent.name, - repr(self.value)) - def _evaluate(self, value): - if value is None: + repr(self.value)) + + def _evaluate(self, value: Union[str, float, int]) -> bool: + if not self.parent.is_legal(value): return False + + cmp = self.parent.compare(value, self.value) + if cmp == -1: + return True else: - return value < self.value - + return False + + class GreaterThanCondition(AbstractCondition): - def __init__(self, child, parent, value): - if not isinstance(parent, (NumericalHyperparameter, - OrdinalHyperparameter)): - raise ValueError("Parent hyperparameter in a > condition must " - "be a subclass of NumericalHyperparameter or " - "OrdinalHyperparameter, but is %s" % type(parent)) + def __init__(self, child: Hyperparameter, parent: Hyperparameter, value: Union[str, float, int]) -> None: super(GreaterThanCondition, self).__init__(child, parent) + self.parent.allow_greater_less_comparison() if not parent.is_legal(value): raise ValueError("Hyperparameter '%s' is " "conditional on the illegal value '%s' of " "its parent hyperparameter '%s'" % (child.name, value, parent.name)) self.value = value - - def __repr__(self): + + def __repr__(self) -> str: return "%s | %s > %s" % (self.child.name, self.parent.name, - repr(self.value)) - def _evaluate(self, value): - if value is None: + repr(self.value)) + + def _evaluate(self, value: Union[str, float, int]) -> bool: + if not self.parent.is_legal(value): return False + + cmp = self.parent.compare(value, self.value) + if cmp == 1: + return True else: - return value > self.value + return False class InCondition(AbstractCondition): - def __init__(self, child, parent, values): + def __init__(self, child: Hyperparameter, parent: Hyperparameter, values: List[Union[str, float, int]]) -> None: super(InCondition, self).__init__(child, parent) for value in values: if not parent.is_legal(value): @@ -273,25 +293,25 @@ def __init__(self, child, parent, values): (child.name, value, parent.name)) self.values = values - def __repr__(self): + def __repr__(self) -> str: return "%s | %s in {%s}" % (self.child.name, self.parent.name, ", ".join( [repr(value) for value in self.values])) - def _evaluate(self, value): + def _evaluate(self, value: Union[str, float, int]) -> bool: return value in self.values class AndConjunction(AbstractConjunction): # TODO: test if an AndConjunction results in an illegal state or a # Tautology! -> SAT solver - def __init__(self, *args): + def __init__(self, *args: AbstractCondition) -> None: if len(args) < 2: raise ValueError("AndConjunction must at least have two " "Conditions.") super(AndConjunction, self).__init__(*args) - def __repr__(self): + def __repr__(self) -> str: retval = io.StringIO() retval.write("(") for idx, component in enumerate(self.components): @@ -301,18 +321,18 @@ def __repr__(self): retval.write(")") return retval.getvalue() - def _evaluate(self, evaluations): + def _evaluate(self, evaluations: Any) -> bool: return reduce(operator.and_, evaluations) class OrConjunction(AbstractConjunction): - def __init__(self, *args): + def __init__(self, *args: AbstractCondition) -> None: if len(args) < 2: raise ValueError("OrConjunction must at least have two " "Conditions.") super(OrConjunction, self).__init__(*args) - def __repr__(self): + def __repr__(self) -> str: retval = io.StringIO() retval.write("(") for idx, component in enumerate(self.components): @@ -322,5 +342,5 @@ def __repr__(self): retval.write(")") return retval.getvalue() - def _evaluate(self, evaluations): + def _evaluate(self, evaluations: Any) -> bool: return reduce(operator.or_, evaluations) diff --git a/ConfigSpace/hyperparameters.py b/ConfigSpace/hyperparameters.py index 2a93ddae..715043b1 100644 --- a/ConfigSpace/hyperparameters.py +++ b/ConfigSpace/hyperparameters.py @@ -141,7 +141,7 @@ def has_neighbors(self) -> bool: def get_num_neighbors(self, value=None) -> int: return 0 - def get_neighbors(self, value: Any, rs: Any, number: int, transform: bool = False) -> List: + def get_neighbors(self, value: Any, rs: np.random.RandomState, number: int, transform: bool = False) -> List: return [] @@ -157,9 +157,22 @@ def __init__(self, name: str, default: Any) -> None: def has_neighbors(self) -> bool: return True - def get_num_neighbors(self, value=None) -> np.inf: + def get_num_neighbors(self, value=None) -> float: + return np.inf + def compare(self, value: Union[int, float, str], value2: Union[int, float, str]) -> int: + if value < value2: + return -1 + elif value > value2: + return 1 + elif value == value2: + return 0 + + def allow_greater_less_comparison(self) -> bool: + return True + + class FloatHyperparameter(NumericalHyperparameter): def __init__(self, name: str, default: Union[int, float]) -> None: @@ -308,7 +321,7 @@ def _inverse_transform(self, vector: Union[np.ndarray, None]) -> Union[float, np vector = np.log(vector) return (vector - self._lower) / (self._upper - self._lower) - def get_neighbors(self, value: Any, rs: np.random, number: int = 4, transform: bool = False) -> List[float]: + def get_neighbors(self, value: Any, rs: np.random.RandomState, number: int = 4, transform: bool = False) -> List[float]: neighbors = [] # type: List[float] while len(neighbors) < number: neighbor = rs.normal(value, 0.2) @@ -385,7 +398,7 @@ def to_integer(self) -> 'NormalIntegerHyperparameter': def is_legal(self, value: Union[float]) -> bool: return isinstance(value, float) or isinstance(value, int) - def _sample(self, rs: np.random, size: Union[None, int] = None) -> np.ndarray: + def _sample(self, rs: np.random.RandomState, size: Union[None, int] = None) -> np.ndarray: mu = self.mu sigma = self.sigma return rs.normal(mu, sigma, size=size) @@ -694,6 +707,12 @@ def __repr__(self) -> str: repr_str.seek(0) return repr_str.getvalue() + def compare(self, value: Union[int, float, str], value2: Union[int, float, str]) -> int: + if value == value2: + return 0 + else: + return 1 + def is_legal(self, value: Union[None, str, float, int]) -> bool: if value in self.choices: return True @@ -767,6 +786,13 @@ def get_neighbors(self, value: int, rs: np.random.RandomState, number: Union[int return neighbors + def allow_greater_less_comparison(self) -> bool: + raise ValueError("Parent hyperparameter in a > or < " + "condition must be a subclass of " + "NumericalHyperparameter or " + "OrdinalHyperparameter, but is " + "") + class OrdinalHyperparameter(Hyperparameter): def __init__(self, name: str, sequence: List[Union[float, int, str]], @@ -802,6 +828,14 @@ def __repr__(self) -> str: repr_str.seek(0) return repr_str.getvalue() + def compare(self, value: Union[int, float, str], value2: Union[int, float, str]) -> int: + if self.value_dict[value] < self.value_dict[value2]: + return -1 + elif self.value_dict[value] > self.value_dict[value2]: + return 1 + elif self.value_dict[value] == self.value_dict[value2]: + return 0 + def is_legal(self, value: Union[int, float, str]) -> bool: """ checks if a certain value is represented in the sequence @@ -915,3 +949,6 @@ def get_neighbors(self, value: Union[int, str, float], rs: None, number: int = 2 neighbors.append(neighbor_idx2) return neighbors + + def allow_greater_less_comparison(self) -> bool: + return True diff --git a/test/test_conditions.py b/test/test_conditions.py index 8b75b66e..46ff9b64 100644 --- a/test/test_conditions.py +++ b/test/test_conditions.py @@ -33,10 +33,11 @@ UniformFloatHyperparameter, NormalFloatHyperparameter, \ UniformIntegerHyperparameter, NormalIntegerHyperparameter, \ CategoricalHyperparameter, OrdinalHyperparameter -from ConfigSpace.conditions import EqualsCondition, NotEqualsCondition,\ - InCondition, AndConjunction, OrConjunction, LessThanCondition,\ +from ConfigSpace.conditions import EqualsCondition, NotEqualsCondition, \ + InCondition, AndConjunction, OrConjunction, LessThanCondition, \ GreaterThanCondition + class TestConditions(unittest.TestCase): # TODO: return only copies of the objects! def test_equals_condition(self): @@ -47,15 +48,15 @@ def test_equals_condition(self): # Test invalid conditions: self.assertRaisesRegexp(ValueError, "Argument 'parent' is not an " - "instance of HPOlibConfigSpace.hyperparameter." - "Hyperparameter.", EqualsCondition, hp2, + "instance of HPOlibConfigSpace.hyperparameter." + "Hyperparameter.", EqualsCondition, hp2, "parent", 0) self.assertRaisesRegexp(ValueError, "Argument 'child' is not an " - "instance of HPOlibConfigSpace.hyperparameter." - "Hyperparameter.", EqualsCondition, "child", + "instance of HPOlibConfigSpace.hyperparameter." + "Hyperparameter.", EqualsCondition, "child", hp1, 0) self.assertRaisesRegexp(ValueError, "The child and parent hyperparameter " - "must be different hyperparameters.", + "must be different hyperparameters.", EqualsCondition, hp1, hp1, 0) self.assertEqual(cond, cond_) @@ -71,12 +72,12 @@ def test_equals_condition_illegal_value(self): epsilon = UniformFloatHyperparameter("epsilon", 1e-5, 1e-1, default=1e-4, log=True) loss = CategoricalHyperparameter("loss", - ["hinge", "log", "modified_huber", "squared_hinge", "perceptron"], - default="hinge") + ["hinge", "log", "modified_huber", "squared_hinge", "perceptron"], + default="hinge") self.assertRaisesRegexp(ValueError, "Hyperparameter 'epsilon' is " - "conditional on the illegal value 'huber' of " - "its parent hyperparameter 'loss'", - EqualsCondition, epsilon, loss, "huber") + "conditional on the illegal value 'huber' of " + "its parent hyperparameter 'loss'", + EqualsCondition, epsilon, loss, "huber") def test_not_equals_condition(self): hp1 = CategoricalHyperparameter("parent", [0, 1]) @@ -136,18 +137,28 @@ def test_greater_and_less_condition(self): self.assertFalse(lt.evaluate({hp.name: None})) hp4 = CategoricalHyperparameter("cat", list(range(6))) - self.assertRaisesRegexp(ValueError, "Parent hyperparameter in a > " - "condition must be a subclass of " - "NumericalHyperparameter or " - "OrdinalHyperparameter, but is " - "", - GreaterThanCondition, child, hp4, 1) - self.assertRaisesRegexp(ValueError, "Parent hyperparameter in a < " - "condition must be a subclass of " - "NumericalHyperparameter or " - "OrdinalHyperparameter, but is " - "", - LessThanCondition, child, hp4, 1) + self.assertRaisesRegexp(ValueError, "Parent hyperparameter in a > or < " + "condition must be a subclass of " + "NumericalHyperparameter or " + "OrdinalHyperparameter, but is " + "", + GreaterThanCondition, child, hp4, 1) + self.assertRaisesRegexp(ValueError, "Parent hyperparameter in a > or < " + "condition must be a subclass of " + "NumericalHyperparameter or " + "OrdinalHyperparameter, but is " + "", + LessThanCondition, child, hp4, 1) + + hp5 = OrdinalHyperparameter("ord", ['cold', 'luke warm', 'warm', 'hot']) + + gt = GreaterThanCondition(child, hp5, 'warm') + self.assertTrue(gt.evaluate({hp5.name: 'hot'})) + self.assertFalse(gt.evaluate({hp5.name: 'cold'})) + + lt = LessThanCondition(child, hp5, 'warm') + self.assertTrue(lt.evaluate({hp5.name: 'luke warm'})) + self.assertFalse(lt.evaluate({hp5.name: 'warm'})) def test_in_condition_illegal_value(self): epsilon = UniformFloatHyperparameter("epsilon", 1e-5, 1e-1, @@ -191,7 +202,6 @@ def test_and_conjunction(self): self.assertNotEqual(andconj1, andconj3) self.assertNotEqual(andconj1, "String") - def test_or_conjunction(self): self.assertRaises(TypeError, AndConjunction, "String1", "String2") @@ -285,4 +295,3 @@ def test_get_parents(self): # All conjunctions inherit get_parents from abstractconjunction conjunction = AndConjunction(condition, condition2) self.assertEqual([_1_S_countercond, _1_0_restarts], conjunction.get_parents()) - \ No newline at end of file