Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add assert_contains to print a more precise error message #1013

Merged
merged 16 commits into from
Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 45 additions & 12 deletions src/orion/algo/space/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import copy
import logging
import numbers
from dataclasses import dataclass, field
from distutils.log import error
from functools import singledispatch
from typing import Any, Generic, TypeVar

Expand Down Expand Up @@ -1095,6 +1097,44 @@ def __setitem__(self, key, value):
)
super().__setitem__(key, value)

def assert_contains(self, trial):
"""Same as __contains__ but instead of return true or false it will raise an exception
with the exact causes of the mismatch.

Raises
------
ValueError if the trial has parameters that are not contained by the space.

"""
if isinstance(trial, str):
if not super().__contains__(trial):
raise ValueError("{trial} does not belong to the dimension")
return

flattened_params = flatten(trial.params)
keys = set(flattened_params.keys())
errors = []

for dim_name, dim in self.items():
if dim_name not in keys:
errors.append(f"{dim_name} is missing")
continue

value = flattened_params[dim_name]
if value not in dim:
errors.append(f"{value} does not belong to the dimension {dim}")

keys.remove(dim_name)

if len(errors) > 0:
raise ValueError(f"Trial {trial.id} is not contained in space:\n{errors}")

if len(keys) != 0:
errors = "\n - ".join(keys)
raise ValueError(f"Trial {trial.id} has additional parameters:\n{errors}")

return True

def __contains__(self, key_or_trial):
Delaunay marked this conversation as resolved.
Show resolved Hide resolved
"""Check whether `trial` is within the bounds of the space.
Or check if a name for a dimension is registered in this space.
Expand All @@ -1105,19 +1145,12 @@ def __contains__(self, key_or_trial):
If str, test if the string is a dimension part of the search space.
If a Trial, test if trial's hyperparameters fit the current search space.
"""
if isinstance(key_or_trial, str):
return super().__contains__(key_or_trial)

trial = key_or_trial
flattened_params = flatten(trial.params)
keys = set(flattened_params.keys())
for dim_name, dim in self.items():
if dim_name not in keys or flattened_params[dim_name] not in dim:
return False

keys.remove(dim_name)

return len(keys) == 0
try:
self.assert_contains(key_or_trial)
return True
except ValueError:
return False

def __repr__(self):
"""Represent as a string the space and the dimensions it contains."""
Expand Down
6 changes: 1 addition & 5 deletions src/orion/core/worker/algo_wrappers/space_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,4 @@ def reverse_transform(self, trial: Trial) -> Trial:

def _verify_trial(self, trial: Trial, space: Space | None = None) -> None:
space = space or self.space
if trial not in space:
raise ValueError(
f"Trial {trial.id} not contained in space:"
f"\nParams: {trial.params}\nSpace: {space}"
)
space.assert_contains(trial)
24 changes: 19 additions & 5 deletions src/orion/core/worker/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def __init__(self, categories):
self.categories = categories
map_dict = {cat: i for i, cat in enumerate(categories)}
self._map = numpy.vectorize(lambda x: map_dict[x], otypes="i")
self._imap = numpy.vectorize(lambda x: categories[x], otypes=[numpy.object])
self._imap = numpy.vectorize(lambda x: categories[x], otypes=[object])

def __deepcopy__(self, memo):
"""Make a deepcopy"""
Expand Down Expand Up @@ -866,6 +866,19 @@ def sample(self, n_samples=1, seed=None):
trials = self.original.sample(n_samples=n_samples, seed=seed)
return [self.reshape(trial) for trial in trials]

def assert_contains(self, trial):
"""Check if the trial or key is contained inside the space, if not an exception is raised

Raises
------
TypeError when a dimension is not compatible with the space

"""
if isinstance(trial, str):
super().assert_contains(trial)

return self.original.assert_contains(self.restore_shape(trial))

def __contains__(self, key_or_trial):
"""Check whether `trial` is within the bounds of the space.
Or check if a name for a dimension is registered in this space.
Expand All @@ -877,10 +890,11 @@ def __contains__(self, key_or_trial):
If a Trial, test if trial's hyperparameters fit the current search space.

"""
if isinstance(key_or_trial, str):
return super().__contains__(key_or_trial)

return self.restore_shape(key_or_trial) in self.original
try:
self.assert_contains(key_or_trial)
return True
except ValueError:
return False

@property
def cardinality(self):
Expand Down
8 changes: 4 additions & 4 deletions tests/unittests/algo/test_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def test_cast_list_multidim(self):
categories[0] = "asdfa"
categories[2] = "lalala"
dim = Categorical("yolo", categories, shape=2)
sample = ["asdfa", "1"] # np.array(['asdfa', '1'], dtype=np.object)
sample = ["asdfa", "1"] # np.array(['asdfa', '1'], dtype=object)
assert dim.cast(sample) == ["asdfa", 1]

def test_cast_array_multidim(self):
Expand All @@ -633,14 +633,14 @@ def test_cast_array_multidim(self):
categories[0] = "asdfa"
categories[2] = "lalala"
dim = Categorical("yolo", categories, shape=2)
sample = np.array(["asdfa", "1"], dtype=np.object)
assert np.all(dim.cast(sample) == np.array(["asdfa", 1], dtype=np.object))
sample = np.array(["asdfa", "1"], dtype=object)
assert np.all(dim.cast(sample) == np.array(["asdfa", 1], dtype=object))

def test_cast_bad_category(self):
"""Make sure array are cast to int and returned as array of values"""
categories = list(range(10))
dim = Categorical("yolo", categories, shape=2)
sample = np.array(["asdfa", "1"], dtype=np.object)
sample = np.array(["asdfa", "1"], dtype=object)
with pytest.raises(ValueError) as exc:
dim.cast(sample)
assert "Invalid category: asdfa" in str(exc.value)
Expand Down
13 changes: 11 additions & 2 deletions tests/unittests/core/worker/algo_wrappers/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,17 @@ def test_verify_trial(self, algo_wrapper: SpaceTransform[DumbAlgo], space: Space

assert algo_wrapper.space is space

with pytest.raises(ValueError, match="not contained in space:"):
with pytest.raises(ValueError, match="yolo is missing"):
invalid_trial = Trial(
params=[
dict(name="yolo2", value=0, type="real"),
dict(name="yolo3", value=3.5, type="real"),
],
status="new",
)
algo_wrapper._verify_trial(invalid_trial)

with pytest.raises(ValueError, match="does not belong to the dimension"):
invalid_trial = format_trials.tuple_to_trial((("asdfa", 2), 10, 3.5), space)
algo_wrapper._verify_trial(invalid_trial)

Expand All @@ -59,7 +69,6 @@ def test_verify_trial(self, algo_wrapper: SpaceTransform[DumbAlgo], space: Space

# transform point
ttrial = tspace.transform(trial)
# TODO: https://github.com/Epistimio/orion/issues/804
assert ttrial in tspace

# Transformed point is not in original space
Expand Down