In [None]:
from functools import partial

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import dblquad
from scipy.interpolate import RegularGridInterpolator
from tqdm.auto import tqdm

from bounded_rand_walkers.cpp import (
    bound_map,
    funky,
    generate_data,
    get_binned_2D,
    get_binned_data,
    get_cached_filename,
)
from bounded_rand_walkers.data_generation import Delaunay, DelaunayArray, in_bounds
from bounded_rand_walkers.rad_interp import (
    exact_radii_interp,
    inv_exact_radii_interp,
    rotation,
)
from bounded_rand_walkers.relief_matrix_shaper import gen_shaper2D
from bounded_rand_walkers.rotation_steps import get_pdf_transform_shaper
from bounded_rand_walkers.shaper_general import (
    gen_rad_shaper,
    gen_rad_shaper_exact,
    shaper_map,
)
from bounded_rand_walkers.utils import approx_edges, cache_dir, get_centres, normalise

mpl.rc_file("matplotlibrc")

In [None]:
# XXX:
import pandas as pd

### Calculate the shaper function using 2D approach and compare to the analytical function

#### Investigate ideal binning parameters for 2D → radial interpolation

In [None]:
bound_name = "square"

vertices = bound_map[bound_name]()
lim = 1.5


def add_one(x):
    x = x.reshape(1, -1)
    return np.vstack((x, x + 1)).ravel()


mses = {}
bmses = {}
for n_bins in tqdm(np.unique(add_one(np.linspace(41, 201, 50, dtype=np.int64)))):
    f_t_x_edges = f_t_y_edges = np.linspace(-lim, lim, n_bins + 1)
    f_t_x_centres = f_t_y_centres = get_centres(f_t_x_edges)

    num_2d_shaper = gen_shaper2D(vertices, f_t_x_edges, f_t_y_edges, verbose=False)

    mses[n_bins] = {}
    bmses[n_bins] = {}

    for mode in range(1, 4):
        radii, radial_shaper = exact_radii_interp(
            num_2d_shaper,
            f_t_x_centres,
            f_t_y_centres,
            normalisation="multiply",
            bin_samples=0.05,
            mode=mode,
        )
        # Calculate the shaper function explicitly at multiple radii.
        analytical_shaper = gen_rad_shaper_exact(
            radii,
            vertices=bound_name if bound_name in bound_map else vertices,
            verbose=False,
        )
        analytical_shaper *= radii

        # Align the two.
        z_mask = ~(np.isclose(radial_shaper, 0) | np.isclose(analytical_shaper, 0))
        analytical_shaper *= np.nanmean(
            radial_shaper[z_mask] / analytical_shaper[z_mask]
        )

        # Compute mse.
        mses[n_bins][mode] = np.mean((analytical_shaper - radial_shaper) ** 2)

        # Same computation for l > 0.5
        mask = radii > 0.5
        radii = radii[mask]
        radial_shaper = radial_shaper[mask]
        analytical_shaper = analytical_shaper[mask]

        z_mask = ~(np.isclose(radial_shaper, 0) | np.isclose(analytical_shaper, 0))
        analytical_shaper *= np.nanmean(
            radial_shaper[z_mask] / analytical_shaper[z_mask]
        )
        bmses[n_bins][mode] = np.mean((analytical_shaper - radial_shaper) ** 2)

In [None]:
df1 = pd.DataFrame(mses).T
df1["cat"] = "all"
df2 = pd.DataFrame(bmses).T
df2["cat"] = "high"

In [None]:
df1.describe()

In [None]:
df2.describe()

In [None]:
df = pd.concat((df1, df2))

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 4))
plot_df = df.sort_index()
for ax in axes:
    for col in [1, 2, 3]:
        ax.plot(plot_df.index.values, plot_df[col].values, label=col, marker="x")
axes[0].set_ylim(0, 0.003)
_ = axes[0].legend(loc="best")

In [None]:
df["odd"] = df.index.values % 2 == 1
df.head()

In [None]:
groupby = df.groupby(["odd", "cat"])
fig, axes = plt.subplots(2, len(groupby), sharex=True, figsize=(18, 10), sharey="row")

for ((name, grouped), axes) in zip(groupby, axes.T):
    for ax, b in zip(axes, [True, False]):
        grouped[[1, 2, 3]].boxplot(ax=ax, showfliers=b)
        ax.set_title(name)

fig.tight_layout()

### Actually calculate the shaper function both ways

In [None]:
bound_name = "square"

vertices = bound_map[bound_name]()

n_bins = 21
lim = 1.5
f_t_x_edges = f_t_y_edges = np.linspace(-lim, lim, n_bins + 1)
f_t_x_centres = f_t_y_centres = get_centres(f_t_x_edges)

num_2d_shaper = gen_shaper2D(vertices, f_t_x_edges, f_t_y_edges)

plt.figure()
plt.pcolormesh(f_t_x_edges, f_t_y_edges, num_2d_shaper)
plt.axis("scaled")
plt.colorbar()
_ = plt.title("Gridded Shaper")

# Extract shaper from 2D shaper values.
radii, radial_shaper = exact_radii_interp(
    num_2d_shaper,
    f_t_x_centres,
    f_t_y_centres,
    normalisation="multiply",
    bin_samples=0.05,
)

radii2, radial_shaper2 = exact_radii_interp(
    num_2d_shaper,
    f_t_x_centres,
    f_t_y_centres,
    normalisation="multiply",
    bin_samples=None,
)

# Calculate the shaper function explicitly at multiple radii.
shaper_radii = np.linspace(0, np.max(radii), 100)
shaper_rad = gen_rad_shaper_exact(
    shaper_radii, vertices=bound_name if bound_name in bound_map else vertices
)

plt.figure(figsize=(15, 8))

plt.plot(
    radii, normalise(radii, radial_shaper), label="Radially interpolated", marker="x"
)
plt.plot(
    radii2,
    normalise(radii2, radial_shaper2),
    label="Raw interpolated",
    marker="x",
    linestyle="",
)
plt.plot(
    shaper_radii,
    normalise(shaper_radii, shaper_rad * shaper_radii),
    label="Analytical",
)
plt.legend(loc="best")
_ = plt.title(f"Shaper - {bound_name}")

### Calculate the shaper function using 2D approach for the weird boundary

In [None]:
bound_name = "weird"

vertices = bound_map[bound_name]()

n_bins = 200
lim = 1.5
f_t_x_edges = f_t_y_edges = np.linspace(-lim, lim, n_bins + 1)
f_t_x_centres = f_t_y_centres = get_centres(f_t_x_edges)

num_2d_shaper = gen_shaper2D(vertices, f_t_x_edges, f_t_y_edges)

plt.figure()
plt.pcolormesh(f_t_x_edges, f_t_y_edges, num_2d_shaper)
plt.axis("scaled")
plt.colorbar()
_ = plt.title("Gridded Shaper")

# Extract shaper from 2D shaper values.
radii, radial_shaper = exact_radii_interp(
    num_2d_shaper, f_t_x_centres, f_t_y_centres, normalisation="multiply"
)

plt.figure()
plt.plot(radii, normalise(radii, radial_shaper), label="Radially interpolated")
plt.legend(loc="best")
_ = plt.title(f"Shaper - {bound_name}")