Notebook for speeding up `interp2d`. Assuming the grids are regular, swapping `searchsorted` for `ceil(...)` gives a marginal speedup in some cases.

In [181]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [183]:
from math import pi
from typing import Optional

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import RectBivariateSpline, interp2d as sinterp2d

from jaxinterp2d import interp2d

Array = jnp.ndarray

In [184]:
xp = jnp.linspace(-1, 1, 2000)
yp = jnp.linspace(-1, 1, 100)
zp = jnp.exp(-(jnp.stack((jnp.meshgrid(xp, yp))) ** 2).sum(0) / 2).T / (2 * pi)

In [185]:
def searchsorted_right_grid(xp: Array, x: Array) -> Array:
    # Equivalent to `jnp.searchsorted(xp, x, side="right")` when `xp`
    # is an evenly-spaced array.
    xn = (x - xp.min()) / (xp.max() - xp.min())
    return jnp.ceil(xn * (len(xp) - 1)).astype(int)


def interp2d_new(
    x: Array,
    y: Array,
    xp: Array,
    yp: Array,
    zp: Array,
    fill_value: float = jnp.nan,
) -> Array:
    # Assumes xp and yp are regularly-spaced grids
    ix = jnp.clip(searchsorted_right_grid(xp, x), 1, len(xp) - 1)
    iy = jnp.clip(searchsorted_right_grid(yp, y), 1, len(yp) - 1)

    x1 = xp[ix - 1]
    x2 = xp[ix]
    y1 = yp[iy - 1]
    y2 = yp[iy]
    z11 = zp[ix - 1, iy - 1]
    z21 = zp[ix, iy - 1]
    z12 = zp[ix - 1, iy]
    z22 = zp[ix, iy]
    z = (
        z11 * (x2 - x) * (y2 - y)
        + z21 * (x - x1) * (y2 - y)
        + z12 * (x2 - x) * (y - y1)
        + z22 * (x - x1) * (y - y1)
    ) / ((x2 - x1) * (y2 - y1))
    return z

In [196]:
fn_interp = jax.jit(lambda x, y: interp2d(x, y, xp, yp, zp))
fn_interp_new = jax.jit(lambda x, y: interp2d_new(x, y, xp, yp, zp))
fn_interp_scipy = RectBivariateSpline(xp, yp, zp)

Benchmarking

In [212]:
xs = jnp.array(np.random.rand(500) * 2 - 1)
ys = jnp.array(np.random.rand(500) * 2 - 1)
jnp.allclose(fn_interp(xs, ys), fn_interp_new(xs, ys))

DeviceArray(True, dtype=bool)

In [213]:
jnp.allclose(fn_interp(xs, ys), jnp.array(fn_interp_scipy(xs, ys, grid=False)))

DeviceArray(False, dtype=bool)

In [214]:
%timeit fn_interp(*jnp.array(np.random.rand(2, 1000) * 2 - 1)).block_until_ready()

329 µs ± 45.5 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [215]:
%timeit fn_interp_new(*jnp.array(np.random.rand(2, 1000) * 2 - 1)).block_until_ready()

322 µs ± 13.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [220]:
%timeit fn_interp_scipy(*(np.random.rand(2, 1000) * 2 - 1), grid=False)

641 µs ± 5.36 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
