In [1]:
# import jax
# import jax.numpy as jnp

# jax.config.update("jax_compilation_cache_dir", "./jax-caches")
# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [2]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../"))

# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.25"
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
# os.environ["XLA_FLAGS"] = (
#     "--xla_disable_hlo_passes=constant_folding "  # this disables constant folding
#     # "--xla_cpu_use_thunk_runtime=false "
# )
from desc import set_device
set_device("gpu")

In [3]:
# import jax
# import jax.numpy as jnp

# jax.config.update("jax_explain_cache_misses", True)

In [4]:
# from desc import set_device, _set_cpu_count

# num_device = 2
# _set_cpu_count(num_device)
# set_device("cpu", num_device=num_device)

In [5]:
import numpy as np
np.set_printoptions(linewidth=np.inf, precision=4, suppress=True, threshold=sys.maxsize)
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.graph_objects as go
import functools
import scipy

In [6]:
import desc

from desc.basis import *
from desc.backend import *
from desc.compute import *
from desc.coils import *
from desc.equilibrium import *
from desc.examples import *
from desc.grid import *
from desc.geometry import *
from desc.io import *

from desc.objectives import *
from desc.objectives.objective_funs import *
from desc.objectives.getters import *
from desc.objectives.normalization import compute_scaling_factors
from desc.objectives.utils import *
from desc.optimize._constraint_wrappers import *

from desc.transform import Transform
from desc.plotting import *
from desc.optimize import *
from desc.perturbations import *
from desc.profiles import *
from desc.compat import *
from desc.utils import *
from desc.magnetic_fields import *
from desc.particles import *
from diffrax import *

from desc.__main__ import main
from desc.vmec_utils import vmec_boundary_subspace
from desc.input_reader import InputReader
from desc.continuation import solve_continuation_automatic
from desc.compute.data_index import register_compute_fun
from desc.optimize.utils import solve_triangular_regularized

print_backend_info()

DESC version=0.16.0+47.g50bd906d1.
Using JAX backend: jax version=0.7.2, jaxlib version=0.7.2, dtype=float64.
Using device: NVIDIA GeForce RTX 4080 Laptop GPU (id=0), with 10.78 GB available memory.


In [9]:
@jax.jit
def vf1(t, x, args):
    # some complex function (doesn't have a meaning for this test)
    val = x
    for _ in range(20):
        a = (val[0] ** 2 + val[1] ** 2 + 1e-6) ** (1.5)
        val = val.at[0].set(jnp.sin(val[0]) / a + jnp.exp(-val[1] ** 2))
        val = val.at[1].set(jnp.cos(val[1]) / a + jnp.exp(-val[0] ** 2))
    return jnp.array([x[2], x[3], val[0], val[1]])


@jax.jit
def vf1_approx(t, x, args):
    x_all, y_all, f1_all, f2_all, dx1, dx2 = args
    # Nearest neighbor indices
    i1 = jnp.clip(((x[0] - x_all[0]) / dx1).astype(int), 0, len(x_all) - 1)
    i2 = jnp.clip(((x[1] - y_all[0]) / dx2).astype(int), 0, len(y_all) - 1)
    f1 = f1_all[i1, i2]
    f2 = f2_all[i1, i2]
    return jnp.array([x[2], x[3], f1, f2])


x_all = jnp.linspace(-50, 50, 5000)
y_all = jnp.linspace(-50, 50, 5000)
X_all, Y_all = jnp.meshgrid(x_all, y_all)
X_all = X_all.flatten(order="F")
Y_all = Y_all.flatten(order="F")

dx1 = x_all[1] - x_all[0]
dx2 = y_all[1] - y_all[0]
all = jax.lax.map(
    lambda y: vf1(0, y, None),
    jnp.stack([X_all, Y_all, jnp.zeros_like(X_all), jnp.zeros_like(X_all)]).T,
    batch_size=1000,
)
del X_all, Y_all
F1_all = all[:, 2].reshape((len(x_all), len(y_all)))
F2_all = all[:, 3].reshape((len(x_all), len(y_all)))
args = (x_all, y_all, F1_all, F2_all, dx1, dx2)

In [10]:
x = jnp.array([1.0, 0.0, 0.0, 1.0])
%timeit vf1(0, x, None).block_until_ready()
%timeit vf1_approx(0, x, args).block_until_ready()

1.31 ms ± 352 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
The slowest run took 4.42 times longer than the fastest. This could mean that an intermediate result is being cached.
377 μs ± 261 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [17]:
y0s = jax.random.uniform(jax.random.PRNGKey(0), shape=(100000, 4))
ts = jnp.linspace(0, 3, 100)
saveat = SaveAt(ts=ts)
solver = Tsit5()
term = ODETerm(vf1)
term2 = ODETerm(vf1_approx)

def vmap_org(y0s):
    return jax.vmap(
        lambda y0: diffeqsolve(
            term,
            solver,
            t0=ts[0],
            t1=ts[-1],
            dt0=1e-2,
            y0=y0,
            args=None,
            max_steps=10000,
            saveat=saveat,
        )
    )(y0s).ys

def vmap_approx(y0s):
    return jax.vmap(
        lambda y0: diffeqsolve(
            term2,
            solver,
            t0=ts[0],
            t1=ts[-1],
            dt0=1e-2,
            y0=y0,
            args=args,
            max_steps=10000,
            saveat=saveat,
        )
    )(y0s).ys

_ = vmap_org(y0s)
_ = vmap_approx(y0s)

In [18]:
%timeit _ = vmap_org(y0s).block_until_ready()
%timeit _ = vmap_approx(y0s).block_until_ready()

4.16 s ± 42.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
322 ms ± 2.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
