Skip to content

Commit

Permalink
SacessOptimizer: collect worker stats (#1381)
Browse files Browse the repository at this point in the history
So far, most stats from different SacessOptimizer workers have only been available from the logs.
Now they are also available via `SacessOptimizer.worker_results`.
Additionally, the total number of objective evaluations across all workers ís logged.
  • Loading branch information
dweindl committed May 3, 2024
1 parent 4f778cc commit e4b15be
Showing 1 changed file with 55 additions and 5 deletions.
60 changes: 55 additions & 5 deletions pypesto/optimize/ess/sacess.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import multiprocessing
import os
import time
from dataclasses import dataclass
from math import ceil, sqrt
from multiprocessing import get_context
from multiprocessing.managers import SyncManager
Expand Down Expand Up @@ -128,6 +129,7 @@ def __init__(
self.exit_flag = ESSExitFlag.DID_NOT_RUN
self.ess_loglevel = ess_loglevel
self.sacess_loglevel = sacess_loglevel
self.worker_results: list[SacessWorkerResult] = []
logger.setLevel(self.sacess_loglevel)

self._tmpdir = tmpdir
Expand Down Expand Up @@ -249,21 +251,29 @@ def minimize(

# wait for finish
# collect results
histories = [
self.worker_results = [
sacess_manager._result_queue.get()
for _ in range(self.num_workers)
]
self.histories = histories
for p in worker_processes:
p.join()

logging_thread.stop()

self.histories = [
worker_result.history for worker_result in self.worker_results
]

result = self._create_result(problem)

walltime = time.time() - start_time
n_eval_total = sum(
worker_result.n_eval for worker_result in self.worker_results
)
logger.info(
f"{self.__class__.__name__} stopped after {walltime:3g}s with global best "
f"{result.optimize_result[0].fval}."
f"{self.__class__.__name__} stopped after {walltime:3g}s "
f"and {n_eval_total} objective evaluations "
f"with global best {result.optimize_result[0].fval}."
)

return result
Expand Down Expand Up @@ -587,7 +597,15 @@ def run(
)

ess.history.finalize(exitflag=ess.exit_flag.name)
self._manager._result_queue.put(ess.history)
worker_result = SacessWorkerResult(
x=ess.x_best,
fx=ess.fx_best,
history=ess.history,
n_eval=ess.evaluator.n_eval,
n_iter=ess.n_iter,
exit_flag=ess.exit_flag,
)
self._manager._result_queue.put(worker_result)
ess._report_final()

def _setup_ess(self, startpoint_method: StartpointMethod) -> ESSOptimizer:
Expand Down Expand Up @@ -1025,3 +1043,35 @@ def __call__(

def __repr__(self):
return f"{self.__class__.__name__}(fides_options={self._fides_options}, fides_kwargs={self._fides_kwargs})"


@dataclass
class SacessWorkerResult:
"""Container for :class:`SacessWorker` results.
Contains various information about the optimization process of a single
:class:`SacessWorker` instance that is to be sent to
:class:`SacessOptimizer`.
Attributes
----------
x:
Best parameters found.
fx:
Objective value corresponding to ``x``.
n_eval:
Number of objective evaluations performed.
n_iter:
Number of scatter search iterations performed.
history:
History object containing information about the optimization process.
exit_flag:
Exit flag of the optimization process.
"""

x: np.array
fx: float
n_eval: int
n_iter: int
history: "pypesto.history.memory.MemoryHistory"
exit_flag: ESSExitFlag

0 comments on commit e4b15be

Please sign in to comment.