In [1]:
# import jax
# import jax.numpy as jnp

# jax.config.update("jax_compilation_cache_dir", "./jax-caches")
# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [None]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../"))

# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.25"
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
# os.environ["XLA_FLAGS"] = (
#     "--xla_disable_hlo_passes=constant_folding"  # this disables constant folding
# )
from desc import set_device
set_device("gpu")

In [2]:
# from desc import set_device, _set_cpu_count

# num_device = 2
# _set_cpu_count(num_device)
# set_device("cpu", num_device=num_device)

In [3]:
import numpy as np
np.set_printoptions(linewidth=np.inf, precision=4, suppress=True, threshold=sys.maxsize)
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.graph_objects as go
import functools
import scipy

In [4]:
import desc

from desc.basis import *
from desc.backend import *
from desc.compute import *
from desc.coils import *
from desc.equilibrium import *
from desc.examples import *
from desc.grid import *
from desc.geometry import *

from desc.objectives import *
from desc.objectives.objective_funs import *
from desc.objectives.getters import *
from desc.objectives.normalization import compute_scaling_factors
from desc.objectives.utils import *
from desc.optimize._constraint_wrappers import *

from desc.transform import Transform
from desc.plotting import *
from desc.optimize import *
from desc.perturbations import *
from desc.profiles import *
from desc.compat import *
from desc.utils import *
from desc.magnetic_fields import *

from desc.__main__ import main
from desc.vmec_utils import vmec_boundary_subspace
from desc.input_reader import InputReader
from desc.continuation import solve_continuation_automatic
from desc.compute.data_index import register_compute_fun
from desc.optimize.utils import solve_triangular_regularized

print_backend_info()

DESC version=0.15.0+530.gda8bd1d8d.dirty.
Using JAX backend: jax version=0.6.2, jaxlib version=0.6.2, dtype=float64.
Using device: NVIDIA GeForce RTX 4080 Laptop GPU (id=0), with 10.37 GB available memory.


In [5]:
from diffrax import *

In [7]:
sigma = 10.0
beta = 8.0 / 3.0
rho = 14.0
ts = np.linspace(0, 30, 1000)
atol = 1e-6
rtol = 1e-6
dtmin = 1e-9
max_steps = 2_000

def vector_field(t, y, args):
    xi, yi, zi = y
    dxdt = sigma * (yi - xi)
    dydt = xi * (rho - zi) - yi
    dzdt = xi * yi - beta * zi
    return jnp.array([dxdt, dydt, dzdt])

def default_event(t, y, args, **kwargs):
    return y[2] > 20
    
term = ODETerm(vector_field)
solver = Tsit5()
saveat = SaveAt(ts=ts)
stepsize_controller = PIDController(rtol=rtol, atol=atol, dtmin=dtmin)
event = Event(default_event)

def solve(y0):
    return diffeqsolve(
        term,
        solver,
        t0=ts[0],
        t1=ts[-1],
        dt0=dtmin,
        y0=y0,
        saveat=saveat,
        stepsize_controller=stepsize_controller,
        max_steps=max_steps,
        adjoint=RecursiveCheckpointAdjoint(),
        event=event,
    )

N = 50
inputs = jnp.zeros((N, 3))
inputs = inputs.at[:, 2].set(
    jax.random.uniform(jax.random.PRNGKey(0), (N,), minval=-15.0, maxval=15.0)
)
inputs = inputs.at[:, 1].set(
    jax.random.uniform(jax.random.PRNGKey(1), (N,), minval=-15.0, maxval=15.0)
)
inputs = inputs.at[:, 0].set(
    jax.random.uniform(jax.random.PRNGKey(2), (N,), minval=-15.0, maxval=15.0)
)
sol = jax.vmap(solve)(inputs)
xyzs = sol.ys
fig = go.Figure()
for i in range(N):
    print(f"num_steps = {sol.stats["num_steps"][i]}")
    xyz = xyzs[i]
    fig.add_trace(
        go.Scatter3d(
            x=xyz[:, 0],
            y=xyz[:, 1],
            z=xyz[:, 2],
            mode="lines",
            showlegend=False,
        )
    )
fig.update_layout(width=1000, height=800)
fig

num_steps = 18
num_steps = 19
num_steps = 291
num_steps = 21
num_steps = 16
num_steps = 24
num_steps = 296
num_steps = 21
num_steps = 26
num_steps = 18
num_steps = 331
num_steps = 18
num_steps = 29
num_steps = 322
num_steps = 37
num_steps = 28
num_steps = 31
num_steps = 325
num_steps = 33
num_steps = 274
num_steps = 14
num_steps = 15
num_steps = 26
num_steps = 17
num_steps = 28
num_steps = 31
num_steps = 327
num_steps = 21
num_steps = 21
num_steps = 343
num_steps = 31
num_steps = 15
num_steps = 232
num_steps = 17
num_steps = 15
num_steps = 27
num_steps = 20
num_steps = 287
num_steps = 18
num_steps = 15
num_steps = 17
num_steps = 27
num_steps = 37
num_steps = 15
num_steps = 39
num_steps = 20
num_steps = 21
num_steps = 27
num_steps = 329
num_steps = 24


In [8]:
fig = go.Figure()
for i in range(N):
    sol = solve(inputs[i, :])
    xyz = sol.ys
    print(f"num_steps = {sol.stats["num_steps"]}")
    fig.add_trace(
        go.Scatter3d(
            x=xyz[:, 0],
            y=xyz[:, 1],
            z=xyz[:, 2],
            mode="lines",
            showlegend=False,
        )
    )
fig.update_layout(width=1000, height=800)
fig

num_steps = 18
num_steps = 19
num_steps = 291
num_steps = 21
num_steps = 16
num_steps = 24
num_steps = 296
num_steps = 21
num_steps = 26
num_steps = 18
num_steps = 331
num_steps = 18
num_steps = 29
num_steps = 322
num_steps = 37
num_steps = 28
num_steps = 31
num_steps = 325
num_steps = 33
num_steps = 274
num_steps = 14
num_steps = 15
num_steps = 26
num_steps = 17
num_steps = 28
num_steps = 31
num_steps = 327
num_steps = 21
num_steps = 21
num_steps = 343
num_steps = 31
num_steps = 15
num_steps = 232
num_steps = 17
num_steps = 15
num_steps = 27
num_steps = 20
num_steps = 287
num_steps = 18
num_steps = 15
num_steps = 17
num_steps = 27
num_steps = 37
num_steps = 15
num_steps = 39
num_steps = 20
num_steps = 21
num_steps = 27
num_steps = 329
num_steps = 24


In [11]:
import diffrax
# 1. Define the stiff vector field: Van der Pol oscillator
def vector_field(t, y, args):
    """
    Defines the van der Pol oscillator.
    The stiffness of the ODE is controlled by the `mu` parameter.
    """
    mu = args
    return jnp.array([y[1], mu * (1 - y[0] ** 2) * y[1] - y[0]])


# 2. Define a terminal event
def event_fn(t, y, args, **kwargs):
    """
    A simple event function that triggers when y[0] crosses -1.5.
    This will cause the integration to terminate.
    """
    return y[0] > 15.0


def run_mwe():
    """Sets up and runs the comparison between a loop and vmap."""
    # A high `mu` value makes the system stiff.
    mu = 1000.0

    # Set up the standard diffrax problem components
    term = diffrax.ODETerm(vector_field)
    solver = diffrax.Tsit5()
    t0 = 0.0
    t1 = 1.0  # A long enough integration time for the dynamics to unfold
    dt0 = None  # Let the solver pick the initial step size
    event = diffrax.Event(event_fn)
    stepsize_controller = diffrax.PIDController(rtol=1e-6, atol=1e-6, dtmin=1e-8)

    y0_batch = jnp.array([[2.0, 0.0], [0.1, 100.1]])

    @jax.jit
    def solve_func(y0):
        sol = diffrax.diffeqsolve(
            term,
            solver,
            t0=t0,
            t1=t1,
            dt0=dt0,
            y0=y0,
            args=mu,
            stepsize_controller=stepsize_controller,
            event=event,
            max_steps=200_000,
        )
        return sol.ys, sol.stats

    print(f"Stiffness mu = {mu}")

    loop_steps = []
    for i in range(y0_batch.shape[0]):
        y0 = y0_batch[i]
        # JIT compile for a fair performance comparison
        ys, stats = solve_func(y0)
        loop_steps.append(stats["num_steps"])

    vmapped_ys, vmapped_stats = jax.vmap(solve_func)(y0_batch)
    vmap_steps = vmapped_stats["num_steps"]

    # 7. Analyze and compare the results
    print("\n--- Analysis ---")
    print(f"Loop step counts: {np.array(loop_steps)}")
    print(f"Vmap step counts: {vmap_steps}")

    # Check if any vmapped run took significantly more steps
    if any(vs > ls * 1.5 for vs, ls in zip(vmap_steps, loop_steps)):
        print("\nSUCCESS: The MWE reproduced the issue.")
    else:
        print("\nFAILURE: The MWE did not reproduce the issue.")
run_mwe()

Stiffness mu = 1000.0

--- Analysis ---
Loop step counts: [1091  778]
Vmap step counts: [1091  778]

FAILURE: The MWE did not reproduce the issue.
