# Meta Optimization

In this notebook, we will use the meta optimization to optimize the hyperparameters of PSO.
We will have an outer loop that will optimize the hyperparameters of PSO, and an inner loop that accepts the hyperparameters and runs the PSO algorithm to optimize the Ackley function, the result (cost on the ackley function) will be used as an indicator of how good the hyperparameters are, and this information will be feed to the outer loop to optimize the hyperparameters.

In [1]:
%load_ext autoreload
%autoreload 2
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
# install evox, skip it if you have already installed evox
try:
    import evox
except ImportError:
    !pip install --disable-pip-version-check --upgrade -q evox
    import evox

In [2]:
from evox import algorithms, problems, workflows, monitors, Algorithm, Problem, State, Workflow, dataclass, pytree_field, Stateful, use_state
import dataclasses
import jax
import jax.numpy as jnp
from jax import vmap

2024-09-24 18:42:15.869911: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


## Setup the inner loop

The inner loop will be the PSO algorithm that will optimize the Ackley function.

In [8]:
from typing import Any, Tuple
from evox.core.module import Any, Tuple
from evox.core.state import State


algorithm = algorithms.PSO(
    lb=jnp.full(shape=(2,), fill_value=-32),
    ub=jnp.full(shape=(2,), fill_value=32),
    pop_size=100,
)
problem = problems.numerical.Ackley()
# We only care about the best solution found, not the full history
monitor = monitors.EvalMonitor(full_fit_history=False)
inner_workflow = workflows.StdWorkflow(algorithm, problem, monitors=[monitor])
inner_workflows = Stateful.stack([inner_workflow for _ in range(10)])

@dataclass
class InnerLoop(Problem):
    inner_workflows: Workflow = pytree_field(stack=True)

    def evaluate(self, state: State, pop: Any):
        """Evaluate the population (a set of hyperparameters)"""

        def _replace(workflow, pop):
            algorithm = workflow.algorithm.replace(**pop)
            workflow = workflow.replace(algorithm=algorithm)
            return workflow

        inner_workflows = vmap(_replace)(self.inner_workflows, pop)
        state = use_state(vmap(workflows.StdWorkflow.step))(inner_workflows, state)
        best_fit, _state = use_state(inner_workflows.monitors[0].get_best_fitness)(state)

        return best_fit, state

In [9]:
inner_loop = InnerLoop(inner_workflows=inner_workflows)
state = inner_loop.init(key=jax.random.PRNGKey(0))

inner_workflows {'static': False, 'stack': True, 'sharding': <ShardingType.REPLICATED: 2>}
algorithm {}
monitors0 {'nested': True, 'static': False, 'stack': False, 'sharding': <ShardingType.REPLICATED: 2>}
problem {}
algorithm {}
monitors0 {'nested': True, 'static': False, 'stack': False, 'sharding': <ShardingType.REPLICATED: 2>}
problem {}


In [13]:
print(inner_loop.inner_workflows.monitors[0]._node_id)
print(inner_loop.inner_workflows._registered_hooks["post_ask"][0]._node_id)

3


AttributeError: 'EvalMonitor' object has no attribute '_node_id'

In [11]:
fit, state = inner_loop.evaluate(state, pop={"w": jnp.array([0.1] * 10), "phi_p": jnp.array([0.1] * 10), "phi_g": jnp.array([0.1] * 10)})

algorithm_node_id 2
monitor 3
algorithm_node_id 2
algorithm_node_id 2


ValueError: EvalMonitor(multi_obj=False, full_fit_history=False, full_sol_history=False, topk=1, fitness_history=[], solution_history=[], opt_direction=1) is not initialized, did you forget to call `init`?

In [6]:
key = jax.random.PRNGKey(0)
keys = jnp.stack([key] * 10)
state = inner_workflow.init(key)

for i in range(10):
    state = inner_workflow.step(state)

TypeError: mul got incompatible shapes for broadcasting: (100, 2), (10, 2).

In [9]:
import dataclasses
from functools import partial

algorithms = vmap(partial(dataclasses.replace, algorithm))(
    w=jnp.array([0.1, 0.2, 0.3, 0.4])
)

workflows = []
for algorithm in algorithms:
    workflow = workflows.StdWorkflow(algorithm, problem, monitors=[monitor])
    workflows.append(workflow)


@dataclass
class ParallelWorkflow(Workflow):
    workflows: list[Workflow] = pytree_field(metadata={"stack": True})

    def setup(self, key: jax.Array) -> State:
        return State(key, workflows=self.workflows)

    def run(self, state: State) -> State:
        for workflow in self.workflows:
            state = workflow.run(state)
        return state


PSO(dim=2, lb=Array([[-32, -32],
       [-32, -32],
       [-32, -32],
       [-32, -32]], dtype=int32, weak_type=True), ub=Array([[32, 32],
       [32, 32],
       [32, 32],
       [32, 32]], dtype=int32, weak_type=True), pop_size=100, w=Array([0.1, 0.2, 0.3, 0.4], dtype=float32), phi_p=Array([2.5, 2.5, 2.5, 2.5], dtype=float32, weak_type=True), phi_g=Array([0.8, 0.8, 0.8, 0.8], dtype=float32, weak_type=True), mean=None, stdev=None, bound_method='clip')

In [4]:
# create a workflow
workflow = workflows.StdWorkflow(
    algorithm,
    problem,
    monitors=[monitor],
)

In [5]:
# init the workflow
key = jax.random.PRNGKey(42)
state = workflow.init(key)

# run the workflow for 100 steps
for i in range(100):
    state = workflow.step(state)

In [6]:
monitor.get_best_fitness()

Array(0., dtype=float32)

In [7]:
monitor.get_best_solution()

Array([-4.0062014e-07,  5.2837186e-07], dtype=float32)

In [5]:
import dataclasses
dataclasses.is_dataclass(inner_workflow)

True

In [6]:
jax.tree.map(lambda x: x.shape, inner_workflow)

AttributeError: 'float' object has no attribute 'shape'