diff --git a/ConfigSpace/configuration_space.py b/ConfigSpace/configuration_space.py index a8fd0afd..44bc283a 100644 --- a/ConfigSpace/configuration_space.py +++ b/ConfigSpace/configuration_space.py @@ -30,13 +30,15 @@ import copy import numpy as np +# import six import io -from functools import reduce + import ConfigSpace.nx from ConfigSpace.hyperparameters import Hyperparameter, Constant, FloatHyperparameter from ConfigSpace.conditions import ConditionComponent, \ AbstractCondition, AbstractConjunction, EqualsCondition from ConfigSpace.forbidden import AbstractForbiddenComponent +from typing import Union, List, Any, Dict, Iterable, Set class ConfigurationSpace(object): @@ -48,10 +50,10 @@ class ConfigurationSpace(object): """Represent a configuration space. """ - def __init__(self, seed=None): - self._hyperparameters = OrderedDict() - self._hyperparameter_idx = dict() - self._idx_to_hyperparameter = dict() + def __init__(self, seed: Union[int, None] = None) -> None: + self._hyperparameters = OrderedDict() # type: OrderedDict[str, Hyperparameter] + self._hyperparameter_idx = dict() # type: Dict[str, int] + self._idx_to_hyperparameter = dict() # type: Dict[int, str] # Use dictionaries to make sure that we don't accidently add # additional keys to these mappings (which happened with defaultdict()). @@ -59,23 +61,23 @@ def __init__(self, seed=None): # spaces when _children of one instance contained all possible # hyperparameters as keys and empty dictionaries as values while the # other instance not containing these. - self._children = OrderedDict() - self._parents = OrderedDict() + self._children = OrderedDict() # type: OrderedDict[str, OrderedDict[str, Union[None, AbstractCondition]]] + self._parents = OrderedDict() # type: OrderedDict[str, OrderedDict[str, Union[None, AbstractCondition]]] # changing this to a normal dict will break sampling because there is # no guarantee that the parent of a condition was evaluated before - self._conditionsals = OrderedDict() - self.forbidden_clauses = [] + self._conditionsals = [] # type: List[str] + self.forbidden_clauses = [] # type: List['AbstractForbiddenComponent'] self.random = np.random.RandomState(seed) self._children['__HPOlib_configuration_space_root__'] = OrderedDict() - def generate_all_continuous_from_bounds(self, bounds): - for i,(l,u) in enumerate(bounds): + def generate_all_continuous_from_bounds(self, bounds: List[List[Any]]) -> None: + for i, (l, u) in enumerate(bounds): hp = ConfigSpace.UniformFloatHyperparameter('x%d' % i, l, u) self.add_hyperparameter(hp) - def add_hyperparameter(self, hyperparameter): + def add_hyperparameter(self, hyperparameter: Hyperparameter) -> Hyperparameter: """Add a hyperparameter to the configuration space. Parameters @@ -115,7 +117,7 @@ def add_hyperparameter(self, hyperparameter): return hyperparameter - def add_condition(self, condition): + def add_condition(self, condition: ConditionComponent) -> ConditionComponent: # Check if adding the condition is legal: # * The parent in a condition statement must exist # * The condition must add no cycles @@ -145,7 +147,7 @@ def add_condition(self, condition): raise Exception("This should never happen!") return condition - def _add_edge(self, parent_node, child_node, condition): + def _add_edge(self, parent_node: str, child_node: str, condition: AbstractCondition) -> None: self._check_edge(parent_node, child_node, condition) try: # TODO maybe this has to be done more carefully @@ -162,9 +164,9 @@ def _add_edge(self, parent_node, child_node, condition): self._children[parent_node][child_node] = condition self._parents[child_node][parent_node] = condition self._sort_hyperparameters() - self._conditionsals[child_node] = child_node + self._conditionsals.append(child_node) - def _check_edge(self, parent_node, child_node, condition): + def _check_edge(self, parent_node: str, child_node: str, condition: AbstractCondition) -> None: # check if both nodes are already inserted into the graph if child_node not in self._hyperparameters: raise ValueError("Child hyperparameter '%s' not in configuration " @@ -180,7 +182,7 @@ def _check_edge(self, parent_node, child_node, condition): tmp_dag.add_edge(parent_node, child_node) if not ConfigSpace.nx.is_directed_acyclic_graph(tmp_dag): - cycles = list(ConfigSpace.nx.simple_cycles(tmp_dag)) + cycles = list(ConfigSpace.nx.simple_cycles(tmp_dag)) # type: List[List[str]] for cycle in cycles: cycle.sort() cycles.sort() @@ -195,9 +197,9 @@ def _check_edge(self, parent_node, child_node, condition): "instead!\nAlready inserted: %s\nNew one: " "%s" % (str(other_condition), str(condition))) - def _sort_hyperparameters(self): - levels = OrderedDict() - to_visit = deque() + def _sort_hyperparameters(self) -> None: + levels = OrderedDict() # type: OrderedDict[str, int] + to_visit = deque() # type: ignore for hp_name in self._hyperparameters: to_visit.appendleft(hp_name) @@ -222,7 +224,7 @@ def _sort_hyperparameters(self): else: to_visit.appendleft(current) - by_level = defaultdict(list) + by_level = defaultdict(list) # type: defaultdict[int, List[str]] for hp in levels: level = levels[hp] by_level[level].append(hp) @@ -238,17 +240,17 @@ def _sort_hyperparameters(self): del self._hyperparameters[node] self._hyperparameters[node] = hp - hp = self._conditionsals.get(node) - if hp is not None: - del self._conditionsals[node] - self._conditionsals[node] = hp + # hp = self._conditionsals.get(node) + # if hp is not None: + # del self._conditionsals[node] + # self._conditionsals[node] = hp # Update to reflect sorting for i, hp in enumerate(self._hyperparameters): self._hyperparameter_idx[hp] = i self._idx_to_hyperparameter[i] = hp - def _create_tmp_dag(self): + def _create_tmp_dag(self) -> ConfigSpace.nx.DiGraph: tmp_dag = ConfigSpace.nx.DiGraph() for hp_name in self._hyperparameters: tmp_dag.add_node(hp_name) @@ -268,7 +270,7 @@ def _create_tmp_dag(self): return tmp_dag - def add_forbidden_clause(self, clause): + def add_forbidden_clause(self, clause: AbstractForbiddenComponent) -> AbstractForbiddenComponent: if not isinstance(clause, AbstractForbiddenComponent): raise TypeError("The method add_forbidden_clause must be called " "with an instance of " @@ -285,8 +287,8 @@ def add_forbidden_clause(self, clause): # HPOlibConfigSpace.nx.draw(self._dg, pos, with_labels=True) # plt.savefig('nx_test.png') - def add_configuration_space(self, prefix, configuration_space, - delimiter=":", parent_hyperparameter=None): + def add_configuration_space(self, prefix: str, configuration_space: 'ConfigurationSpace', + delimiter: str=":", parent_hyperparameter: Hyperparameter=None) -> 'ConfigurationSpace': if not isinstance(configuration_space, ConfigurationSpace): raise TypeError("The method add_configuration_space must be " "called with an instance of " @@ -349,10 +351,10 @@ def add_configuration_space(self, prefix, configuration_space, return configuration_space - def get_hyperparameters(self): + def get_hyperparameters(self) -> List[Hyperparameter]: return list(self._hyperparameters.values()) - def get_hyperparameter(self, name): + def get_hyperparameter(self, name: str) -> Hyperparameter: hp = self._hyperparameters.get(name) if hp is None: @@ -361,7 +363,7 @@ def get_hyperparameter(self, name): else: return hp - def get_hyperparameter_by_idx(self, idx): + def get_hyperparameter_by_idx(self, idx: int) -> str: hp = self._idx_to_hyperparameter.get(idx) if hp is None: @@ -370,7 +372,7 @@ def get_hyperparameter_by_idx(self, idx): else: return hp - def get_idx_by_hyperparameter_name(self, name): + def get_idx_by_hyperparameter_name(self, name: str) -> int: idx = self._hyperparameter_idx.get(name) if idx is None: @@ -379,9 +381,9 @@ def get_idx_by_hyperparameter_name(self, name): else: return idx - def get_conditions(self): + def get_conditions(self) -> List[AbstractCondition]: conditions = [] - added_conditions = set() + added_conditions = set() # type: Set[str] # Nodes is a list of nodes for source_node in self.get_hyperparameters(): @@ -396,16 +398,16 @@ def get_conditions(self): return conditions - def get_children_of(self, name): + def get_children_of(self, name: Hyperparameter) -> List[Hyperparameter]: conditions = self.get_child_conditions_of(name) - parents = [] + parents = [] # type: List[Hyperparameter] for condition in conditions: parents.extend(condition.get_children()) return parents - def get_child_conditions_of(self, name): + def get_child_conditions_of(self, name: Union[str, Hyperparameter]) -> List[AbstractCondition]: if isinstance(name, Hyperparameter): - name = name.name + name = name.name # type: ignore # This raises an exception if the hyperparameter does not exist self.get_hyperparameter(name) @@ -422,7 +424,7 @@ def get_child_conditions_of(self, name): if child_name != "__HPOlib_configuration_space_root__"] return conditions - def get_parents_of(self, name): + def get_parents_of(self, name: Union[str, Hyperparameter]) -> List[Hyperparameter]: """Return the parent hyperparameters of a given hyperparameter. Parameters @@ -437,41 +439,40 @@ def get_parents_of(self, name): List with all parent hyperparameters. """ conditions = self.get_parent_conditions_of(name) - parents = [] + parents = [] # type: List[Hyperparameter] for condition in conditions: parents.extend(condition.get_parents()) return parents - def get_parent_conditions_of(self, name): + def get_parent_conditions_of(self, name: Union[str, Hyperparameter]) -> List[AbstractCondition]: if isinstance(name, Hyperparameter): - name = name.name + name = name.name # type: ignore # This raises an exception if the hyperparameter does not exist self.get_hyperparameter(name) return self._get_parent_conditions_of(name) - def _get_parent_conditions_of(self, name): + def _get_parent_conditions_of(self, name: str) -> List[AbstractCondition]: parents = self._parents[name] conditions = [parents[parent_name] for parent_name in parents if parent_name != "__HPOlib_configuration_space_root__"] return conditions - def get_all_unconditional_hyperparameters(self): + def get_all_unconditional_hyperparameters(self) -> List[str]: hyperparameters = [hp_name for hp_name in self._children[ '__HPOlib_configuration_space_root__']] return hyperparameters - def get_all_conditional_hyperparameters(self): + def get_all_conditional_hyperparameters(self) -> List[str]: return self._conditionsals - def get_default_configuration(self): + def get_default_configuration(self) -> 'Configuration': return self._check_default_configuration() - def _check_default_configuration(self): - # Check if adding that hyperparameter leads to an illegal default - # configuration: - instantiated_hyperparameters = {} + def _check_default_configuration(self) -> 'Configuration': + # Check if adding that hyperparameter leads to an illegal default configuration + instantiated_hyperparameters = {} # type: Dict[str, Union[None, int, float, str]] for hp in self.get_hyperparameters(): conditions = self._get_parent_conditions_of(hp.name) active = True @@ -489,10 +490,6 @@ def _check_default_configuration(self): active = False if active == False: - # Condition evaluation must be called with all - # hyperparameters that are in the condition, even if they are - # inactive. For that, an inactive hyperparameter is assigned - # the value None instantiated_hyperparameters[hp.name] = None elif isinstance(hp, Constant): instantiated_hyperparameters[hp.name] = hp.value @@ -505,15 +502,15 @@ def _check_default_configuration(self): # configuration is forbidden! return Configuration(self, instantiated_hyperparameters) - def check_configuration(self, configuration): + def check_configuration(self, configuration: 'Configuration') -> None: if not isinstance(configuration, Configuration): raise TypeError("The method check_configuration must be called " - "with an instance of %s. " + "with an instance of %s. " "Your input was of type %s"% (Configuration, type(configuration))) self._check_configuration(configuration) - def _check_configuration(self, configuration, - allow_inactive_with_values=False): + def _check_configuration(self, configuration: 'Configuration', + allow_inactive_with_values: bool=False) -> None: for hp_name in self._hyperparameters: hyperparameter = self._hyperparameters[hp_name] hp_value = configuration[hp_name] @@ -557,14 +554,14 @@ def _check_configuration(self, configuration, (hp_name, hp_value)) self._check_forbidden(configuration) - def _check_forbidden(self, configuration): + def _check_forbidden(self, configuration: 'Configuration') -> None: for clause in self.forbidden_clauses: if clause.is_forbidden(configuration, strict=False): raise ValueError("%sviolates forbidden clause %s" % ( str(configuration), str(clause))) # http://stackoverflow.com/a/25176504/4636294 - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: """Override the default Equals behavior""" if isinstance(other, self.__class__): this_dict = self.__dict__.copy() @@ -574,17 +571,17 @@ def __eq__(self, other): return this_dict == other_dict return NotImplemented - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: """Define a non-equality test""" if isinstance(other, self.__class__): return not self.__eq__(other) return NotImplemented - def __hash__(self): + def __hash__(self) -> int: """Override the default hash behavior (that returns the id or the object)""" return hash(tuple(sorted(self.__dict__.items()))) - def __repr__(self): + def __repr__(self) -> str: retval = io.StringIO() retval.write("Configuration space object:\n Hyperparameters:\n") @@ -615,14 +612,14 @@ def __repr__(self): retval.seek(0) return retval.getvalue() - def __iter__(self): + def __iter__(self) -> Iterable: """ Allows to iterate over the hyperparameter names in (hopefully?) the right order.""" return iter(self._hyperparameters.keys()) - def sample_configuration(self, size=1): + def sample_configuration(self, size: int=1) -> Union['Configuration', List['Configuration']]: iteration = 0 missing = size - accepted_configurations = [] + accepted_configurations = [] # type: List['Configuration'] num_hyperparameters = len(self._hyperparameters) while len(accepted_configurations) < size: @@ -636,10 +633,10 @@ def sample_configuration(self, size=1): vector[:, i] = hyperparameter._sample(self.random, missing) for i in range(missing): - inactive = set() + inactive = set() # type: Set['str'] visited = set() visited.update(self.get_all_unconditional_hyperparameters()) - to_visit = deque() + to_visit = deque() # type: deque[str] to_visit.extendleft(self.get_all_conditional_hyperparameters()) infiniteloopcounter = 0 while len(to_visit) > 0: @@ -661,8 +658,7 @@ def sample_configuration(self, size=1): to_visit.appendleft(hp_name) break - parents = {parent_name: self._hyperparameters[ - parent_name]._transform(vector[i][ + parents = {parent_name: self._hyperparameters[parent_name]._transform(vector[i][ self._hyperparameter_idx[ parent_name]]) for parent_name in parent_names} @@ -702,15 +698,15 @@ def sample_configuration(self, size=1): else: return accepted_configurations - def seed(self, seed): + def seed(self, seed: int) -> None: self.random = np.random.RandomState(seed) class Configuration(object): - # TODO add a method to eliminate inactive hyperparameters from a - # configuration - def __init__(self, configuration_space, values=None, vector=None, - allow_inactive_with_values=False, origin=None): + # TODO add a method to eliminate inactive hyperparameters from a configuration + def __init__(self, configuration_space: ConfigurationSpace, values: Union[None, Dict[str, Union[str, float, int]]] = None, + vector: Union[None, np.ndarray]=None, allow_inactive_with_values: bool=False, origin: Any=None)\ + -> None: """A single configuration. Parameters @@ -741,7 +737,7 @@ def __init__(self, configuration_space, values=None, vector=None, self._query_values = False self._num_hyperparameters = len(self.configuration_space._hyperparameters) self.origin = origin - self._keys = None + self._keys = None # type: Union[None, List[str]] if values is not None and vector is not None: raise ValueError('Configuration specified both as dictionary and ' @@ -750,7 +746,7 @@ def __init__(self, configuration_space, values=None, vector=None, # Using cs._hyperparameters to iterate makes sure that the # hyperparameters in the configuration are sorted in the same way as # they are sorted in the configuration space - self._values = dict() + self._values = dict() # type: Dict[str, Union[str, float, int]] for key in configuration_space._hyperparameters: value = values.get(key) if value is None: @@ -790,11 +786,11 @@ def __init__(self, configuration_space, values=None, vector=None, raise ValueError('Configuration neither specified as dictionary ' 'or vector.') - def is_valid_configuration(self): + def is_valid_configuration(self) -> None: self.configuration_space._check_configuration( self, allow_inactive_with_values=self.allow_inactive_with_values) - def __getitem__(self, item): + def __getitem__(self, item: str) -> Any: if self._query_values or item in self._values: return self._values.get(item) @@ -809,18 +805,18 @@ def __getitem__(self, item): self._values[item] = value return self._values[item] - def get(self, item, default=None): + def get(self, item: str, default: Union[None, Any]=None) -> Union[None, Any]: try: return self[item] except: return default - def __contains__(self, item): + def __contains__(self, item: str) -> bool: self._populate_values() return item in self._values # http://stackoverflow.com/a/25176504/4636294 - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: """Override the default Equals behavior""" if isinstance(other, self.__class__): self._populate_values() @@ -829,24 +825,24 @@ def __eq__(self, other): self.configuration_space == other.configuration_space return NotImplemented - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: """Define a non-equality test""" if isinstance(other, self.__class__): return not self.__eq__(other) return NotImplemented - def __hash__(self): + def __hash__(self) -> int: """Override the default hash behavior (that returns the id or the object)""" self._populate_values() return hash(self.__repr__()) - def _populate_values(self): + def _populate_values(self) -> None: if self._query_values is False: for hyperparameter in self.configuration_space.get_hyperparameters(): self[hyperparameter.name] self._query_values = True - def __repr__(self): + def __repr__(self) -> str: self._populate_values() representation = io.StringIO() @@ -868,21 +864,20 @@ def __repr__(self): return representation.getvalue() - def __iter__(self): + def __iter__(self) -> Iterable: return iter(self.keys()) - def keys(self): + def keys(self) -> List[str]: # Cache the keys to speed up the process of retrieving the keys if self._keys is None: self._keys = list(self.configuration_space._hyperparameters.keys()) return self._keys - - def get_dictionary(self): + def get_dictionary(self) -> Dict[str, Union[str, float, int]]: self._populate_values() return self._values - def get_array(self): + def get_array(self) -> np.ndarray: """ Returns ------- diff --git a/test/test_configuration_space.py b/test/test_configuration_space.py index 8f2e6b7d..e5c87e2c 100644 --- a/test/test_configuration_space.py +++ b/test/test_configuration_space.py @@ -26,6 +26,7 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from collections import OrderedDict from itertools import product import json import sys @@ -641,10 +642,10 @@ def test_uniformfloat_transform(self): log=True)) for i in range(100): config = cs.sample_configuration() - value = config.get_dictionary() + value = OrderedDict(sorted(config.get_dictionary().items())) string = json.dumps(value) saved_value = json.loads(string) - saved_value = byteify(saved_value) + saved_value = OrderedDict(sorted(byteify(saved_value).items())) self.assertEqual(repr(value), repr(saved_value)) # Next, test whether the truncation also works when initializing the