In [1]:
# import jax
# import jax.numpy as jnp

# jax.config.update("jax_compilation_cache_dir", "./jax-caches")
# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [2]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../"))

# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.25"
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
# os.environ["XLA_FLAGS"] = (
#     "--xla_disable_hlo_passes=constant_folding "  # this disables constant folding
#     # "--xla_cpu_use_thunk_runtime=false "
# )
from desc import set_device
set_device("gpu")

In [3]:
# import jax
# import jax.numpy as jnp

# jax.config.update("jax_explain_cache_misses", True)

In [4]:
# from desc import set_device, _set_cpu_count

# num_device = 2
# _set_cpu_count(num_device)
# set_device("cpu", num_device=num_device)

In [5]:
import numpy as np
np.set_printoptions(linewidth=np.inf, precision=4, suppress=True, threshold=sys.maxsize)
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.graph_objects as go
import functools
import scipy

In [6]:
import desc

from desc.basis import *
from desc.backend import *
from desc.compute import *
from desc.coils import *
from desc.equilibrium import *
from desc.examples import *
from desc.grid import *
from desc.geometry import *
from desc.io import *

from desc.objectives import *
from desc.objectives.objective_funs import *
from desc.objectives.getters import *
from desc.objectives.normalization import compute_scaling_factors
from desc.objectives.utils import *
from desc.optimize._constraint_wrappers import *

from desc.transform import Transform
from desc.plotting import *
from desc.optimize import *
from desc.perturbations import *
from desc.profiles import *
from desc.compat import *
from desc.utils import *
from desc.magnetic_fields import *
from desc.particles import *
from diffrax import *

from desc.__main__ import main
from desc.vmec_utils import vmec_boundary_subspace
from desc.input_reader import InputReader
from desc.continuation import solve_continuation_automatic
from desc.compute.data_index import register_compute_fun
from desc.optimize.utils import solve_triangular_regularized

print_backend_info()

DESC version=0.16.0+411.g6e1e51890.
Using JAX backend: jax version=0.6.2, jaxlib version=0.6.2, dtype=float64.
Using device: NVIDIA GeForce RTX 4080 Laptop GPU (id=0), with 11.62 GB available memory.


In [None]:
from collections import OrderedDict
from typing import Any, Union

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from jax import jit
from jaxtyping import Array, ArrayLike, Float, Inexact, Num, Real

from interpax._coefs import A_BICUBIC, A_CUBIC, A_TRICUBIC
from interpax._fd_derivs import approx_df
from interpax.utils import asarray_inexact, errorif, isbool, wrap_jit

CUBIC_METHODS = (
    "cubic",
    "cubic2",
    "cardinal",
    "catmull-rom",
    "akima",
    "monotonic",
    "monotonic-0",
)
OTHER_METHODS = ("nearest", "linear")
METHODS_1D = CUBIC_METHODS + OTHER_METHODS
METHODS_2D = CUBIC_METHODS + OTHER_METHODS
METHODS_3D = CUBIC_METHODS + OTHER_METHODS


class AbstractInterpolator(eqx.Module):
    """ABC convenience class for representing an interpolated function.

    Subclasses should implement the `__call__` method to evaluate the
    interpolated function.

    """

    f: eqx.AbstractVar[Inexact[Array, "..."]]  # function values to interpolate
    derivs: eqx.AbstractVar[dict[str, Inexact[Array, "..."]]]
    method: str = eqx.field(static=True)
    extrap: eqx.AbstractVar[Union[bool, float, tuple]]
    period: eqx.AbstractVar[Union[None, float, tuple]]
    axis: eqx.AbstractVar[int]


class Interpolator3D(AbstractInterpolator):
    """Convenience class for representing a 3D interpolated function.

    Parameters
    ----------
    x : ndarray, shape(Nx,)
        x coordinates of known function values ("knots")
    y : ndarray, shape(Ny,)
        y coordinates of known function values ("knots")
    z : ndarray, shape(Nz,)
        z coordinates of known function values ("knots")
    f : ndarray, shape(Nx,Ny,Nz,...)
        function values to interpolate
    method : str
        method of interpolation

        - ``'nearest'``: nearest neighbor interpolation
        - ``'linear'``: linear interpolation
        - ``'cubic'``: C1 cubic splines (aka local splines)
        - ``'cubic2'``: C2 cubic splines (aka natural splines)
        - ``'catmull-rom'``: C1 cubic centripetal "tension" splines
        - ``'cardinal'``: C1 cubic general tension splines. If used, can also pass
          keyword parameter ``c`` in float[0,1] to specify tension
        - ``'monotonic'``: C1 cubic splines that attempt to preserve monotonicity in the
          data, and will not introduce new extrema in the interpolated points
        - ``'monotonic-0'``: same as ``'monotonic'`` but with 0 first derivatives at
          both endpoints
        - ``'akima'``: C1 cubic splines that appear smooth and natural

    extrap : bool, float, array-like
        whether to extrapolate values beyond knots (True) or return nan (False),
        or a specified value to return for query points outside the bounds. Can
        also be passed as an array or tuple to specify different conditions
        [[xlow, xhigh],[ylow,yhigh]]
    period : float > 0, None, array-like, shape(2,)
        periodicity of the function in x, y, z directions. None denotes no periodicity,
        otherwise function is assumed to be periodic on the interval [0,period]. Use a
        single value for the same in both directions.

    """

    x: Float[Array, " Nx"]
    y: Float[Array, " Ny"]
    z: Float[Array, " Nz"]
    f: Inexact[Array, " Nx Ny Nz ..."]
    derivs: dict
    method: str = eqx.field(static=True)
    extrap: Union[bool, float, tuple]
    period: Union[None, float, tuple]
    axis: int

    def __init__(
        self,
        x: Real[ArrayLike, " Nx"],
        y: Real[ArrayLike, " Ny"],
        z: Real[ArrayLike, " Nz"],
        f: Num[ArrayLike, " Nx Ny Nz ..."],
        method: str = "cubic",
        extrap: Union[bool, float, tuple] = False,
        period: Union[None, float, tuple] = None,
        **kwargs,
    ):
        x, y, z, f = map(asarray_inexact, (x, y, z, f))
        axis = kwargs.get("axis", 0)

        errorif(
            (len(x) != f.shape[0]) or (x.ndim != 1),
            ValueError,
            "x and f must be arrays of equal length",
        )
        errorif(
            (len(y) != f.shape[1]) or (y.ndim != 1),
            ValueError,
            "y and f must be arrays of equal length",
        )
        errorif(
            (len(z) != f.shape[2]) or (z.ndim != 1),
            ValueError,
            "z and f must be arrays of equal length",
        )
        errorif(method not in METHODS_3D, ValueError, f"unknown method {method}")

        fx = kwargs.pop("fx", None)
        fy = kwargs.pop("fy", None)
        fz = kwargs.pop("fz", None)
        fxy = kwargs.pop("fxy", None)
        fxz = kwargs.pop("fxz", None)
        fyz = kwargs.pop("fyz", None)
        fxyz = kwargs.pop("fxyz", None)

        self.x = x
        self.y = y
        self.z = z
        self.f = f
        self.axis = axis
        self.method = method
        self.extrap = extrap
        self.period = period

        if fx is None:
            fx = approx_df(x, f, method, 0, **kwargs)
        if fy is None:
            fy = approx_df(y, f, method, 1, **kwargs)
        if fz is None:
            fz = approx_df(z, f, method, 2, **kwargs)
        if fxy is None:
            fxy = approx_df(y, fx, method, 1, **kwargs)
        if fxz is None:
            fxz = approx_df(z, fx, method, 2, **kwargs)
        if fyz is None:
            fyz = approx_df(z, fy, method, 2, **kwargs)
        if fxyz is None:
            fxyz = approx_df(z, fxy, method, 2, **kwargs)

        self.derivs = {
            "fx": fx,
            "fy": fy,
            "fz": fz,
            "fxy": fxy,
            "fxz": fxz,
            "fyz": fyz,
            "fxyz": fxyz,
        }

    def __call__(
        self,
        xq: Real[ArrayLike, "..."],
        yq: Real[ArrayLike, "..."],
        zq: Real[ArrayLike, "..."],
        dx: int = 0,
        dy: int = 0,
        dz: int = 0,
    ) -> Inexact[Array, "..."]:
        """Evaluate the interpolated function or its derivatives.

        Parameters
        ----------
        xq, yq, zq : ndarray, shape(Nq,)
            x, y, z query points where interpolation is desired
        dx, dy, dz : int >= 0
            Derivative to take in x, y, z directions.

        Returns
        -------
        fq : ndarray, shape(Nq, ...)
            Interpolated values.
        """
        return interp3d(
            xq,
            yq,
            zq,
            self.x,
            self.y,
            self.z,
            self.f,
            self.method,
            (dx, dy, dz),
            self.extrap,
            self.period,
            **self.derivs,
        )
@wrap_jit(static_argnames=["method"])
def interp3d(  # noqa: C901 - FIXME: break this up into simpler pieces
    xq: Real[ArrayLike, " Nq"],
    yq: Real[ArrayLike, " Nq"],
    zq: Real[ArrayLike, " Nq"],
    x: Real[ArrayLike, " Nx"],
    y: Real[ArrayLike, " Ny"],
    z: Real[ArrayLike, " Nz"],
    f: Num[ArrayLike, "Nx Ny Nz ..."],
    method: str = "cubic",
    derivative: Union[int, tuple] = 0,
    extrap: Union[bool, float, tuple] = False,
    period: Union[None, float, tuple] = None,
    **kwargs,
) -> Inexact[Array, "Nq ..."]:
    """Interpolate a 3d function.

    Parameters
    ----------
    xq : ndarray, shape(Nq,)
        x query points where interpolation is desired
    yq : ndarray, shape(Nq,)
        y query points where interpolation is desired
    zq : ndarray, shape(Nq,)
        z query points where interpolation is desired
    x : ndarray, shape(Nx,)
        x coordinates of known function values ("knots")
    y : ndarray, shape(Ny,)
        y coordinates of known function values ("knots")
    z : ndarray, shape(Nz,)
        z coordinates of known function values ("knots")
    f : ndarray, shape(Nx,Ny,Nz,...)
        function values to interpolate
    method : str
        method of interpolation

        - ``'nearest'``: nearest neighbor interpolation
        - ``'linear'``: linear interpolation
        - ``'cubic'``: C1 cubic splines (aka local splines)
        - ``'cubic2'``: C2 cubic splines (aka natural splines)
        - ``'catmull-rom'``: C1 cubic centripetal "tension" splines
        - ``'cardinal'``: C1 cubic general tension splines. If used, can also pass
          keyword parameter ``c`` in float[0,1] to specify tension
        - ``'monotonic'``: C1 cubic splines that attempt to preserve monotonicity in the
          data, and will not introduce new extrema in the interpolated points
        - ``'monotonic-0'``: same as ``'monotonic'`` but with 0 first derivatives at
          both endpoints
        - ``'akima'``: C1 cubic splines that appear smooth and natural

    derivative : int >= 0, array-like, shape(3,)
        derivative order to calculate in x,y,z directions. Use a single value for the
        same in all directions.
    extrap : bool, float, array-like
        whether to extrapolate values beyond knots (True) or return nan (False),
        or a specified value to return for query points outside the bounds. Can
        also be passed as an array or tuple to specify different conditions for
        [[xlow, xhigh],[ylow,yhigh],[zlow,zhigh]]
    period : float > 0, None, array-like, shape(3,)
        periodicity of the function in x, y, z directions. None denotes no periodicity,
        otherwise function is assumed to be periodic on the interval [0,period]. Use a
        single value for the same in all directions.

    Returns
    -------
    fq : ndarray, shape(Nq,...)
        function value at query points

    Notes
    -----
    For repeated interpolation given the same x, y, z, f data, recommend using
    Interpolator3D which caches the calculation of the derivatives and spline
    coefficients.

    """
    xq, yq, zq, x, y, z, f = map(asarray_inexact, (xq, yq, zq, x, y, z, f))
    errorif(
        (len(x) != f.shape[0]) or (x.ndim != 1),
        ValueError,
        "x and f must be arrays of equal length",
    )
    errorif(
        (len(y) != f.shape[1]) or (y.ndim != 1),
        ValueError,
        "y and f must be arrays of equal length",
    )
    errorif(
        (len(z) != f.shape[2]) or (z.ndim != 1),
        ValueError,
        "z and f must be arrays of equal length",
    )
    errorif(method not in METHODS_3D, ValueError, f"unknown method {method}")

    xq, yq, zq = jnp.broadcast_arrays(xq, yq, zq)
    outshape = xq.shape + f.shape[3:]

    # Promote scalar query points to 1D array.
    # Note this is done after the computation of outshape
    # to make jax.grad work in the scalar case.
    xq, yq, zq = map(jnp.atleast_1d, (xq, yq, zq))

    fx = kwargs.pop("fx", None)
    fy = kwargs.pop("fy", None)
    fz = kwargs.pop("fz", None)
    fxy = kwargs.pop("fxy", None)
    fxz = kwargs.pop("fxz", None)
    fyz = kwargs.pop("fyz", None)
    fxyz = kwargs.pop("fxyz", None)

    periodx, periody, periodz = _parse_ndarg(period, 3)
    derivative_x, derivative_y, derivative_z = _parse_ndarg(derivative, 3)
    lowx, highx, lowy, highy, lowz, highz = _parse_extrap(extrap, 3)

    if periodx is not None:
        xq, x, f, fx, fy, fz, fxy, fxz, fyz, fxyz = _make_periodic(
            xq, x, periodx, 0, f, fx, fy, fz, fxy, fxz, fyz, fxyz
        )
        lowx = highx = True
    if periody is not None:
        yq, y, f, fx, fy, fz, fxy, fxz, fyz, fxyz = _make_periodic(
            yq, y, periody, 1, f, fx, fy, fz, fxy, fxz, fyz, fxyz
        )
        lowy = highy = True
    if periodz is not None:
        zq, z, f, fx, fy, fz, fxy, fxz, fyz, fxyz = _make_periodic(
            zq, z, periodz, 2, f, fx, fy, fz, fxy, fxz, fyz, fxyz
        )
        lowz = highz = True

    assert method in CUBIC_METHODS
    if fx is None:
        fx = approx_df(x, f, method, 0, **kwargs)
    if fy is None:
        fy = approx_df(y, f, method, 1, **kwargs)
    if fz is None:
        fz = approx_df(z, f, method, 2, **kwargs)
    if fxy is None:
        fxy = approx_df(y, fx, method, 1, **kwargs)
    if fxz is None:
        fxz = approx_df(z, fx, method, 2, **kwargs)
    if fyz is None:
        fyz = approx_df(z, fy, method, 2, **kwargs)
    if fxyz is None:
        fxyz = approx_df(z, fxy, method, 2, **kwargs)
    assert (
        fx.shape
        == fy.shape
        == fz.shape
        == fxy.shape
        == fxz.shape
        == fyz.shape
        == fxyz.shape
        == f.shape
    )
    i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1)
    j = jnp.clip(jnp.searchsorted(y, yq, side="right"), 1, len(y) - 1)
    k = jnp.clip(jnp.searchsorted(z, zq, side="right"), 1, len(z) - 1)

    dx = x[i] - x[i - 1]
    deltax = xq - x[i - 1]
    dxi = jnp.where(dx == 0, 0, 1 / dx)
    tx = deltax * dxi

    dy = y[j] - y[j - 1]
    deltay = yq - y[j - 1]
    dyi = jnp.where(dy == 0, 0, 1 / dy)
    ty = deltay * dyi

    dz = z[k] - z[k - 1]
    deltaz = zq - z[k - 1]
    dzi = jnp.where(dz == 0, 0, 1 / dz)
    tz = deltaz * dzi

    fs = OrderedDict()
    fs["f"] = f
    fs["fx"] = fx
    fs["fy"] = fy
    fs["fz"] = fz
    fs["fxy"] = fxy
    fs["fxz"] = fxz
    fs["fyz"] = fyz
    fs["fxyz"] = fxyz
    fsq = OrderedDict()
    for ff in fs.keys():
        for kk in [0, 1]:
            for jj in [0, 1]:
                for ii in [0, 1]:
                    s = ff + str(ii) + str(jj) + str(kk)
                    fsq[s] = fs[ff][i - 1 + ii, j - 1 + jj, k - 1 + kk]
                    if "x" in ff:
                        fsq[s] = (dx * fsq[s].T).T
                    if "y" in ff:
                        fsq[s] = (dy * fsq[s].T).T
                    if "z" in ff:
                        fsq[s] = (dz * fsq[s].T).T

    F = jnp.stack([foo for foo in fsq.values()], axis=0).T
    coef = jnp.vectorize(jnp.matmul, signature="(n,n),(n)->(n)")(A_TRICUBIC, F).T
    coef = jnp.moveaxis(coef.reshape((4, 4, 4, *coef.shape[1:]), order="F"), 3, 0)
    ttx = _get_t_der(tx, derivative_x, dxi)
    tty = _get_t_der(ty, derivative_y, dyi)
    ttz = _get_t_der(tz, derivative_z, dzi)
    fq = jnp.einsum("lijk...,li,lj,lk->l...", coef, ttx, tty, ttz)

    fq = _extrap(xq, fq, x, lowx, highx)
    fq = _extrap(yq, fq, y, lowy, highy)
    fq = _extrap(zq, fq, z, lowz, highz)

    return fq.reshape(outshape)


@wrap_jit(static_argnames=["axis"])
def _make_periodic(
    xq: jax.Array,
    x: jax.Array,
    period: float,
    axis: int,
    *arrs: jax.Array,
) -> tuple[jax.Array, ...]:
    """Make arrays periodic along a specified axis."""
    period = abs(period)
    xq = xq % period
    x = x % period
    i = jnp.argsort(x)
    x = x[i]
    x = jnp.concatenate([x[-1:] - period, x, x[:1] + period])
    arrlist = list(arrs)
    for k in range(len(arrlist)):
        if arrlist[k] is not None:
            arrlist[k] = jnp.take(arrlist[k], i, axis, mode="wrap")
            arrlist[k] = jnp.concatenate(
                [
                    jnp.take(arrlist[k], jnp.array([-1]), axis),
                    arrlist[k],
                    jnp.take(arrlist[k], jnp.array([0]), axis),
                ],
                axis=axis,
            )
    return (xq, x, *arrlist)


@jit
def _get_t_der(t: jax.Array, derivative: int, dxi: jax.Array):
    """Get arrays of [1,t,t^2,t^3] for cubic interpolation."""
    t0 = jnp.zeros_like(t)
    t1 = jnp.ones_like(t)
    dxi = jnp.atleast_1d(dxi)[:, None]
    # derivatives of monomials
    d0 = lambda: jnp.array([t1, t, t**2, t**3]).T * dxi**0
    d1 = lambda: jnp.array([t0, t1, 2 * t, 3 * t**2]).T * dxi
    d2 = lambda: jnp.array([t0, t0, 2 * t1, 6 * t]).T * dxi**2
    d3 = lambda: jnp.array([t0, t0, t0, 6 * t1]).T * dxi**3
    d4 = lambda: jnp.array([t0, t0, t0, t0]).T * (dxi * 0)

    return jax.lax.switch(derivative, [d0, d1, d2, d3, d4])


def _parse_ndarg(arg: Any, n: int) -> Union[Any, tuple]:
    try:
        k = len(arg)
    except TypeError:
        arg = tuple(arg for _ in range(n))
        k = n
    assert k == n, "got too many args"
    return arg


def _parse_extrap(extrap, n):
    if isbool(extrap):  # same for lower,upper in all dimensions
        return tuple(extrap for _ in range(2 * n))
    elif jnp.isscalar(extrap):
        return tuple(extrap for _ in range(2 * n))
    elif len(extrap) == 2 and jnp.isscalar(extrap[0]):  # same l,h for all dimensions
        return tuple(e for _ in range(n) for e in extrap)
    elif len(extrap) == n and all(len(extrap[i]) == 2 for i in range(n)):
        return tuple(eij for ei in extrap for eij in ei)
    else:
        raise ValueError(
            "extrap should either be a scalar, 2 element sequence (lo, hi), "
            + "or a sequence with 2 elements for each dimension"
        )


@jit
def _extrap(
    xq: jax.Array,
    fq: jax.Array,
    x: jax.Array,
    lo: Union[bool, float],
    hi: Union[bool, float],
):
    """Clamp or extrapolate values outside bounds."""

    def loclip(fq: jax.Array, lo: Union[bool, float]):
        # lo is either False (no extrapolation) or a fixed value to fill in
        if isbool(lo):
            lo = jnp.nan
        return jnp.where(xq < x[0], lo, fq.T).T

    def hiclip(fq: jax.Array, hi: Union[bool, float]):
        # hi is either False (no extrapolation) or a fixed value to fill in
        if isbool(hi):
            hi = jnp.nan
        return jnp.where(xq > x[-1], hi, fq.T).T

    def noclip(fq, *_):
        return fq

    # if extrap = True, don't clip. If it's false or numeric, clip to that value
    # isbool(x) & bool(x) is testing if extrap is True but works for np/jnp bools
    fq = jax.lax.cond(
        isbool(lo) & jnp.asarray(lo).astype(bool),
        noclip,
        loclip,
        fq,
        lo,
    )
    fq = jax.lax.cond(
        isbool(hi) & jnp.asarray(hi).astype(bool),
        noclip,
        hiclip,
        fq,
        hi,
    )

    return fq

In [None]:
@wrap_jit(static_argnames=["method"])
def interp3d(
    xq: Real[ArrayLike, " Nq"],
    yq: Real[ArrayLike, " Nq"],
    zq: Real[ArrayLike, " Nq"],
    x: Real[ArrayLike, " Nx"],
    y: Real[ArrayLike, " Ny"],
    z: Real[ArrayLike, " Nz"],
    f: Num[ArrayLike, "Nx Ny Nz ..."],
    method: str = "cubic",
    derivative: Union[int, tuple] = 0,
    extrap: Union[bool, float, tuple] = False,
    period: Union[None, float, tuple] = None,
    **kwargs,
) -> Inexact[Array, "Nq ..."]:
    """Interpolate a 3d function assuming a REGULAR grid (constant dx, dy, dz)."""

    # 1. Inputs and Checks
    xq, yq, zq, x, y, z, f = map(asarray_inexact, (xq, yq, zq, x, y, z, f))
    errorif((len(x) != f.shape[0]) or (x.ndim != 1), ValueError, "x/f mismatch")
    errorif((len(y) != f.shape[1]) or (y.ndim != 1), ValueError, "y/f mismatch")
    errorif((len(z) != f.shape[2]) or (z.ndim != 1), ValueError, "z/f mismatch")
    errorif(method not in METHODS_3D, ValueError, f"unknown method {method}")

    # Standardize Query Points
    xq, yq, zq = jnp.broadcast_arrays(xq, yq, zq)
    outshape = xq.shape + f.shape[3:]
    xq, yq, zq = map(jnp.atleast_1d, (xq, yq, zq))

    # 2. Extract Derivatives (assuming precomputed passed via kwargs)
    fx = kwargs.pop("fx", None)
    fy = kwargs.pop("fy", None)
    fz = kwargs.pop("fz", None)
    fxy = kwargs.pop("fxy", None)
    fxz = kwargs.pop("fxz", None)
    fyz = kwargs.pop("fyz", None)
    fxyz = kwargs.pop("fxyz", None)

    # 3. Handle Extrapolation / Periodicity options
    periodx, periody, periodz = _parse_ndarg(period, 3)
    derivative_x, derivative_y, derivative_z = _parse_ndarg(derivative, 3)
    lowx, highx, lowy, highy, lowz, highz = _parse_extrap(extrap, 3)

    if periodx is not None:
        xq, x, f, fx, fy, fz, fxy, fxz, fyz, fxyz = _make_periodic(
            xq, x, periodx, 0, f, fx, fy, fz, fxy, fxz, fyz, fxyz
        )
        lowx = highx = True
    if periody is not None:
        yq, y, f, fx, fy, fz, fxy, fxz, fyz, fxyz = _make_periodic(
            yq, y, periody, 1, f, fx, fy, fz, fxy, fxz, fyz, fxyz
        )
        lowy = highy = True
    if periodz is not None:
        zq, z, f, fx, fy, fz, fxy, fxz, fyz, fxyz = _make_periodic(
            zq, z, periodz, 2, f, fx, fy, fz, fxy, fxz, fyz, fxyz
        )
        lowz = highz = True

    # 4. Fill missing derivatives (only if not cached)
    # Note: For maximum speed in loop, cache these in Interpolator3D
    if fx is None:
        fx = approx_df(x, f, method, 0, **kwargs)
    if fy is None:
        fy = approx_df(y, f, method, 1, **kwargs)
    if fz is None:
        fz = approx_df(z, f, method, 2, **kwargs)
    if fxy is None:
        fxy = approx_df(y, fx, method, 1, **kwargs)
    if fxz is None:
        fxz = approx_df(z, fx, method, 2, **kwargs)
    if fyz is None:
        fyz = approx_df(z, fy, method, 2, **kwargs)
    if fxyz is None:
        fxyz = approx_df(z, fxy, method, 2, **kwargs)

    # -----------------------------------------------------------------
    # OPTIMIZATION: Arithmetic Indexing for Regular Grids
    # -----------------------------------------------------------------

    # Calculate spacing (assume constant)
    dx = x[1] - x[0]
    dy = y[1] - y[0]
    dz = z[1] - z[0]

    # Calculate indices directly (O(1)) instead of searchsorted (O(log N))
    # We clip to [0, N-2] because the interpolant uses indices [i, i+1]
    # Note: indices are 0-based relative to the array
    i = jnp.clip(jnp.floor((xq - x[0]) / dx).astype(int), 0, len(x) - 2)
    j = jnp.clip(jnp.floor((yq - y[0]) / dy).astype(int), 0, len(y) - 2)
    k = jnp.clip(jnp.floor((zq - z[0]) / dz).astype(int), 0, len(z) - 2)

    # Local coordinates t in [0, 1]
    # deltax = xq - x[i]
    # but x[i] = x[0] + i*dx
    deltax = xq - (x[0] + i * dx)
    deltay = yq - (y[0] + j * dy)
    deltaz = zq - (z[0] + k * dz)

    dxi = 1.0 / dx
    dyi = 1.0 / dy
    dzi = 1.0 / dz

    tx = deltax * dxi
    ty = deltay * dyi
    tz = deltaz * dzi

    # -----------------------------------------------------------------
    # OPTIMIZATION: Efficient Memory Access via Dynamic Slicing
    # -----------------------------------------------------------------

    # We need to extract the 2x2x2 block of data for the cubic cell.
    # Instead of fancy indexing (which can be slow/memory heavy with vmap),
    # we use dynamic_slice which is often better optimized for sliding windows.
    # However, for simple vmap over query points, integer indexing [i,j,k]
    # is the standard JAX pattern.

    # We need to reshape the arrays to handle the '...' batch dimensions
    # if f has extra dimensions.

    fs = OrderedDict()
    fs["f"] = f
    fs["fx"] = fx
    fs["fy"] = fy
    fs["fz"] = fz
    fs["fxy"] = fxy
    fs["fxz"] = fxz
    fs["fyz"] = fyz
    fs["fxyz"] = fxyz

    fsq = []

    # Iterate over the 8 corners (000, 001, 010, ..., 111)
    # We construct the 64-element local vector F for the Tricubic matrix mult
    # The order must match what A_TRICUBIC expects.
    # A_TRICUBIC expects F constructed from f, fx, fy... at corners.

    # We gather the values at [i, j, k] and [i+1, j+1, k+1]
    # To avoid passing the whole array F to a kernel, we use vmap-friendly indexing.
    # Since i,j,k are arrays (matches xq shape), JAX handles this as gather.

    # Scalings for chain rule
    # Derivatives stored in fs are df/dx. The Tricubic matrix usually expects df/dt.
    # df/dt = df/dx * dx
    scales = {
        "f": 1.0,
        "fx": dx,
        "fy": dy,
        "fz": dz,
        "fxy": dx * dy,
        "fxz": dx * dz,
        "fyz": dy * dz,
        "fxyz": dx * dy * dz,
    }

    # Order of corners for A_TRICUBIC
    # It loops: function type (f, fx...), then k(0,1), j(0,1), i(0,1)
    for key in fs.keys():
        scale = scales[key]
        arr = fs[key]

        # We need corners:
        # (i,j,k), (i+1,j,k), (i,j+1,k), (i+1,j+1,k), (i,j,k+1)...
        # But wait, A_TRICUBIC usually applies to the flattened vector
        # ordered by derivatives first, then corners.
        # Let's verify the order from previous code:
        # Loop order: key, kk, jj, ii.

        # We can optimize this gather. Instead of 8 separate gathers per key,
        # we can slice the 2x2x2 block?
        # Actually, for arbitrary batch indices i,j,k, gathering 8 points is fine.

        val000 = arr[i, j, k]
        val100 = arr[i + 1, j, k]
        val010 = arr[i, j + 1, k]
        val110 = arr[i + 1, j + 1, k]
        val001 = arr[i, j, k + 1]
        val101 = arr[i + 1, j, k + 1]
        val011 = arr[i, j + 1, k + 1]
        val111 = arr[i + 1, j + 1, k + 1]

        # Apply scaling
        fsq.extend(
            [
                val000 * scale,
                val100 * scale,
                val010 * scale,
                val110 * scale,
                val001 * scale,
                val101 * scale,
                val011 * scale,
                val111 * scale,
            ]
        )

    # F has shape (64, Nq, ...)
    F = jnp.stack(fsq, axis=0)

    # Move Nq batch dim to front to prepare for matmul: (Nq, 64, ...)
    # If f has extra dims, we treat them as batch dims for the matmul too
    # F shape currently: (64, Nq, extra...)
    # We want: (Nq, 64, extra...)
    F = jnp.moveaxis(F, 0, 1)  # Now (Nq, 64, ...)

    # A_TRICUBIC is (64, 64). We need (Nq, 64) result = A @ F
    # If F has extra dims, we want to broadcast A over them.
    # tensordot: contract axes 1 of A and 0 of F_local?
    # A shape: (64_coefs, 64_values)
    # F shape: (Nq, 64_values, ...)

    # Result coef: (64_coefs, Nq, ...)
    # But let's keep Nq first for consistency

    # Using einsum for clarity with batch dims:
    # c: coef index (64), v: value index (64), b: batch (Nq), e: extra dims
    coef = jnp.einsum("cv, bve -> bce", A_TRICUBIC, F)

    # Reshape coefs to (Nq, 4, 4, 4, ...) for x, y, z powers
    # The A_TRICUBIC order produces flattened (4,4,4) in order x, y, z?
    # We need to ensure we reshape compatible with the coefficient layout.
    # Standard tricubic layout is typically x fastest, then y, then z, or reverse.
    # Let's assume standard C-order reshape matches the A_TRICUBIC generation.
    coef = coef.reshape((xq.shape[0], 4, 4, 4) + f.shape[3:])

    # Compute basis vectors [1, t, t^2, t^3] and derivs
    ttx = _get_t_der(tx, derivative_x, dxi)  # Shape (Nq, 4)
    tty = _get_t_der(ty, derivative_y, dyi)  # Shape (Nq, 4)
    ttz = _get_t_der(tz, derivative_z, dzi)  # Shape (Nq, 4)

    # Contract: coef[b, i, j, k, ...] * ttx[b, i] * tty[b, j] * ttz[b, k]
    fq = jnp.einsum("bijke, bi, bj, bk -> be", coef, ttx, tty, ttz)

    # 5. Extrapolate
    fq = _extrap(xq, fq, x, lowx, highx)
    fq = _extrap(yq, fq, y, lowy, highy)
    fq = _extrap(zq, fq, z, lowz, highz)

    return fq.reshape(outshape)