Skip to content

Commit

Permalink
revert change
Browse files Browse the repository at this point in the history
  • Loading branch information
Delaunay committed Oct 8, 2021
1 parent 7e8bd78 commit 2f4a626
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 88 deletions.
89 changes: 39 additions & 50 deletions src/orion/client/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from orion.core.worker.trial import Trial, TrialCM
from orion.core.worker.trial_pacemaker import TrialPacemaker
from orion.executor.base import Executor
from orion.ext.extensions import OrionExtensionManager
from orion.plotting.base import PlotAccessor
from orion.storage.base import FailedUpdate
from orion.ext.extensions import OrionExtensionManager

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -772,37 +772,6 @@ def workon(

return sum(trials)

def _optimize_trial(self, fct, trial, trial_arg, kwargs, worker_broken_trials, max_broken, on_error):
kwargs.update(flatten(trial.params))

if trial_arg:
kwargs[trial_arg] = trial

try:
with self.extensions.trial(trial):
results = self.executor.wait(
[self.executor.submit(fct, **unflatten(kwargs))]
)[0]
self.observe(trial, results=results)
except (KeyboardInterrupt, InvalidResult):
raise
except BaseException as e:
if on_error is None or on_error(self, trial, e, worker_broken_trials):
log.error(traceback.format_exc())
worker_broken_trials += 1
else:
log.error(str(e))
log.debug(traceback.format_exc())

if worker_broken_trials >= max_broken:
raise BrokenExperiment(
"Worker has reached broken trials threshold"
)
else:
self.release(trial, status="broken")

return worker_broken_trials

def _optimize(
self, fct, pool_size, max_trials, max_broken, trial_arg, on_error, **kwargs
):
Expand All @@ -812,24 +781,44 @@ def _optimize(
max_trials = min(max_trials, self.max_trials)

while not self.is_done and trials - worker_broken_trials < max_trials:
try:
with self.suggest(pool_size=pool_size) as trial:

worker_broken_trials = self._optimize_trial(
fct,
trial,
trial_arg,
kwargs,
worker_broken_trials,
max_broken,
on_error
)

except CompletedExperiment as e:
log.warning(e)
break

trials += 1
try:
with self.suggest(pool_size=pool_size) as trial:

kwargs.update(flatten(trial.params))

if trial_arg:
kwargs[trial_arg] = trial

try:
with self.extensions.trial(trial):
results = self.executor.wait(
[self.executor.submit(fct, **unflatten(kwargs))]
)[0]
self.observe(trial, results=results)
except (KeyboardInterrupt, InvalidResult):
raise
except BaseException as e:
if on_error is None or on_error(
self, trial, e, worker_broken_trials
):
log.error(traceback.format_exc())
worker_broken_trials += 1
else:
log.error(str(e))
log.debug(traceback.format_exc())

if worker_broken_trials >= max_broken:
raise BrokenExperiment(
"Worker has reached broken trials threshold"
)
else:
self.release(trial, status="broken")

except CompletedExperiment as e:
log.warning(e)
break

trials += 1

return trials

Expand Down
45 changes: 25 additions & 20 deletions src/orion/ext/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class EventDelegate:
if false events are triggered as soon as broadcast is called
if true the events will need to be triggered manually
"""

def __init__(self, name, deferred=False) -> None:
self.handlers = []
self.deferred_calls = []
Expand Down Expand Up @@ -48,7 +49,9 @@ def _execute(self, args, kwargs):
fun(*args, **kwargs)
except Exception as err:
if self.manager:
self.manager.on_extension_error.broadcast(self.name, fun, err, args=(args, kwargs))
self.manager.on_extension_error.broadcast(
self.name, fun, err, args=(args, kwargs)
)

def execute(self):
"""Execute all our deferred handlers if any"""
Expand Down Expand Up @@ -86,41 +89,40 @@ class OrionExtensionManager:

def __init__(self):
self._events = {}
self._get_event('on_extension_error')
self._get_event("on_extension_error")

# -- Trials
self._get_event('new_trial')
self._get_event('on_trial_error')
self._get_event('end_trial')
self._get_event("new_trial")
self._get_event("on_trial_error")
self._get_event("end_trial")

# -- Experiments
self._get_event('start_experiment')
self._get_event('on_experiment_error')
self._get_event('end_experiment')
self._get_event("start_experiment")
self._get_event("on_experiment_error")
self._get_event("end_experiment")

def experiment(self, *args, **kwargs):
"""Initialize a context manager that will call start/error/end events automatically"""
return _DelegateStartEnd(
self.start_experiment,
self.on_experiment_error,
self.end_experiment,
self._get_event('start_experiment'),
self._get_event('on_experiment_error'),
self._get_event('end_experiment'),
*args,
**kwargs
)

def trial(self, *args, **kwargs):
"""Initialize a context manager that will call start/error/end events automatically"""
return _DelegateStartEnd(
self.new_trial,
self.on_trial_error,
self.end_trial,
self._get_event('new_trial'),
self._get_event('on_trial_error'),
self._get_event('end_trial'),
*args,
**kwargs
)

def __getattr__(self, name):
if name in self._events:
return self._get_event(name)
def broadcast(self, name, *args, **kwargs):
return self._get_event(name).broadcast(*args, **kwargs)

def _get_event(self, key):
"""Retrieve or generate a new event delegate"""
Expand Down Expand Up @@ -183,7 +185,9 @@ def on_extension_error(self, name, fun, exception, args):
"""
return

def on_trial_error(self, trial, exception_type, exception_value, exception_traceback):
def on_trial_error(
self, trial, exception_type, exception_value, exception_traceback
):
"""Called when a error occur during the optimization process"""
return

Expand All @@ -195,7 +199,9 @@ def end_trial(self, trial):
"""Called when the trial finished"""
return

def on_experiment_error(self, experiment, exception_type, exception_value, exception_traceback):
def on_experiment_error(
self, experiment, exception_type, exception_value, exception_traceback
):
"""Called when a error occur during the optimization process"""
return

Expand All @@ -206,4 +212,3 @@ def start_experiment(self, experiment):
def end_experiment(self, experiment):
"""Called at the end of the optimization process after the worker exits"""
return

45 changes: 27 additions & 18 deletions tests/unittests/ext/test_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,28 +42,30 @@
"params": [],
}


class OrionExtensionTest:
"""Base orion extension interface you need to implement"""

def __init__(self) -> None:
self.calls = defaultdict(int)

def on_experiment_error(self, *args, **kwargs):
self.calls['on_experiment_error'] += 1
self.calls["on_experiment_error"] += 1

def on_trial_error(self, *args, **kwargs):
self.calls['on_trial_error'] += 1
self.calls["on_trial_error"] += 1

def start_experiment(self, *args, **kwargs):
self.calls['start_experiment'] += 1
self.calls["start_experiment"] += 1

def new_trial(self, *args, **kwargs):
self.calls['new_trial'] += 1
self.calls["new_trial"] += 1

def end_trial(self, *args, **kwargs):
self.calls['end_trial'] += 1
self.calls["end_trial"] += 1

def end_experiment(self, *args, **kwargs):
self.calls['end_experiment'] += 1
self.calls["end_experiment"] += 1


def test_client_extension():
Expand All @@ -88,31 +90,38 @@ def foo(x):
n_broken = len(experiment.fetch_trials_by_status("broken"))
n_reserved = len(experiment.fetch_trials_by_status("reserved"))

assert ext.calls['new_trial'] == n_trials + n_broken - n_reserved, 'all trials should have triggered callbacks'
assert ext.calls['end_trial'] == n_trials + n_broken - n_reserved, 'all trials should have triggered callbacks'
assert ext.calls['on_trial_error'] == n_broken, 'failed trial should be reported '
assert (
ext.calls["new_trial"] == n_trials + n_broken - n_reserved
), "all trials should have triggered callbacks"
assert (
ext.calls["end_trial"] == n_trials + n_broken - n_reserved
), "all trials should have triggered callbacks"
assert (
ext.calls["on_trial_error"] == n_broken
), "failed trial should be reported "

assert ext.calls['start_experiment'] == 1, 'experiment should have started'
assert ext.calls['end_experiment'] == 1, 'experiment should have ended'
assert ext.calls['on_experiment_error'] == 1, 'failed experiment '
assert ext.calls["start_experiment"] == 1, "experiment should have started"
assert ext.calls["end_experiment"] == 1, "experiment should have ended"
assert ext.calls["on_experiment_error"] == 1, "failed experiment "

unregistered_callback = client.extensions.unregister(ext)
assert unregistered_callback == 6, "All ext callbacks got unregistered"


class BadOrionExtensionTest:
"""Base orion extension interface you need to implement"""

def __init__(self) -> None:
self.calls = defaultdict(int)

def on_extension_error(self, name, fun, exception, args):
self.calls['on_extension_error'] += 1
self.calls["on_extension_error"] += 1

def on_experiment_error(self, *args, **kwargs):
self.calls['on_experiment_error'] += 1
self.calls["on_experiment_error"] += 1

def on_trial_error(self, *args, **kwargs):
self.calls['on_trial_error'] += 1
self.calls["on_trial_error"] += 1

def new_trial(self, *args, **kwargs):
raise RuntimeError()
Expand All @@ -132,9 +141,9 @@ def foo(x):
assert client.max_trials == MAX_TRIALS
client.workon(foo, max_trials=MAX_TRIALS, max_broken=MAX_BROKEN)

assert ext.calls['on_trial_error'] == 0, 'Orion worked as expected'
assert ext.calls['on_experiment_error'] == 0, 'Orion worked as expected'
assert ext.calls['on_extension_error'] == 9, 'Extension error got reported'
assert ext.calls["on_trial_error"] == 0, "Orion worked as expected"
assert ext.calls["on_experiment_error"] == 0, "Orion worked as expected"
assert ext.calls["on_extension_error"] == 9, "Extension error got reported"

unregistered_callback = client.extensions.unregister(ext)
assert unregistered_callback == 4, "All ext callbacks got unregistered"

0 comments on commit 2f4a626

Please sign in to comment.