In [None]:
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 16, 'figure.figsize': (40, 8), 'font.family': 'serif', 'text.usetex': True, 'pgf.rcfonts': False})

%load_ext autoreload
%autoreload 2

import jax
jax.config.update('jax_platform_name', 'cpu')
    
from jax_smolyak import indices, points
import jax_smolyak.smolyak_jax as sj
import jax_smolyak.smolyak as sn

def target_f_2d_to_1d(x):
    x = np.atleast_2d(x)
    y = np.sin(2 * np.pi * x[:, 0]) * np.cos(2 * np.pi * x[:, 1])
    return y

def get_meshgrid(g, n=100) :
    xs = [np.linspace(gi.domain[0], gi.domain[1], n) for gi in g]
    return np.meshgrid(*xs)

In [None]:
g = points.LejaMulti(domains=[[-1,1], [-1,1]])
l = 30
ip = sj.MultivariateSmolyakBarycentricInterpolator(g=g, k=[1,1], l=l, rank=1, f=target_f_2d_to_1d)
ipn = sn.MultivariateSmolyakBarycentricInterpolator(g=g, k=[1,1], l=[l], f=target_f_2d_to_1d)

n = 100
X, Y = get_meshgrid(g, n)
Z_f = target_f_2d_to_1d(np.stack([X.ravel(), Y.ravel()], axis=1)).reshape(n, n)
Z_ip = ip(np.stack([X.ravel(), Y.ravel()], axis=1)).reshape(n, n)
Z_ipn = ipn(np.stack([X.ravel(), Y.ravel()], axis=1)).reshape(n, n)
Z_diff = Z_ipn - Z_ip

fig, axes = plt.subplots(1, 4, figsize=(24, 5))
for ax, Z, title in zip(axes, [Z_f, Z_ipn, Z_ip, Z_diff], ['target', 'interpolant numpy', 'interpolant jax', 'difference']):
    im = ax.imshow(Z, extent=[np.min(X), np.max(X), np.min(Y), np.max(Y)], origin='lower')
    ax.set_title(title); fig.colorbar(im, ax=ax)

plt.show()

#### Chasing some weird artefacts:

In [None]:
g = points.LejaMulti(domains=[[-1,2.], [-1,2.]])
l = 3 # also 30
ip = sj.MultivariateSmolyakBarycentricInterpolator(g=g, k=[1,1], l=l, rank=1, f=target_f_2d_to_1d)
ipn = sn.MultivariateSmolyakBarycentricInterpolator(g=g, k=[1,1], l=[l], f=target_f_2d_to_1d)

n = 100
X, Y = get_meshgrid(g, n)
Z_f = target_f_2d_to_1d(np.stack([X.ravel(), Y.ravel()], axis=1)).reshape(n, n)
Z_ip = ip(np.stack([X.ravel(), Y.ravel()], axis=1)).reshape(n, n)
Z_ipn = ipn(np.stack([X.ravel(), Y.ravel()], axis=1)).reshape(n, n)
Z_diff = Z_ipn - Z_ip

fig, axes = plt.subplots(1, 4, figsize=(24, 5))
for ax, Z, title in zip(axes, [Z_f, Z_ipn, Z_ip, Z_diff], ['target', 'interpolant numpy', 'interpolant jax', 'difference']):
    im = ax.imshow(Z, extent=[np.min(X), np.max(X), np.min(Y), np.max(Y)], origin='lower')
    ax.set_title(title); fig.colorbar(im, ax=ax)

plt.show()