# Multidevice Algorithm

This guide will show you how to write algorithms that can run on multiple devices (multiple GPUs) in EvoX.

In [None]:
import jax
import jax.numpy as jnp

from evox import dataclass, pytree_field, problems, workflows, monitors, algorithms, use_state
from evox.core.distributed import ShardingType
from evox.utils import *

In this example, we consider the following simple setup:
```
    Node1
      |
 +----+----+
 |         |
GPU       GPU
```
Where we only have one node with multiple GPUs. The communication between the GPUs is done through the PCIe or NVLink.
When running in a distributed setup, we need to make decisions on how to place the data on these GPUs.

Here, we use the vanilla PSO algorithm as an example. In PSO, each GPU can independently update the local information for its particles. On the other hand, updating the global information requires communication between GPUs, but this process can be handled rather efficiently using an all-reduce operation.

Here is an illustration of population in PSO, it has two dimensions: the number of particles (population size) and the problem dimension.
```
  Problem Dimension
+-------------------+
|                   |
|  GPU 0            |  Population Size
|  All particles    |
|                   |
+-------------------+
```
After sharding it across the population dimension, we have the following:
```
  Problem Dimension
+-------------------+
|                   |
|  GPU 0            |  Population Size / 2
+-------------------+
|                   |
|  GPU 1            |  Population Size / 2
+-------------------+
```
Similarly, we can also shard the velocity variable.
This will reduce the memory usage on each GPU by half.

In [3]:
# The only change:
# Add the sharding metadata
@dataclass
class SpecialPSOState:
    population: jax.Array = pytree_field(sharding=ShardingType.SHARED_FIRST_DIM)
    velocity: jax.Array = pytree_field(sharding=ShardingType.SHARED_FIRST_DIM)
    fitness: jax.Array = pytree_field(sharding=ShardingType.SHARED_FIRST_DIM)
    local_best_location: jax.Array = pytree_field(sharding=ShardingType.SHARED_FIRST_DIM)
    local_best_fitness: jax.Array = pytree_field(sharding=ShardingType.SHARED_FIRST_DIM)
    global_best_location: jax.Array
    global_best_fitness: jax.Array
    key: jax.random.PRNGKey


# inherit from the base PSO algorithm
# and replace the State type with SpecialPSOState, which contains the sharding metadata
@dataclass
class PSO(algorithms.PSO):
    def setup(self, key):
        state_key, init_pop_key, init_v_key = jax.random.split(key, 3)
        length = self.ub - self.lb
        population = jax.random.uniform(
            init_pop_key, shape=(self.pop_size, self.dim)
        )
        population = population * length + self.lb
        velocity = jax.random.uniform(init_v_key, shape=(self.pop_size, self.dim))
        velocity = velocity * length * 2 - length

        return SpecialPSOState(
            population=population,
            velocity=velocity,
            fitness=jnp.full((self.pop_size,), jnp.inf),
            local_best_location=population,
            local_best_fitness=jnp.full((self.pop_size,), jnp.inf),
            global_best_location=population[0],
            global_best_fitness=jnp.array([jnp.inf]),
            key=state_key,
        )

In [4]:
pso = PSO(
    lb=jnp.full(shape=(2,), fill_value=-32),
    ub=jnp.full(shape=(2,), fill_value=32),
    pop_size=100,
)
ackley = problems.numerical.Ackley()

In [5]:
monitor = monitors.EvalMonitor()
workflow = workflows.StdWorkflow(
    pso,
    ackley,
    monitors=[monitor],
)
key = jax.random.PRNGKey(42)
state = workflow.init(key)
state = workflow.enable_multi_devices(state)

In [6]:
# check if the state is correctly sharded
jax.tree.map(lambda x: x.sharding, state)

State(StdWorkflowState(generation=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec(), memory_kind=device), first_step=True), {'algorithm': State(SpecialPSOState(population=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec('POP',), memory_kind=device), velocity=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec('POP',), memory_kind=device), fitness=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec('POP',), memory_kind=device), local_best_location=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec('POP',), memory_kind=device), local_best_fitness=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec('POP',), memory_kind=device), global_best_location=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec(), memory_kind=device), global_best_fitness=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec(), memory_kind=device), key=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec(), memory_kind=device)), {}),'monitors0': State(EvalMonitorState(first_step=True, latest_sol

In [7]:
# run the workflow for 50 steps
for i in range(50):
    state = workflow.step(state)

In [8]:
best_solution, _state = use_state(monitor.get_best_solution)(state)

In [9]:
print(best_solution)

[ 0.0002041  -0.00019218]
