diff --git a/pyproject.toml b/pyproject.toml index 21dd2eee..55a5c01b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -172,6 +172,7 @@ select = [ ignore = [ "T201", # TODO: Remove + "COM812", # Causes issues with ruff formatter "D100", "D104", # Missing docstring in public package "D105", # Missing docstring in magic mthod diff --git a/src/ConfigSpace/__init__.py b/src/ConfigSpace/__init__.py index bce613bd..0dcdb9bf 100644 --- a/src/ConfigSpace/__init__.py +++ b/src/ConfigSpace/__init__.py @@ -85,51 +85,51 @@ ) __all__ = [ - "__authors__", - "__version__", - "Configuration", - "ConfigurationSpace", - "CategoricalHyperparameter", - "UniformFloatHyperparameter", - "UniformIntegerHyperparameter", + "ActiveHyperparameterNotSetError", + "AmbiguousConditionError", + "AndConjunction", + "Beta", "BetaFloatHyperparameter", "BetaIntegerHyperparameter", - "NormalFloatHyperparameter", - "NormalIntegerHyperparameter", + "Categorical", + "CategoricalHyperparameter", + "ChildNotFoundError", + "Configuration", + "ConfigurationSpace", "Constant", - "UnParametrizedHyperparameter", - "OrdinalHyperparameter", - "AndConjunction", - "OrConjunction", + "CyclicDependancyError", + "Distribution", "EqualsCondition", - "NotEqualsCondition", - "InCondition", - "GreaterThanCondition", - "LessThanCondition", + "Float", "ForbiddenAndConjunction", "ForbiddenEqualsClause", - "ForbiddenInClause", - "ForbiddenLessThanRelation", "ForbiddenEqualsRelation", "ForbiddenGreaterThanRelation", - "Beta", - "Categorical", - "Distribution", - "Float", + "ForbiddenInClause", + "ForbiddenLessThanRelation", + "ForbiddenValueError", + "GreaterThanCondition", + "HyperparameterAlreadyExistsError", + "HyperparameterIndexError", + "HyperparameterNotFoundError", + "IllegalValueError", + "InCondition", + "InactiveHyperparameterSetError", "Integer", + "LessThanCondition", "Normal", + "NormalFloatHyperparameter", + "NormalIntegerHyperparameter", + "NotEqualsCondition", + "OrConjunction", + "OrdinalHyperparameter", + "ParentNotFoundError", + "UnParametrizedHyperparameter", "Uniform", + "UniformFloatHyperparameter", + "UniformIntegerHyperparameter", + "__authors__", + "__version__", "distributions", "types", - "ForbiddenValueError", - "IllegalValueError", - "ActiveHyperparameterNotSetError", - "InactiveHyperparameterSetError", - "HyperparameterNotFoundError", - "ChildNotFoundError", - "ParentNotFoundError", - "HyperparameterIndexError", - "AmbiguousConditionError", - "HyperparameterAlreadyExistsError", - "CyclicDependancyError", ] diff --git a/src/ConfigSpace/_condition_tree.py b/src/ConfigSpace/_condition_tree.py index 30fbac9e..b2c2b537 100644 --- a/src/ConfigSpace/_condition_tree.py +++ b/src/ConfigSpace/_condition_tree.py @@ -41,7 +41,7 @@ import numpy as np from more_itertools import unique_everseen -from ConfigSpace.conditions import Condition, Conjunction +from ConfigSpace.conditions import Condition, ConditionLike, Conjunction from ConfigSpace.exceptions import ( AmbiguousConditionError, ChildNotFoundError, @@ -62,7 +62,6 @@ from ConfigSpace.types import f64 if TYPE_CHECKING: - from ConfigSpace.conditions import ConditionLike from ConfigSpace.hyperparameters import Hyperparameter from ConfigSpace.types import Array @@ -558,6 +557,7 @@ def _parents(_f: ForbiddenLike) -> list[str]: conditions = [] for node in self.nodes.values(): if node.parent_condition is not None: + # print(node.name) conditions.append(node.parent_condition) self.conditions = list(unique_everseen(conditions, key=str)) @@ -658,6 +658,109 @@ def add(self, hp: Hyperparameter) -> None: self.nodes[hp.name] = node self.roots[hp.name] = node + def remove(self, value: Hyperparameter) -> None: + """Remove a hyperparameter from the DAG.""" + if not self._updating: + raise RuntimeError( + "Cannot remove hyperparameters outside of transaction." + "Please use `remove` inside `with dag.transaction():`", + ) + + existing = self.nodes.get(value.name, None) + if existing is None: + raise HyperparameterNotFoundError( + f"Hyperparameter '{value.name}' does not exist in space.", + ) + + # Update each condition containing this hyperparameter + def remove_hyperparameter_from_condition( + target: Conjunction | Condition | ForbiddenRelation | ForbiddenClause, + ) -> ( + Conjunction + | Condition + | ForbiddenClause + | ForbiddenRelation + | ForbiddenConjunction + | None + ): + if isinstance(target, ForbiddenRelation) and ( + value in (target.left, target.right) + ): + return None + if isinstance(target, ForbiddenClause) and target.hyperparameter == value: + return None + if isinstance(target, Condition) and ( + value in (target.parent, target.child) + ): + return None + if isinstance(target, (Conjunction, ForbiddenConjunction)): + new_components = [] + for component in target.components: + new_component = remove_hyperparameter_from_condition(component) + if new_component is not None: + new_components.append(new_component) + if len(new_components) >= 2: # Can create a conjunction + return type(target)(*new_components) + if len(new_components) == 1: # Only one component remains + return new_components[0] + return None # No components remain + return target # Nothing to change + + # Update each of the forbiddens containing this hyperparameter + for findex, forbidden in enumerate(self.unconditional_forbiddens): + self.unconditional_forbiddens[findex] = ( + remove_hyperparameter_from_condition(forbidden) + ) + for findex, forbidden in enumerate(self.conditional_forbiddens): + self.conditional_forbiddens[findex] = remove_hyperparameter_from_condition( + forbidden + ) + # Filter None values from the forbiddens + self.unconditional_forbiddens = [ + f for f in self.unconditional_forbiddens if f is not None + ] + self.conditional_forbiddens = [ + f for f in self.conditional_forbiddens if f is not None + ] + + for node in self.nodes.values(): + if node.parent_condition is None: + continue + node.parent_condition = remove_hyperparameter_from_condition( + node.parent_condition + ) + + self.nodes.pop(value.name) + for child, _ in existing.children.values(): + del child.parents[existing.name] + + # Recalculate the depth of the children + def mark_children_recursively(node: HPNode, marked: set[str]): + for child, _ in node.children.values(): + if child.maximum_depth == node.maximum_depth + 1: + marked.add(child.name) + mark_children_recursively(child, marked) + + marked_nodes: set[str] = set() + mark_children_recursively(existing, marked_nodes) + while marked_nodes: # Update the maximum depth of the marked nodes + remove = [] + for node_name in marked_nodes: + node = self.nodes.get(node_name) + if not node.parents: + # print("Parentless node:", node.name) + node.maximum_depth = 0 + remove.append(node_name) + elif all(p.name not in marked_nodes for p, _ in node.parents.values()): + # print("New maximum depth node:", node.name, node.parents) + node.maximum_depth = ( + max(parent.maximum_depth for parent, _ in node.parents.values()) + + 1 + ) + remove.append(node_name) + for node_name in remove: + marked_nodes.remove(node_name) + def add_condition(self, condition: ConditionLike) -> None: """Add a condition to the DAG.""" if not self._updating: @@ -784,6 +887,7 @@ def _minimum_conditions(self) -> list[ConditionNode]: for node in self.nodes.values(): # This node has no parent as is a root if node.parent_condition is None: + # print(self.roots.keys()) assert node.name in self.roots continue diff --git a/src/ConfigSpace/api/__init__.py b/src/ConfigSpace/api/__init__.py index e1ea41b9..473cafe5 100644 --- a/src/ConfigSpace/api/__init__.py +++ b/src/ConfigSpace/api/__init__.py @@ -3,13 +3,13 @@ from ConfigSpace.api.types import Categorical, Float, Integer __all__ = [ - "types", - "distributions", "Beta", - "Distribution", - "Normal", - "Uniform", "Categorical", + "Distribution", "Float", "Integer", + "Normal", + "Uniform", + "distributions", + "types", ] diff --git a/src/ConfigSpace/configuration.py b/src/ConfigSpace/configuration.py index 1b556f7f..ed080b8b 100644 --- a/src/ConfigSpace/configuration.py +++ b/src/ConfigSpace/configuration.py @@ -73,10 +73,10 @@ def __init__( ConfigSpace package. """ if ( - values is not None - and vector is not None - or values is None - and vector is None + (values is not None + and vector is not None) + or (values is None + and vector is None) ): raise ValueError( "Specify Configuration as either a dictionary or a vector.", diff --git a/src/ConfigSpace/configuration_space.py b/src/ConfigSpace/configuration_space.py index c0c0ecc6..e38e65f8 100644 --- a/src/ConfigSpace/configuration_space.py +++ b/src/ConfigSpace/configuration_space.py @@ -350,6 +350,38 @@ def _put_to_list( self._len = len(self._dag.nodes) self._check_default_configuration() + def remove( + self, + *args: Hyperparameter, + ) -> None: + """Remove a hyperparameter from the configuration space. + + If the hyperparameter has children, the children are also removed. + This includes defined conditions and conjunctions! + + !!! note + + If removing multiple hyperparameters, it is better to remove them all + at once with one call to `remove()`, as we rebuilt a cache after each + call to `remove()`. + + Args: + args: Hyperparameter(s) to remove + """ + hps = [] + for arg in args: + if isinstance(arg, Hyperparameter): + hps.append(arg) + else: + raise TypeError(f"Unknown type {type(arg)}") + + with self._dag.update(): + for hp in hps: + self._dag.remove(hp) + + self._len = len(self._dag.nodes) + self._check_default_configuration() + def add_configuration_space( self, prefix: str, @@ -878,7 +910,7 @@ def __iter__(self) -> Iterator[str]: return iter(self._dag.nodes.keys()) def items(self) -> ItemsView[str, Hyperparameter]: - """Return an items view of the hyperparameters, same as `dict.items()`.""" # noqa: D402 + """Return an items view of the hyperparameters, same as `dict.items()`.""" return {name: node.hp for name, node in self._dag.nodes.items()}.items() def __len__(self) -> int: diff --git a/src/ConfigSpace/hyperparameters/__init__.py b/src/ConfigSpace/hyperparameters/__init__.py index 75a67a14..545703fa 100644 --- a/src/ConfigSpace/hyperparameters/__init__.py +++ b/src/ConfigSpace/hyperparameters/__init__.py @@ -24,7 +24,7 @@ "NormalIntegerHyperparameter", "NumericalHyperparameter", "OrdinalHyperparameter", + "UnParametrizedHyperparameter", "UniformFloatHyperparameter", "UniformIntegerHyperparameter", - "UnParametrizedHyperparameter", ] diff --git a/src/ConfigSpace/read_and_write/dictionary.py b/src/ConfigSpace/read_and_write/dictionary.py index 360b874e..22163f52 100644 --- a/src/ConfigSpace/read_and_write/dictionary.py +++ b/src/ConfigSpace/read_and_write/dictionary.py @@ -56,7 +56,7 @@ def _backwards_compat(item: dict[str, Any]) -> dict[str, Any]: ) if (default := item.pop("default", None)) is not None: warnings.warn( - "The field 'default' should be 'default_value' !" f"\nFound in item {item}", + f"The field 'default' should be 'default_value' !\nFound in item {item}", stacklevel=3, ) item["default_value"] = default diff --git a/test/test_configuration_space.py b/test/test_configuration_space.py index 2748942a..c4fb2d8f 100644 --- a/test/test_configuration_space.py +++ b/test/test_configuration_space.py @@ -82,6 +82,110 @@ def test_add(): cs.add(hp) +def test_remove(): + cs = ConfigurationSpace() + hp = UniformIntegerHyperparameter("name", 0, 10) + hp2 = UniformFloatHyperparameter("name2", 0, 10) + hp3 = CategoricalHyperparameter( + "weather", ["dry", "rainy", "snowy"], default_value="dry" + ) + cs.add(hp, hp2, hp3) + cs.remove(hp) + assert len(cs) == 2 + + # Test multi removal + cs.add(hp) + cs.remove(hp, hp2) + assert len(cs) == 1 + + # Test faulty input + with pytest.raises(TypeError): + cs.remove(object()) + + # Non existant HP + with pytest.raises(HyperparameterNotFoundError): + cs.remove(hp) + + cs.add(hp, hp2) + # Test one correct one faulty, nothing should happen + with pytest.raises(TypeError): + cs.remove(hp, object()) + assert len(cs) == 3 + + # Make hp2 a conditional parameter, the condition should also be removed when hp is removed + cond = EqualsCondition(hp, hp2, 1) + cs.add(cond) + cs.remove(hp) + assert len(cs) == 2 + assert cs.conditional_hyperparameters == [] + assert cs.conditions == [] + + # Set up forbidden relation, the relation should also be removed + forb = ForbiddenEqualsClause(hp3, "snowy") + cs.add(forb) + cs.remove(hp3) + assert len(cs) == 1 + assert cs.forbidden_clauses == [] + + # And now for more complicated conditions + cs = ConfigurationSpace() + hp1 = CategoricalHyperparameter("input1", [0, 1]) + cs.add(hp1) + hp2 = CategoricalHyperparameter("input2", [0, 1]) + cs.add(hp2) + hp3 = CategoricalHyperparameter("input3", [0, 1]) + cs.add(hp3) + hp4 = CategoricalHyperparameter("input4", [0, 1]) + cs.add(hp4) + hp5 = CategoricalHyperparameter("input5", [0, 1]) + cs.add(hp5) + hp6 = Constant("constant1", "True") + cs.add(hp6) + + cond1 = EqualsCondition(hp6, hp1, 1) + cond2 = NotEqualsCondition(hp6, hp2, 1) + cond3 = InCondition(hp6, hp3, [1]) + cond4 = EqualsCondition(hp6, hp4, 1) + cond5 = EqualsCondition(hp6, hp5, 1) + + conj1 = AndConjunction(cond1, cond2) + conj2 = OrConjunction(conj1, cond3) + conj3 = AndConjunction(conj2, cond4, cond5) + cs.add(conj3) + + cs.remove(hp3) + assert len(cs) == 5 + # Only one part of the condition should be removed, not the entire condition + assert len(cs.conditional_hyperparameters) == 1 + assert len(cs.conditions) == 1 + # Test the exact value + assert ( + str(cs.conditions[0]) + == "((constant1 | input1 == 1 && constant1 | input2 != 1) && constant1 | input4 == 1 && constant1 | input5 == 1)" + ) + + # Now more complicated forbiddens + cs = ConfigurationSpace() + cs.add([hp1, hp2, hp3, hp4, hp5, hp6]) + cs.add(conj3) + + forb1 = ForbiddenEqualsClause(hp1, 1) + forb2 = ForbiddenAndConjunction(forb1, ForbiddenEqualsClause(hp2, 1)) + forb3 = ForbiddenAndConjunction(forb2, ForbiddenEqualsClause(hp3, 1)) + forb4 = ForbiddenEqualsClause(hp3, 1) + forb5 = ForbiddenEqualsClause(hp4, 1) + cs.add(forb3, forb4, forb5) + + cs.remove(hp3) + assert len(cs) == 5 + assert len(cs.forbidden_clauses) == 2 + assert ( + str(cs.forbidden_clauses[0]) + == "(Forbidden: input1 == 1 && Forbidden: input2 == 1)" + ) + assert str(cs.forbidden_clauses[1]) == "Forbidden: input4 == 1" + + def test_add_non_hyperparameter(): cs = ConfigurationSpace() with pytest.raises(TypeError): @@ -423,7 +527,7 @@ def test_get_conditions(): cs.add(hp1) hp2 = UniformIntegerHyperparameter("child", 0, 10) cs.add(hp2) - assert [] == cs.conditions + assert cs.conditions == [] cond1 = EqualsCondition(hp2, hp1, 0) cs.add(cond1) assert [cond1] == cs.conditions