In [1]:
# import jax
# import jax.numpy as jnp

# jax.config.update("jax_compilation_cache_dir", "./jax-caches")
# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [2]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../../"))

# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.25"
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
# os.environ["XLA_FLAGS"] = (
#     "--xla_disable_hlo_passes=constant_folding "  # this disables constant folding
#     # "--xla_cpu_use_thunk_runtime=false "
# )
from desc import set_device
set_device("gpu")

In [3]:
# import jax
# import jax.numpy as jnp

# jax.config.update("jax_explain_cache_misses", True)

In [4]:
# from desc import set_device, _set_cpu_count

# num_device = 2
# _set_cpu_count(num_device)
# set_device("cpu", num_device=num_device)

In [5]:
import numpy as np
np.set_printoptions(linewidth=np.inf, precision=4, suppress=True, threshold=sys.maxsize)
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.graph_objects as go
import functools
import scipy

In [6]:
import desc

from desc.basis import *
from desc.backend import *
from desc.compute import *
from desc.coils import *
from desc.equilibrium import *
from desc.examples import *
from desc.grid import *
from desc.geometry import *
from desc.io import *

from desc.objectives import *
from desc.objectives.objective_funs import *
from desc.objectives.getters import *
from desc.objectives.normalization import compute_scaling_factors
from desc.objectives.utils import *
from desc.optimize._constraint_wrappers import *

from desc.transform import Transform
from desc.plotting import *
from desc.optimize import *
from desc.perturbations import *
from desc.profiles import *
from desc.compat import *
from desc.utils import *
from desc.magnetic_fields import *
from desc.particles import *
from diffrax import *

from desc.__main__ import main
from desc.vmec_utils import vmec_boundary_subspace
from desc.input_reader import InputReader
from desc.continuation import solve_continuation_automatic
from desc.compute.data_index import register_compute_fun
from desc.optimize.utils import solve_triangular_regularized

print_backend_info()

DESC version=0.16.0+584.g250e50058.
Using JAX backend: jax version=0.6.2, jaxlib version=0.6.2, dtype=float64.
Using device: NVIDIA GeForce RTX 4080 Laptop GPU (id=0), with 10.75 GB available memory.


In [7]:
from desc.plotting import sequential_colors

In [None]:
name = "precise_QA"
try:
    # if the file exists, load it
    eq = desc.io.load(f"Optimization/eqs/{name}_vacuum_scaled_solved.h5")
    eqi_scaled = eq.copy()
except:
    # else, create it from scratch
    eqi = get(name)
    eq = rescale(eq=eqi, L=("a", 1.7044), B=("<B>", 5.86), copy=True)
    eq.pressure = 0
    eq.current = 0
    eq.solve(ftol=1e-4, verbose=1)
    eqi_scaled = eq.copy()
    eq.save(f"Optimization/eqs/{name}_vacuum_scaled_solved.h5")
eq.iota = eq.get_profile("iota")

In [None]:
xis = np.array([0.01, 0.2, 0.5])
rho = 0.4
# xis = 0.5
# rho = 0.4
colors = [
    "#c80016",  # red
    # "#dc5b0e",  # burnt orange
    # "#f0b528",  # light orange
    # "#dce953",  # yellow
    "#0f5c10",  # green
    # "#1fb7c9",  # teal
    # "#2192e3",  # medium blue
    # "#4f66d4",  # blue-violet
    "#9a45db",  # purple
]
particles = ManualParticleInitializerFlux(
    rho0=rho, theta0=0, zeta0=0.0, xi0=xis, E=3.5e6
)
model = VacuumGuidingCenterTrajectory(frame="flux")
ts = np.linspace(0, 7e-5, 2000)

In [None]:
fig = plot_3d(
    eq,
    "|B|",
    alpha=0.2,
    showaxislabels=False,
    showgrid=False,
    zeroline=False,
    showticklabels=False,
    showscale=False,
)
fig, data = plot_particle_trajectories(
    eq,
    model,
    particles,
    ts=ts,
    fig=fig,
    rtol=1e-5, 
    atol=1e-5, 
    min_step_size=1e-8,
    showaxislabels=False,
    showgrid=False,
    zeroline=False,
    showticklabels=False,
    return_data=True,
    color=colors,
)
fig

In [None]:
for idx in range(len(data["rho"])):
    plt.plot(
        ts*1000,
        data["rho"][idx],
        color=colors[idx % len(colors)],
        label=f"vpar={xis[idx]:.2f}v",
    )
plt.xlabel("Time (ms)")
plt.ylabel("Flux Surface Label (rho)")
plt.ylim([0, 1])
plt.legend()
# plt.savefig("rho-time-ncsx.png", dpi=500)

In [None]:
def fit_line(y):
    # replace nans with zeros, since (0,0) is already the initial
    # point, this will not affect the fit
    y = jnp.where(jnp.isnan(y), 0.0, y)
    ts0 = jnp.where(jnp.isnan(y), 0.0, ts)
    coeffs = jnp.polyfit(ts0, y, 1)
    return coeffs

for idx in range(len(data["rho"])):
    plt.plot(
        ts * 1000,
        data["rho"][idx],
        color=colors[idx % len(colors)],
        label=f"vpar={xis[idx]:.2f}v",
    )
    coeffs = fit_line(data["rho"][idx])
    plt.plot(
        ts * 1000,
        (coeffs[0]*ts + coeffs[1]),
        "--",
        color=colors[idx % len(colors)],
        label=f"vpar={xis[idx]:.2f}v line fit",
    )
plt.xlabel("Time (ms)")
plt.ylabel("Flux Surface Label (rho)")
plt.ylim([0, 1])
plt.legend()
# plt.savefig("W7-X-particle-line-fit.png", dpi=500)

In [None]:
for idx in range(len(data["rho"])):
    plt.polar(
        data["theta"][idx],
        data["rho"][idx],
        color=colors[idx % len(colors)],
        label=f"vpar={xis[idx]:.2f}v",
    )
plt.ylim([0, 1])
plt.legend()
# plt.savefig("ncsx-particle-polar.png", dpi=500)

In [None]:
N = 5000  # number of particles traced
RHO0 = [0.3] * N
xi0 = np.linspace(0.1, 0.9, N, endpoint=True)

model = VacuumGuidingCenterTrajectory(frame="flux")
particles = ManualParticleInitializerFlux(
    rho0=RHO0,
    theta0=np.pi/2,
    zeta0=0,
    xi0=xi0,  # add negative region too
    E=3.5e6,
)
x0, model_args = particles.init_particles(model, eq)
interpolator = FourierChebyshevField(L=eq.L_grid, M=eq.M_grid, N=eq.N_grid)
interpolator.build(eq)
interpolator.fit(eq.params_dict, {"iota": eq.iota, "current": eq.current})

res = 1
spliner = SplineFieldFlux(L=eq.L_grid * res, M=eq.M_grid * res, N=eq.N_grid * res, method="linear")
spliner.build(eq)
spliner.fit(eq.params_dict, {"iota": eq.iota, "current": eq.current})

stepsize_controller = ConstantStepSize()
ts = np.linspace(0, 1e-4, 1000)
min_step_size = 1e-8
max_steps = int(ts[-1] / min_step_size)
# solver = Tsit5(scan_kind="bounded")
solver = Tsit5()
adjoint = RecursiveCheckpointAdjoint()
def default_event(t, y, args, **kwargs):
    i = jnp.sqrt(y[0] ** 2 + y[1] ** 2)
    return i > 1.0

event = Event(default_event)
particle_chunk_size = None

from desc.particles import _trace_particles

In [None]:

rtz1, _, aux1 = _trace_particles(
    field=eq,
    y0=x0,
    model=model,
    model_args=model_args,
    ts=ts,
    params=eq.params_dict,
    stepsize_controller=PIDController(
        rtol=1e-6,
        atol=1e-6,
        dtmin=min_step_size,
        pcoeff=0.3,
        icoeff=0.3,
        dcoeff=0,
    ),
    saveat=SaveAt(steps=True),
    # saveat=SaveAt(ts=ts),
    max_steps=max_steps,
    min_step_size=min_step_size,
    solver=solver,
    adjoint=adjoint,
    event=event,
    options={},
    chunk_size=particle_chunk_size,
    throw=False,
    return_aux=True,
)

In [None]:
rtz2, _, aux2 = _trace_particles(
    field=interpolator,
    y0=x0,
    model=model,
    model_args=model_args,
    ts=ts,
    params=None,
    stepsize_controller=PIDController(
        rtol=1e-6,
        atol=1e-6,
        dtmin=min_step_size,
        pcoeff=0.3,
        icoeff=0.3,
        dcoeff=0,
    ),
    # stepsize_controller=ConstantStepSize(),
    saveat=SaveAt(steps=True),
    # saveat=SaveAt(ts=ts),
    max_steps=max_steps,
    min_step_size=min_step_size,
    solver=solver,
    adjoint=adjoint,
    event=event,
    options={},
    chunk_size=particle_chunk_size,
    throw=False,
    return_aux=True,
)

In [None]:
rtz3, _, aux3 = _trace_particles(
    field=spliner,
    y0=x0,
    model=model,
    model_args=model_args,
    ts=ts,
    params=None,
    stepsize_controller=PIDController(
        rtol=1e-6,
        atol=1e-6,
        dtmin=min_step_size,
        pcoeff=0.3,
        icoeff=0.3,
        dcoeff=0,
    ),
    # stepsize_controller=ConstantStepSize(),
    saveat=SaveAt(steps=True),
    # saveat=SaveAt(ts=ts),
    max_steps=max_steps,
    min_step_size=min_step_size,
    solver=solver,
    adjoint=adjoint,
    event=event,
    options={},
    chunk_size=particle_chunk_size,
    throw=False,
    return_aux=True,
)

In [None]:
id = 999
plt.figure(figsize=(10, 5))
plt.plot(aux1[0][id], rtz1[id, :, 0], label="rtz1", linewidth=5)
plt.plot(aux2[0][id], rtz2[id, :, 0], label="rtz2", linewidth=3)
plt.plot(aux3[0][id], rtz3[id, :, 0], label="rtz3", linewidth=1)
plt.legend()
plt.title(f"xi = {xi0[id]}")

In [None]:
plt.figure(figsize=(10, 5))
plt.polar(rtz1[id, :, 1], rtz1[id, :, 0], label="rtz1", linewidth=5)
plt.polar(rtz2[id, :, 1], rtz2[id, :, 0], label="rtz2", linewidth=3)
plt.polar(rtz3[id, :, 1], rtz3[id, :, 0], label="rtz3", linewidth=1)
plt.legend()
plt.title(f"xi = {xi0[id]}")

In [None]:
plt.figure(figsize=(15, 8))
plt.semilogy(np.diff(aux1[0][id]), label="rtz1", linewidth=5)
plt.semilogy(np.diff(aux2[0][id]), label="rtz2", linewidth=3)
plt.semilogy(np.diff(aux3[0][id]), label="rtz3", linewidth=1)
plt.legend()

In [None]:
aux1[1], aux2[1], aux3[1]

In [None]:
aux3[2]

In [None]:
import plotly.graph_objects as go
import numpy as np
import jax.numpy as jnp

# 1. Get the 1D unique axis values
l_vals = interpolator.params_dict["l"]  # Radial (L)
m_vals = interpolator.params_dict["m"]  # Poloidal (M)
n_vals = interpolator.params_dict["n"]  # Toroidal (N)

# 2. Create the 3D grid
# Note: Input order matches the data shape (N, L, M) -> (z, x, y)
# N_grid varies along axis 0 (Toroidal/N)
# L_grid varies along axis 1 (Radial/L)
# M_grid varies along axis 2 (Poloidal/M)
N_grid, L_grid, M_grid = np.meshgrid(n_vals, l_vals, m_vals, indexing="ij")

# 3. Process Coefficients
coeffs = (
    interpolator.params_dict["coefs_real"][0]
    + 1j * interpolator.params_dict["coefs_imag"][0]
)
coeffs = jnp.abs(coeffs)
coeffs_log = jnp.log10(
    coeffs + 1e-16
)  # Use log10 for easier reading of orders of magnitude

# 4. Flatten everything for Plotly
# We map the grids to the desired visual axes:
# Visual X-axis -> Radial (L) -> L_grid
# Visual Y-axis -> Poloidal (M) -> M_grid
# Visual Z-axis -> Toroidal (N) -> N_grid
l_flat = L_grid.flatten()
m_flat = M_grid.flatten()
n_flat = N_grid.flatten()
c_flat = coeffs_log.flatten()

fig = go.Figure(
    data=[
        go.Scatter3d(
            x=l_flat,  # Radial Axis
            y=m_flat,  # Poloidal Axis
            z=n_flat,  # Toroidal Axis
            mode="markers",
            marker=dict(
                size=4,
                color=c_flat,
                colorscale="Turbo",
                opacity=0.8,
                colorbar=dict(title="log10(|C_lmn|)", thickness=20),
            ),
            opacity=0.5,
            hovertemplate=(
                "L (Radial): %{x}<br>"
                + "M (Poloidal): %{y}<br>"
                + "N (Toroidal): %{z}<br>"
                + "Log Magnitude: %{marker.color:.2f}<extra></extra>"
            ),
        )
    ]
)

fig.update_layout(
    title="3D Fourier-Chebyshev Spectrum (B_r)",
    scene=dict(
        xaxis_title="Radial Mode (L)",
        yaxis_title="Poloidal Mode (M)",
        zaxis_title="Toroidal Mode (N)",
        aspectmode="cube",
    ),
    margin=dict(r=0, b=0, l=0, t=40),
    width=1000,
    height=800,
)

fig.show()

# Spline stuff

In [None]:
eq = get("precise_QA")
eq.iota = eq.get_profile("iota")
interpolator = SplineFieldFlux(L=eq.L_grid*2, M=eq.M_grid*2, N=eq.N_grid*2)
interpolator.build(eq=eq)
interpolator.fit(
    params=eq.params_dict, profiles={"current": eq.current, "iota": eq.iota}
)

In [None]:
rhos = np.linspace(0.05, 1.0, 5)
grid = LinearGrid(rho=rhos, M=3, N=3, NFP=eq.NFP, sym=eq.sym)
keys = ["|B|", "b", "grad(|B|)", "e^rho", "e^theta*rho", "e^zeta"]
data = eq.compute(keys, grid=grid)

for i, coord in enumerate(grid.nodes):
    rho, theta, zeta = coord
    data_interp = interpolator.evaluate(rho, theta, zeta)
    for key in keys:
        msg = f"{key} mismatch at coord:{i} ρ={rho}, θ={theta}, ζ={zeta}"
        np.testing.assert_allclose(
            data[key][i], data_interp[key], rtol=5e-4, atol=5e-4, err_msg=msg
        )

In [None]:
eq = get("precise_QA")
iota = eq.get_profile("iota")
interpolator = SplineFieldFlux(L=eq.L_grid*2, M=eq.M_grid, N=eq.N_grid)
interpolator.build(eq=eq)
interpolator.fit(
    params=eq.params_dict, profiles={"current": eq.current, "iota": eq.iota}
)

model = VacuumGuidingCenterTrajectory(frame="flux")
rhos = np.linspace(0.05, 1.0, 3)
grid = LinearGrid(rho=rhos, M=2, N=2, NFP=eq.NFP, sym=eq.sym)
particles = ManualParticleInitializerFlux(
    rho0=grid.nodes[:, 0],
    theta0=grid.nodes[:, 1],
    zeta0=grid.nodes[:, 2],
    xi0=2 * np.random.rand(grid.num_nodes) - 1,
    E=3.5e6,
)
x0, args = particles.init_particles(model=model, field=eq)

# for xi, argsi in zip(x0, args):
#     rho, theta, zeta, vpar = xi
#     xp = rho * np.cos(theta)
#     yp = rho * np.sin(theta)
#     x = jnp.array([xp, yp, zeta, vpar])

#     params = eq.params_dict
#     params["i_l"] = iota.params
#     exact = model.vf(0, x=x, args=[argsi, eq, params, {"iota": iota}])
#     interpolated = model.vf(0, x=x, args=[argsi, interpolator, None, {}])

#     comps = ["xp_dot", "yp_dot", "zeta_dot", "vpar_dot"]

#     for i, comp in enumerate(comps):
#         msg = f"{comp} mismatch at ρ={rho}, θ={theta}, ζ={zeta}"
#         np.testing.assert_allclose(
#             exact[i], interpolated[i], rtol=2e-2, atol=1e-3, err_msg=msg
#         )

In [10]:
class FourierChebyshevFieldTest(IOAble):
    """Diffrax-compatible field class using strictly real arithmetic.

    Optimized to reduce memory allocation and kernel launches inside the
    Diffrax stepping loop.
    """

    _static_attrs = ["L", "M", "N", "M_fft", "N_fft", "data_keys"]

    def __init__(self, L, M, N):
        self.L = L
        self.M = M
        self.N = N

    def build(self, eq):
        """Build the constants for fit."""
        self.data_keys = ["B", "grad(|B|)", "e^rho", "e^theta*rho", "e^zeta"]
        self.l = jnp.arange(self.L)
        self.M_fft = 2 * self.M + 1
        self.N_fft = 2 * self.N + 1

        self.m = jnp.fft.fftfreq(self.M_fft) * self.M_fft
        self.n = jnp.fft.fftfreq(self.N_fft) * self.N_fft

        # Chebyshev nodes
        x = jnp.cos(jnp.pi * (2 * self.l + 1) / (2 * self.L))
        rho = (x + 1) / 2

        self.grid = LinearGrid(rho=rho, M=self.M, N=self.N, sym=False, NFP=eq.NFP)
        self.transforms = get_transforms(self.data_keys, eq, self.grid)

    def fit(self, params, profiles):
        """Fit series and prepare optimized real-valued coefficients."""
        # 1. Compute raw data
        data_raw = compute_fun(
            "desc.equilibrium.equilibrium.Equilibrium",
            self.data_keys,
            params,
            self.transforms,
            profiles,
        )
        L, M, N = self.L, self.M_fft, self.N_fft

        # 2. Stack data for batch processing
        # Order: B(3), grad|B|(3), e^rho(3), e^theta*rho(3), e^zeta_p(1) -> Total 13
        keys3d = [key for key in self.data_keys if key != "e^zeta"]
        arrays = [
            data_raw[key][:, i].reshape(N, L, M) for key in keys3d for i in [0, 1, 2]
        ]
        arrays.append(data_raw["e^zeta"][:, 1].reshape(N, L, M))

        stacked_data = jnp.stack(arrays)  # Shape (13, N, L, M)

        # 3. Perform Transforms
        # Chebyshev Transform (DCT)
        coefs = jax.scipy.fft.dct(stacked_data, axis=2, norm=None)
        coefs = coefs.at[:, :, 0, :].multiply(0.5)  # Fix 0-th mode
        coefs = coefs * (1.0 / self.L)

        # Fourier Transforms (FFT)
        coefs = jnp.fft.fft(coefs, axis=3, norm=None)  # M axis
        coefs = jnp.fft.fft(coefs, axis=1, norm=None)  # N axis

        # 4. Optimization: Pre-normalize
        # Move the division by (M*N) from evaluate() to here
        norm_factor = 1.0 / (self.M_fft * self.N_fft)
        coefs = coefs * norm_factor

        # 5. Optimization: Stack for Dot Product
        # We need Re(Field) = C_real * Basis_real - C_imag * Basis_imag
        # We store this as: Dot([C_real, -C_imag], [Basis_real, Basis_imag])
        # Result shape: (2, 13, N, L, M)
        coefs_optimized = jnp.stack([coefs.real, -coefs.imag], axis=0)

        self.params_dict = {
            "coefs_opt": coefs_optimized,
            "l": self.l,
            "m": self.m,
            "n": self.n,
        }

    # JIT this function or the function calling it!
    def evaluate(self, rho, theta, zeta, params=None):
        if params is None:
            params = self.params_dict

        # --- 1. Chebyshev Basis (L) ---
        r0p = 1 - 2 * rho
        # Shape: (L,)
        Tl = jnp.cos(params["l"] * jnp.arccos(r0p))

        # --- 2. Fourier Basis (M, N) ---
        # Map zeta to [0, 2pi/NFP]
        zeta = (zeta * self.grid.NFP) % (2 * jnp.pi)

        # Calculate trig components once
        m_theta = jnp.outer(
            theta, params["m"]
        ).flatten()  # Handle batch if needed, here assuming scalar
        n_zeta = jnp.outer(zeta, params["n"]).flatten()

        cm, sm = jnp.cos(m_theta), jnp.sin(m_theta)
        cn, sn = jnp.cos(n_zeta), jnp.sin(n_zeta)

        # Compute Basis_Real and Basis_Imag using outer products
        # Real(e^i(m+n)) = cm*cn - sm*sn
        # Imag(e^i(m+n)) = sm*cn + cm*sn
        # We use broadcasting to get shape (M, N)
        basis_real = cm[None, :] * cn[:, None] - sm[None, :] * sn[:, None]
        basis_imag = sm[None, :] * cn[:, None] + cm[None, :] * sn[:, None]

        # Stack shape: (2, N, M) - Note: transposing to match coefs layout (N, M) if needed
        # Coefs are (..., N, L, M). Let's align basis to (2, N, M)
        # basis_real is currently (M, N) -> Transpose to (N, M)
        basis_stack = jnp.stack([basis_real.T, basis_imag.T], axis=0)

        # --- 3. Fused Contraction ---
        # p: Real/Imag stack (2)
        # k: Field components (13)
        # n: Toroidal modes
        # l: Radial modes
        # m: Poloidal modes
        # We sum over p, n, l, m.
        # Result shape: (k,)

        results = jnp.einsum(
            "pknlm,l,pnm->k", params["coefs_opt"], Tl, basis_stack, optimize="optimal"
        )

        # --- 4. Pack Output ---
        # This part is cheap.
        B = results[0:3]
        B_norm = jnp.linalg.norm(B)

        return {
            "|B|": B_norm,
            "b": B / B_norm,
            "grad(|B|)": results[3:6],
            "e^rho": results[6:9],
            "e^theta*rho": results[9:12],
            "e^zeta": jnp.array([0.0, results[12], 0.0]),
        }

In [17]:
eq = get("precise_QA")
iota = eq.get_profile("iota")
params = eq.params_dict
params["i_l"] = iota.params

model = VacuumGuidingCenterTrajectory(frame="flux")
rhos = np.linspace(0.05, 1.0, 3)
grid = LinearGrid(rho=rhos, M=2, N=2, NFP=eq.NFP, sym=eq.sym)
particles = ManualParticleInitializerFlux(
    rho0=grid.nodes[:, 0],
    theta0=grid.nodes[:, 1],
    zeta0=grid.nodes[:, 2],
    xi0=2 * np.random.rand(grid.num_nodes) - 1,
    E=3.5e6,
)
x0, args = particles.init_particles(model=model, field=eq)

xi = x0[0]
argsi = args[0]
rho, theta, zeta, vpar = xi

xp = rho * np.cos(theta)
yp = rho * np.sin(theta)

x = jnp.array([xp, yp, zeta, vpar])
spliner = SplineFieldFlux(L=eq.L_grid, M=eq.M_grid, N=eq.N_grid)
spliner.build(eq=eq)
spliner.fit(
    params=params, profiles={"current": eq.current, "iota": eq.iota}
)
interpolator = FourierChebyshevField(L=eq.L_grid, M=eq.M_grid, N=eq.N_grid)
interpolator.build(eq)
interpolator.fit(params, {"iota": eq.iota, "current": eq.current})


interpolator2 = FourierChebyshevFieldTest(L=eq.L_grid, M=eq.M_grid, N=eq.N_grid)
interpolator2.build(eq)
interpolator2.fit(params, {"iota": eq.iota, "current": eq.current})

%timeit -n 10 model._compute_flux_coordinates(x=x.squeeze(), eq=eq, params=params, m=argsi[0], q=argsi[1], mu=argsi[2], iota=iota).block_until_ready()
%timeit -n 10 model._compute_flux_coordinates_with_fit(x=x.squeeze(), field=interpolator, m=argsi[0], q=argsi[1], mu=argsi[2]).block_until_ready()
%timeit -n 10 model._compute_flux_coordinates_with_fit(x=x.squeeze(), field=interpolator2, m=argsi[0], q=argsi[1], mu=argsi[2]).block_until_ready()
%timeit -n 10 model._compute_flux_coordinates_with_fit(x=x.squeeze(), field=spliner, m=argsi[0], q=argsi[1], mu=argsi[2]).block_until_ready()

199 ms ± 5.61 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
9.61 ms ± 576 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
9.7 ms ± 499 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
9.4 ms ± 943 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [18]:
%timeit -n 10 model._compute_flux_coordinates_with_fit(x=x, field=interpolator2, m=argsi[0], q=argsi[1], mu=argsi[2]).block_until_ready()
%timeit -n 10 model._compute_flux_coordinates(x=x, eq=eq, params=params, m=argsi[0], q=argsi[1], mu=argsi[2], iota=iota).block_until_ready()
%timeit -n 10 model._compute_flux_coordinates_with_fit(x=x, field=interpolator, m=argsi[0], q=argsi[1], mu=argsi[2]).block_until_ready()
%timeit -n 10 model._compute_flux_coordinates_with_fit(x=x, field=spliner, m=argsi[0], q=argsi[1], mu=argsi[2]).block_until_ready()

9.14 ms ± 783 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
193 ms ± 3.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
9.57 ms ± 690 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
9.64 ms ± 759 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
