In [180]:
from jax.config import config
config.update("jax_enable_x64", True)

from jax.scipy.ndimage import map_coordinates
import jax.numpy as jnp
from jax import vmap, jit
import numpy as np
from numpy.testing import assert_array_almost_equal as aaae
from scipy.interpolate import interp1d, interp2d
import itertools
from lcm.interpolation import linear_interpolation

In [181]:
def linear_linspace_interpolation(values, point, starts, step_lengths):
    """Specialized version of linear interpolation to linspaces.
    
    This is meant as a comparison to check that the dynamic function
    calling doest not cause any performance penalty.
    
    """
    mapped_point = (point - starts) / step_lengths

    res = map_coordinates(
        input=values,
        coordinates=mapped_point,
        order=1,
        mode="nearest"
    )

    return res

## 1d Benchmarks

In [182]:
def get_1d_inputs(n_grid=1000, n_inter=10_000, target="jax_linspace"):
    np.random.seed(1234)
    grid = np.linspace(1, 5, n_grid)
    values = grid ** 1.1 + np.sin(grid)
    points = np.linspace(1, 4, n_inter)
    np.random.shuffle(points)
    
    if target == "jax_linspace":
        vmapped = vmap(linear_linspace_interpolation, in_axes=(None, 0, None, None))
        func = jit(vmapped)
        args = (
            jnp.array(values), 
            jnp.array(points.reshape(-1, 1)), 
            jnp.array([1.]), grid[1] - grid[0]
        )
    
    elif target == "jax":
        vmapped = vmap(linear_interpolation, in_axes=(None, 0, None))
        func = jit(vmapped, static_argnums=2)
        grid_info = (("linspace", (1, 5, n_grid)),)
        args = (
            jnp.array(values), 
            jnp.array(points.reshape(-1, 1)), 
            grid_info)
    elif target == "scipy":
        func = interp1d(grid, values)
        args = (points,)
    else:
        raise ValueError()
        
    return func, args


def get_1d_runtime(n_grid=1000, n_inter=10_000, target="jax_linespace"):
    func, args = get_1d_inputs(n_grid, n_inter, target)
    # firs evaluation for jit overhead
    func(*args)
    timeit_res = %timeit -o func(*args)
    runtime = timeit_res.average
    return runtime

In [183]:
func, args = get_1d_inputs(target="scipy")
%timeit func(*args)
scipy_res = func(*args)

752 µs ± 15.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [184]:
func, args = get_1d_inputs(target="jax_linspace")
%timeit func(*args).block_until_ready()
linspace_res = func(*args)

58.6 µs ± 1.55 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [185]:
func, args = get_1d_inputs(target="jax")
%timeit func(*args).block_until_ready()
jax_res = func(*args)

50.7 µs ± 870 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [186]:
aaae(scipy_res, linspace_res)
aaae(scipy_res, jax_res)

## 2d Benchmarks

In [187]:
def get_2d_inputs(n_grid=200, n_inter=100_000, target="jax_linspace"):
    np.random.seed(1234)
    grid1 = np.linspace(1, 5, n_grid)
    grid2 = np.linspace(-1, 5, n_grid)
    
    root_n_inter = int(np.sqrt(n_inter))
    
    inter_grid1 = np.linspace(1.33, 4.11, root_n_inter)
    inter_grid2 = np.linspace(-0.66, 3.79, root_n_inter)
    
    points = np.array(list(itertools.product(inter_grid1, inter_grid2)))
    
    product_arr = np.array(list(itertools.product(grid1, grid2)))
    values = (product_arr ** 2).sum(axis=1).reshape(n_grid, n_grid)
    
    if target == "jax_linspace":
        vmapped = vmap(linear_linspace_interpolation, in_axes=(None, 0, None, None))
        func = jit(vmapped)
        step_sizes=jnp.array([grid1[1] - grid1[0], grid2[1] - grid2[0]])
        args = (
            jnp.array(values),
            jnp.array(points), 
            jnp.array([1., -1]), 
            step_sizes)
    
    elif target == "jax":
        vmapped = vmap(linear_interpolation, in_axes=(None, 0, None))
        func = jit(vmapped, static_argnums=2)
        grid_info = (
            ("linspace", (1, 5, n_grid)),
            ("linspace", (-1, 5, n_grid))
        )
        args = (jnp.array(values), jnp.array(points), grid_info)
    elif target == "scipy":
        func = interp2d(x=grid2, y=grid1, z=values)
        args = (inter_grid1, inter_grid2)
    else:
        raise ValueError()
        
    return func, args
    




In [188]:
func, args = get_2d_inputs(target="scipy")
scipy_res = func(*args)
%timeit func(*args)

753 µs ± 5.34 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [189]:
func, args = get_2d_inputs(target="jax_linspace")
linspace_res = func(*args)
%timeit func(*args).block_until_ready()

281 µs ± 16.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [190]:
jax_res = get_2d_inputs(target="jax")
jax_res = func(*args)
%timeit func(*args).block_until_ready()

282 µs ± 11.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [191]:
import pandas as pd
df = pd.DataFrame()
df["scipy"] = scipy_res.flatten()
df["linspace"] = linspace_res
df["jax"] = jax_res
df.corr()

Unnamed: 0,scipy,linspace,jax
scipy,1.0,0.96638,0.96638
linspace,0.96638,1.0,1.0
jax,0.96638,1.0,1.0
