Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions ConfigSpace/forbidden.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@
import io
from functools import reduce

from ConfigSpace.hyperparameters import Hyperparameter

from ConfigSpace.hyperparameters import Hyperparameter
from typing import List, Dict, Any, Union

class AbstractForbiddenComponent(object):
__metaclass__ = ABCMeta

hyperparameter = None # type: Hyperparameter
@abstractmethod
def __init__(self):
pass
Expand All @@ -47,19 +48,19 @@ def __repr__(self):
pass

# http://stackoverflow.com/a/25176504/4636294
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
"""Override the default Equals behavior"""
if isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
return NotImplemented

def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
"""Define a non-equality test"""
if isinstance(other, self.__class__):
return not self.__eq__(other)
return NotImplemented

def __hash__(self):
def __hash__(self) -> int:
"""Override the default hash behavior (that returns the id or the object)"""
return hash(tuple(sorted(self.__dict__.items())))

Expand All @@ -68,17 +69,17 @@ def get_descendant_literal_clauses(self):
pass

@abstractmethod
def is_forbidden(self, instantiated_hyperparameters):
def is_forbidden(self, instantiated_hyperparameters, strict):
pass


class AbstractForbiddenClause(AbstractForbiddenComponent):
def get_descendant_literal_clauses(self):
def get_descendant_literal_clauses(self) -> List[AbstractForbiddenComponent]:
return [self]


class SingleValueForbiddenClause(AbstractForbiddenClause):
def __init__(self, hyperparameter, value):
def __init__(self, hyperparameter: Hyperparameter, value: Any) -> None:
super(SingleValueForbiddenClause, self).__init__()
if not isinstance(hyperparameter, Hyperparameter):
raise TypeError("'%s' is not of type %s." %
Expand All @@ -90,7 +91,7 @@ def __init__(self, hyperparameter, value):
"'%s'" % (hyperparameter, str(value)))
self.value = value

def is_forbidden(self, instantiated_hyperparameters, strict=True):
def is_forbidden(self, instantiated_hyperparameters: Dict[str, Union[None, str, float, int]], strict: bool=True) -> bool:
value = instantiated_hyperparameters.get(self.hyperparameter.name)

if value is None:
Expand All @@ -110,7 +111,7 @@ def _is_forbidden(self, target_instantiated_hyperparameter):


class MultipleValueForbiddenClause(AbstractForbiddenClause):
def __init__(self, hyperparameter, values):
def __init__(self, hyperparameter: Hyperparameter, values: Any) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

super(MultipleValueForbiddenClause, self).__init__()
if not isinstance(hyperparameter, Hyperparameter):
raise TypeError("Argument 'hyperparameter' is not of type %s." %
Expand All @@ -123,7 +124,7 @@ def __init__(self, hyperparameter, values):
"'%s'" % (hyperparameter, str(value)))
self.values = values

def is_forbidden(self, instantiated_hyperparameters, strict=True):
def is_forbidden(self, instantiated_hyperparameters: Dict[str, Union[None, str, float, int]], strict: bool=True) -> bool:
value = instantiated_hyperparameters.get(self.hyperparameter.name)

if value is None:
Expand All @@ -143,31 +144,31 @@ def _is_forbidden(self, target_instantiated_hyperparameter):


class ForbiddenEqualsClause(SingleValueForbiddenClause):
def __repr__(self):
def __repr__(self) -> str:
return "Forbidden: %s == %s" % (self.hyperparameter.name,
repr(self.value))

def _is_forbidden(self, value):
def _is_forbidden(self, value: Any) -> bool:
return value == self.value


class ForbiddenInClause(MultipleValueForbiddenClause):
def __init__(self, hyperparameter, values):
def __init__(self, hyperparameter: Dict[str, Union[None, str, float, int]], values: Any) -> None:
super(ForbiddenInClause, self).__init__(hyperparameter, values)
self.values = set(self.values)

def __repr__(self):
def __repr__(self) -> str:
return "Forbidden: %s in %s" % (
self.hyperparameter.name,
"{" + ", ".join((repr(value)
for value in sorted(self.values))) + "}")

def _is_forbidden(self, value):
def _is_forbidden(self, value: Any) -> bool:
return value in self.values


class AbstractForbiddenConjunction(AbstractForbiddenComponent):
def __init__(self, *args):
def __init__(self, *args: AbstractForbiddenComponent) -> None:
super(AbstractForbiddenConjunction, self).__init__()
# Test the classes
for idx, component in enumerate(args):
Expand All @@ -183,7 +184,8 @@ def __init__(self, *args):
def __repr__(self):
pass

def get_descendant_literal_clauses(self):
# todo:recheck is return type should be AbstractForbiddenComponent or AbstractForbiddenConjunction or Hyperparameter
def get_descendant_literal_clauses(self) -> List[AbstractForbiddenComponent]:
children = []
for component in self.components:
if isinstance(component, AbstractForbiddenConjunction):
Expand All @@ -192,7 +194,7 @@ def get_descendant_literal_clauses(self):
children.append(component)
return children

def is_forbidden(self, instantiated_hyperparameters, strict=True):
def is_forbidden(self, instantiated_hyperparameters: Dict[str, Union[None, str, float, int]], strict: bool=True) -> bool:
ihp_names = list(instantiated_hyperparameters.keys())

dlcs = self.get_descendant_literal_clauses()
Expand Down Expand Up @@ -222,7 +224,7 @@ def _is_forbidden(self, evaluations):


class ForbiddenAndConjunction(AbstractForbiddenConjunction):
def __repr__(self):
def __repr__(self) -> str:
retval = io.StringIO()
retval.write("(")
for idx, component in enumerate(self.components):
Expand All @@ -232,5 +234,5 @@ def __repr__(self):
retval.write(")")
return retval.getvalue()

def _is_forbidden(self, evaluations):
def _is_forbidden(self, evaluations: List[bool]) -> bool:
return reduce(operator.and_, evaluations)