# How to use Multiple Devices

In this tutorial, we will see how to use multiple devices to run DESC. This will make the optimization problem scalable to computing clusters.

This tutorials will not be able to run on a Jupyter Notebook, so we will provide the content of the script here but run an underlying python script to show the results.

## Solving Equilibrium

In [1]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../../"))

In [None]:
num_device = 4
from desc import set_device, _set_cpu_count

# These will be used for diving the single CPU into multiple virtual CPUs
# such that JAX and XLA thinks there are multiple devices
# Note that this is just to trick JAX. Since JAX can already use multiple core and threads
# for single CPU, this will not give a speedup. This is just to test the code
_set_cpu_count(num_device)
set_device("cpu", num_device=num_device)

In [6]:
from desc.backend import print_backend_info

print_backend_info()

DESC version=0.13.0+1675.g3b4f847fd.dirty.
Using JAX backend: jax version=0.5.0, jaxlib version=0.5.0, dtype=float64.
Using 4 CPUs:
	 CPU 0: TFRT_CPU_0 with 7.15 GB available memory
	 CPU 1: TFRT_CPU_1 with 7.15 GB available memory
	 CPU 2: TFRT_CPU_2 with 7.15 GB available memory
	 CPU 3: TFRT_CPU_3 with 7.15 GB available memory


```python
import os
import sys

# Add the path to the parent directory to augment search for module
sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../../"))

# These will be used for diving the single CPU into multiple virtual CPUs
# such that JAX and XLA thinks there are multiple devices
from desc import _set_cpu_count, set_device

num_device = 4
_set_cpu_count(num_device)
set_device("cpu", num_device=num_device)

from mpi4py import MPI

from desc.examples import get
from desc.objectives.getters import (
    get_fixed_boundary_constraints,
    get_parallel_forcebalance,
)

if __name__ == "__main__":
    rank = MPI.COMM_WORLD.Get_rank()
    eq = get("HELIOTRON")
    eq.change_resolution(6, 6, 6, 12, 12, 12)

    # this will create a parallel objective function
    # user can create their own parallel objective function as well which will be
    # shown in the next example
    obj = get_parallel_forcebalance(eq, num_device=num_device, mpi=MPI, verbose=1)
    cons = get_fixed_boundary_constraints(eq)

    # Until this line, the code is performed on all ranks, so it might print some
    # information multiple times. The following part will only be performed on the
    # master rank

    # this context manager will put the workers in a loop to listen to the master
    # to compute the objective function and its derivatives
    with obj as obj:
        # apart from cost evaluation and derivatives, everything else will be only
        # performed on the master rank
        if rank == 0:
            eq.solve(
                objective=obj,
                constraints=cons,
                maxiter=1,
                ftol=0,
                gtol=0,
                xtol=0,
                verbose=3,
            )

    # if you put a code here, it will be performed on all ranks

```

In [4]:
!mpirun -n 4 python mpi-tutorials/mpi-eq-solve.py

  pid, fd = os.forkpty()


Precomputing transforms
Precomputing transforms
Precomputing transforms
Precomputing transforms
Precomputing transforms
Precomputing transforms
Precomputing transforms
Precomputing transforms
Building objective: lcfs R
Building objective: lcfs Z
Building objective: fixed Psi
Precomputing transforms
Building objective: fixed pressure
Precomputing transforms
Building objective: fixed iota
Building objective: fixed sheet current
Building objective: self_consistency R
Building objective: self_consistency Z
Building objective: lambda gauge
Building objective: axis R self consistency
Building objective: axis Z self consistency
Timer: Objective build = 1.53 sec
Precomputing transforms
Precomputing transforms
Precomputing transforms
Precomputing transforms
Precomputing transforms
Precomputing transforms
Timer: LinearConstraintProjection build = 4.31 sec
Number of parameters: 609
Number of objectives: 15000
Timer: Initializing the optimization = 5.90 sec

Starting optimization
Using method: lsq

## Using other Objectives
Above we used the convenience function for force balance objective, but we can also other objectives with this approach. There are some extra steps you need to apply though.

```python
import os
import sys

# Add the path to the parent directory to augment search for module
sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../../"))

# These will be used for diving the single CPU into multiple virtual CPUs
# such that JAX and XLA thinks there are multiple devices
from desc import _set_cpu_count, set_device

num_device = 3
_set_cpu_count(num_device)
set_device("cpu", num_device=num_device)

import numpy as np
from mpi4py import MPI

from desc.backend import jax, jnp
from desc.examples import get
from desc.grid import LinearGrid
from desc.objectives import (
    AspectRatio,
    FixBoundaryR,
    FixBoundaryZ,
    FixCurrent,
    FixPressure,
    FixPsi,
    ForceBalance,
    ObjectiveFunction,
    QuasisymmetryTwoTerm,
)
from desc.optimize import Optimizer

if __name__ == "__main__":
    rank = MPI.COMM_WORLD.Get_rank()

    eq = get("precise_QA")
    eq.change_resolution(3, 3, 3, 6, 6, 6)

    # create two grids with different rho values, this will effectively separate
    # the quasisymmetry objective into two parts
    grid1 = LinearGrid(
        M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, rho=jnp.linspace(0.2, 0.5, 4), sym=True
    )
    grid2 = LinearGrid(
        M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, rho=jnp.linspace(0.6, 1.0, 6), sym=True
    )

    # when using parallel objectives, the user needs to supply the device_id
    obj1 = QuasisymmetryTwoTerm(eq=eq, helicity=(1, eq.NFP), grid=grid1, device_id=0)
    obj2 = QuasisymmetryTwoTerm(eq=eq, helicity=(1, eq.NFP), grid=grid2, device_id=1)
    obj3 = AspectRatio(eq=eq, target=8, weight=100, device_id=2)

    objs = [obj1, obj2, obj3]
    # this part will probably be automatized in the future
    for obji in objs:
        obji.build(verbose=3)
        obji = jax.device_put(obji, obji._device)
        obji.things[0] = eq

    # Parallel objective function needs the MPI communicator
    objective = ObjectiveFunction(objs, mpi=MPI)
    objective.build(verbose=3)

    # we will fix some modes as usual
    k = 1
    R_modes = np.vstack(
        (
            [0, 0, 0],
            eq.surface.R_basis.modes[
                np.max(np.abs(eq.surface.R_basis.modes), 1) > k, :
            ],
        )
    )
    Z_modes = eq.surface.Z_basis.modes[
        np.max(np.abs(eq.surface.Z_basis.modes), 1) > k, :
    ]
    constraints = (
        ForceBalance(eq=eq),
        FixBoundaryR(eq=eq, modes=R_modes),
        FixBoundaryZ(eq=eq, modes=Z_modes),
        FixPressure(eq=eq),
        FixPsi(eq=eq),
        FixCurrent(eq=eq),
    )
    optimizer = Optimizer("proximal-lsq-exact")

    # Until this line, the code is performed on all ranks, so it might print some
    # information multiple times. The following part will only be performed on the
    # master rank

    # this context manager will put the workers in a loop to listen to the master
    # to compute the objective function and its derivatives
    with objective as objective:
        # apart from cost evaluation and derivatives, everything else will be only
        # performed on the master rank
        if rank == 0:
            eq.optimize(
                objective=objective,
                constraints=constraints,
                optimizer=optimizer,
                maxiter=1,
                verbose=3,
                options={
                    "initial_trust_ratio": 1.0,
                },
            )

    # if you put a code here, it will be performed on all ranks

```

In [5]:
!mpirun -n 3 python mpi-tutorials/mpi-proximal.py

Precomputing transforms
Precomputing transforms
Timer: Precomputing transforms = 1.37 sec
Timer: Precomputing transforms = 1.40 sec
Precomputing transforms
Precomputing transforms
Timer: Precomputing transforms = 1.25 sec
Precomputing transforms
Timer: Precomputing transforms = 1.27 sec
Precomputing transforms
Precomputing transforms
Timer: Precomputing transforms = 1.24 sec
Timer: Objective build = 21.9 ms
Timer: Precomputing transforms = 1.26 sec
Timer: Objective build = 22.0 ms
Building objective: force
Precomputing transforms
Timer: Precomputing transforms = 2.27 sec
Timer: Precomputing transforms = 1.57 sec
Precomputing transforms
Timer: Objective build = 2.37 sec
Timer: Objective build = 25.1 ms
Timer: Precomputing transforms = 2.02 sec
Precomputing transforms
Timer: Precomputing transforms = 2.01 sec
Timer: Objective build = 36.1 ms
Timer: Eq Update LinearConstraintProjection build = 5.33 sec
Timer: Proximal projection build = 9.29 sec
Building objective: lcfs R
Building objecti