Skip to content

Commit

Permalink
Merge b75bf78 into 8ce4722
Browse files Browse the repository at this point in the history
  • Loading branch information
smohsinali committed Feb 9, 2017
2 parents 8ce4722 + b75bf78 commit 3d21d32
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 98 deletions.
156 changes: 88 additions & 68 deletions ConfigSpace/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from abc import ABCMeta, abstractmethod
from itertools import combinations
from typing import Any, List, Union
import operator

import io
Expand All @@ -40,43 +41,43 @@ class ConditionComponent(object):
__metaclass__ = ABCMeta

@abstractmethod
def __init__(self):
def __init__(self) -> None:
pass

@abstractmethod
def __repr__(self):
def __repr__(self) -> str:
pass

@abstractmethod
def get_children(self):
def get_children(self) -> List['ConditionComponent']:
pass

@abstractmethod
def get_parents(self):
def get_parents(self) -> List['ConditionComponent']:
pass

@abstractmethod
def get_descendant_literal_conditions(self):
def get_descendant_literal_conditions(self) ->List['AbstractCondition']:
pass

@abstractmethod
def evaluate(self, instantiated_parent_hyperparameter):
def evaluate(self, instantiated_parent_hyperparameter: Hyperparameter) -> bool:
pass

# http://stackoverflow.com/a/25176504/4636294
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
"""Override the default Equals behavior"""
if isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
return NotImplemented

def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
"""Define a non-equality test"""
if isinstance(other, self.__class__):
return not self.__eq__(other)
return NotImplemented

def __hash__(self):
def __hash__(self) -> int:
"""Override the default hash behavior (that returns the id or the object)"""
return hash(tuple(sorted(self.__dict__.items())))

Expand All @@ -85,7 +86,7 @@ class AbstractCondition(ConditionComponent):
# TODO create a condition evaluator!

@abstractmethod
def __init__(self, child, parent):
def __init__(self, child: Hyperparameter, parent: Hyperparameter) -> None:
if not isinstance(child, Hyperparameter):
raise ValueError("Argument 'child' is not an instance of "
"HPOlibConfigSpace.hyperparameter.Hyperparameter.")
Expand All @@ -98,34 +99,35 @@ def __init__(self, child, parent):
self.child = child
self.parent = parent

def get_children(self):
def get_children(self) -> List[Hyperparameter]:
return [self.child]

def get_parents(self):
def get_parents(self) -> List[Hyperparameter]:
return [self.parent]

def get_descendant_literal_conditions(self):
def get_descendant_literal_conditions(self) -> List['AbstractCondition']:
return [self]

def evaluate(self, instantiated_parent_hyperparameter):
def evaluate(self, instantiated_parent_hyperparameter: Hyperparameter) -> bool:
hp_name = self.parent.name
return self._evaluate(instantiated_parent_hyperparameter[hp_name])

@abstractmethod
def _evaluate(self, instantiated_parent_hyperparameter):
def _evaluate(self, instantiated_parent_hyperparameter: Union[str, int, float]) -> bool:
pass


class AbstractConjunction(ConditionComponent):
def __init__(self, *args):
def __init__(self, *args: AbstractCondition) -> None:
super(AbstractConjunction, self).__init__()
self.components = args

# Test the classes
for idx, component in enumerate(self.components):
if not isinstance(component, ConditionComponent):
raise TypeError("Argument #%d is not an instance of %s, "
"but %s" % (
idx, ConditionComponent, type(component)))
idx, ConditionComponent, type(component)))

# Test that all conjunctions and conditions have the same child!
children = self.get_children()
Expand All @@ -134,28 +136,28 @@ def __init__(self, *args):
raise ValueError("All Conjunctions and Conditions must have "
"the same child.")

def get_descendant_literal_conditions(self):
children = []
def get_descendant_literal_conditions(self) -> List[AbstractCondition]:
children = [] # type: List[AbstractCondition]
for component in self.components:
if isinstance(component, AbstractConjunction):
children.extend(component.get_descendant_literal_conditions())
else:
children.append(component)
return children

def get_children(self):
children = []
def get_children(self) -> List[ConditionComponent]:
children = [] # type: List[ConditionComponent]
for component in self.components:
children.extend(component.get_children())
return children

def get_parents(self):
parents = []
def get_parents(self) -> List[ConditionComponent]:
parents = [] # type: List[ConditionComponent]
for component in self.components:
parents.extend(component.get_parents())
return parents

def evaluate(self, instantiated_hyperparameters):
def evaluate(self, instantiated_hyperparameters: Hyperparameter) -> bool:
# Then, check if all parents were passed
conditions = self.get_descendant_literal_conditions()
for condition in conditions:
Expand All @@ -175,12 +177,12 @@ def evaluate(self, instantiated_hyperparameters):
return self._evaluate(evaluations)

@abstractmethod
def _evaluate(self, evaluations):
def _evaluate(self, evaluations: List[bool]) -> bool:
pass


class EqualsCondition(AbstractCondition):
def __init__(self, child, parent, value):
def __init__(self, child: Hyperparameter, parent: Hyperparameter, value: Union[str, float, int]) -> None:
super(EqualsCondition, self).__init__(child, parent)
if not parent.is_legal(value):
raise ValueError("Hyperparameter '%s' is "
Expand All @@ -189,16 +191,23 @@ def __init__(self, child, parent, value):
(child.name, value, parent.name))
self.value = value

def __repr__(self):
def __repr__(self) -> str:
return "%s | %s == %s" % (self.child.name, self.parent.name,
repr(self.value))

def _evaluate(self, value):
return value == self.value
def _evaluate(self, value: Union[str, float, int]) -> bool:
if not self.parent.is_legal(value):
return False

cmp = self.parent.compare(value, self.value)
if cmp == 0:
return True
else:
return False


class NotEqualsCondition(AbstractCondition):
def __init__(self, child, parent, value):
def __init__(self, child: Hyperparameter, parent: Hyperparameter, value: Union[str, float, int]) -> None:
super(NotEqualsCondition, self).__init__(child, parent)
if not parent.is_legal(value):
raise ValueError("Hyperparameter '%s' is "
Expand All @@ -207,63 +216,74 @@ def __init__(self, child, parent, value):
(child.name, value, parent.name))
self.value = value

def __repr__(self):
def __repr__(self) -> str:
return "%s | %s != %s" % (self.child.name, self.parent.name,
repr(self.value))

def _evaluate(self, value):
return value != self.value

def _evaluate(self, value: Union[str, float, int]) -> bool:
if not self.parent.is_legal(value):
return False

cmp = self.parent.compare(value, self.value)
if cmp != 0:
return True
else:
return False


class LessThanCondition(AbstractCondition):
def __init__(self, child, parent, value):
if not isinstance(parent, (NumericalHyperparameter,
OrdinalHyperparameter)):
raise ValueError("Parent hyperparameter in a < condition must "
"be a subclass of NumericalHyperparameter or "
"OrdinalHyperparameter, but is %s" % type(parent))
def __init__(self, child: Hyperparameter, parent: Hyperparameter, value: Union[str, float, int]) -> None:
super(LessThanCondition, self).__init__(child, parent)
self.parent.allow_greater_less_comparison()
if not parent.is_legal(value):
raise ValueError("Hyperparameter '%s' is "
"conditional on the illegal value '%s' of "
"its parent hyperparameter '%s'" %
(child.name, value, parent.name))
self.value = value
def __repr__(self):

def __repr__(self) -> str:
return "%s | %s < %s" % (self.child.name, self.parent.name,
repr(self.value))
def _evaluate(self, value):
if value is None:
repr(self.value))

def _evaluate(self, value: Union[str, float, int]) -> bool:
if not self.parent.is_legal(value):
return False

cmp = self.parent.compare(value, self.value)
if cmp == -1:
return True
else:
return value < self.value

return False


class GreaterThanCondition(AbstractCondition):
def __init__(self, child, parent, value):
if not isinstance(parent, (NumericalHyperparameter,
OrdinalHyperparameter)):
raise ValueError("Parent hyperparameter in a > condition must "
"be a subclass of NumericalHyperparameter or "
"OrdinalHyperparameter, but is %s" % type(parent))
def __init__(self, child: Hyperparameter, parent: Hyperparameter, value: Union[str, float, int]) -> None:
super(GreaterThanCondition, self).__init__(child, parent)
self.parent.allow_greater_less_comparison()
if not parent.is_legal(value):
raise ValueError("Hyperparameter '%s' is "
"conditional on the illegal value '%s' of "
"its parent hyperparameter '%s'" %
(child.name, value, parent.name))
self.value = value
def __repr__(self):

def __repr__(self) -> str:
return "%s | %s > %s" % (self.child.name, self.parent.name,
repr(self.value))
def _evaluate(self, value):
if value is None:
repr(self.value))

def _evaluate(self, value: Union[str, float, int]) -> bool:
if not self.parent.is_legal(value):
return False

cmp = self.parent.compare(value, self.value)
if cmp == 1:
return True
else:
return value > self.value
return False

class InCondition(AbstractCondition):
def __init__(self, child, parent, values):
def __init__(self, child: Hyperparameter, parent: Hyperparameter, values: List[Union[str, float, int]]) -> None:
super(InCondition, self).__init__(child, parent)
for value in values:
if not parent.is_legal(value):
Expand All @@ -273,25 +293,25 @@ def __init__(self, child, parent, values):
(child.name, value, parent.name))
self.values = values

def __repr__(self):
def __repr__(self) -> str:
return "%s | %s in {%s}" % (self.child.name, self.parent.name,
", ".join(
[repr(value) for value in self.values]))

def _evaluate(self, value):
def _evaluate(self, value: Union[str, float, int]) -> bool:
return value in self.values


class AndConjunction(AbstractConjunction):
# TODO: test if an AndConjunction results in an illegal state or a
# Tautology! -> SAT solver
def __init__(self, *args):
def __init__(self, *args: AbstractCondition) -> None:
if len(args) < 2:
raise ValueError("AndConjunction must at least have two "
"Conditions.")
super(AndConjunction, self).__init__(*args)

def __repr__(self):
def __repr__(self) -> str:
retval = io.StringIO()
retval.write("(")
for idx, component in enumerate(self.components):
Expand All @@ -301,18 +321,18 @@ def __repr__(self):
retval.write(")")
return retval.getvalue()

def _evaluate(self, evaluations):
def _evaluate(self, evaluations: Any) -> bool:
return reduce(operator.and_, evaluations)


class OrConjunction(AbstractConjunction):
def __init__(self, *args):
def __init__(self, *args: AbstractCondition) -> None:
if len(args) < 2:
raise ValueError("OrConjunction must at least have two "
"Conditions.")
super(OrConjunction, self).__init__(*args)

def __repr__(self):
def __repr__(self) -> str:
retval = io.StringIO()
retval.write("(")
for idx, component in enumerate(self.components):
Expand All @@ -322,5 +342,5 @@ def __repr__(self):
retval.write(")")
return retval.getvalue()

def _evaluate(self, evaluations):
def _evaluate(self, evaluations: Any) -> bool:
return reduce(operator.or_, evaluations)

0 comments on commit 3d21d32

Please sign in to comment.