Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typecheck conditions p3 #17

Merged
merged 16 commits into from Feb 23, 2017
Merged
156 changes: 88 additions & 68 deletions ConfigSpace/conditions.py
Expand Up @@ -28,6 +28,7 @@

from abc import ABCMeta, abstractmethod
from itertools import combinations
from typing import Any, List, Union
import operator

import io
Expand All @@ -40,43 +41,43 @@ class ConditionComponent(object):
__metaclass__ = ABCMeta

@abstractmethod
def __init__(self):
def __init__(self) -> None:
pass

@abstractmethod
def __repr__(self):
def __repr__(self) -> str:
pass

@abstractmethod
def get_children(self):
def get_children(self) -> List['ConditionComponent']:
pass

@abstractmethod
def get_parents(self):
def get_parents(self) -> List['ConditionComponent']:
pass

@abstractmethod
def get_descendant_literal_conditions(self):
def get_descendant_literal_conditions(self) ->List['AbstractCondition']:
pass

@abstractmethod
def evaluate(self, instantiated_parent_hyperparameter):
def evaluate(self, instantiated_parent_hyperparameter: Hyperparameter) -> bool:
pass

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please annotate all abstract methods as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# http://stackoverflow.com/a/25176504/4636294
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
"""Override the default Equals behavior"""
if isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
return NotImplemented

def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
"""Define a non-equality test"""
if isinstance(other, self.__class__):
return not self.__eq__(other)
return NotImplemented

def __hash__(self):
def __hash__(self) -> int:
"""Override the default hash behavior (that returns the id or the object)"""
return hash(tuple(sorted(self.__dict__.items())))

Expand All @@ -85,7 +86,7 @@ class AbstractCondition(ConditionComponent):
# TODO create a condition evaluator!

@abstractmethod
def __init__(self, child, parent):
def __init__(self, child: Hyperparameter, parent: Hyperparameter) -> None:
if not isinstance(child, Hyperparameter):
raise ValueError("Argument 'child' is not an instance of "
"HPOlibConfigSpace.hyperparameter.Hyperparameter.")
Expand All @@ -98,34 +99,35 @@ def __init__(self, child, parent):
self.child = child
self.parent = parent

def get_children(self):
def get_children(self) -> List[Hyperparameter]:
return [self.child]

def get_parents(self):
def get_parents(self) -> List[Hyperparameter]:
return [self.parent]

def get_descendant_literal_conditions(self):
def get_descendant_literal_conditions(self) -> List['AbstractCondition']:
return [self]

def evaluate(self, instantiated_parent_hyperparameter):
def evaluate(self, instantiated_parent_hyperparameter: Hyperparameter) -> bool:
hp_name = self.parent.name
return self._evaluate(instantiated_parent_hyperparameter[hp_name])

@abstractmethod
def _evaluate(self, instantiated_parent_hyperparameter):
def _evaluate(self, instantiated_parent_hyperparameter: Union[str, int, float]) -> bool:
pass


class AbstractConjunction(ConditionComponent):
def __init__(self, *args):
def __init__(self, *args: AbstractCondition) -> None:
super(AbstractConjunction, self).__init__()
self.components = args

# Test the classes
for idx, component in enumerate(self.components):
if not isinstance(component, ConditionComponent):
raise TypeError("Argument #%d is not an instance of %s, "
"but %s" % (
idx, ConditionComponent, type(component)))
idx, ConditionComponent, type(component)))

# Test that all conjunctions and conditions have the same child!
children = self.get_children()
Expand All @@ -134,28 +136,28 @@ def __init__(self, *args):
raise ValueError("All Conjunctions and Conditions must have "
"the same child.")

def get_descendant_literal_conditions(self):
children = []
def get_descendant_literal_conditions(self) -> List[AbstractCondition]:
children = [] # type: List[AbstractCondition]
for component in self.components:
if isinstance(component, AbstractConjunction):
children.extend(component.get_descendant_literal_conditions())
else:
children.append(component)
return children

def get_children(self):
children = []
def get_children(self) -> List[ConditionComponent]:
children = [] # type: List[ConditionComponent]
for component in self.components:
children.extend(component.get_children())
return children

def get_parents(self):
parents = []
def get_parents(self) -> List[ConditionComponent]:
parents = [] # type: List[ConditionComponent]
for component in self.components:
parents.extend(component.get_parents())
return parents

def evaluate(self, instantiated_hyperparameters):
def evaluate(self, instantiated_hyperparameters: Hyperparameter) -> bool:
# Then, check if all parents were passed
conditions = self.get_descendant_literal_conditions()
for condition in conditions:
Expand All @@ -175,12 +177,12 @@ def evaluate(self, instantiated_hyperparameters):
return self._evaluate(evaluations)

@abstractmethod
def _evaluate(self, evaluations):
def _evaluate(self, evaluations: List[bool]) -> bool:
pass


class EqualsCondition(AbstractCondition):
def __init__(self, child, parent, value):
def __init__(self, child: Hyperparameter, parent: Hyperparameter, value: Union[str, float, int]) -> None:
super(EqualsCondition, self).__init__(child, parent)
if not parent.is_legal(value):
raise ValueError("Hyperparameter '%s' is "
Expand All @@ -189,16 +191,23 @@ def __init__(self, child, parent, value):
(child.name, value, parent.name))
self.value = value

def __repr__(self):
def __repr__(self) -> str:
return "%s | %s == %s" % (self.child.name, self.parent.name,
repr(self.value))

def _evaluate(self, value):
return value == self.value
def _evaluate(self, value: Union[str, float, int]) -> bool:
if not self.parent.is_legal(value):
return False

cmp = self.parent.compare(value, self.value)
if cmp == 0:
return True
else:
return False


class NotEqualsCondition(AbstractCondition):
def __init__(self, child, parent, value):
def __init__(self, child: Hyperparameter, parent: Hyperparameter, value: Union[str, float, int]) -> None:
super(NotEqualsCondition, self).__init__(child, parent)
if not parent.is_legal(value):
raise ValueError("Hyperparameter '%s' is "
Expand All @@ -207,63 +216,74 @@ def __init__(self, child, parent, value):
(child.name, value, parent.name))
self.value = value

def __repr__(self):
def __repr__(self) -> str:
return "%s | %s != %s" % (self.child.name, self.parent.name,
repr(self.value))

def _evaluate(self, value):
return value != self.value

def _evaluate(self, value: Union[str, float, int]) -> bool:
if not self.parent.is_legal(value):
return False

cmp = self.parent.compare(value, self.value)
if cmp != 0:
return True
else:
return False


class LessThanCondition(AbstractCondition):
def __init__(self, child, parent, value):
if not isinstance(parent, (NumericalHyperparameter,
OrdinalHyperparameter)):
raise ValueError("Parent hyperparameter in a < condition must "
"be a subclass of NumericalHyperparameter or "
"OrdinalHyperparameter, but is %s" % type(parent))
def __init__(self, child: Hyperparameter, parent: Hyperparameter, value: Union[str, float, int]) -> None:
super(LessThanCondition, self).__init__(child, parent)
self.parent.allow_greater_less_comparison()
if not parent.is_legal(value):
raise ValueError("Hyperparameter '%s' is "
"conditional on the illegal value '%s' of "
"its parent hyperparameter '%s'" %
(child.name, value, parent.name))
self.value = value
def __repr__(self):

def __repr__(self) -> str:
return "%s | %s < %s" % (self.child.name, self.parent.name,
repr(self.value))
def _evaluate(self, value):
if value is None:
repr(self.value))

def _evaluate(self, value: Union[str, float, int]) -> bool:
if not self.parent.is_legal(value):
return False

cmp = self.parent.compare(value, self.value)
if cmp == -1:
return True
else:
return value < self.value

return False


class GreaterThanCondition(AbstractCondition):
def __init__(self, child, parent, value):
if not isinstance(parent, (NumericalHyperparameter,
OrdinalHyperparameter)):
raise ValueError("Parent hyperparameter in a > condition must "
"be a subclass of NumericalHyperparameter or "
"OrdinalHyperparameter, but is %s" % type(parent))
def __init__(self, child: Hyperparameter, parent: Hyperparameter, value: Union[str, float, int]) -> None:
super(GreaterThanCondition, self).__init__(child, parent)
self.parent.allow_greater_less_comparison()
if not parent.is_legal(value):
raise ValueError("Hyperparameter '%s' is "
"conditional on the illegal value '%s' of "
"its parent hyperparameter '%s'" %
(child.name, value, parent.name))
self.value = value
def __repr__(self):

def __repr__(self) -> str:
return "%s | %s > %s" % (self.child.name, self.parent.name,
repr(self.value))
def _evaluate(self, value):
if value is None:
repr(self.value))

def _evaluate(self, value: Union[str, float, int]) -> bool:
if not self.parent.is_legal(value):
return False

cmp = self.parent.compare(value, self.value)
if cmp == 1:
return True
else:
return value > self.value
return False

class InCondition(AbstractCondition):
def __init__(self, child, parent, values):
def __init__(self, child: Hyperparameter, parent: Hyperparameter, values: List[Union[str, float, int]]) -> None:
super(InCondition, self).__init__(child, parent)
for value in values:
if not parent.is_legal(value):
Expand All @@ -273,25 +293,25 @@ def __init__(self, child, parent, values):
(child.name, value, parent.name))
self.values = values

def __repr__(self):
def __repr__(self) -> str:
return "%s | %s in {%s}" % (self.child.name, self.parent.name,
", ".join(
[repr(value) for value in self.values]))

def _evaluate(self, value):
def _evaluate(self, value: Union[str, float, int]) -> bool:
return value in self.values


class AndConjunction(AbstractConjunction):
# TODO: test if an AndConjunction results in an illegal state or a
# Tautology! -> SAT solver
def __init__(self, *args):
def __init__(self, *args: AbstractCondition) -> None:
if len(args) < 2:
raise ValueError("AndConjunction must at least have two "
"Conditions.")
super(AndConjunction, self).__init__(*args)

def __repr__(self):
def __repr__(self) -> str:
retval = io.StringIO()
retval.write("(")
for idx, component in enumerate(self.components):
Expand All @@ -301,18 +321,18 @@ def __repr__(self):
retval.write(")")
return retval.getvalue()

def _evaluate(self, evaluations):
def _evaluate(self, evaluations: Any) -> bool:
return reduce(operator.and_, evaluations)


class OrConjunction(AbstractConjunction):
def __init__(self, *args):
def __init__(self, *args: AbstractCondition) -> None:
if len(args) < 2:
raise ValueError("OrConjunction must at least have two "
"Conditions.")
super(OrConjunction, self).__init__(*args)

def __repr__(self):
def __repr__(self) -> str:
retval = io.StringIO()
retval.write("(")
for idx, component in enumerate(self.components):
Expand All @@ -322,5 +342,5 @@ def __repr__(self):
retval.write(")")
return retval.getvalue()

def _evaluate(self, evaluations):
def _evaluate(self, evaluations: Any) -> bool:
return reduce(operator.or_, evaluations)