TODO: 
- update label class and plot function
- add a test script to generate x,v with Agama
- write two tutorials: one for fitting density, one for fitting label

In [None]:
import copy
import os

from astropy.constants import G
import astropy.table as at
import astropy.coordinates as coord
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline
import numpy as np

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.potential as gp
import gala.integrate as gi
from gala.units import galactic

import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from empaf import DensityOrbitModel
from empaf.plot import plot_data_models_residual

In [None]:
import agama

agama.setUnits(mass=u.Msun, length=u.kpc, time=u.Myr)

In [None]:
gala_pot = gp.load("../../gaia-actions/potentials/MilkyWayPotential2022.yml")

agama_components = []
for p in gala_pot["disk"].get_three_potentials().values():
    agama_components.append(
        dict(
            type="miyamotonagai",
            mass=p.parameters["m"].value,
            scaleradius=p.parameters["a"].value,
            scaleheight=p.parameters["b"].value,
        )
    )

for k in ["bulge", "nucl"]:
    p = gala_pot[k]
    agama_components.append(
        dict(
            type="dehnen",
            mass=p.parameters["m"].value,
            scaleradius=p.parameters["c"].value,
            gamma=1.0,
        )
    )

p = gala_pot["halo"]
agama_components.append(
    dict(
        type="nfw", mass=p.parameters["m"].value, scaleradius=p.parameters["r_s"].value
    )
)
agama_pot = agama.Potential(*agama_components)

In [None]:
# xv = np.load('../test-data/agama-galaxymodel-df-small.npy')
# w0 = gd.PhaseSpacePosition.from_w(xv.T, units=galactic)
# tbl = at.QTable.read('../test-data/agama-galaxymodel-particles.fits')
tbl = at.QTable.read("../test-data/agama-galaxymodel-particles-qIso.fits")

In [None]:
Jphi0 = 229 * u.km / u.s * 8.3 * u.kpc
mask = np.abs(tbl["J_phi"] - Jphi0) < (1 * u.kpc * 229 * u.km / u.s)
print(mask.sum(), len(tbl))
sub_tbl = tbl[mask]

xv = np.array(
    [
        sub_tbl["x"].value,
        sub_tbl["y"].value,
        sub_tbl["z"].value,
        sub_tbl["v_x"].value,
        sub_tbl["v_y"].value,
        sub_tbl["v_z"].value,
    ]
).T
w0 = gd.PhaseSpacePosition.from_w(xv.T, units=galactic)

In [None]:
# act_finder = agama.ActionFinder(agama_pot)
# agama_act, agama_ang, agama_freq = act_finder(xv, angles=True)
# agama_aaf = at.Table({
#     "J_z": agama_act[:, 1],
#     "theta_z": agama_ang[:, 1],
#     "Omega_z": agama_freq[:, 1],
#     "T_z": 2*np.pi / agama_freq[:, 1]
# })

In [None]:
# tbl = at.QTable()
# tbl['z'] = xv[:, 2] * u.kpc
# tbl['vz'] = xv[:, 5] * u.kpc/u.Myr

bins = (np.linspace(-0.1, 0.1, 151), np.linspace(-2.5, 2.5, 151))
plt.hist2d(
    w0.v_z.value,
    w0.z.value,
    bins=bins,
    norm=mpl.colors.LogNorm(),
)
plt.xlim(bins[0].min(), bins[0].max())
plt.ylim(bins[1].min(), bins[1].max())
plt.xlabel("$v_z$")
plt.ylabel("$z$")

In [None]:
from empaf.model_helpers import (
    custom_tanh_func_alt,
    monotonic_poly_func_alt,
    monotonic_quadratic_spline,
)
from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline

In [None]:
# e_params0 = {}
# e_bounds = {}

# # def e2_func(rzp, f1):
# #     return f1 * rzp**2

# # e_params0[2] = {"f1": 0.1}
# # e_bounds[2] = {"f1": (0, 0.8)}


# # poly:
# # def e2_func(rzp, f1, alpha, x0):
# #     return monotonic_poly_func_alt(
# #         rzp, f0=0.0, fx=f1, alpha=alpha, x0=x0, xval=1.0
# #     )


# # e_params0[2] = {"f1": 0.1, "alpha": 0.33, "x0": 3.0}
# # e_bounds[2] = {"f1": (0, 0.8), "alpha": (0.2, 0.5), "x0": (2, 30.0)}


# # tanh:
# # def e2_func(rzp, f1, ln_alpha, ln_x0):
# #     return custom_tanh_func_alt(
# #         rzp, f_xval=f1, alpha=jnp.exp(ln_alpha), x0=jnp.exp(ln_x0), xval=1.0
# #     )


# # e_params0[2] = {"f1": 0.4, "ln_alpha": np.log(1e1), "ln_x0": 0.0}
# # e_bounds[2] = {"f1": (0, 0.8), "ln_alpha": (-3, 5), "ln_x0": (np.log(1e-2), np.log(1e2))}


# # my spline:
# # def e2_func(rzp, derivs):
# #     xs = np.linspace(0, 1.0, 11)** 2
# #     ys = jnp.concatenate((jnp.array([0.]), np.sum(derivs) - np.cumsum(derivs)))
# #     return monotonic_quadratic_spline(xs, ys, rzp)


# # e_params0[2] = {"derivs": np.full(10, 0.1)}
# # e_bounds[2] = {"derivs": (np.full(10, 0), np.full(10, 0.3))}

# # regular spline
# n_e2_knots = 21
# def e2_func(rzp, vals):
#     xs = np.linspace(0, 1.0, n_e2_knots+1) ** 2
#     ys = jnp.concatenate((jnp.array([0.]), np.cumsum(vals)))
#     return InterpolatedUnivariateSpline(xs, ys, k=1)(rzp)


# e_params0[2] = {"vals": np.full(n_e2_knots, 0.4) / n_e2_knots}
# e_bounds[2] = {"vals": (np.full(n_e2_knots, 0), np.full(n_e2_knots, 0.3))}

# # -------------------------------------------------------------------------------------

# # poly:
# # def e4_func(rzp, f1, alpha, x0):
# #     return monotonic_poly_func_alt(rzp, f0=0.0, fx=f1, alpha=alpha, x0=x0, xval=1.0)


# # e_params0[4] = {"f1": -0.02, "alpha": 0.45, "x0": 3.0}
# # e_bounds[4] = {"f1": (-0.3, 0), "alpha": (0.2, 0.5), "x0": (2, 30.0)}

# # tanh:
# # def e4_func(rzp, f1, alpha, x0):
# #     return custom_tanh_func_alt(
# #         rzp, f_xval=f1, alpha=alpha, x0=x0, xval=1.0
# #     )
# # e_params0[4] = {"f1": -0.1, "alpha": 0.45, "x0": 1.0}
# # e_bounds[4] = {"f1": (-0.3, 0), "alpha": (0.2, 0.5), "x0": (1e-3, 100.0)}

# def e4_func(rzp, f1):
#     return rzp * f1


# e_params0[4] = {"f1": -0.1}
# e_bounds[4] = {"f1": (-0.3, 0)}

In [None]:
from empaf.model_helpers import monotonic_quadratic_spline

n_e2_knots = 11
n_e4_knots = 5


def e2_func(rzp, e2_vals):
    e2_knots = jnp.linspace(0, 1.0, n_e2_knots) ** 2
    vals = monotonic_quadratic_spline(
        e2_knots, jnp.concatenate((jnp.array([0.0]), e2_vals)), rzp
    )
    return vals


def e4_func(rzp, e4_vals):
    e4_knots = jnp.linspace(0, 1.0, n_e4_knots) ** 2
    vals = monotonic_quadratic_spline(
        e4_knots, jnp.concatenate((jnp.array([0.0]), e4_vals)), rzp
    )
    return -vals


e_params0 = {}
e_bounds = {}
# e_params0[2] = {"e2_vals": np.full(n_e2_knots - 1, 0.2)}
e_params0[2] = {"e2_vals": np.linspace(1.5, 0.2, n_e2_knots - 1) / 0.6 * 0.2}
e_params0[4] = {"e4_vals": np.full(n_e4_knots - 1, 0.08)}
e_bounds[2] = {"e2_vals": (np.full(n_e2_knots - 1, 0), np.full(n_e2_knots - 1, 10))}
e_bounds[4] = {"e4_vals": (np.full(n_e4_knots - 1, 0), np.full(n_e4_knots - 1, 10))}

In [None]:
grid = np.linspace(0, 1, 128)
plt.plot(grid, e2_func(grid, **e_params0[2]))
plt.plot(grid, e4_func(grid, **e_params0[4]))
# plt.plot(grid, e2_func(grid, derivs=np.array([1., 2., 0.4, 1., 1., 1., 1., 1.])))

In [None]:
# def ln_dens_func(rzp, ln_dens_vals):
#     # xs = np.geomspace(1e-3, 3., 5)
#     xs = np.linspace(0, 1.0, 7) ** 2
#     ys = jnp.cumsum(jnp.concatenate((ln_dens_vals[0:1], -ln_dens_vals[1:])))
#     return InterpolatedUnivariateSpline(xs, ys, k=2)(rzp)

# ln_dens_bounds = {
#     "ln_dens_vals": (
#         jnp.concatenate((jnp.array([-1]), jnp.full(6, 0.0))),
#         jnp.concatenate((jnp.array([15.0]), jnp.full(6, 5.0))),
#     )
# }

n_dens_knots = 15


def ln_dens_func(rz, ln_dens_vals):
    # xs = np.geomspace(1e-3, 3., 5)
    xs = np.linspace(0, 1.0, n_dens_knots) ** 2
    ys = jnp.concatenate((ln_dens_vals[0:1], jnp.cumsum(-ln_dens_vals[1:])))
    return monotonic_quadratic_spline(xs, ys, rz)


ln_dens_bounds = {
    "ln_dens_vals": (
        jnp.concatenate((jnp.array([-1]), jnp.full(n_dens_knots - 1, 0.0))),
        jnp.concatenate((jnp.array([20.0]), jnp.full(n_dens_knots - 1, 10.0))),
    )
}

In [None]:
# grid = np.linspace(0, 1, 128)
# plt.plot(ln_dens_func(grid, np.array([9., 1., 0.4, 1., 4., 0.3, 1])))

In [None]:
model = DensityOrbitModel(
    e_funcs={2: e2_func, 4: e4_func},
    ln_dens_func=ln_dens_func,
    units=galactic,
)

# init_model = DensityOrbitModel(
#     ln_dens_knots=jnp.linspace(0, 1.5, 15) ** 2,
#     e_knots={
#         2: jnp.linspace(0, 2, 9),
#         4: jnp.linspace(0, 2, 9),
#         # 6: jnp.linspace(0, np.sqrt(3), 5) ** 2,
#         # 8: jnp.linspace(0, np.sqrt(3), 5) ** 2,
#     },
#     e_signs={2: 1.0, 4: -1.0}, # , 6: -1.0, 8: -1.0},
#     e_k=3,
#     ln_dens_k=3,
#     units=galactic,
# )

In [None]:
# specify some initial values by hand:

# params0 = {}
# params0["ln_Omega"] = np.log(0.06)
# params0["z0"] = 0.0
# params0["vz0"] = 0.0

# params0["e_params"] = {m: {} for m in [2, 4]}

# # params0["e_params"][2]["A"] = 0.1
# params0["e_params"][2]["f1"] = 0.1
# params0["e_params"][2]["alpha"] = 0.33
# params0["e_params"][2]["x0"] = 3.0

# # params0["e_params"][4]["A"] = 0.05
# params0["e_params"][4]["f1"] = -0.02
# params0["e_params"][4]["alpha"] = 0.45
# params0["e_params"][4]["x0"] = 3.0

# # params0["ln_dens_params"] = {"f0": 8.0, "f1": -10.0, "alpha": 0.6, "x0": 5.0}
# params0["ln_dens_params"] = {"ln_dens_vals": np.concatenate(([np.log(200)], np.random.uniform(0, 1, size=14)))}

In [None]:
# or use the estimator:

params0 = model.get_params_init(
    w0.z, w0.v_z, ln_dens_params0={"ln_dens_vals": np.zeros(15)}
)
params0["ln_dens_params"]["ln_dens_vals"] = jnp.abs(
    params0["ln_dens_params"]["ln_dens_vals"]
)
params0["e_params"] = e_params0

In [None]:
check, _ = model.check_e_funcs(params0["e_params"])
assert check

In [None]:
im_bins = {"z": np.linspace(-3, 3, 211), "vz": np.linspace(-0.12, 0.12, 211)}
# im_bins['vz'] = im_bins['z'] * np.exp(params0['ln_Omega'])
data = model.get_data_im(w0.z, w0.v_z, im_bins)

In [None]:
vlim = dict(
    norm=mpl.colors.LogNorm(vmax=3e4, vmin=1e-1), shading="auto"
)  # vmin=0, vmax=30)

fig, axes = plt.subplots(
    1, 2, figsize=(11, 5), sharex=True, sharey=True, constrained_layout=True
)

cs = axes[0].pcolormesh(data["vz"], data["z"], data["H"], **vlim)

cs = axes[1].pcolormesh(
    data["vz"],
    data["z"],
    np.exp(model.ln_density(z=data["z"], vz=data["vz"], params=params0)),
    **vlim
)
fig.colorbar(cs, ax=axes[:2])

axes[0].set_title("data")
axes[1].set_title("initial model")

In [None]:
bounds = {}

_dens0 = [0.01, 2] * u.Msun / u.pc**3
bounds["ln_Omega"] = np.log(np.sqrt(_dens0 * 4 * np.pi * G).to_value(1 / u.Myr))
bounds["z0"] = (-0.05, 0.05)
bounds["vz0"] = (-0.02, 0.02)

bounds["e_params"] = e_bounds
bounds["ln_dens_params"] = ln_dens_bounds

In [None]:
# fig, axes = plt.subplots(1, 2, figsize=(13, 5), sharex=True, sharey=True)

# bounds_l, bounds_r = model.unpack_bounds(bounds)
# for ax, p in zip(axes, [bounds_l, bounds_r]):
#     cs = ax.pcolormesh(
#         data_H["vz"],
#         data_H["z"],
#         np.exp(model.ln_density(z=data_H["z"], vz=data_H["vz"], params=p)),
#         norm=mpl.colors.LogNorm()
#     )
#     fig.colorbar(cs, ax=ax)

In [None]:
# model.objective(bounds_l, **data_H)
# model.objective(bounds_r, **data_H)

In [None]:
model.objective(params0, **data)

In [None]:
params0

In [None]:
res = model.optimize(
    **data,
    params0=params0,
    bounds=bounds,
    jaxopt_kwargs=dict(options=dict(maxls=1000, disp=False), tol=1e-8)
)
res.state

In [None]:
res.params

In [None]:
fig, axes = plot_data_models_residual(
    data, model, params0, res.params, smooth_residual=1.0
)

In [None]:
# model.compute_action_angle([0.]*u.kpc, [0]*u.km/u.s, res.params, 101)

In [None]:
# model.compute_action_angle(np.zeros(1000)*u.kpc, np.zeros(1000)*u.km/u.s, res.params, 101)

In [None]:
plot_rz = np.linspace(0, 1, 301)
es = model.get_es(plot_rz, res.params["e_params"])
for n, ee in es.items():
    plt.plot(plot_rz, ee, marker="")

In [None]:
# JR = 0.
# Jphi = np.mean(w0.angular_momentum()[2]).decompose(galactic).value

# grid_Jzs = np.geomspace(1e-4, 0.5, 32)
# grid_Omega_zs = np.zeros_like(grid_Jzs)
# for i, Jz in enumerate(grid_Jzs):
#     act = np.array([JR, Jz, Jphi])
#     torus_mapper = agama.ActionMapper(agama_pot, act)

#     xv_torus = torus_mapper([0., 0., 0.])
#     *_, tmp_freq = act_finder(xv_torus, angles=True)
#     grid_Omega_zs[i] = tmp_freq[1]

In [None]:
# act = np.array([JR, 0., Jphi])
# torus_mapper = agama.ActionMapper(agama_pot, act)
# xv_torus = torus_mapper([0., 0., 0.])
# asym_freq = np.sqrt(gala_pot.density(xv_torus[:3])[0] * 4*np.pi * G).to(u.rad/u.Myr, u.dimensionless_angles())

In [None]:
pars = res.params

plot_rz = np.linspace(1e-3, 0.55, 101)
dens = np.exp(model.get_ln_dens(plot_rz, pars))
tmp_z = np.array(
    [model.get_z(plot_rz[n], np.pi / 2, pars) for n in range(len(plot_rz))]
)
tmp_rzp = np.array(
    [model.get_rz_prime(plot_rz[n], 0.0, pars["e_params"]) for n in range(len(plot_rz))]
)

tmp_aaf = model.compute_action_angle(
    tmp_z * u.kpc, np.zeros_like(tmp_z) * u.km / u.s, pars, 101
)
es = model.get_es(tmp_rzp, pars["e_params"])

tmp_rzp, tmp_tzp = model.z_vz_to_rz_theta_prime(xv[:, 2], xv[:, 5], pars)
agama_rz = model.get_rz(tmp_rzp, tmp_tzp, pars["e_params"])

fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True, sharex=True)

ax = axes[0]
ax.plot(plot_rz, tmp_aaf["Omega_z"].value, zorder=100)
ax.plot(
    agama_rz,
    agama_aaf["Omega_z"].value,
    ls="none",
    marker="o",
    mew=0,
    alpha=0.2,
    ms=1.0,
    zorder=1,
)
ax.set_ylabel(r"$\Omega_z$ " + f"[{tmp_aaf['Omega_z'].unit:latex_inline}]")
# ax.axhline(np.exp(pars["ln_Omega"]), color="tab:orange", ls="--", alpha=0.4)
# ax.axhline(asym_freq.value, color="tab:green", ls="--", alpha=0.4)

for n in es:
    axes[1].plot(plot_rz, es[n], label=f"$e_{n}$")
axes[1].set_ylabel("$e_m(r_z')$")

axes[2].plot(plot_rz, dens)
axes[2].set_yscale("log")

axes[1].legend()

for ax in axes:
    ax.set_xlabel(r"$r_z$")

In [None]:
# plt.scatter(es[2], tmp_aaf['Omega_z'])

In [None]:
plot_rz = np.linspace(0, 3, 101)
plt.plot(plot_rz, np.tanh(plot_rz))

In [None]:
plot_rz = np.linspace(0, 1, 101)
es = model.get_es(plot_rz)
for n, ee in es.items():
    plt.plot(plot_rz, ee, marker="")

In [None]:
pars = res.params

sqrtOm = np.sqrt(np.exp(pars["ln_Omega"]))
plot_rzp = np.linspace(0, 3, 101) * sqrtOm
es = model.get_es(plot_rzp, pars["e_params"])

tmp_rz = model.get_rz(plot_rzp, np.zeros_like(plot_rzp), pars["e_params"])
dens = np.exp(model.get_ln_dens(tmp_rz, pars))
tmp_aaf = model.compute_action_angle(
    plot_rzp / sqrtOm * u.kpc, np.zeros_like(plot_rzp) * u.km / u.s, pars, 11
)
sqrtJz = np.sqrt(tmp_aaf["J_z"].value)

fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True, sharex=True)

ax = axes[0]
ax.plot(sqrtJz, tmp_aaf["Omega_z"].value, zorder=100)
ax.plot(
    np.sqrt(agama_aaf["J_z"].value),
    agama_aaf["Omega_z"].value,
    ls="none",
    marker="o",
    mew=0,
    alpha=0.2,
    ms=1.0,
    zorder=1,
)
ax.set_ylabel(r"$\Omega_z$ " + f"[{tmp_aaf['Omega_z'].unit:latex_inline}]")
ax.axhline(np.exp(pars["ln_Omega"]), color="tab:orange", ls="--", alpha=0.4)
ax.axhline(asym_freq.value, color="tab:green", ls="--", alpha=0.4)

for n in es:
    axes[1].plot(sqrtJz, es[n], label=f"$e_{n}$")
axes[2].plot(sqrtJz, dens)
axes[2].set_yscale("log")

axes[1].legend()

for ax in axes:
    ax.set_xlabel(r"$\sqrt{J_z}'$")

In [None]:
grid_aaf = model.compute_action_angle(
    data_H["z"].ravel() * model.units["length"],
    data_H["vz"].ravel() * model.units["length"] / model.units["time"],
    params=res.params,
    N_grid=11,
)

_rzp, _tzp = model.z_vz_to_rz_theta_prime(
    data_H["z"].ravel(), data_H["vz"].ravel(), res.params
)
grid_rz = model.get_rz(_rzp, _tzp, res.params["e_params"])

In [None]:
fig, axes = plt.subplots(
    1, 2, figsize=(10, 5), sharex=True, sharey=True, constrained_layout=True
)

for ax in axes:
    ax.pcolormesh(
        data_H["vz"], data_H["z"], data_H["H"], cmap="Blues", norm=mpl.colors.LogNorm()
    )

cs = axes[0].contour(
    data_H["vz"],
    data_H["z"],
    grid_rz.reshape(data_H["z"].shape),
    # cmap='turbo',
    colors="k",
    levels=np.linspace(0, 0.5, 11),
)
axes[0].clabel(cs, cs.levels, inline=True, fontsize=10)
axes[0].set_title("$r_z$")

cs = axes[1].contour(
    data_H["vz"],
    data_H["z"],
    np.sqrt(grid_aaf["J_z"].value).reshape(data_H["z"].shape),
    # cmap='turbo',
    colors="k",
    levels=np.linspace(0, 0.5, 11),
)
axes[1].clabel(cs, cs.levels, inline=True, fontsize=10)
axes[1].set_title(r"$\sqrt{J_z}$")

Zoom in:

In [None]:
zgrid, vzgrid = np.meshgrid(
    np.linspace(-0.2, 0.2, 128), np.linspace(-0.014, 0.014, 128)
)

zoom_grid_aaf = model.compute_action_angle(
    zgrid.ravel() * u.kpc,
    vzgrid.ravel() * u.kpc / u.Myr,
    params=res.params,
    N_grid=21,
)

_rzp, _tzp = model.z_vz_to_rz_theta_prime(zgrid.ravel(), vzgrid.ravel(), res.params)
zoom_grid_rz = model.get_rz(_rzp, _tzp, res.params["e_params"])

In [None]:
fig, axes = plt.subplots(
    1, 2, figsize=(10, 5), sharex=True, sharey=True, constrained_layout=True
)

for ax in axes:
    ax.pcolormesh(
        data_H["vz"],
        data_H["z"],
        data_H["H"],
        cmap="Blues",
        norm=mpl.colors.LogNorm(),
        zorder=-100,
    )

cs = axes[0].contour(
    vzgrid,
    zgrid,
    zoom_grid_rz.reshape(zgrid.shape),
    # cmap='turbo',
    colors="k",
    levels=np.linspace(0, 0.07, 15),
    zorder=100,
)
axes[0].clabel(cs, cs.levels, inline=True, fontsize=10)
axes[0].set_title("$r_z$")

cs = axes[1].contour(
    vzgrid,
    zgrid,
    np.sqrt(zoom_grid_aaf["J_z"].value).reshape(zgrid.shape),
    # cmap='turbo',
    colors="k",
    levels=np.linspace(0, 0.07, 15),
    zorder=100,
)
axes[1].clabel(cs, cs.levels, inline=True, fontsize=10, zorder=100)
axes[1].set_title(r"$\sqrt{J_z}$")

axes[0].set_xlim(vzgrid.min(), vzgrid.max())
axes[0].set_ylim(zgrid.min(), zgrid.max())

In [None]:
# plot_rz = model.get_rz(plot_rzp, np.zeros_like(plot_rzp), res.params['e_params'])
# # plt.plot(plot_rzp, np.sqrt(tmp_aaf['J_z']))
# plt.plot(plot_rz, np.sqrt(tmp_aaf['J_z']))
# plt.xlabel('$r_z$')
# plt.ylabel(r'$\sqrt{J_z}$')

In [None]:
# plt.plot(np.sqrt(grid_Jzs), grid_Omega_zs)
# plt.plot(sqrtJz, tmp_aaf["Omega_z"].value)

In [None]:
# model_H = np.exp(model.ln_density(z=data_H["z"], vz=data_H["vz"], params=res.params))

In [None]:
# H, xe = np.histogram(np.sqrt(agama_aaf['J_z'].value), bins=np.linspace(0, 0.5, 101))
# xc = 0.5 * (xe[:-1] + xe[1:])

# huh = np.diff(data_H['vz'][0])[0] * np.diff(data_H['z'][:, 0])[0]

# plt.plot(xc, H / np.diff(xe) * huh / np.sqrt(xc))
# plt.yscale('log')

In [None]:
# thp = np.linspace(0, 2*np.pi, 256)
# for rzp in np.linspace(0, 3.0, 16):
#     plt.plot(thp, model.get_rz(rzp, thp), marker='')

In [None]:
model_aaf = model.compute_action_angle(
    w0.z.astype(np.float64)[:100_000],
    w0.v_z.astype(np.float64)[:100_000],
    res.params,
    21,
)
model_aaf[:3]

In [None]:
agamas = [
    sub_tbl["J_z"][: len(model_aaf)].value,
    sub_tbl["Omega_z"][: len(model_aaf)].value,
    np.cos(sub_tbl["theta_z"][: len(model_aaf)]),
]
models = [
    model_aaf["J_z"].value,
    model_aaf["Omega_z"].value,
    np.cos(model_aaf["theta_z"]),
]
labels = ["$J_z$", r"$\Omega_z$", r"$\cos\theta_z$"]

fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True)
lims = [(0, 0.15), (0, 0.1), (-1, 1)]
for ax, lim, x1, x2, label in zip(axes, lims, agamas, models, labels):
    ax.hist2d(
        x1,
        x2,
        bins=np.linspace(*lim, 128),
        cmap="Greys",
        norm=mpl.colors.LogNorm(vmin=5e-1),
    )

    xx = np.linspace(*lim, 10)
    ax.plot(xx, xx, marker="", color="tab:green", ls="--", alpha=0.3)
    ax.set_xlim(*lim)
    ax.set_ylim(*lim)

    ax.set_xlabel(f"Agama {label}")
    ax.set_ylabel(f"empaf {label}")