In [1]:
import itertools

import numpy as np
import jax.numpy as jnp
from jax import vmap, jit
from jax.config import config
from jax.scipy.ndimage import map_coordinates

from lcm.interpolation import linear_interpolation
from numpy.testing import assert_array_almost_equal as aaae
from scipy.interpolate import interp1d, interp2d, RegularGridInterpolator


config.update("jax_enable_x64", True)

In [2]:
def linear_linspace_interpolation(values, point, starts, step_lengths):
    """Specialized version of linear interpolation to linspaces.

    This is meant as a comparison to check that the dynamic function
    calling doest not cause any performance penalty.
    """
    mapped_point = (point - starts) / step_lengths

    res = map_coordinates(
        input=values, coordinates=mapped_point, order=1, mode="nearest"
    )

    return res

In [3]:
def f(a, b):
    return 2 * a ** 3 + 3 * b ** 2


def g(a, b, c):
    return f(a, b) + c


def h(a, b, c, d):
    return g(a, b, c) - d ** 5

## 1d Benchmarks

In [4]:
def get_1d_inputs(n_grid=1000, n_inter=100_000, target="jax_linspace"):
    np.random.seed(1234)
    grid = np.linspace(1, 5, n_grid)
    values = grid ** 1.1 + np.sin(grid)
    points = np.linspace(1, 4, n_inter)
    np.random.shuffle(points)
    
    if target == "jax_linspace":
        vmapped = vmap(linear_linspace_interpolation, in_axes=(None, 0, None, None))
        func = jit(vmapped)
        args = (
            jnp.array(values), 
            jnp.array(points.reshape(-1, 1)), 
            jnp.array([1.]), grid[1] - grid[0]
        )
    
    elif target == "jax":
        vmapped = vmap(linear_interpolation, in_axes=(None, 0, None))
        func = jit(vmapped, static_argnums=2)
        grid_info = (("linspace", (1, 5, n_grid)),)
        args = (
            jnp.array(values), 
            jnp.array(points.reshape(-1, 1)), 
            grid_info)
        
    elif target == "scipy_interp1d":
        func = interp1d(grid, values)
        args = (points,)
        
    else:
        raise ValueError()
        
    return func, args

In [5]:
func, args = get_1d_inputs(target="scipy_interp1d")
scipy_res_1d = func(*args)
%timeit func(*args)

7.74 ms ± 43.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [6]:
func, args = get_1d_inputs(target="jax_linspace")
jax_linspace_res = func(*args)
%timeit func(*args)



253 µs ± 19.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [7]:
func, args = get_1d_inputs(target="jax")
jax_res = func(*args)
%timeit func(*args)

159 µs ± 9.09 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [8]:
aaae(scipy_res_1d, jax_linspace_res)
aaae(scipy_res_1d, jax_res)

## 2d Benchmarks

In [9]:
def get_2d_inputs(n_grid=200, n_inter=100_000, target="jax_linspace"):
    np.random.seed(1234)

    grid1 = np.linspace(1, 5, n_grid)
    grid2 = np.linspace(-1, 5, n_grid)

    values = f(*np.meshgrid(grid1, grid2, indexing="ij", sparse=False))

    root_n_inter = int(np.sqrt(n_inter))

    inter_grid1 = np.linspace(1.33, 4.11, root_n_inter)
    inter_grid2 = np.linspace(-0.66, 3.79, root_n_inter)

    points = np.array(list(itertools.product(inter_grid1, inter_grid2)))

    if target == "jax_linspace":
        vmapped = vmap(linear_linspace_interpolation, in_axes=(None, 0, None, None))
        func = jit(vmapped)
        step_sizes = jnp.array([grid1[1] - grid1[0], grid2[1] - grid2[0]])
        args = (jnp.array(values), jnp.array(points), jnp.array([1.0, -1]), step_sizes)

    elif target == "jax":
        vmapped = vmap(linear_interpolation, in_axes=(None, 0, None))
        func = jit(vmapped, static_argnums=2)
        grid_info = (("linspace", (1, 5, n_grid)), ("linspace", (-1, 5, n_grid)))
        args = (jnp.array(values), jnp.array(points), grid_info)
        
    elif target == "scipy_interp2d":
        func = interp2d(x=grid2, y=grid1, z=values)
        args = (inter_grid1, inter_grid2)
        
    elif target == "scipy":
        func = RegularGridInterpolator(
            points=(grid1, grid2), values=values, method="linear"
        )
        args = points

    else:
        raise ValueError()

    return func, args

In [10]:
func, args = get_2d_inputs(target="scipy")
scipy_res = func(args)
%timeit func(args)

11.3 ms ± 82.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [11]:
func, args = get_2d_inputs(target="scipy_interp2d")
scipy_res_2d = func(*args)
%timeit func(*args)

951 µs ± 11 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [12]:
func, args = get_2d_inputs(target="jax_linspace")
jax_linspace_res = func(*args)
%timeit func(*args)

481 µs ± 20.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [13]:
func, args = get_2d_inputs(target="jax")
jax_res = func(*args)
%timeit func(*args)

325 µs ± 21.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [14]:
aaae(scipy_res, jax_linspace_res)
aaae(scipy_res, jax_res)

## 3d Benchmarks

In [15]:
def get_3d_inputs(n_grid=200, n_inter=100_000, target="jax_linspace"):
    np.random.seed(1234)

    grid1 = np.linspace(1, 5, n_grid)
    grid2 = np.linspace(-1, 5, n_grid)
    grid3 = np.linspace(2, 7, n_grid)

    values = g(*np.meshgrid(grid1, grid2, grid3, indexing="ij", sparse=False))

    root_n_inter = int(np.sqrt(n_inter))

    inter_grid1 = np.linspace(1.33, 4.11, root_n_inter)
    inter_grid2 = np.linspace(-0.66, 3.79, root_n_inter)
    inter_grid3 = np.linspace(2.07, 6.99, root_n_inter)

    points = np.array(list(itertools.product(inter_grid1, inter_grid2, inter_grid3)))

    if target == "jax_linspace":
        vmapped = vmap(linear_linspace_interpolation, in_axes=(None, 0, None, None))
        func = jit(vmapped)
        step_sizes = jnp.array(
            [grid1[1] - grid1[0], grid2[1] - grid2[0], grid3[1] - grid3[0]]
        )
        args = (
            jnp.array(values),
            jnp.array(points),
            jnp.array([1.0, -1, 2]),
            step_sizes,
        )

    elif target == "jax":
        vmapped = vmap(linear_interpolation, in_axes=(None, 0, None))
        func = jit(vmapped, static_argnums=2)
        grid_info = (
            ("linspace", (1, 5, n_grid)),
            ("linspace", (-1, 5, n_grid)),
            ("linspace", (2, 7, n_grid)),
        )
        args = (jnp.array(values), jnp.array(points), grid_info)

    elif target == "scipy":
        func = RegularGridInterpolator(
            points=(grid1, grid2, grid3), values=values, method="linear"
        )
        args = points

    else:
        raise ValueError()

    return func, args

In [16]:
func, args = get_3d_inputs(n_grid=100, n_inter=1_000, target="scipy")
scipy_res = func(args)
%timeit -r 7 -n 1000 func(args)

6.94 ms ± 430 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [17]:
func, args = get_3d_inputs(n_grid=100, n_inter=1_000, target="jax_linspace")
jax_linspace_res = func(*args)
%timeit -r 7 -n 1000 func(*args)

700 µs ± 57.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [18]:
func, args = get_3d_inputs(n_grid=100, n_inter=1_000, target="jax")
jax_res = func(*args)
%timeit -r 7 -n 1000 func(*args)

526 µs ± 11.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [19]:
aaae(scipy_res, jax_linspace_res)
aaae(scipy_res, jax_res)

## 4d Benchmarks

In [20]:
def get_4d_inputs(n_grid=200, n_inter=100_000, target="jax_linspace"):
    np.random.seed(1234)

    grid1 = np.linspace(1, 5, n_grid)
    grid2 = np.linspace(-1, 5, n_grid)
    grid3 = np.linspace(2, 7, n_grid)
    grid4 = np.linspace(-3, 7, n_grid)

    values = h(*np.meshgrid(grid1, grid2, grid3, grid4, indexing="ij", sparse=False))

    root_n_inter = int(np.sqrt(n_inter))

    inter_grid1 = np.linspace(1.33, 4.11, root_n_inter)
    inter_grid2 = np.linspace(-0.66, 3.79, root_n_inter)
    inter_grid3 = np.linspace(2.07, 6.99, root_n_inter)
    inter_grid4 = np.linspace(-2.84, 4.77, root_n_inter)

    points = np.array(
        list(itertools.product(inter_grid1, inter_grid2, inter_grid3, inter_grid4))
    )

    if target == "jax_linspace":
        vmapped = vmap(linear_linspace_interpolation, in_axes=(None, 0, None, None))
        func = jit(vmapped)
        step_sizes = jnp.array(
            [
                grid1[1] - grid1[0],
                grid2[1] - grid2[0],
                grid3[1] - grid3[0],
                grid4[1] - grid4[0],
            ]
        )
        args = (
            jnp.array(values),
            jnp.array(points),
            jnp.array([1.0, -1, 2, -3.0]),
            step_sizes,
        )

    elif target == "jax":
        vmapped = vmap(linear_interpolation, in_axes=(None, 0, None))
        func = jit(vmapped, static_argnums=2)
        grid_info = (
            ("linspace", (1, 5, n_grid)),
            ("linspace", (-1, 5, n_grid)),
            ("linspace", (2, 7, n_grid)),
            ("linspace", (-3, 7, n_grid)),
        )
        args = (jnp.array(values), jnp.array(points), grid_info)

    elif target == "scipy":
        func = RegularGridInterpolator(
            points=(grid1, grid2, grid3, grid4), values=values, method="linear"
        )
        args = points

    else:
        raise ValueError()

    return func, args

In [22]:
func, args = get_4d_inputs(n_grid=100, n_inter=1_000, target="scipy")
scipy_res = func(args)
%timeit -r 7 -n 100 func(args)

670 ms ± 8.03 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [23]:
func, args = get_4d_inputs(n_grid=100, n_inter=1_000, target="jax_linspace")
jax_linspace_res = func(*args)
%timeit -r 7 -n 100 func(*args)

42.9 ms ± 5.52 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [24]:
func, args = get_4d_inputs(n_grid=100, n_inter=1_000, target="jax")
jax_res = func(*args)
%timeit -r 7 -n 100 func(*args)

27.3 ms ± 3.2 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [25]:
aaae(scipy_res, jax_linspace_res)
aaae(scipy_res, jax_res)