In [None]:
from functools import partial

import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import dblquad
from scipy.interpolate import RegularGridInterpolator
from tqdm.auto import tqdm

from bounded_rand_walkers.cpp import (
    bound_map,
    funky,
    generate_data,
    get_binned_2D,
    get_binned_data,
    get_cached_filename,
)
from bounded_rand_walkers.data_generation import Delaunay, DelaunayArray, in_bounds
from bounded_rand_walkers.rad_interp import (
    exact_radii_interp,
    inv_exact_radii_interp,
    rotation,
)
from bounded_rand_walkers.relief_matrix_shaper import gen_shaper2D
from bounded_rand_walkers.rotation_steps import get_pdf_transform_shaper
from bounded_rand_walkers.shaper_general import (
    gen_rad_shaper,
    gen_rad_shaper_exact,
    shaper_map,
)
from bounded_rand_walkers.utils import approx_edges, cache_dir, get_centres, normalise

In [None]:
bound_name = "square"

vertices = bound_map[bound_name]()

n_bins = 50
f_t_x_edges = f_t_y_edges = np.linspace(-2, 2, n_bins + 1)
f_t_x_centres = f_t_y_centres = get_centres(f_t_x_edges)

order_divisions = 200  # Bump to 400 does not improve things visibly.

In [None]:
raw_shaper_X, raw_shaper_Y, raw_shaper = gen_shaper2D(order_divisions, vertices)

x0 = y0 = 2
divisions_x = order_divisions
divisions_y = divisions_x
# divisions_y = order_divisions * int(float(y0) / float(x0))

interp = RegularGridInterpolator(
    (
        get_centres(np.linspace(-x0, x0, divisions_x + 1)),
        get_centres(np.linspace(-y0, y0, divisions_y + 1)),
    ),
    raw_shaper,
    method="linear",
    bounds_error=False,
    fill_value=0.0,
)
f_t_X_grid, f_t_Y_grid = np.meshgrid(f_t_x_centres, f_t_y_centres, indexing="ij")
interp_shaper = interp(
    np.hstack((f_t_X_grid.ravel()[:, None], f_t_Y_grid.ravel()[:, None]))
).reshape(f_t_X_grid.shape)

In [None]:
plt.figure()
plt.pcolormesh(f_t_x_edges, f_t_y_edges, interp_shaper)
plt.axis("scaled")
plt.colorbar()
_ = plt.title("Gridded Shaper")

In [None]:
# Extract shaper from 2D shaper values.
radii, radial_shaper = exact_radii_interp(
    interp_shaper, f_t_x_centres, f_t_y_centres, normalisation="multiply"
)

# Calculate the shaper function explicitly at multiple radii.
shaper_radii = np.linspace(np.min(radii), np.max(radii), 40)
shaper_rad = gen_rad_shaper_exact(shaper_radii, vertices=bound_name)

In [None]:
# Extract shaper from 2D shaper values.
radii, radial_shaper = exact_radii_interp(
    interp_shaper, f_t_x_centres, f_t_y_centres, normalisation="multiply"
)

# Calculate the shaper function explicitly at multiple radii.
shaper_radii = np.linspace(np.min(radii), np.max(radii), 100)
shaper_rad = gen_rad_shaper_exact(
    shaper_radii, vertices=bound_name if bound_name in bound_map else vertices
)

In [None]:
plt.figure()
plt.plot(radii, normalise(radii, radial_shaper), label="Radially interpolated")
plt.plot(
    shaper_radii,
    normalise(shaper_radii, shaper_rad * shaper_radii),
    label="Analytical",
)
plt.grid(linestyle="--", alpha=0.4)
plt.legend(loc="best")
_ = plt.title(f"Shaper - {bound_name}")