Skip to content

Commit

Permalink
Merge efa1dbb into 2c33155
Browse files Browse the repository at this point in the history
  • Loading branch information
jlnav committed Oct 9, 2023
2 parents 2c33155 + efa1dbb commit e9792d3
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 23 deletions.
21 changes: 15 additions & 6 deletions docs/function_guides/generator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@ Generator Functions

Generator and :ref:`Simulator functions<funcguides-sim>` have relatively similar interfaces.

Writing a Generator
-------------------

.. code-block:: python
@input_fields(["f"])
@output_data([("x", float)])
def my_generator(Input, persis_info, gen_specs, libE_info):
batch_size = gen_specs["user"]["batch_size"]
Expand All @@ -22,19 +27,23 @@ Most ``gen_f`` function definitions written by users resemble::

where:

* ``Input`` is a selection of the :ref:`History array<funcguides-history>`
* :ref:`persis_info<datastruct-persis-info>` is a dictionary containing state information
* :ref:`gen_specs<datastruct-gen-specs>` is a dictionary of generator parameters, including which fields from the History array got sent
* ``libE_info`` is a dictionary containing libEnsemble-specific entries
* ``Input`` is a selection of the :ref:`History array<funcguides-history>`, a NumPy array.
* :ref:`persis_info<datastruct-persis-info>` is a dictionary containing state information.
* :ref:`gen_specs<datastruct-gen-specs>` is a dictionary of generator parameters.
* ``libE_info`` is a dictionary containing miscellaneous entries.

*Optional* ``input_fields`` and ``output_data`` decorators for the function describe the
fields to pass in and the output data format. Otherwise those fields
need to be specified in :class:`GenSpecs<libensemble.specs.GenSpecs>`.

Valid generator functions can accept a subset of the above parameters. So a very simple generator can start::

def my_generator(Input):

If gen_specs was initially defined::
If ``gen_specs`` was initially defined::

gen_specs = {
"gen_f": some_function,
"gen_f": my_generator,
"in": ["f"],
"out:" ["x", float, (1,)],
"user": {
Expand Down
4 changes: 4 additions & 0 deletions libensemble/sim_funcs/one_d_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

import numpy as np

from libensemble.specs import input_fields, output_data


@input_fields(["x"])
@output_data([("f", float)])
def one_d_example(x, persis_info, sim_specs, _):
"""
Evaluates the six hump camel function for a single point ``x``.
Expand Down
3 changes: 3 additions & 0 deletions libensemble/sim_funcs/six_hump_camel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
import numpy as np

from libensemble.message_numbers import EVAL_SIM_TAG, FINISHED_PERSISTENT_SIM_TAG, PERSIS_STOP, STOP_TAG
from libensemble.specs import input_fields, output_data
from libensemble.tools.persistent_support import PersistentSupport


@input_fields(["x"])
@output_data([("f", float)])
def six_hump_camel(H, persis_info, sim_specs, libE_info):
"""
Evaluates the six hump camel function for a collection of points given in ``H["x"]``.
Expand Down
128 changes: 121 additions & 7 deletions libensemble/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
from pydantic import BaseConfig, BaseModel, Field, root_validator, validator

from libensemble.alloc_funcs.give_sim_work_first import give_sim_work_first
from libensemble.gen_funcs.sampling import latin_hypercube_sample
from libensemble.resources.platforms import Platform
from libensemble.sim_funcs.one_d_func import one_d_example
from libensemble.utils.specs_checkers import (
_check_any_workers_and_disable_rm_if_tcp,
_check_exit_criteria,
Expand Down Expand Up @@ -39,13 +37,13 @@ class SimSpecs(BaseModel):
Specifications for configuring a Simulation Function.
"""

sim_f: Callable = one_d_example
sim_f: Callable = None
"""
Python function matching the ``sim_f`` interface. Evaluates parameters
produced by a generator function.
"""

inputs: List[str] = Field([], alias="in")
inputs: Optional[List[str]] = Field([], alias="in")
"""
List of **field names** out of the complete history to pass
into the simulation function upon calling.
Expand All @@ -58,7 +56,7 @@ class SimSpecs(BaseModel):
"""

# list of tuples for dtype construction
outputs: List[Union[Tuple[str, Any], Tuple[str, Any, Union[int, Tuple]]]] = Field([], alias="out")
outputs: Optional[List[Union[Tuple[str, Any], Tuple[str, Any, Union[int, Tuple]]]]] = Field([], alias="out")
"""
List of 2- or 3-tuples corresponding to NumPy dtypes.
e.g. ``("dim", int, (3,))``, or ``("path", str)``.
Expand Down Expand Up @@ -95,13 +93,27 @@ def check_valid_in(cls, v):
raise ValueError(_IN_INVALID_ERR)
return v

@root_validator
def set_in_out_from_attrs(cls, values):
if not values.get("sim_f"):
from libensemble.sim_funcs.one_d_func import one_d_example

values["sim_f"] = one_d_example
if hasattr(values.get("sim_f"), "inputs") and not values.get("inputs"):
values["inputs"] = values.get("sim_f").inputs
if hasattr(values.get("sim_f"), "outputs") and not values.get("outputs"):
values["out"] = values.get("sim_f").outputs
if hasattr(values.get("sim_f"), "persis_in") and not values.get("persis_in"):
values["persis_in"] = values.get("sim_f").persis_in
return values


class GenSpecs(BaseModel):
"""
Specifications for configuring a Generator Function.
"""

gen_f: Optional[Callable] = latin_hypercube_sample
gen_f: Optional[Callable] = None
"""
Python function matching the ``gen_f`` interface. Produces parameters for evaluation by a
simulator function, and makes decisions based on simulator function output.
Expand All @@ -119,7 +131,7 @@ class GenSpecs(BaseModel):
throughout the run, following initialization.
"""

outputs: List[Union[Tuple[str, Any], Tuple[str, Any, Union[int, Tuple]]]] = Field([], alias="out")
outputs: Optional[List[Union[Tuple[str, Any], Tuple[str, Any, Union[int, Tuple]]]]] = Field([], alias="out")
"""
List of 2- or 3-tuples corresponding to NumPy dtypes.
e.g. ``("dim", int, (3,))``, or ``("path", str)``. Typically used to initialize an
Expand Down Expand Up @@ -155,6 +167,20 @@ def check_valid_in(cls, v):
raise ValueError(_IN_INVALID_ERR)
return v

@root_validator
def set_in_out_from_attrs(cls, values):
if not values.get("gen_f"):
from libensemble.gen_funcs.sampling import latin_hypercube_sample

values["gen_f"] = latin_hypercube_sample
if hasattr(values.get("gen_f"), "inputs") and not values.get("inputs"):
values["inputs"] = values.get("gen_f").inputs
if hasattr(values.get("gen_f"), "outputs") and not values.get("outputs"):
values["out"] = values.get("gen_f").outputs
if hasattr(values.get("gen_f"), "persis_in") and not values.get("persis_in"):
values["persis_in"] = values.get("gen_f").persis_in
return values


class AllocSpecs(BaseModel):
"""
Expand Down Expand Up @@ -579,3 +605,91 @@ def check_H0(cls, values):
if values.get("H0") is not None:
return _check_H0(values)
return values


def input_fields(fields: List[str]):
"""Decorates a user-function with a list of field names to pass in on initialization.
Decorated functions don't need those fields specified in ``SimSpecs.inputs`` or ``GenSpecs.inputs``.
.. code-block:: python
from libensemble.specs import input_fields, output_data
@input_fields(["x"])
@output_data([("f", float)])
def one_d_example(x, persis_info, sim_specs):
H_o = np.zeros(1, dtype=sim_specs["out"])
H_o["f"] = np.linalg.norm(x)
return H_o, persis_info
"""

def decorator(func):
setattr(func, "inputs", fields)
func.__doc__ = f"\n **Input Fields:** ``{func.inputs}``\n" + func.__doc__
return func

return decorator


def persistent_input_fields(fields: List[str]):
"""Decorates a *persistent* user-function with a list of field names to send in throughout runtime.
Decorated functions don't need those fields specified in ``SimSpecs.persis_in`` or ``GenSpecs.persis_in``.
.. code-block:: python
from libensemble.specs import persistent_input_fields, output_data
@persistent_input_fields(["f"])
@output_data(["x", float])
def persistent_uniform(_, persis_info, gen_specs, libE_info):
b, n, lb, ub = _get_user_params(gen_specs["user"])
ps = PersistentSupport(libE_info, EVAL_GEN_TAG)
tag = None
while tag not in [STOP_TAG, PERSIS_STOP]:
H_o = np.zeros(b, dtype=gen_specs["out"])
H_o["x"] = persis_info["rand_stream"].uniform(lb, ub, (b, n))
tag, Work, calc_in = ps.send_recv(H_o)
if hasattr(calc_in, "__len__"):
b = len(calc_in)
return H_o, persis_info, FINISHED_PERSISTENT_GEN_TAG
"""

def decorator(func):
setattr(func, "persis_in", fields)
func.__doc__ = f"\n **Persistent Input Fields:** ``{func.persis_inputs}``\n" + func.__doc__
return func

return decorator


def output_data(fields: List[Union[Tuple[str, Any], Tuple[str, Any, Union[int, Tuple]]]]):
"""Decorates a user-function with a list of tuples corresponding to NumPy dtypes for the function's output data.
Decorated functions don't need those fields specified in ``SimSpecs.outputs`` or ``GenSpecs.outputs``.
.. code-block:: python
from libensemble.specs import input_fields, output_data
@input_fields(["x"])
@output_data([("f", float)])
def one_d_example(x, persis_info, sim_specs):
H_o = np.zeros(1, dtype=sim_specs["out"])
H_o["f"] = np.linalg.norm(x)
return H_o, persis_info
"""

def decorator(func):
setattr(func, "outputs", fields)
func.__doc__ = f"\n **Output Datatypes:** ``{func.outputs}``\n" + func.__doc__
return func

return decorator
12 changes: 2 additions & 10 deletions libensemble/tests/regression_tests/test_1d_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,8 @@
# Main block is necessary only when using local comms with spawn start method (default on macOS and Windows).
if __name__ == "__main__":
sampling = Ensemble(parse_args=True)
sampling.libE_specs = LibeSpecs(
save_every_k_gens=300,
safe_mode=False,
disable_log_files=True,
)
sampling.sim_specs = SimSpecs(
sim_f=sim_f,
inputs=["x"],
outputs=[("f", float)],
)
sampling.libE_specs = LibeSpecs(save_every_k_gens=300, safe_mode=False, disable_log_files=True)
sampling.sim_specs = SimSpecs(sim_f=sim_f)
sampling.gen_specs = GenSpecs(
gen_f=gen_f,
outputs=[("x", float, (1,))],
Expand Down

0 comments on commit e9792d3

Please sign in to comment.