# How to use Multiple Devices

In this tutorial, we will see how to use multiple devices to run DESC. This will make the optimization problem scalable to computing clusters.

This tutorials will not be able to run on a Jupyter Notebook, so we will provide the content of the script here but run an underlying python script to show the results.

## Solving Equilibrium

In [1]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../../"))

from IPython.display import Markdown

In [2]:
# Display the content of mpi-eq-solve.py
with open("mpi-tutorials/mpi-eq-solve.py", "r") as f:
    code = f.read()

Markdown(f"```python\n{code}\n```")

```python
import os
import sys

# Add the path to the parent directory to augment search for module
sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../../"))
sys.path.append(os.path.abspath("../../../../"))

import numpy as np
from mpi4py import MPI

from desc import _set_cpu_count, set_device

kind = "cpu"  # or "gpu"
num_device = 2
# ====== Using CPUs ======
# These will be used for diving the single CPU into multiple virtual CPUs
# such that JAX and XLA thinks there are multiple devices
if kind == "cpu":
    # !!! If you have multiple CPUs, you shouldn't call `_set_cpu_count` !!!
    _set_cpu_count(num_device)
    set_device("cpu", num_device=num_device, mpi=MPI)

# ====== Using GPUs ======
# When we have multiple processes using the same devices (for example, 3 processes
# using 3 GPUs), each process will try to pre-allocate 75% of the GPU memory which will
# cause the memory allocation to fail. To avoid this, we can set the allocator to `platform`
# such that there is no pre-allocation. This is a bit conservative (and probably there is room
# for improvement), but if a process needs more memory, it can use more memory on the fly.
elif kind == "gpu":
    os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
    set_device("gpu", num_device=num_device)

from desc import config as desc_config
from desc.backend import jax, print_backend_info
from desc.examples import get
from desc.grid import LinearGrid
from desc.objectives import ForceBalance, ObjectiveFunction
from desc.objectives.getters import get_fixed_boundary_constraints

if __name__ == "__main__":
    rank = MPI.COMM_WORLD.Get_rank()
    size = MPI.COMM_WORLD.Get_size()
    if rank == 0:
        print(f"====== TOTAL OF {size} RANKS ======")

    # see which rank is running on which device
    # Note: JAX has 2 functions for this: `jax.devices()` and `jax.local_devices()`
    # `jax.devices()` will return all devices available to JAX, while `jax.local_devices()`
    # will return only the devices that are available to the current process. This is
    # useful when you have multiple processes running on multiple nodes and you want
    # to see which devices are available to each process.
    if desc_config["kind"] == "gpu":
        print(
            f"Rank {rank} can see {jax.local_devices(backend='gpu')} "
            f"and {jax.local_devices(backend='cpu')}\n"
        )
    else:
        print(f"Rank {rank} can see {jax.local_devices(backend='cpu')}\n")

    if rank == 0:
        print("====== BACKEND INFO ======")
        print_backend_info()
        print("\n")

    eq = get("HELIOTRON")
    if desc_config["kind"] == "cpu":
        # for local testing use lower resolution
        eq.change_resolution(M=3, N=2, M_grid=6, N_grid=4)

    # setup 2 grids for 2 objectives covering different flux surfaces
    rhos = np.linspace(0.1, 1.0, eq.L_grid)
    grid1 = LinearGrid(
        rho=rhos[: rhos.size // 2],
        M=eq.M_grid,
        N=eq.N_grid,
        NFP=eq.NFP,
    )
    grid2 = LinearGrid(
        rho=rhos[rhos.size // 2 :],
        M=eq.M_grid,
        N=eq.N_grid,
        NFP=eq.NFP,
    )
    obj = ObjectiveFunction(
        [
            ForceBalance(eq, grid=grid1, device_id=0),
            ForceBalance(eq, grid=grid2, device_id=1),
        ],
        mpi=MPI,
        deriv_mode="blocked",
    )
    obj.build()
    cons = get_fixed_boundary_constraints(eq)

    # Until this line, the code is performed on all ranks, so it might print some
    # information multiple times. The following part will only be performed on the
    # master rank

    # this context manager will put the workers in a loop to listen to the master
    # to compute the objective function and its derivatives
    with obj:
        # apart from cost evaluation and derivatives, everything else will be only
        # performed on the master rank
        if rank == 0:
            eq.solve(
                objective=obj,
                constraints=cons,
                maxiter=10,
                ftol=0,
                gtol=0,
                xtol=0,
                verbose=3,
            )

    # if you put a code here, it will be performed on all ranks

```

In [3]:
!mpirun -n 2 python mpi-tutorials/mpi-eq-solve.py

Rank 0 can see [CpuDevice(id=0), CpuDevice(id=1)]

DESC version=0.15.0+189.g6c33270c1.dirty.
Using JAX backend: jax version=0.6.2, jaxlib version=0.6.2, dtype=float64.
Using 2 CPUs with 19.06 GB total available memory:
	 CPU : 0  13th Gen Intel(R) Core(TM) i9-13900HX
	 CPU : 1  13th Gen Intel(R) Core(TM) i9-13900HX

Note: The backend information assumes that the user has 1 process per CPU (node). Using multiple processes per CPU (node) is not the most efficient way to use MPI with purely CPUs.


Rank 1 can see [CpuDevice(id=0), CpuDevice(id=1)]

Building objective: force
Precomputing transforms
Building objective: force
Precomputing transforms
Building objective: force
Precomputing transforms
Building objective: force
Precomputing transforms
Putting objective force on device 1
------------------------------------------------------------
Rank 0 will run objective(s): ['ForceBalance']
Rank 1 will run objective(s): ['ForceBalance']
---------------------------------------------------------

## Using other Objectives
Above we used MPI for force balance objective, but we can also use it for general optimization.

**Note:** Currently, if the optimizer solves the equilibrium at each step, this equilibrium solve cannot use MPI.

In [4]:
# Display the content of mpi-proximal.py
with open("mpi-tutorials/mpi-proximal.py", "r") as f:
    code = f.read()

Markdown(f"```python\n{code}\n```")

```python
import os
import sys

# Add the path to the parent directory to augment search for module
sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../../"))
sys.path.append(os.path.abspath("../../../../"))

from mpi4py import MPI

from desc import _set_cpu_count, set_device

kind = "cpu"  # or "gpu"
num_device = 2
# ====== Using CPUs ======
# These will be used for diving the single CPU into multiple virtual CPUs
# such that JAX and XLA thinks there are multiple devices
if kind == "cpu":
    # !!! If you have multiple CPUs, you shouldn't call `_set_cpu_count` !!!
    _set_cpu_count(num_device)
    set_device("cpu", num_device=num_device, mpi=MPI)

# ====== Using GPUs ======
# When we have multiple processes using the same devices (for example, 3 processes
# using 3 GPUs), each process will try to pre-allocate 75% of the GPU memory which will
# cause the memory allocation to fail. To avoid this, we can set the allocator to `platform`
# such that there is no pre-allocation. This is a bit conservative (and probably there is room
# for improvement), but if a process needs more memory, it can use more memory on the fly.
elif kind == "gpu":
    os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
    set_device("gpu", num_device=num_device)


import numpy as np

from desc import config as desc_config
from desc.backend import jax, jnp, print_backend_info
from desc.examples import get
from desc.grid import LinearGrid
from desc.objectives import (
    AspectRatio,
    FixBoundaryR,
    FixBoundaryZ,
    FixCurrent,
    FixPressure,
    FixPsi,
    ForceBalance,
    ObjectiveFunction,
    QuasisymmetryTwoTerm,
)
from desc.optimize import Optimizer

if __name__ == "__main__":
    rank = MPI.COMM_WORLD.Get_rank()
    size = MPI.COMM_WORLD.Get_size()
    if rank == 0:
        print(f"====== TOTAL OF {size} RANKS ======")

    # see which rank is running on which device
    # Note: JAX has 2 functions for this: `jax.devices()` and `jax.local_devices()`
    # `jax.devices()` will return all devices available to JAX, while `jax.local_devices()`
    # will return only the devices that are available to the current process. This is
    # useful when you have multiple processes running on multiple nodes and you want
    # to see which devices are available to each process.
    if desc_config["kind"] == "gpu":
        print(
            f"Rank {rank} is running on {jax.local_devices(backend='gpu')} "
            f"and {jax.local_devices(backend='cpu')}\n"
        )
    else:
        print(f"Rank {rank} is running on {jax.local_devices(backend='cpu')}\n")

    if rank == 0:
        print("====== BACKEND INFO ======")
        print_backend_info()
        print("\n")

    eq = get("precise_QA")
    if desc_config["kind"] == "cpu":
        eq.change_resolution(M=3, N=2, M_grid=6, N_grid=4)

    # create two grids with different rho values, this will effectively separate
    # the quasisymmetry objective into two parts
    grid1 = LinearGrid(
        M=eq.M_grid,
        N=eq.N_grid,
        NFP=eq.NFP,
        rho=jnp.linspace(0.2, 0.5, 4),
        sym=True,
    )
    grid2 = LinearGrid(
        M=eq.M_grid,
        N=eq.N_grid,
        NFP=eq.NFP,
        rho=jnp.linspace(0.6, 1.0, 6),
        sym=True,
    )

    # when using parallel objectives, the user needs to supply the device_id
    obj1 = QuasisymmetryTwoTerm(eq=eq, helicity=(1, eq.NFP), grid=grid1, device_id=0)
    obj2 = QuasisymmetryTwoTerm(eq=eq, helicity=(1, eq.NFP), grid=grid2, device_id=1)
    obj3 = AspectRatio(eq=eq, target=8, weight=100, device_id=0)
    objs = [obj1, obj2, obj3]

    # Parallel objective function needs the MPI communicator
    # If you don't specify `deriv_mode=blocked`, you will get a warning and DESC will
    # automatically switch to `blocked`.
    objective = ObjectiveFunction(
        objs, deriv_mode="blocked", mpi=MPI, rank_per_objective=np.array([0, 1, 0])
    )
    if rank == 0:
        objective.build(verbose=3)
    else:
        objective.build(verbose=0)

    # we will fix some modes as usual
    k = 1
    R_modes = np.vstack(
        (
            [0, 0, 0],
            eq.surface.R_basis.modes[
                np.max(np.abs(eq.surface.R_basis.modes), 1) > k, :
            ],
        )
    )
    Z_modes = eq.surface.Z_basis.modes[
        np.max(np.abs(eq.surface.Z_basis.modes), 1) > k, :
    ]
    constraints = (
        ForceBalance(eq=eq),
        FixBoundaryR(eq=eq, modes=R_modes),
        FixBoundaryZ(eq=eq, modes=Z_modes),
        FixPressure(eq=eq),
        FixPsi(eq=eq),
        FixCurrent(eq=eq),
    )
    optimizer = Optimizer("proximal-lsq-exact")

    # Until this line, the code is performed on all ranks, so it might print some
    # information multiple times. The following part will only be performed on the
    # master rank

    # this context manager will put the workers in a loop to listen to the master
    # to compute the objective function and its derivatives
    with objective:
        # apart from cost evaluation and derivatives, everything else will be only
        # performed on the master rank
        if rank == 0:
            eq.optimize(
                objective=objective,
                constraints=constraints,
                optimizer=optimizer,
                maxiter=3,
                verbose=3,
                options={
                    "initial_trust_ratio": 1.0,
                },
            )

    # if you put a code here, it will be performed on all ranks

```

In [5]:
!mpirun -n 2 python mpi-tutorials/mpi-proximal.py

Rank 1 is running on [CpuDevice(id=0), CpuDevice(id=1)]

Rank 0 is running on [CpuDevice(id=0), CpuDevice(id=1)]

DESC version=0.15.0+189.g6c33270c1.dirty.
Using JAX backend: jax version=0.6.2, jaxlib version=0.6.2, dtype=float64.
Using 2 CPUs with 19.37 GB total available memory:
	 CPU : 0  13th Gen Intel(R) Core(TM) i9-13900HX
	 CPU : 1  13th Gen Intel(R) Core(TM) i9-13900HX

Note: The backend information assumes that the user has 1 process per CPU (node). Using multiple processes per CPU (node) is not the most efficient way to use MPI with purely CPUs.


Building objective: QS two-term
Precomputing transforms
[32mTimer: Precomputing transforms = 771 ms[0m
Building objective: QS two-term
Precomputing transforms
[32mTimer: Precomputing transforms = 656 ms[0m
Putting objective QS two-term on device 1
Building objective: aspect ratio
Precomputing transforms
[32mTimer: Precomputing transforms = 636 ms[0m
------------------------------------------------------------
Rank 0 will run o

## Using Slurm for Multi-Node and Multi-Process Scripts

**Note :** These instructions may differ for the cluster you are trying to use. The reason we give this example is to set some terminology for users that are not familiar with multi-node and multi-processing.

**Note :** For more details, one can check Princeton University Research Computing page [here](https://researchcomputing.princeton.edu/support/knowledge-base/slurm#Multinode--Multithreaded-Jobs).

One needs to use proper slurm script to run parallel code on a cluster. Here, we will give an example in which we use 2 nodes, 8 processes per node and 4 CPU cores per process. *Node* means the actual CPU chip, so we will have 2 CPUs or you can think of it as, we will have 2 computers that are connected to each other. We will have 16 processes and 64 CPU cores in total. Additionally, you can specify number of GPUs per node.

```bash

#!/bin/bash
#SBATCH --job-name=mpi-example        # create a short name for your job
#SBATCH --nodes=2                # node count
#SBATCH --ntasks-per-node=8      # total number of tasks per node
#SBATCH --cpus-per-task=4        # cpu-cores per task (>1 if multi-threaded tasks)
#SBATCH --mem-per-cpu=4G         # memory per cpu-core (4G is default)
#SBATCH --time=00:10:00          # total run time limit (HH:MM:SS)
#SBATCH --gres=gpu:4             # number of GPUs per node (in this case 8 GPUs in total)

export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK
export SRUN_CPUS_PER_TASK=$SLURM_CPUS_PER_TASK
module purge

# module names and version might be different for clusters
module load anaconda3/2024.6
module load openmpi/gcc/4.1.6

# activate the environment that has DESC requirements
# as well as proper mpi4py installation
conda activate mpi-env

srun python your-script.py

```

When using MPI with multiple nodes, each process will see 1 CPU (with multiple cores), and if you requested GPUs, only the GPUs connected to that CPU will be visible to your program. With this in mind, for example, if you want to use 2 nodes, and 3 GPUs per nodes with 3 processes per node, you can use 6 objectives in this way.

```python

# each node will see 3 GPUs
num_device = 3
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
set_device("gpu", num_device=num_device)


...


# this will run on node 1, GPU 0 (rank=0)
obj1 = QuasisymmetryTwoTerm(eq=eq, helicity=(1, eq.NFP), grid=grid1, device_id=0)
# this will run on node 1, GPU 1 (rank=1)
obj2 = QuasisymmetryTwoTerm(eq=eq, helicity=(1, eq.NFP), grid=grid2, device_id=1)
# this will run on node 1, GPU 2 (rank=2)
obj3 = QuasisymmetryTwoTerm(eq=eq, helicity=(1, eq.NFP), grid=grid3, device_id=2)
# this will run on node 2, GPU 0 (rank=3)
obj4 = AspectRatio(eq=eq, target=8, weight=100, device_id=0)
# this will run on node 2, GPU 1 (rank=4)
obj5 = Objective(..., device_id=1)
# this will run on node 2, GPU 2 (rank=5)
obj6 = Objective(..., device_id=2)
objs = [obj1, obj2, obj3, obj4, obj5, obj6]

# Parallel objective function needs the MPI communicator
objective = ObjectiveFunction(objs, deriv_mode="blocked", mpi=MPI)

```

When you write your script for multiple nodes, the number of devices and the device IDs must be selected as if there is only 1 node and only the local GPUs are visible. Other nodes will be used through `rank` of MPI communicator.

**Note: Most clusters have multiple GPUs connected to each node, so before using multiple nodes, use all the GPUs available to that node. Multi-node communication is significantly slower and your script will be easier to write properly.**

Note: You should have at least 6 objectives, so at least 1 objective per device. If you want to run multiple objectives on the same device, you can specify the ``rank_per_objective`` in the `ObjectiveFunction` keywords. By default, the initializer will assign different ranks for each sub-objective.