# ToDo:
- try around with monotone scale
- implement + test monotone scale

# Insights:
- if order > 1, all entries matter (not only cubical)
- basis of log-scale is irrelevant

In [None]:
%load_ext autoreload
%autoreload 2

from jax.config import config

config.update("jax_enable_x64", True)

from jax.scipy.ndimage import map_coordinates
import jax.numpy as jnp
from jax import vmap, jit
import numpy as np
from numpy.testing import assert_array_almost_equal as aaae
from scipy.interpolate import interp1d
from scipy.interpolate import RegularGridInterpolator

import itertools
from lcm.interpolation import linear_interpolation
from lcm.grids import (
    get_linspace_coordinate,
    get_logspace_coordinate,
    logspace,
)

In [None]:
import jax.numpy as jnp
import lcm.grids as grids_module
from jax.scipy.ndimage import map_coordinates


def linear_interpolation(values, point, grid_info):
    """"""

    mapped_values = []
    for i, (grid_type, args) in enumerate(grid_info):
        func = getattr(grids_module, f"get_{grid_type}_coordinate")
        mapped_values.append(func(point[i], *args))

    mapped_point = jnp.array(mapped_values)

    res = map_coordinates(
        input=values, coordinates=mapped_point, order=1, mode="nearest"
    )

    return res

# Try different scales

## linspace

In [None]:
def test_linear_interpolation_2d():
    grid1 = np.array([1, 2, 3, 4, 5.0])
    grid2 = np.array([2, 3, 4.0])

    prod_grid = np.array(list(itertools.product(grid1, grid2)))
    values = (prod_grid**2).sum(axis=1).reshape(5, 3)

    points = np.array([[2.5, 3.5], [2.1, 3.8], [2.7, 3.3]])

    grid_info = [("linspace", (1, 5, 5)), ("linspace", (2, 4, 3))]

    for point in points:
        calculated = linear_interpolation(
            values=values,
            point=point,
            grid_info=grid_info,
        )

        scipy_func = RegularGridInterpolator(
            points=(grid1, grid2), values=values, method="linear"
        )
        scipy_res = scipy_func(point)

        aaae(calculated, scipy_res)

In [None]:
test_linear_interpolation_2d()

## logspace

In [None]:
def logspace(start, stop, n_points):
    start_exp = jnp.log(start)
    stop_exp = jnp.log(stop)
    return jnp.logspace(start_exp, stop_exp, n_points, base=2.718281828459045)

In [None]:
def get_linspace_coordinate(value, start, stop, n_points):
    """Map a value into the input needed for map_coordinates."""
    step_length = (stop - start) / (n_points - 1)
    mapped_point = (value - start) / step_length
    return mapped_point

In [None]:
def get_logspace_coordinate(value, start, stop, n_points):
    """Map a value into the input needed for map_coordinates."""
    step_factor = jnp.exp((jnp.log(stop) - jnp.log(start)) / (n_points - 1))
    mapped_point = jnp.log(value / start) / jnp.log(step_factor)
    return mapped_point

In [None]:
def get_logspace_coordinate1(value, start, stop, n_points):
    """Map a value into the input needed for map_coordinates."""
    step_factor = (jnp.log(stop) - jnp.log(start)) / (n_points - 1)
    mapped_point = (jnp.log(value) - jnp.log(start)) / step_factor
    return mapped_point

In [None]:
def get_logspace_coordinate2(value, start, stop, n_points):
    """Map a value into the input needed for map_coordinates."""
    start_exp = jnp.log(start)
    stop_exp = jnp.log(stop)
    value_exp = jnp.log(value)
    step_length = (stop_exp - start_exp) / (n_points - 1)
    mapped_point_log = (value_exp - start_exp) / step_length
    print(mapped_point_log)
    # Calculate mapped point on a linear scale
    lower_point = jnp.exp(start_exp + step_length * jnp.floor(mapped_point_log))
    upper_point = jnp.exp(start_exp + step_length * jnp.ceil(mapped_point_log))
    print(lower_point, upper_point)
    if lower_point == upper_point:
        mapped_point_lin = mapped_point_log
    else:
        mapped_point_lin = jnp.floor(mapped_point_log) + (value - lower_point) / (
            upper_point - lower_point
        )
    return mapped_point_lin

In [None]:
grid1 = np.array([1, 10, 100.0])
grid2 = np.array([1, 10, 100.0])

prod_grid = np.array(list(itertools.product(grid1, grid2)))
values = (prod_grid**2).sum(axis=1).reshape(3, 3)

points = np.array([[1, 1], [2.5, 3.5], [2.1, 3.8], [2.7, 3.3]])

grid_info = [("linspace", (1, 5, 5)), ("linspace", (2, 4, 3))]

for point in points:
    calculated = linear_interpolation(
        values=values,
        point=point,
        grid_info=grid_info,
    )

    scipy_func = RegularGridInterpolator(
        points=(grid1, grid2), values=values, method="linear"
    )
    scipy_res = scipy_func(point)
    print(calculated, scipy_res)
    aaae(calculated, scipy_res)

In [None]:
grid1, grid2, values

In [None]:
(101 + 2) / 2

In [None]:
grid1 = np.logspace(np.log10(1), np.log10(100), 3)
grid2 = np.logspace(np.log10(1), np.log10(100), 3)

prod_grid = np.array(list(itertools.product(grid1, grid2)))
values = (prod_grid**2).sum(axis=1).reshape(3, 3)

points = np.array([[5.5, 1], [10, 10], [9.8, 2.3], [2.1, 8.2], [2.7, 1.1]])

grid_info = [("logspace", (1, 100, 3)), ("logspace", (1, 100, 3))]

for point in points:
    calculated = linear_interpolation(
        values=values,
        point=point,
        grid_info=grid_info,
    )

    scipy_func = RegularGridInterpolator(
        points=(grid1, grid2), values=values, method="linear"
    )
    scipy_res = scipy_func(point)
    print(calculated, scipy_res)
    aaae(calculated, scipy_res)

In [None]:
grid4

In [None]:
def f(a, b, c):
    return 2 * a**3 + 3 * b**2 - c


def g(a, b, c, d):
    return f(a, b, c) - d


def h(a, b, c, d, e):
    return g(a, b, c, d) + e**5


grid1 = np.logspace(np.log10(1), np.log10(5), 5)
grid2 = np.logspace(np.log10(4), np.log10(7), 4)
grid3 = np.logspace(np.log10(7), np.log10(9), 2)
grid4 = np.logspace(np.log10(10), np.log10(11), 2)
grid5 = np.logspace(np.log10(3), np.log10(4), 10)

values = h(*np.meshgrid(grid1, grid2, grid3, grid4, grid5, indexing="ij", sparse=False))
print(grid1, grid2, grid3, grid4, grid5)
points = np.array(
    [[2.1, 6.2, 8.3, 10.4, 3], [5, 4.3, 7, 10.99999, 4], [3.3, 5.2, 7.1, 10, 3.6]]
)

grid_info = [
    ("logspace", (1, 5, 5)),
    ("logspace", (4, 7, 4)),
    ("logspace", (7, 9, 2)),
    ("logspace", (10, 11, 2)),
    ("logspace", (3, 4, 10)),
]

for point in points:
    print(point)
    calculated = linear_interpolation(
        values=values,
        point=point,
        grid_info=grid_info,
    )
    scipy_func = RegularGridInterpolator(
        points=(grid1, grid2, grid3, grid4, grid5), values=values, method="linear"
    )
    scipy_res = scipy_func(point)
    aaae(calculated, scipy_res)

In [None]:
grid1 = np.logspace(np.log10(1), np.log10(100), 3)
grid2 = np.logspace(np.log10(1), np.log10(100), 3)

prod_grid = np.array(list(itertools.product(grid1, grid2)))
values = (prod_grid**2).sum(axis=1).reshape(3, 3)

points = np.array([[1, 1], [10, 10], [2.7, 1.1]])

grid_info = [("logspace", (1, 100, 3)), ("logspace", (1, 100, 3))]

for point in points:
    calculated = linear_interpolation(
        values=values,
        point=point,
        grid_info=grid_info,
    )

    scipy_func = RegularGridInterpolator(
        points=(grid1, grid2), values=values, method="linear"
    )
    scipy_res = scipy_func(point)
    print(point, calculated, scipy_res)
print(grid1, grid2)
print(values)

In [None]:
logspace(1, 100, 7)

In [None]:
jnp.floor(3.01), jnp.ceil(3.01)

In [None]:
get_logspace_coordinate(10, 1, 10, 11), get_logspace_coordinate1(
    10, 1, 10, 11
), get_logspace_coordinate2(10, 1, 10, 11)

In [None]:
get_logspace_coordinate(2, 1, 10, 11)

In [None]:
get_logspace_coordinate_2(9, 1, 10, 11)

In [None]:
grid1 = logspace(1, 10, 3)
grid2 = logspace(1, 10, 3)

prod_grid = np.array(list(itertools.product(grid1, grid2)))
values = (prod_grid**2).sum(axis=1).reshape(3, 3)

points = np.array([[9, 1], [2.1, 1], [2.7, 1]])

grid_info = [("logspace", (1, 10, 3)), ("logspace", (1, 10, 3))]

for point in points:
    calculated = linear_interpolation(
        values=values,
        point=point,
        grid_info=grid_info,
    )

    scipy_func = RegularGridInterpolator(
        points=(grid1, grid2), values=values, method="linear"
    )
    scipy_res = scipy_func(point)
    print(scipy_res, calculated)
    aaae(calculated, scipy_res)

In [None]:
values

In [None]:
grid1 = logspace(1, 10, 11)

values = grid1**2

points = np.array([[2.5], [2.1], [2.7]])

grid_info = [("logspace", (1, 10, 11))]

for point in points:
    calculated = linear_interpolation(
        values=values,
        point=point,
        grid_info=grid_info,
    )

    scipy_func = RegularGridInterpolator(
        points=(grid1, grid2), values=values, method="linear"
    )
    scipy_res = scipy_func(point)
    print(scipy_res)
    # aaae(calculated, scipy_res)

In [None]:
values

In [None]:
grid1 = np.array([1, 10, 100, 1000, 10000.0])
grid2 = np.array([10, 100, 1000.0])


def f(a, b):
    return a + b


values = f(*np.meshgrid(grid1, grid2, indexing="ij", sparse=False))
print(values)

prod_grid = np.array(list(itertools.product(grid1, grid2)))
# values = (prod_grid ** 2).sum(axis=1).reshape(5, 3)

points = np.array([[5.5, 10], [2.1, 38], [2.7, 33]])

grid_info = [("logspace", (1, 5, 5)), ("logspace", (2, 4, 3))]

for point in points:
    # calculated = linear_interpolation(
    #    values=values,
    #    point=point,
    #    grid_info=grid_info,
    # )

    scipy_func = RegularGridInterpolator(
        points=(grid1, grid2), values=values, method="linear"
    )
    scipy_res = scipy_func(point)
    print(scipy_res)
    # aaae(calculated, scipy_res)

In [None]:
grid1 = np.array([1, 2, 3, 4, 5.0])
grid2 = np.array([2, 3, 4.0])

prod_grid = np.array(list(itertools.product(grid1, grid2)))
values = (prod_grid**2).sum(axis=1).reshape(5, 3)

points = np.array([[2.5, 3.5], [2.1, 3.8], [2.7, 3.3]])

grid_info = [("logspace", (1, 5, 5)), ("logspace", (2, 4, 3))]

for point in points:
    # calculated = linear_interpolation(
    #    values=values,
    #    point=point,
    #    grid_info=grid_info,
    # )

    scipy_func = RegularGridInterpolator(
        points=(grid1, grid2), values=values, method="linear"
    )
    scipy_res = scipy_func(point)
    print(scipy_res)
    # aaae(calculated, scipy_res)

In [None]:
def test_linear_interpolation_2d():
    grid1 = np.array([1, 2, 3, 4, 5.0])
    grid2 = np.array([2, 3, 4.0])

    prod_grid = np.array(list(itertools.product(grid1, grid2)))
    values = (prod_grid**2).sum(axis=1).reshape(5, 3)

    points = np.array([[2.5, 3.5], [2.1, 3.8], [2.7, 3.3]])

    grid_info = [("logspace", (1, 5, 5)), ("logspace", (2, 4, 3))]

    for point in points:
        calculated = linear_interpolation(
            values=values,
            point=point,
            grid_info=grid_info,
        )

        scipy_func = RegularGridInterpolator(
            points=(grid1, grid2), values=values, method="linear"
        )
        scipy_res = scipy_func(point)

        aaae(calculated, scipy_res)

In [None]:
test_linear_interpolation_2d()

In [None]:
def linear_linspace_interpolation(values, point, starts, step_lengths):
    """Specialized version of linear interpolation to linspaces.

    This is meant as a comparison to check that the dynamic function
    calling doest not cause any performance penalty.

    """
    mapped_point = (point - starts) / step_lengths

    res = map_coordinates(
        input=values, coordinates=mapped_point, order=1, mode="nearest"
    )

    return res

## 1d Benchmarks

In [None]:
def get_1d_inputs(n_grid=1000, n_inter=10_000, target="jax_linspace"):
    np.random.seed(1234)
    grid = np.linspace(1, 5, n_grid)
    values = grid**1.1 + np.sin(grid)
    points = np.linspace(1, 4, n_inter)
    np.random.shuffle(points)

    if target == "jax_linspace":
        vmapped = vmap(linear_linspace_interpolation, in_axes=(None, 0, None, None))
        func = jit(vmapped)
        args = (
            jnp.array(values),
            jnp.array(points.reshape(-1, 1)),
            jnp.array([1.0]),
            grid[1] - grid[0],
        )

    elif target == "jax":
        vmapped = vmap(linear_interpolation, in_axes=(None, 0, None))
        func = jit(vmapped, static_argnums=2)
        grid_info = (("linspace", (1, 5, n_grid)),)
        args = (jnp.array(values), jnp.array(points.reshape(-1, 1)), grid_info)
    elif target == "scipy":
        func = interp1d(grid, values)
        args = (points,)
    else:
        raise ValueError()

    return func, args


def get_1d_runtime(n_grid=1000, n_inter=10_000, target="jax_linespace"):
    func, args = get_1d_inputs(n_grid, n_inter, target)
    # firs evaluation for jit overhead
    func(*args)
    timeit_res = %timeit -o func(*args)
    runtime = timeit_res.average
    return runtime

In [None]:
func, args = get_1d_inputs(target="scipy")
%timeit func(*args)
scipy_res = func(*args)

In [None]:
func, args = get_1d_inputs(target="jax_linspace")
%timeit func(*args).block_until_ready()
linspace_res = func(*args)

In [None]:
func, args = get_1d_inputs(target="jax")
%timeit func(*args).block_until_ready()
jax_res = func(*args)

In [None]:
aaae(scipy_res, linspace_res)
aaae(scipy_res, jax_res)

# Sandbox

Janos' idea: constant growth of the differences

In [None]:
import matplotlib.pyplot as plt

x = np.linspace(0, 5, 1000)
y_exp = np.exp(x)
y_bas2 = 2**x + 10
plt.plot(x, y_exp)
plt.plot(x, y_bas2)

In [None]:
y_bas2