# JAX timing: 2D vs N-D interpolation (self-contained)

This notebook compares the runtime of two interpolation implementations used in dcegm:
- Specialized 2D implementation (regular grid × irregular wealth grid).
- General N-D implementation with exactly one irregular axis (wealth).

We mirror the setup from the tests `test_interpNd_policy_matches_2d_impl` and `test_interpNd_value_matches_2d_impl`, JIT + VMAP the batched calls, and time both with `jax.block_until_ready` and `%timeit`.

In [13]:
import os
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit

jax.config.update("jax_enable_x64", True)
np.random.seed(1234)

print(jax.default_backend())

cpu


## Helpers: 1D primitives and indexing
Minimal helpers adapted from the project's interpolation utilities.

In [14]:
def get_index_high_and_low(x: jnp.ndarray, x_new: float | jnp.ndarray):
    """Get index of the highest value in x that is smaller than x_new.

    Returns (ind_high, ind_low).
    """
    ind_high = jnp.searchsorted(x, x_new).clip(max=(x.shape[0] - 1), min=1)
    ind_high -= jnp.isnan(x[ind_high]).astype(int)
    return ind_high, ind_high - 1


def _interp1d_one(xrow: jnp.ndarray, vrow: jnp.ndarray, xnew: float | jnp.ndarray):
    hi, lo = get_index_high_and_low(xrow, xnew)
    lo_v, hi_v = vrow[lo], vrow[hi]
    lo_x, hi_x = xrow[lo], xrow[hi]
    denom = jnp.maximum(hi_x - lo_x, jnp.finfo(xrow.dtype).eps)
    s = (xnew - lo_x) / denom
    return lo_v + s * (hi_v - lo_v)


def _indices_weight_regular(regular_grid: jnp.ndarray, y: float | jnp.ndarray):
    hi, lo = get_index_high_and_low(regular_grid, y)
    g = regular_grid
    denom = jnp.maximum(g[hi] - g[lo], jnp.finfo(g.dtype).eps)
    t = (y - g[lo]) / denom
    return lo, hi, t

## N-D interpolation (1 irregular axis)
Copied and lightly adapted from `interpNd.py`. This handles arbitrary R regular axes with a single irregular axis (wealth) at the end.

In [15]:
import jax
import jax.numpy as jnp


def _regular_indices_and_weights_static(regular_grids, regular_point):
    R = len(regular_grids)
    idx_lo = []
    idx_hi = []
    ts = []
    for i in range(R):
        hi, lo = get_index_high_and_low(regular_grids[i], regular_point[i])
        g = regular_grids[i]
        denom = jnp.maximum(g[hi] - g[lo], jnp.finfo(g.dtype).eps)
        t = (regular_point[i] - g[lo]) / denom
        idx_lo.append(lo)
        idx_hi.append(hi)
        ts.append(t)
    return jnp.array(idx_lo), jnp.array(idx_hi), jnp.array(ts)


def _enumerate_corners(R: int) -> jnp.ndarray:
    C = 1 << R
    ks = jnp.arange(C, dtype=jnp.uint32)[:, None]
    axes = jnp.arange(R, dtype=jnp.uint32)[None, :]
    return (ks >> axes) & 1


def _flat_strides(dims):
    if not dims:
        return jnp.array([], dtype=jnp.int32)
    return jnp.concatenate(
        [
            jnp.array([1], dtype=jnp.int32),
            jnp.cumprod(jnp.array(dims[:-1], dtype=jnp.int32)),
        ]
    )


def interpNd_one_irregular(
    regular_grids, irregular_grid, values, regular_point, irregular_point
):
    R = len(regular_grids)
    dims = tuple(g.shape[0] for g in regular_grids)
    nW = irregular_grid.shape[-1]
    idx_lo, idx_hi, t = _regular_indices_and_weights_static(
        regular_grids, regular_point
    )
    sel = _enumerate_corners(R).astype(jnp.int32)
    idx_lo_b = jnp.broadcast_to(idx_lo, sel.shape)
    idx_hi_b = jnp.broadcast_to(idx_hi, sel.shape)
    corner_idx = jnp.where(sel == 0, idx_lo_b, idx_hi_b)
    t_b = jnp.broadcast_to(t, sel.shape)
    w_axes = jnp.where(sel == 0, 1.0 - t_b, t_b)
    w_corners = jnp.prod(w_axes, axis=1)
    strides = _flat_strides(dims)
    flat_idx = jnp.sum(corner_idx * strides, axis=1)
    irr_flat = irregular_grid.reshape((-1, nW))
    val_flat = values.reshape((-1, nW))
    irr_sel = irr_flat[flat_idx]
    val_sel = val_flat[flat_idx]

    def interp1d_row(xrow, vrow, xnew):
        hi, lo = get_index_high_and_low(xrow, xnew)
        lo_v = vrow[lo]
        hi_v = vrow[hi]
        lo_x = xrow[lo]
        hi_x = xrow[hi]
        denom = jnp.maximum(hi_x - lo_x, jnp.finfo(xrow.dtype).eps)
        s = (xnew - lo_x) / denom
        return lo_v + s * (hi_v - lo_v)

    z_corner = jax.vmap(interp1d_row, in_axes=(0, 0, None))(
        irr_sel, val_sel, irregular_point
    )
    return jnp.sum(w_corners * z_corner)


def interpNd_policy(
    regular_grids, wealth_grid, policy_grid, regular_point, wealth_point
):
    return interpNd_one_irregular(
        regular_grids, wealth_grid, policy_grid, regular_point, wealth_point
    )


def interpNd_value_with_cc(
    regular_grids,
    wealth_grid,
    value_grid,
    regular_point,
    wealth_point,
    compute_utility,
    state_choice_vec,
    params,
    discount_factor,
):
    R = len(regular_grids)
    dims = tuple(g.shape[0] for g in regular_grids)
    nW = wealth_grid.shape[-1]
    idx_lo, idx_hi, t = _regular_indices_and_weights_static(
        regular_grids, regular_point
    )
    sel = _enumerate_corners(R).astype(jnp.int32)
    idx_lo_b = jnp.broadcast_to(idx_lo, sel.shape)
    idx_hi_b = jnp.broadcast_to(idx_hi, sel.shape)
    corner_idx = jnp.where(sel == 0, idx_lo_b, idx_hi_b)
    strides = _flat_strides(dims)
    flat_idx = jnp.sum(corner_idx * strides, axis=1)
    wealth_min_unconstrained = wealth_grid[..., 1]
    value_at_zero_wealth = value_grid[..., 0]
    w_min_flat = wealth_min_unconstrained.reshape((-1,))
    v0_flat = value_at_zero_wealth.reshape((-1,))
    w_min_sel = w_min_flat[flat_idx]
    v0_sel = v0_flat[flat_idx]
    v_cc = (
        compute_utility(
            consumption=wealth_point,
            params=params,
            continuous_state=regular_point,
            **state_choice_vec,
        )
        + discount_factor * v0_sel
    )
    val_flat = value_grid.reshape((-1, nW))
    val_sel = val_flat[flat_idx]
    z_corner_unconstrained = jax.vmap(
        lambda xrow, vrow: _interp1d_one(xrow, vrow, wealth_point)
    )(wealth_grid.reshape((-1, nW))[flat_idx], val_sel)
    constrained = wealth_point <= w_min_sel
    z_corner = jnp.where(constrained, v_cc, z_corner_unconstrained)
    t_b = jnp.broadcast_to(t, sel.shape)
    w_axes = jnp.where(sel == 0, 1.0 - t_b, t_b)
    w_corners = jnp.prod(w_axes, axis=1)
    return jnp.sum(w_corners * z_corner)

## Specialized 2D interpolation (1 regular axis + 1 irregular wealth axis)
This version avoids corner enumeration and directly interpolates along the two neighboring regular rows, then blends.

In [16]:
def interp2d_policy_on_wealth_and_regular_grid(
    regular_grid: jnp.ndarray,
    wealth_grid: jnp.ndarray,
    policy_grid: jnp.ndarray,
    regular_point_to_interp: float | jnp.ndarray,
    wealth_point_to_interp: float | jnp.ndarray,
):
    lo, hi, t = _indices_weight_regular(regular_grid, regular_point_to_interp)
    v_lo = _interp1d_one(wealth_grid[lo], policy_grid[lo], wealth_point_to_interp)
    v_hi = _interp1d_one(wealth_grid[hi], policy_grid[hi], wealth_point_to_interp)
    return (1.0 - t) * v_lo + t * v_hi


def interp2d_value_on_wealth_and_regular_grid(
    regular_grid: jnp.ndarray,
    wealth_grid: jnp.ndarray,
    value_grid: jnp.ndarray,
    regular_point_to_interp: float | jnp.ndarray,
    wealth_point_to_interp: float | jnp.ndarray,
    compute_utility,
    state_choice_vec,
    params,
    discount_factor,
):
    lo, hi, t = _indices_weight_regular(regular_grid, regular_point_to_interp)
    wmin_lo = wealth_grid[lo, 1]
    wmin_hi = wealth_grid[hi, 1]
    v0_lo = value_grid[lo, 0]
    v0_hi = value_grid[hi, 0]
    x = wealth_point_to_interp
    # per-row values with credit-constraint replacement
    uncon_lo = _interp1d_one(wealth_grid[lo], value_grid[lo], x)
    uncon_hi = _interp1d_one(wealth_grid[hi], value_grid[hi], x)
    cc_val = compute_utility(
        consumption=x,
        params=params,
        continuous_state=regular_point_to_interp,
        **state_choice_vec,
    )
    val_lo = jnp.where(x <= wmin_lo, cc_val + discount_factor * v0_lo, uncon_lo)
    val_hi = jnp.where(x <= wmin_hi, cc_val + discount_factor * v0_hi, uncon_hi)
    return (1.0 - t) * val_lo + t * val_hi

## Synthetic data and utility function
Match the structure used in the unit tests.

In [17]:
def make_test_case(n_regular=10, n_wealth=100, x_low=1.0, x_high=100.0):
    # functional form like in tests
    a, b = np.random.uniform(1, 10), np.random.uniform(1, 10)

    def f(x, y):
        return a + np.log((x + y) * b)

    irregular_grids = np.empty((n_regular, n_wealth))
    for k in range(n_regular):
        irregular_grids[k, :] = np.sort(
            np.exp(np.random.uniform(1, np.log(x_high), n_wealth))
        )
    regular_grid = np.linspace(1e-8, x_high, n_regular)
    regular_grids_tiled = np.column_stack([regular_grid for _ in range(n_wealth)])

    policy = f(irregular_grids, regular_grids_tiled)
    value = policy * 3.5

    # queries
    test_x = np.random.uniform(30, 40, 5000)  # wealth queries
    test_y = np.random.choice(regular_grid, 5000)  # regular grid queries

    return {
        "regular_grid": jnp.array(regular_grid),
        "wealth_grid": jnp.array(irregular_grids),
        "policy_grid": jnp.array(policy),
        "value_grid": jnp.array(value),
        "test_x": jnp.array(test_x),
        "test_y": jnp.array(test_y),
    }


# simple CRRA-like utility compatible with signatures used above
PARAMS = {"discount_factor": 0.95, "rho": 1.5}


def compute_utility(consumption, params, continuous_state=None, choice=0):
    rho = params.get("rho", 1.5)
    c = jnp.maximum(consumption, 1e-10)
    # Use JAX conditional to avoid Python boolean conversion of a tracer
    return jnp.where(
        jnp.isclose(rho, 1.0),
        jnp.log(c),
        (c ** (1 - rho) - 1) / (1 - rho),
    )


data = make_test_case()
{k: v.shape if isinstance(v, jnp.ndarray) else type(v) for k, v in data.items()}

{'regular_grid': (10,),
 'wealth_grid': (10, 100),
 'policy_grid': (10, 100),
 'value_grid': (10, 100),
 'test_x': (5000,),
 'test_y': (5000,)}

## Correctness sanity check
Compare outputs from both implementations for a small batch.

In [18]:
rg = data["regular_grid"]
wg = data["wealth_grid"]
pg = data["policy_grid"]
vg = data["value_grid"]
xs = data["test_x"][:16]
ys = data["test_y"][:16]

# policy


def nd_policy_point(x_in, y_in):
    return interpNd_policy([rg], wg, pg, jnp.array([y_in]), x_in)


def d2_policy_point(x_in, y_in):
    return interp2d_policy_on_wealth_and_regular_grid(rg, wg, pg, y_in, x_in)


p_nd = jax.vmap(nd_policy_point)(xs, ys)
p_2d = jax.vmap(d2_policy_point)(xs, ys)
print("policy max abs diff:", jnp.max(jnp.abs(p_nd - p_2d)))

# value with CC
disc = PARAMS["discount_factor"]
state_choice = {"choice": 0}


def nd_value_point(x_in, y_in):
    return interpNd_value_with_cc(
        [rg],
        wg,
        vg,
        jnp.array([y_in]),
        x_in,
        compute_utility,
        state_choice,
        PARAMS,
        disc,
    )


def d2_value_point(x_in, y_in):
    return interp2d_value_on_wealth_and_regular_grid(
        rg, wg, vg, y_in, x_in, compute_utility, state_choice, PARAMS, disc
    )


v_nd = jax.vmap(nd_value_point)(xs, ys)
v_2d = jax.vmap(d2_value_point)(xs, ys)
print("value max abs diff:", jnp.max(jnp.abs(v_nd - v_2d)))

policy max abs diff: 0.0
value max abs diff: 0.0


## Batched functions (JIT + VMAP)
We compile a vectorized function for fair timing.

In [19]:
# Batched policy
def nd_policy_batched(xs, ys):
    f = lambda x_in, y_in: interpNd_policy([rg], wg, pg, jnp.array([y_in]), x_in)
    return jax.vmap(f)(xs, ys)


def d2_policy_batched(xs, ys):
    f = lambda x_in, y_in: interp2d_policy_on_wealth_and_regular_grid(
        rg, wg, pg, y_in, x_in
    )
    return jax.vmap(f)(xs, ys)


nd_policy_compiled = jit(nd_policy_batched)
d2_policy_compiled = jit(d2_policy_batched)

# Batched value with CC


def nd_value_batched(xs, ys):
    f = lambda x_in, y_in: interpNd_value_with_cc(
        [rg],
        wg,
        vg,
        jnp.array([y_in]),
        x_in,
        compute_utility,
        state_choice,
        PARAMS,
        disc,
    )
    return jax.vmap(f)(xs, ys)


def d2_value_batched(xs, ys):
    f = lambda x_in, y_in: interp2d_value_on_wealth_and_regular_grid(
        rg, wg, vg, y_in, x_in, compute_utility, state_choice, PARAMS, disc
    )
    return jax.vmap(f)(xs, ys)


nd_value_compiled = jit(nd_value_batched)
d2_value_compiled = jit(d2_value_batched)

# Warmup + correctness on full batch
X = data["test_x"]
Y = data["test_y"]
jax.block_until_ready(nd_policy_compiled(X, Y))
jax.block_until_ready(d2_policy_compiled(X, Y))
jax.block_until_ready(nd_value_compiled(X, Y))
jax.block_until_ready(d2_value_compiled(X, Y))

p_nd_full = nd_policy_compiled(X, Y)
p_2d_full = d2_policy_compiled(X, Y)
v_nd_full = nd_value_compiled(X, Y)
v_2d_full = d2_value_compiled(X, Y)
print("policy full max abs diff:", jnp.max(jnp.abs(p_nd_full - p_2d_full)))
print("value  full max abs diff:", jnp.max(jnp.abs(v_nd_full - v_2d_full)))

policy full max abs diff: 0.0
value  full max abs diff: 0.0


## Timing: policy interpolation
We ensure results are realized with `block_until_ready` and use `%timeit` to report execution time (excluding compilation where noted).

In [20]:
# One-time compilation is already done above (warmup).
jax.block_until_ready(nd_policy_compiled(X, Y))
jax.block_until_ready(d2_policy_compiled(X, Y))
%timeit jax.block_until_ready(nd_policy_compiled(X, Y))
%timeit jax.block_until_ready(d2_policy_compiled(X, Y))

883 μs ± 57.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
803 μs ± 21.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
803 μs ± 21.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Timing: value interpolation with credit-constraint handling

In [21]:
jax.block_until_ready(nd_value_compiled(X, Y))
jax.block_until_ready(d2_value_compiled(X, Y))
%timeit jax.block_until_ready(nd_value_compiled(X, Y))
%timeit jax.block_until_ready(d2_value_compiled(X, Y))

500 μs ± 22.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
581 μs ± 18.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
581 μs ± 18.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### Notes
- The first call includes compilation time; subsequent `%timeit` cells measure steady-state execution.
- Adjust `make_test_case` sizes to stress-test performance. Larger batches emphasize throughput advantages of vectorized code.
- On GPU/TPU, results can change; ensure device visibility before running.