# JAX and BACKEND

In [2]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../../"))

DESC uses JAX for faster execution times with just-in-time (JIT) compilation, automatic differentiation, and other scientific computing tools.
The purpose of ``backend.py`` is to determine whether DESC may take advantage of JAX and GPUs or default to standard ``numpy`` and CPUs.

JAX provides a ``numpy`` style API for array operations.
In many cases, to take advantage of JAX, one only needs to replace calls to ``numpy`` with calls to ``jax.numpy``.
A convenient way to do this is with the import statement ``import jax.numpy as jnp``.

In [13]:
from desc.backend import jax, jnp
import numpy as np

In [14]:
# give some JAX examples
zeros_jnp = jnp.zeros(4)
zeros_np = np.zeros(4)

print(zeros_jnp)
print(zeros_np)

[0. 0. 0. 0.]
[0. 0. 0. 0.]


Of course if such an import statement is used in DESC, and DESC is run on a machine where JAX is not installed, then a runtime error is thrown.
We would prefer if DESC still works on machines where JAX is not installed.
With that goal, in functions which can benefit from JAX, we use the following import statement: ``from desc.backend import jnp``.
``desc.backend.jnp`` is an alias to ``jax.numpy`` if JAX is installed and ``numpy`` otherwise.

While ``jax.numpy`` attempts to serve as a drop in replacement for ``numpy``, it imposes some constraints on how the code is written.
For example, ``jax.numpy`` arrays are immutable.
This means in-place updates to elements in arrays is not possible.
To update elements in ``jax.numpy`` arrays, memory needs to be allocated to create a new array with the updated element.
Similarly, JAX's JIT compilation requires control flow structures such as loops and conditionals to be written in a specific way.

The utility functions in ``desc.backend`` provide a simple interface to perform these operations.

In [10]:
zeros_jnp = jnp.zeros(4)
# this will give an error
# zeros_jnp[0] = 1
# we need to use the at[] method
zeros_jnp = zeros_jnp.at[0].set(1)
print(zeros_jnp)

[1. 0. 0. 0.]


In [11]:
# or to make this compatible with numpy backend we can use the following
from desc.backend import put

zeros_jnp = put(zeros_jnp, 0, 2)
print(zeros_jnp)

[2. 0. 0. 0.]
