# How to use Multiple Devices

In this tutorial, we will see how to use multiple devices to run DESC. This will make the optimization problem scalable to computing clusters.

This tutorials will not be able to run on a Jupyter Notebook, so we will provide the content of the script here but run an underlying python script to show the results.

## Solving Equilibrium

In [1]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../../"))

In [2]:
num_device = 4
from desc import set_device, _set_cpu_count

# These will be used for diving the single CPU into multiple virtual CPUs
# such that JAX and XLA thinks there are multiple devices
# Note that this is just to trick JAX. Since JAX can already use multiple core and threads
# for single CPU, this will not give a speedup. This is just to test the code
_set_cpu_count(num_device)
set_device("cpu", num_device=num_device)

In [None]:
from desc.backend import print_backend_info

print_backend_info()

DESC version=0.14.2+102.g199b09f73.dirty.
Using JAX backend: jax version=0.5.0, jaxlib version=0.5.0, dtype=float64.
Using 4 CPUs:
	 CPU 0: TFRT_CPU_0 with 8.26 GB available memory
	 CPU 1: TFRT_CPU_1 with 8.26 GB available memory
	 CPU 2: TFRT_CPU_2 with 8.26 GB available memory
	 CPU 3: TFRT_CPU_3 with 8.26 GB available memory


```python

import os
import sys

# Add the path to the parent directory to augment search for module
sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../../"))

from desc import _set_cpu_count, set_device

# ====== Using CPUs ======
num_device = 4
# These will be used for diving the single CPU into multiple virtual CPUs
# such that JAX and XLA thinks there are multiple devices
# If you have multiple CPUs, you don't need to call `_set_cpu_count`
_set_cpu_count(num_device)
set_device("cpu", num_device=num_device)

# ====== Using GPUs ======
# When we have multiple processes using the same devices (for example, 3 processes
# using 3 GPUs), each process will try to pre-allocate 75% of the GPU memory which will
# cause the memory allocation to fail. To avoid this, we can set the memory fraction
# to 1/(num_device + 2) which will allow each process to allocate 1/(num_device + 2) of
# the GPU memory. This is a bit conservative, but if a process needs more memory, it can
# allocate more memory on the fly.
#
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(1 / (num_device + 2))
# set_device("gpu", num_device=num_device)

from mpi4py import MPI

from desc import config as desc_config
from desc.backend import jax, print_backend_info
from desc.examples import get
from desc.objectives.getters import (
    get_fixed_boundary_constraints,
    get_parallel_forcebalance,
)

if __name__ == "__main__":
    rank = MPI.COMM_WORLD.Get_rank()
    size = MPI.COMM_WORLD.Get_size()
    if rank == 0:
        print(f"====== TOTAL OF {size} RANKS ======")

    # see which rank is running on which device
    # Note: JAX has 2 functions for this: `jax.devices()` and `jax.local_devices()`
    # `jax.devices()` will return all devices available to JAX, while `jax.local_devices()`
    # will return only the devices that are available to the current process. This is
    # useful when you have multiple processes running on multiple nodes and you want
    # to see which devices are available to each process.
    if desc_config["kind"] == "gpu":
        print(
            f"Rank {rank} is running on {jax.local_devices(backend="gpu")} "
            f"and {jax.local_devices(backend="cpu")}\n"
        )
    else:
        print(f"Rank {rank} is running on {jax.local_devices(backend='cpu')}\n")
    if rank == 0:
        print(f"====== BACKEND INFO ======")
        print_backend_info()
        print("\n")

    eq = get("HELIOTRON")
    eq.change_resolution(M=3, N=2, M_grid=6, N_grid=4)

    # this will create a parallel objective function
    # user can create their own parallel objective function as well which will be
    # shown in the next example
    obj = get_parallel_forcebalance(eq, num_device=num_device, mpi=MPI, verbose=1)
    cons = get_fixed_boundary_constraints(eq)

    # Until this line, the code is performed on all ranks, so it might print some
    # information multiple times. The following part will only be performed on the
    # master rank

    # this context manager will put the workers in a loop to listen to the master
    # to compute the objective function and its derivatives
    with obj as obj:
        # apart from cost evaluation and derivatives, everything else will be only
        # performed on the master rank
        if rank == 0:
            eq.solve(
                objective=obj,
                constraints=cons,
                maxiter=3,
                ftol=0,
                gtol=0,
                xtol=0,
                verbose=3,
            )

    # if you put a code here, it will be performed on all ranks



```

In [4]:
!mpirun -n 4 python mpi-tutorials/mpi-eq-solve.py

  pid, fd = os.forkpty()


Rank 0 is running on [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]

DESC version=0.14.2+102.g199b09f73.dirty.
Using JAX backend: jax version=0.5.0, jaxlib version=0.5.0, dtype=float64.
Using 4 CPUs:
	 CPU 0: TFRT_CPU_0 with 8.15 GB available memory
	 CPU 1: TFRT_CPU_1 with 8.15 GB available memory
	 CPU 2: TFRT_CPU_2 with 8.15 GB available memory
	 CPU 3: TFRT_CPU_3 with 8.15 GB available memory


Rank 1 is running on [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]

Building objective: force
Precomputing transforms
Building objective: force
Precomputing transforms
Rank 2 is running on [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]

Rank 3 is running on [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]

Building objective: force
Precomputing transforms
Putting objective force on device 1
Building objective: force
Precomputing transforms
Building objective: force
Precomputing transforms
Putting objective

## Using other Objectives
Above we used the convenience function for force balance objective, but we can also use other objectives with this approach.

```python
import os
import sys

# Add the path to the parent directory to augment search for module
sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../../"))

from desc import _set_cpu_count, set_device

# ====== Using CPUs ======
num_device = 2
# These will be used for diving the single CPU into multiple virtual CPUs
# such that JAX and XLA thinks there are multiple devices
# If you have multiple CPUs, you don't need to call `_set_cpu_count`
_set_cpu_count(num_device)
set_device("cpu", num_device=num_device)

# ====== Using GPUs ======
# When we have multiple processes using the same devices (for example, 3 processes
# using 3 GPUs), each process will try to pre-allocate 75% of the GPU memory which will
# cause the memory allocation to fail. To avoid this, we can set the memory fraction
# to 1/(num_device + 2) which will allow each process to allocate 1/(num_device + 2) of
# the GPU memory. This is a bit conservative, but if a process needs more memory, it can
# allocate more memory on the fly.
#
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(1 / (num_device + 2))
# set_device("gpu", num_device=num_device)


import numpy as np
from mpi4py import MPI

from desc import config as desc_config
from desc.backend import jax, jnp, print_backend_info
from desc.examples import get
from desc.grid import LinearGrid
from desc.objectives import (
    AspectRatio,
    FixBoundaryR,
    FixBoundaryZ,
    FixCurrent,
    FixPressure,
    FixPsi,
    ForceBalance,
    ObjectiveFunction,
    QuasisymmetryTwoTerm,
)
from desc.optimize import Optimizer

if __name__ == "__main__":
    rank = MPI.COMM_WORLD.Get_rank()
    size = MPI.COMM_WORLD.Get_size()
    if rank == 0:
        print(f"====== TOTAL OF {size} RANKS ======")

    # see which rank is running on which device
    # Note: JAX has 2 functions for this: `jax.devices()` and `jax.local_devices()`
    # `jax.devices()` will return all devices available to JAX, while `jax.local_devices()`
    # will return only the devices that are available to the current process. This is
    # useful when you have multiple processes running on multiple nodes and you want
    # to see which devices are available to each process.
    if desc_config["kind"] == "gpu":
        print(
            f"Rank {rank} is running on {jax.local_devices(backend="gpu")} "
            f"and {jax.local_devices(backend="cpu")}\n"
        )
    else:
        print(f"Rank {rank} is running on {jax.local_devices(backend='cpu')}\n")

    if rank == 0:
        print(f"====== BACKEND INFO ======")
        print_backend_info()
        print("\n")

    eq = get("precise_QA")
    eq.change_resolution(M=3, N=2, M_grid=6, N_grid=4)

    # create two grids with different rho values, this will effectively separate
    # the quasisymmetry objective into two parts
    grid1 = LinearGrid(
        M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, rho=jnp.linspace(0.2, 0.5, 4), sym=True
    )
    grid2 = LinearGrid(
        M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, rho=jnp.linspace(0.6, 1.0, 6), sym=True
    )

    # when using parallel objectives, the user needs to supply the device_id
    obj1 = QuasisymmetryTwoTerm(eq=eq, helicity=(1, eq.NFP), grid=grid1, device_id=0)
    obj2 = QuasisymmetryTwoTerm(eq=eq, helicity=(1, eq.NFP), grid=grid2, device_id=1)
    obj3 = AspectRatio(eq=eq, target=8, weight=100, device_id=0)
    objs = [obj1, obj2, obj3]

    # Parallel objective function needs the MPI communicator
    # If you don't specify `deriv_mode=blocked`, you will get a warning and DESC will
    # automatically switch to `blocked`.
    objective = ObjectiveFunction(
        objs, deriv_mode="blocked", mpi=MPI, rank_per_objective=np.array([0, 1, 0])
    )
    if rank == 0:
        objective.build(verbose=3)
    else:
        objective.build(verbose=0)

    # we will fix some modes as usual
    k = 1
    R_modes = np.vstack(
        (
            [0, 0, 0],
            eq.surface.R_basis.modes[
                np.max(np.abs(eq.surface.R_basis.modes), 1) > k, :
            ],
        )
    )
    Z_modes = eq.surface.Z_basis.modes[
        np.max(np.abs(eq.surface.Z_basis.modes), 1) > k, :
    ]
    constraints = (
        ForceBalance(eq=eq),
        FixBoundaryR(eq=eq, modes=R_modes),
        FixBoundaryZ(eq=eq, modes=Z_modes),
        FixPressure(eq=eq),
        FixPsi(eq=eq),
        FixCurrent(eq=eq),
    )
    optimizer = Optimizer("proximal-lsq-exact")

    # Until this line, the code is performed on all ranks, so it might print some
    # information multiple times. The following part will only be performed on the
    # master rank

    # this context manager will put the workers in a loop to listen to the master
    # to compute the objective function and its derivatives
    with objective as objective:
        # apart from cost evaluation and derivatives, everything else will be only
        # performed on the master rank
        if rank == 0:
            eq.optimize(
                objective=objective,
                constraints=constraints,
                optimizer=optimizer,
                maxiter=3,
                verbose=3,
                options={
                    "initial_trust_ratio": 1.0,
                },
            )

    # if you put a code here, it will be performed on all ranks


In [5]:
!mpirun -n 2 python mpi-tutorials/mpi-proximal.py

Rank 0 is running on [CpuDevice(id=0), CpuDevice(id=1)]

DESC version=0.14.2+102.g199b09f73.dirty.
Using JAX backend: jax version=0.5.0, jaxlib version=0.5.0, dtype=float64.
Using 2 CPUs:
	 CPU 0: TFRT_CPU_0 with 8.13 GB available memory
	 CPU 1: TFRT_CPU_1 with 8.13 GB available memory


Rank 1 is running on [CpuDevice(id=0), CpuDevice(id=1)]

Building objective: QS two-term
Precomputing transforms
Timer: Precomputing transforms = 1.54 sec
Building objective: QS two-term
Precomputing transforms
Timer: Precomputing transforms = 1.37 sec
Putting objective QS two-term on device 1
Building objective: aspect ratio
Precomputing transforms
Timer: Precomputing transforms = 1.37 sec
------------------------------------------------------------
Rank 0 will run objective(s): ['QuasisymmetryTwoTerm', 'AspectRatio']
Rank 1 will run objective(s): ['QuasisymmetryTwoTerm']
------------------------------------------------------------
Timer: Objective build = 5.35 sec
Building objective: force
Precomput

# Using Slurm for Multi-Node and Multi-Process Scripts

**Note :** These instructions may differ for the cluster you are trying to use. The reason we give this example is to set some terminology for users that are not familiar with multi-node and multi-processing.

**Note :** For more details, one can check Princeton University Research Computing page [here](https://researchcomputing.princeton.edu/support/knowledge-base/slurm#Multinode--Multithreaded-Jobs).

One needs to use proper slurm script to run parallel code on a cluster. Here, we will give an example in which we use 2 nodes, 8 processes per node and 4 CPU cores per process. *Node* means the actual CPU chip, so we will have 2 CPUs or you can think of it as, we will have 2 computers that are connected to each other. We will have 16 processes and 64 CPU cores in total. Additionally, you can specify number of GPUs per node.

```bash

#!/bin/bash
#SBATCH --job-name=mpi-example        # create a short name for your job
#SBATCH --nodes=2                # node count
#SBATCH --ntasks-per-node=8      # total number of tasks per node
#SBATCH --cpus-per-task=4        # cpu-cores per task (>1 if multi-threaded tasks)
#SBATCH --mem-per-cpu=4G         # memory per cpu-core (4G is default)
#SBATCH --time=00:10:00          # total run time limit (HH:MM:SS)
#SBATCH --gres=gpu:4             # number of GPUs per node (in this case 8 GPUs in total)

export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK
export SRUN_CPUS_PER_TASK=$SLURM_CPUS_PER_TASK
module purge
module load intel/2022.2.0
module load intel-mpi/intel/2021.7.0

srun python your-script.py

```

When using MPI with multiple nodes, each process will see 1 CPU, and if you requested GPUs, only the GPUs connected to that CPU will be visible to your program. With this in mind, for example, if you want to use 2 nodes, and 3 GPUs per nodes with 3 processes per node, you can use 6 objectives in this way.

```python

# each node will see 3 GPUs
num_device = 3
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(1 / (num_device + 2))
set_device("gpu", num_device=num_device)


...


# this will run on node 1, GPU 0 (rank=0)
obj1 = QuasisymmetryTwoTerm(eq=eq, helicity=(1, eq.NFP), grid=grid1, device_id=0)
# this will run on node 1, GPU 1 (rank=1)
obj2 = QuasisymmetryTwoTerm(eq=eq, helicity=(1, eq.NFP), grid=grid2, device_id=1)
# this will run on node 1, GPU 2 (rank=2)
obj3 = QuasisymmetryTwoTerm(eq=eq, helicity=(1, eq.NFP), grid=grid3, device_id=2)
# this will run on node 2, GPU 0 (rank=3)
obj4 = AspectRatio(eq=eq, target=8, weight=100, device_id=0)
# this will run on node 2, GPU 2 (rank=4)
obj5 = Objective(..., device_id=1)
# this will run on node 2, GPU 2 (rank=5)
obj6 = Objective(..., device_id=2)
objs = [obj1, obj2, obj3, obj4, obj5, obj6]

# Parallel objective function needs the MPI communicator
objective = ObjectiveFunction(objs, deriv_mode="blocked", mpi=MPI)

```