In [1]:
# install evox, skip it if you have already installed evox
try:
    import evox
except ImportError:
    !pip install --disable-pip-version-check --upgrade -q evox
    import evox

In [2]:
from evox import algorithms, problems, workflows, monitors
import jax
import jax.numpy as jnp

In [3]:
algorithm = algorithms.PSO(
    lb=jnp.full(shape=(2,), fill_value=-32),
    ub=jnp.full(shape=(2,), fill_value=32),
    pop_size=100,
)
problem = problems.numerical.Ackley()
monitor = monitors.StdSOMonitor()

In [4]:
# create a workflow
workflow = workflows.StdWorkflow(
    algorithm,
    problem,
    monitor,
    record_pop=True,
)

In [5]:
# init the workflow
key = jax.random.PRNGKey(42)
state = workflow.init(key)

# run the workflow for 100 steps
for i in range(100):
    state = workflow.step(state)

In [6]:
monitor.get_best_fitness()

Array(0., dtype=float32)

In [7]:
monitor.get_best_solution()

Array([-4.0062014e-07,  5.2837186e-07], dtype=float32)