# Multidevice Algorithm

This guide will show you how to write algorithms that can run on multiple devices (multiple GPUs) in EvoX.

In [2]:
from typing import Optional

import jax
import jax.numpy as jnp

from evox import Algorithm, dataclass, pytree_field, problems, workflows, monitors, use_state
from evox.core.distributed import ShardingType
from evox.utils import *

In this example, we consider the following simple setup:
```
    Node1
      |
 +----+----+
 |    |    |
GPU  GPU  GPU
```
Where we only have one node with multiple GPUs. The communication between the GPUs is done through the PCIe or NVLink.
When running in a distributed setup, we need to make decisions on how to place the data on these GPUs.

In [3]:
# The only changes:
# Add the sharding metadata
@dataclass
class SpecialPSOState:
    population: jax.Array = pytree_field(sharding=ShardingType.SHARED_FIRST_DIM)
    velocity: jax.Array = pytree_field(sharding=ShardingType.SHARED_FIRST_DIM)
    fitness: jax.Array = pytree_field(sharding=ShardingType.SHARED_FIRST_DIM)
    local_best_location: jax.Array = pytree_field(sharding=ShardingType.SHARED_FIRST_DIM)
    local_best_fitness: jax.Array = pytree_field(sharding=ShardingType.SHARED_FIRST_DIM)
    global_best_location: jax.Array
    global_best_fitness: jax.Array
    key: jax.random.PRNGKey


@dataclass
class PSO(Algorithm):
    dim: jax.Array = pytree_field(static=True, init=False)
    lb: jax.Array
    ub: jax.Array
    pop_size: jax.Array = pytree_field(static=True)
    w: jax.Array = pytree_field(default=0.6)
    phi_p: jax.Array = pytree_field(default=2.5)
    phi_g: jax.Array = pytree_field(default=0.8)
    mean: Optional[jax.Array] = pytree_field(default=None)
    stdev: Optional[jax.Array] = pytree_field(default=None)
    bound_method: str = pytree_field(static=True, default="clip")

    def __post_init__(self):
        self.set_frozen_attr("dim", self.lb.shape[0])

    def setup(self, key):
        state_key, init_pop_key, init_v_key = jax.random.split(key, 3)
        if self.mean is not None and self.stdev is not None:
            population = self.stdev * jax.random.normal(
                init_pop_key, shape=(self.pop_size, self.dim)
            )
            population = jnp.clip(population, self.lb, self.ub)
            velocity = self.stdev * jax.random.normal(
                init_v_key, shape=(self.pop_size, self.dim)
            )
        else:
            length = self.ub - self.lb
            population = jax.random.uniform(
                init_pop_key, shape=(self.pop_size, self.dim)
            )
            population = population * length + self.lb
            velocity = jax.random.uniform(init_v_key, shape=(self.pop_size, self.dim))
            velocity = velocity * length * 2 - length

        return SpecialPSOState(
            population=population,
            velocity=velocity,
            fitness=jnp.full((self.pop_size,), jnp.inf),
            local_best_location=population,
            local_best_fitness=jnp.full((self.pop_size,), jnp.inf),
            global_best_location=population[0],
            global_best_fitness=jnp.array([jnp.inf]),
            key=state_key,
        )

    def ask(self, state):
        return state.population, state

    def tell(self, state, fitness):
        key, rg_key, rp_key = jax.random.split(state.key, 3)

        rg = jax.random.uniform(rg_key, shape=(self.pop_size, self.dim))
        rp = jax.random.uniform(rp_key, shape=(self.pop_size, self.dim))

        compare = state.local_best_fitness > fitness
        local_best_location = jnp.where(
            compare[:, jnp.newaxis], state.population, state.local_best_location
        )
        local_best_fitness = jnp.minimum(state.local_best_fitness, fitness)

        global_best_location, global_best_fitness = min_by(
            [state.global_best_location[jnp.newaxis, :], state.population],
            [state.global_best_fitness, fitness],
        )

        global_best_fitness = jnp.atleast_1d(global_best_fitness)

        velocity = (
            self.w * state.velocity
            + self.phi_p * rp * (local_best_location - state.population)
            + self.phi_g * rg * (global_best_location - state.population)
        )
        population = state.population + velocity

        if self.bound_method == "clip":
            population = jnp.clip(population, self.lb, self.ub)
            velocity = jnp.clip(velocity, self.lb, self.ub)
        elif self.bound_method == "reflect":
            lower_bound_violation = population < self.lb
            upper_bound_violation = population > self.ub

            population = jnp.where(
                lower_bound_violation, 2 * self.lb - population, population
            )
            population = jnp.where(
                upper_bound_violation, 2 * self.ub - population, population
            )
            velocity = jnp.where(
                lower_bound_violation | upper_bound_violation, -velocity, velocity
            )
            # enforce the bounds in case the reflected particles are still out of bounds
            population = jnp.clip(population, self.lb, self.ub)
            velocity = jnp.clip(velocity, self.lb, self.ub)

        return state.replace(
            population=population,
            velocity=velocity,
            local_best_location=local_best_location,
            local_best_fitness=local_best_fitness,
            global_best_location=global_best_location,
            global_best_fitness=global_best_fitness,
            key=key,
        )


In [4]:
pso = PSO(
    lb=jnp.full(shape=(2,), fill_value=-32),
    ub=jnp.full(shape=(2,), fill_value=32),
    pop_size=100,
)
ackley = problems.numerical.Ackley()

In [5]:
monitor = monitors.EvalMonitor()
workflow = workflows.StdWorkflow(
    pso,
    ackley,
    monitors=[monitor],
)
key = jax.random.PRNGKey(42)
state = workflow.init(key)
state = workflow.enable_multi_devices(state)

In [6]:
# check if the state is correctly sharded
jax.tree.map(lambda x: x.sharding, state)

State(StdWorkflowState(generation=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec(), memory_kind=device), first_step=True), {'algorithm': State(SpecialPSOState(population=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec('POP',), memory_kind=device), velocity=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec('POP',), memory_kind=device), fitness=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec('POP',), memory_kind=device), local_best_location=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec('POP',), memory_kind=device), local_best_fitness=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec('POP',), memory_kind=device), global_best_location=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec(), memory_kind=device), global_best_fitness=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec(), memory_kind=device), key=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec(), memory_kind=device)), {}),'monitors0': State(EvalMonitorState(first_step=True, latest_sol

In [7]:
# run the workflow for 50 steps
for i in range(50):
    state = workflow.step(state)

In [8]:
best_solution, _state = use_state(monitor.get_best_solution)(state)

In [9]:
print(best_solution)

[ 0.0002041  -0.00019218]
