TODO:
- compare inferred acceleration trends - which selection recovers local acceleration the best?
- also do different spatial/velocity/R selection and rerun

In [None]:
import copy
import os

from astropy.constants import G
import astropy.table as at
import astropy.coordinates as coord
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline
import numpy as np

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.potential as gp
import gala.integrate as gi
from gala.units import galactic

import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline

from empaf import DensityOrbitModel
from empaf.plot import plot_data_models_residual
from empaf.model_helpers import generalized_logistic_func_alt
from empaf.model_helpers import monotonic_quadratic_spline

# Load test data

Load some particle data generated in an equilibrium galaxy model:

In [None]:
# particle_data = at.QTable.read('../test-data/agama-galaxymodel-particles.fits')
tbl = at.QTable.read("../test-data/agama-galaxymodel-particles-qIso.fits")

In [None]:
Jphi0 = 229 * u.km / u.s * 8.3 * u.kpc
R0 = 8.3 * u.kpc
R = np.sqrt(tbl["xyz"][:, 0] ** 2 + tbl["xyz"][:, 1] ** 2)
v_R = (
    tbl["xyz"][:, 0] * tbl["v_xyz"][:, 0] + tbl["xyz"][:, 1] * tbl["v_xyz"][:, 1]
) / R
mask = (
    (np.abs(tbl["J_phi"] - Jphi0) < (0.5 * u.kpc * 229 * u.km / u.s))
    & (np.abs(R - R0) < 0.5 * u.kpc)
    #     & (np.abs(v_R) < 15*u.km/u.s)
)
print(mask.sum())

particle_data = tbl[mask]

In [None]:
data = DensityOrbitModel.get_data_im(
    z=particle_data["xyz"][:, 2].decompose(galactic).value,
    vz=particle_data["v_xyz"][:, 2].decompose(galactic).value,
    bins={"z": np.linspace(-2.5, 2.5, 155), "vz": np.linspace(-0.1, 0.1, 155)},
)

In [None]:
plt.pcolormesh(
    data["vz"], data["z"], data["H"], cmap="magma", norm=mpl.colors.LogNorm()
)
plt.xlabel("$v_z$")
plt.ylabel("$z$")

In [None]:
max_rz = 0.75

n_dens_knots = 19
# Knot locations, spaced equally in sqrt(r_z)
ln_dens_knots = jnp.linspace(0, max_rz, n_dens_knots)  #  ** 2
# print(ln_dens_knots)

# def ln_dens_func(rz, ln_dens_vals):
#     spl = InterpolatedUnivariateSpline(ln_dens_knots, ln_dens_vals, k=2)
#     return spl(rz)


def ln_dens_func(rz, ln_dens_vals):
    vals = monotonic_quadratic_spline(ln_dens_knots, ln_dens_vals, rz)
    return vals

In [None]:
# ln_dens_bounds = {
#     "ln_dens_vals": (
#         jnp.full(n_dens_knots, -5.0),
#         jnp.full(n_dens_knots, 25.0)
#     )
# }

ln_dens_bounds = {
    "ln_dens_vals": (
        np.concatenate(([0.0], np.full(n_dens_knots - 1, -30.0))),
        np.concatenate(([12.0], np.full(n_dens_knots - 1, 0.0))),
    )
}

In [None]:
n_e2_knots = 9
n_e4_knots = 5
e2_knots = jnp.linspace(0, np.sqrt(max_rz), n_e2_knots) ** 2
e4_knots = jnp.linspace(0, np.sqrt(max_rz), n_e4_knots) ** 2
e6_knots = jnp.linspace(0, np.sqrt(max_rz), n_e4_knots) ** 2


def e2_func(rzp, e2_vals):
    vals = monotonic_quadratic_spline(
        e2_knots, jnp.concatenate((jnp.array([0.0]), e2_vals)), rzp
    )
    return vals


def e4_func(rzp, e4_vals):
    vals = monotonic_quadratic_spline(
        e4_knots, jnp.concatenate((jnp.array([0.0]), e4_vals)), rzp
    )
    return -vals


def e6_func(rzp, e6_vals):
    vals = monotonic_quadratic_spline(
        e6_knots, jnp.concatenate((jnp.array([0.0]), e6_vals)), rzp
    )
    return vals

In [None]:
e_params0 = {}
e_bounds = {}
# e_params0[2] = {"e2_vals": np.full(n_e2_knots - 1, 0.2)}
e_params0[2] = {"e2_vals": np.linspace(1.5, 0.2, n_e2_knots - 1) / 0.6 * 0.2}
e_params0[4] = {"e4_vals": np.full(n_e4_knots - 1, 0.08)}
# e_params0[6] = {"e6_vals": np.full(n_e4_knots - 1, 0.03)}
e_bounds[2] = {"e2_vals": (np.full(n_e2_knots - 1, 0), np.full(n_e2_knots - 1, 10))}
e_bounds[4] = {"e4_vals": (np.full(n_e4_knots - 1, 0), np.full(n_e4_knots - 1, 10))}
# e_bounds[6] = {"e6_vals": (np.full(n_e4_knots-1, 0), np.full(n_e4_knots-1, 10))}

In [None]:
grid = np.linspace(0, max_rz, 128)
plt.plot(grid, e2_func(grid, **e_params0[2]))
plt.plot(grid, e4_func(grid, **e_params0[4]))
# plt.plot(grid, e6_func(grid, **e_params0[6]))
plt.xlabel("$r_z'$")
plt.ylabel("initial $e_m(r_z')$")

In [None]:
model = DensityOrbitModel(
    ln_dens_func=ln_dens_func,
    e_funcs={2: e2_func, 4: e4_func},  # , 6: e6_func},  # the keys are the "m" values
    units=galactic,
)

In [None]:
xx, yy, ln_dens_spl = model.get_data_ln_dens_func(
    particle_data["xyz"][:, 2], particle_data["v_xyz"][:, 2], spl_k=1
)
spl_y0 = ln_dens_spl(xx[0])
ln_dens_vals0 = np.concatenate(([spl_y0], ln_dens_spl.derivative()(ln_dens_knots[1:])))

params0 = model.get_params_init(
    particle_data["xyz"][:, 2],
    particle_data["v_xyz"][:, 2],
    ln_dens_params0={"ln_dens_vals": ln_dens_vals0},
)
params0

In [None]:
plot_rz = np.linspace(0, max_rz, 128)

fig, ax = plt.subplots()
ax.plot(plot_rz, model.get_ln_dens(plot_rz, params0))
# ax.scatter(ln_dens_knots, model.get_ln_dens(ln_dens_knots, res.params))
ax.set_xlim(plot_rz.min(), plot_rz.max())
ax.set_xlabel("$r_z$", labelpad=20)
ax.set_ylabel(r"$\ln n(r_z)$")

In [None]:
params0["e_params"] = e_params0

In [None]:
vlim = dict(
    norm=mpl.colors.LogNorm(vmax=3e4, vmin=1e-1), shading="auto"
)  # vmin=0, vmax=30)

fig, axes = plt.subplots(
    1, 2, figsize=(11, 5), sharex=True, sharey=True, constrained_layout=True
)

cs = axes[0].pcolormesh(data["vz"], data["z"], data["H"], **vlim)

cs = axes[1].pcolormesh(
    data["vz"],
    data["z"],
    np.exp(model.ln_density(z=data["z"], vz=data["vz"], params=params0)),
    **vlim
)
fig.colorbar(cs, ax=axes[:2])

axes[0].set_title("data")
axes[1].set_title("initial model")

In [None]:
bounds = {}

_dens0 = [0.01, 2] * u.Msun / u.pc**3
bounds["ln_Omega"] = np.log(np.sqrt(_dens0 * 4 * np.pi * G).to_value(1 / u.Myr))
bounds["z0"] = (-0.05, 0.05)
bounds["vz0"] = (-0.02, 0.02)

bounds["e_params"] = e_bounds
bounds["ln_dens_params"] = ln_dens_bounds

In [None]:
model.objective(params0, data["z"], data["vz"], data["H"])

In [None]:
res = model.optimize(
    params0=params0, bounds=bounds, jaxopt_kwargs={"tol": 1e-10}, **data
)
res.state

In [None]:
pars = res.params
pars

In [None]:
model.check_e_funcs(res.params["e_params"], rz_prime_max=0.5)[0]

In [None]:
fig, axes = plot_data_models_residual(data, model, params0, res.params)

In [None]:
plot_rz = np.linspace(0, max_rz, 301)

fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharex=True, constrained_layout=True)

ax = axes[0]
es = model.get_es(plot_rz, res.params["e_params"])
for m, ee in es.items():
    ax.plot(plot_rz, ee, marker="", label=f"$m={m}$")
    ax.legend(fontsize=16)

e2_knot_vals = model.get_es(e2_knots, res.params["e_params"])[2]
e4_knot_vals = model.get_es(e4_knots, res.params["e_params"])[4]
ax.scatter(e2_knots, e2_knot_vals)
ax.scatter(e4_knots, e4_knot_vals)

ax.set_ylabel("$e_m(r_z')$ for $m=2,4$")

# ---

ax = axes[1]
ax.plot(plot_rz, model.get_ln_dens(plot_rz, res.params), marker="")
ax.scatter(ln_dens_knots, model.get_ln_dens(ln_dens_knots, res.params))
ax.set_xlim(plot_rz.min(), plot_rz.max())
ax.set_xlabel("$r_z$", labelpad=20)
ax.set_ylabel(r"$\ln n(r_z)$")

ax2 = ax.twiny()
ax2.set_xlim(0, model.get_z(plot_rz[-1], np.pi / 2, res.params))
ax2.set_xlabel("$z$ (at $v_z=0$) [kpc]", labelpad=20)

In [None]:
# # Compute model predicted density:
# plot_rz = np.linspace(1e-3, 0.55, 101)
# model_dens = np.exp(model.get_ln_dens(plot_rz, pars))

# # Compute rz values at image pixel locations:
# tmp_rzp, tmp_tzp = model.z_vz_to_rz_theta_prime(
#     data["z"].astype(np.float64), data["vz"].astype(np.float64), pars
# )
# im_rz = model.get_rz(tmp_rzp, tmp_tzp, pars["e_params"])

# # Compute model implicit Omega_z vs. r_z function:
# tmp_z = np.array(
#     [model.get_z(plot_rz[n], np.pi / 2, pars) for n in range(len(plot_rz))]
# )
# tmp_rzp = np.array(
#     [model.get_rz_prime(plot_rz[n], 0.0, pars["e_params"]) for n in range(len(plot_rz))]
# )

# tmp_aaf = model.compute_action_angle(
#     tmp_z * u.kpc, np.zeros_like(tmp_z) * u.km / u.s, pars, 101
# )
# model_Omega_z = tmp_aaf['Omega_z']
# model_J_z = tmp_aaf['J_z']

# # Compute Omega_z at image pixel locations:
# tmp_aaf = model.compute_action_angle(
#     data["z"].ravel() * u.kpc, data["vz"].ravel() * u.kpc / u.Myr, pars, 25
# )
# im_Omega_z = tmp_aaf['Omega_z']

Compare acceleration with truth:

In [None]:
pot = gp.load("../test-data/agama-galaxymodel-gala_pot.yml")

In [None]:
# ztmp, vztmp = np.meshgrid(np.linspace(0, 2, 64), np.linspace(0, 0.1, 64))
# empaf_az = model.get_az(ztmp * u.kpc, vztmp * u.kpc / u.Myr, res.params).to_value(
#     u.pc / u.Myr**2
# )

ztmp = np.linspace(0, 3, 64)
empaf_az = model.get_az(ztmp * u.kpc, res.params).to_value(u.pc / u.Myr**2)

In [None]:
xyz = np.zeros((3,) + ztmp.shape)
xyz[0] = R0.to_value(u.kpc)
title = "constant cylindrical R"

# xyz[0] = np.sqrt(R0**2 - (ztmp*u.kpc)**2).to_value(u.kpc)
# title = "constant spherical R"

xyz[2] = ztmp
true_az = pot.acceleration(xyz)[2].to_value(u.pc / u.Myr**2)

In [None]:
# fig, axes = plt.subplots(
#     1, 3, figsize=(20, 6.1), sharex=True, sharey=True, constrained_layout=True
# )

# true_max = np.abs(empaf_az).max()
# levels = -np.sqrt(np.linspace(0, true_max**2, 32))[::-1]
# cs0 = axes[0].contour(vztmp, ztmp, empaf_az, levels=levels)
# cs1 = axes[1].contour(vztmp, ztmp, true_az, levels=levels)
# cs1.set_clim(cs0.get_clim())
# cb1 = fig.colorbar(cs0, ax=axes[:2])
# cb1.set_label(f"constant $a_z$ [{u.pc/u.Myr**2:latex_inline}]")

# cs2 = axes[2].pcolormesh(
#     vztmp,
#     ztmp,
#     (empaf_az - true_az),
#     cmap="RdBu",
#     vmin=-true_max / 10,
#     vmax=true_max / 10,
# )
# cb2 = fig.colorbar(cs2, ax=axes[2])
# cb2.set_label(f"residual [{u.pc/u.Myr**2:latex_inline}]")

# axes[0].set_title("empaf")
# axes[1].set_title("true potential model")
# axes[2].set_title("residual")

# for ax in axes:
#     ax.set_xlabel(f"$v_z$ [{u.kpc/u.Myr:latex_inline}]")
# axes[0].set_ylabel(f"$z$ [{u.kpc:latex_inline}]")

# fig.suptitle(title, fontsize=26)


fig, axes = plt.subplots(1, 2, figsize=(13, 6.1), sharex=True, constrained_layout=True)

axes[0].plot(ztmp, empaf_az, label="empaf")
axes[0].plot(ztmp, true_az, label="true")
axes[0].set_ylim(-5, 0.2)
axes[0].legend()

axes[1].plot(ztmp, (empaf_az - true_az))
axes[1].set_ylim(-0.5, 0.5)

axes[1].set_title("residual")

for ax in axes:
    ax.set_xlabel(f"$z$ [{u.kpc:latex_inline}]")
axes[0].set_ylabel("$a_z$")

fig.suptitle(title, fontsize=26)

In [None]:
# fig, axes = plt.subplots(
#     1, 3, figsize=(20, 6.1), sharex=True, sharey=True, constrained_layout=True
# )

# true_max = np.abs(empaf_az).max()
# levels = -np.sqrt(np.linspace(0, true_max**2, 32))[::-1]
# cs0 = axes[0].contour(vztmp, ztmp, empaf_az, levels=levels)
# cs1 = axes[1].contour(vztmp, ztmp, true_az, levels=levels)
# cs1.set_clim(cs0.get_clim())
# cb1 = fig.colorbar(cs0, ax=axes[:2])
# cb1.set_label(f"constant $a_z$ [{u.pc/u.Myr**2:latex_inline}]")

# cs2 = axes[2].pcolormesh(
#     vztmp,
#     ztmp,
#     (empaf_az - true_az) / true_az,
#     cmap="RdBu",
#     vmin=-0.25,
#     vmax=0.25,
# )
# cb2 = fig.colorbar(cs2, ax=axes[2])
# cb2.set_label(f"fractional residual [{u.pc/u.Myr**2:latex_inline}]")

# axes[0].set_title("empaf")
# axes[1].set_title("true potential model")
# axes[2].set_title("fractional residual")

# for ax in axes:
#     ax.set_xlabel(f"$v_z$ [{u.kpc/u.Myr:latex_inline}]")
# axes[0].set_ylabel(f"$z$ [{u.kpc:latex_inline}]")

# fig.suptitle(title, fontsize=26)

In [None]:
# fig, axes = plt.subplots(1, 2, figsize=(12, 5.5), sharex=True)

# axes[0].plot(ztmp[i, :], true_az[i], color="k", lw=2, marker="")

# ax = axes[1]
# for i in range(1, 31+1, 5):
#     tmp_vz = vztmp[i, 0] * u.kpc / u.Myr
#     ax.plot(ztmp[i, :], (empaf_az[i] - true_az[i]) / true_az[i], marker="", alpha=0.5)
#     axes[0].plot(
#         ztmp[i, :],
#         empaf_az[i],
#         marker="",
#         alpha=0.5,
#         label=f"$v_z = {tmp_vz.to_value(u.kpc/u.Myr):.2f}$ {u.kpc/u.Myr:latex_inline}",
#     )

# ax.axhline(0, zorder=-10, color="k", lw=2)

# axes[0].set(
#     xlabel="$z$ [kpc]", ylabel="$a_z$"
# )
# axes[1].set(
#     ylim=(-0.2, 0.2), xlim=(-0.1, 2), xlabel="$z$ [kpc]", ylabel="fractional force error"
# )

# axes[0].legend(loc="best")

## Grid for Larry:

In [None]:
# pot = gp.load("../test-data/agama-galaxymodel-gala_pot.yml")

# Rgrid = np.arange(0, 16 + 1e-3, 0.05) * u.kpc
# zgrid = np.arange(0, 4 + 1e-3, 0.05) * u.kpc
# Rgrid, zgrid = np.meshgrid(Rgrid, zgrid)
# xyz = np.zeros((3,) + Rgrid.shape) * u.kpc
# xyz[0] = Rgrid
# xyz[2] = zgrid

# Phi = pot.energy(xyz)
# acc = pot.acceleration(xyz)

# tbl = at.QTable(
#     {
#         "R": Rgrid.ravel(),
#         "z": zgrid.ravel(),
#         "potential": Phi.ravel(),
#         "a_R": acc[0].ravel(),
#         "a_z": acc[2].ravel(),
#     }
# )
# tbl.write('../test-data/agama-galaxymodel-pot-grid.fits', overwrite=True)

In [None]:
print(vz_func(np.array([0.5]), np.array([0.5]), res.params))
print(dvz_dz_func(np.array([0.5]), np.array([0.5]), res.params))

In [None]:
kw = dict(z=0.5, rz=0.5, params=res.params)
vz = tmp_get_vz(**kw)
dvz_dz = jax.grad(tmp_get_vz)(*kw.values())

In [None]:
vz * dvz_dz

In [None]:
xyz = np.zeros(3)
xyz[0] = R0.to_value(u.kpc)
xyz[2] = args[0]
pot.acceleration(xyz)[2, 0]

In [None]:
vzs = np.linspace(0, 0.3, 128)
test = np.array([help_rootfind(vvz, 0.5, 0.5, res.params) for vvz in vzs])

In [None]:
plt.plot(vzs, test)

CBE calculation:

In [None]:
Nsamples = 100_000
rng = np.random.default_rng(seed=42)
zgrid = rng.uniform(-2, 2, size=Nsamples)
vzgrid = rng.uniform(-0.08, 0.08, size=Nsamples)

lndensmin = model.ln_density(2.0, 0.08, pars)
lndensmax = model.ln_density(0.0, 0.0, pars)
lndensgrid = rng.uniform(lndensmin, lndensmax, size=Nsamples)
lndens = model.ln_density(zgrid, vzgrid, pars)

In [None]:
mask = lndensgrid < lndens
print(mask.sum())

plt.hist2d(vzgrid[mask], zgrid[mask], bins=64);  # , norm=mpl.colors.LogNorm());

In [None]:
def func(z, vz):
    rzp, thp = model.z_vz_to_rz_theta_prime(z, vz, pars)
    rz = model.get_rz(rzp, thp, pars["e_params"])
    return rz


drz_dz = jax.vmap(jax.grad(func, argnums=0))
drz_dvz = jax.vmap(jax.grad(func, argnums=1))


def cbe(z, vz):
    return drz_dz(z, vz) / drz_dvz(z, vz) * vz

In [None]:
dphi_dz = cbe(zgrid, vzgrid)

In [None]:
zz = np.linspace(-2, 2.0, 256) * u.kpc
vv = np.zeros(len(zz)) * u.km / u.s
aaf = model.compute_action_angle(zz, vv, pars)

TODO: what can you learn from an orbit from having functions z(vz) or vz(z)?

Ez = 1/2 vz^2 + Phi(z)
f(Ez) = f(1/2 vz^2 + Phi(z))

df/dz = dPhi/dz
df/dvz = vz

In [None]:
Phi(a) - Phi(b) -> vz(a)^2 - vz(b)^2

In [None]:
1/2(vz(a)^2 - vz(b)^2) = Phi(a) - Phi(b)

In [None]:
plt.plot(aaf["Omega_z"].value ** 2)

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
ax.plot(zgrid, dphi_dz, marker="o", ms=1, ls="none", alpha=0.25)

xyz = np.zeros((3, 128))
xyz[0] = R0.to_value(u.kpc)
xyz[2] = np.linspace(zgrid.min(), zgrid.max(), xyz.shape[1])
ax.plot(xyz[2], pot.gradient(xyz)[2], marker="", color="tab:red", lw=3)

ax.set_ylim(-0.005, 0.005)

In [None]:
zz = 0.2
vzz = 1e-3

vals = []
fucs = np.linspace(0, 0.1, 256)
for vzz in fucs:
    drz_dz, drz_dvz = shit(zz, vzz)
    vals.append(drz_dz / drz_dvz * vzz)

pot.gradient([R0.to_value(u.kpc), 0, zz])[2, 0]

In [None]:
plt.plot(fucs, vals)
truth = pot.gradient([R0.to_value(u.kpc), 0, zz])[2, 0].value
plt.axhline(truth)
plt.axvline(zz)

print(fucs[np.abs(vals - truth).argmin()])

plt.figure()
plt.plot(fucs, vals / truth)

In [None]:
rz = model.get_rz(0.05, np.pi / 2, pars["e_params"])

In [None]:
model.get_vz(rz, np.pi / 2, pars)

In [None]:
# compute rz values for all particles:
tmp_rzp, tmp_tzp = model.z_vz_to_rz_theta_prime(
    particle_data["xyz"][:, 2].astype(np.float64),
    particle_data["v_xyz"][:, 2].astype(np.float64),
    pars,
)
particle_rz = model.get_rz(tmp_rzp, tmp_tzp, pars["e_params"])

particle_aaf = model.compute_action_angle(
    particle_data["xyz"].astype(np.float64)[:10_000, 2],
    particle_data["v_xyz"].astype(np.float64)[:10_000, 2],
    res.params,
    21,
)
particle_aaf[:3]

In [None]:
dens = np.exp(model.ln_density(tmp_z, np.zeros_like(tmp_z), res.params))

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5), constrained_layout=True, sharex=True)

ax = axes[0]
ax.plot(plot_rz, model_Omega_z.value, zorder=100, marker="")
ax.plot(
    particle_rz[: len(particle_aaf)],
    particle_aaf["Omega_z"].value,
    ls="none",
    marker="o",
    mew=0,
    alpha=0.2,
    ms=3.0,
    zorder=1000,
)
# ax.plot(
#     im_rz.ravel(),
#     im_Omega_z.value,
#     ls="none",
#     marker="o"
# )
ax.set_ylabel(r"$\Omega_z$ " + f"[{tmp_aaf['Omega_z'].unit:latex_inline}]")

axes[1].plot(plot_rz, dens, zorder=100)
axes[1].plot(im_rz.ravel(), data["H"].ravel())
axes[1].set_yscale("log")

for ax in axes:
    ax.set_xlabel(r"$r_z$", fontsize=18)

Finally, we can plot a map of the inferred orbit shapes over the phase-space distribution:

In [None]:
grid_aaf = model.compute_action_angle(
    data["z"].ravel() * model.units["length"],
    data["vz"].ravel() * model.units["length"] / model.units["time"],
    params=res.params,
    N_grid=25,
)

_rzp, _tzp = model.z_vz_to_rz_theta_prime(
    data["z"].ravel(), data["vz"].ravel(), res.params
)
grid_rz = model.get_rz(_rzp, _tzp, res.params["e_params"])

In [None]:
fig, axes = plt.subplots(
    1, 2, figsize=(10, 5), sharex=True, sharey=True, constrained_layout=True
)

for ax in axes:
    ax.pcolormesh(
        data["vz"], data["z"], data["H"], cmap="Blues", norm=mpl.colors.LogNorm()
    )

cs = axes[0].contour(
    data["vz"],
    data["z"],
    grid_rz.reshape(data["z"].shape),
    colors="k",
    levels=np.linspace(0, 0.5, 11),
)
axes[0].clabel(cs, cs.levels, inline=True, fontsize=10)
axes[0].set_title("$r_z$")

cs = axes[1].contour(
    data["vz"],
    data["z"],
    np.sqrt(grid_aaf["J_z"].value).reshape(data["z"].shape),
    colors="k",
    levels=np.linspace(0, 0.5, 11),
)
axes[1].clabel(cs, cs.levels, inline=True, fontsize=10)
axes[1].set_title(r"$\sqrt{J_z}$")

# Compute Actions, Angles, Frequencies with the fitted model

With the model fitted to the orbital phase-space distribution, we can now use the model to compute empirical actions, angles, and frequencies for all (or a subset, for speed) stars that went into the initial histogram that we fit to:

In [None]:
model_aaf = model.compute_action_angle(
    particle_data["xyz"].astype(np.float64)[:100_000, 2],
    particle_data["v_xyz"].astype(np.float64)[:100_000, 2],
    res.params,
    21,
)
model_aaf[:3]

Let's compare our empirically-derived values to the "truth" from Agama:

In [None]:
agamas = [
    particle_data["J_z"][: len(model_aaf)].value,
    particle_data["Omega_z"][: len(model_aaf)].value,
    np.cos(particle_data["theta_z"][: len(model_aaf)]),
]
models = [
    model_aaf["J_z"].value,
    model_aaf["Omega_z"].value,
    np.cos(model_aaf["theta_z"]),
]
labels = ["$J_z$", r"$\Omega_z$", r"$\cos\theta_z$"]

fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True)
lims = [(0, 0.15), (0, 0.1), (-1, 1)]
for ax, lim, x1, x2, label in zip(axes, lims, agamas, models, labels):
    ax.hist2d(
        x1,
        x2,
        bins=np.linspace(*lim, 128),
        cmap="Greys",
        norm=mpl.colors.LogNorm(vmin=5e-1),
    )

    xx = np.linspace(*lim, 10)
    ax.plot(xx, xx, marker="", color="tab:green", ls="--", alpha=0.3)
    ax.set_xlim(*lim)
    ax.set_ylim(*lim)

    ax.set_xlabel(f"Agama {label}")
    ax.set_ylabel(f"empaf {label}")

Nice - those look great! 

In [None]:
agamas = [
    particle_data["J_z"][: len(model_aaf)].value,
    particle_data["Omega_z"][: len(model_aaf)].value,
    np.cos(particle_data["theta_z"][: len(model_aaf)])
    #     particle_data["theta_z"][:len(model_aaf)].value
]
models = [
    model_aaf["J_z"].value,
    model_aaf["Omega_z"].value,
    np.cos(model_aaf["theta_z"]),
    #     model_aaf["theta_z"].value % (2*np.pi)
]
labels = ["$J_z$", r"$\Omega_z$", r"$\cos\theta_z$"]

fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True)
lims = [(0, 0.15), (0, 0.1), (-1, 1)]
for ax, lim, x1, x2, label in zip(axes, lims, agamas, models, labels):
    ax.hist2d(
        x1,
        (x2 - x1),
        bins=(np.linspace(*lim, 128), np.linspace(-0.5 * lim[1], 0.5 * lim[1], 128)),
        cmap="Greys",
        norm=mpl.colors.LogNorm(vmin=5e-1),
    )

    ax.axhline(0, marker="", color="tab:green", ls="--", alpha=0.3)
    ax.set_xlim(*lim)
    # ax.set_ylim(-0.5, 0.5)
    ax.set_ylim(-0.3 * lim[1], 0.3 * lim[1])

    ax.set_xlabel(f"Agama {label}")
    ax.set_ylabel(f"(empaf - Agama) {label}")

axes[2].set_ylim(-0.3, 0.3)

In [None]:
JRs = particle_data["J_R"][: len(model_aaf)].value

labels = ["$J_z$", r"$\Omega_z$", r"$\cos\theta_z$"]

fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True)
lims = [(0, 0.15), (0, 0.1), (-1, 1)]
for ax, lim, x1, x2, label in zip(axes, lims, agamas, models, labels):
    ax.hist2d(
        JRs,
        (x2 - x1),
        bins=(np.linspace(*lim, 128), np.linspace(-0.5 * lim[1], 0.5 * lim[1], 128)),
        cmap="Greys",
        norm=mpl.colors.LogNorm(vmin=5e-1),
    )

    ax.axhline(0, marker="", color="tab:green", ls="--", alpha=0.3)
    ax.set_xlim(*lim)
    # ax.set_ylim(-0.5, 0.5)
    ax.set_ylim(-0.3 * lim[1], 0.3 * lim[1])

    ax.set_xlabel(f"Agama $J_R$")
    ax.set_ylabel(f"(empaf - Agama) {label}")

axes[2].set_ylim(-0.3, 0.3)

There is some bias at large and small frequency: this is because the distribution function we used to generate the particle data is nearly flat at $J_z\sim 0$, so there is no constraining power to measure the shapes of density contours. In the opposite regime, the method is limited by particle (shot) noise at large $J_z$ or small $\Omega_z$.

# Fitting a Label Model

In [None]:
from empaf.model import LabelOrbitModel
from empaf.plot import plot_data_models_label_residual

In [None]:
label_data = LabelOrbitModel.get_data_im(
    z=particle_data["z"].decompose(galactic).value,
    vz=particle_data["v_z"].decompose(galactic).value,
    label=particle_data["MG_FE"],
    bins={"z": np.linspace(-2.5, 2.5, 155), "vz": np.linspace(-0.1, 0.1, 155)},
)

In [None]:
plt.figure(figsize=(6, 5))
plt.pcolormesh(
    label_data["vz"],
    label_data["z"],
    label_data["label"],
    cmap="magma_r",
)
plt.xlabel("$v_z$")
plt.ylabel("$z$")
cb = plt.colorbar()

In [None]:
n_label_knots = 9


def label_func(rz, label_vals):
    # Knot locations, spaced equally in sqrt(r_z)
    xs = jnp.linspace(0, 1.0, n_label_knots) ** 2

    spl = InterpolatedUnivariateSpline(xs, label_vals, k=2)
    return spl(rz)

In [None]:
label_bounds = {
    "label_vals": (jnp.full(n_label_knots, -5.0), jnp.full(n_label_knots, 5.0))
}

In [None]:
label_model = LabelOrbitModel(
    label_func=label_func,
    e_funcs={2: e2_func, 4: e4_func},
    units=galactic,
)

In [None]:
label_params0 = label_model.get_params_init(
    vz=label_data["vz"] * u.kpc / u.Myr,
    z=label_data["z"] * u.kpc,
    label=label_data["label"],
    label_params0={"label_vals": np.zeros(n_label_knots)},
)

label_params0["e_params"] = params0["e_params"]

In [None]:
vlim = dict(vmin=0, vmax=0.25)

fig, axes = plt.subplots(
    1, 2, figsize=(11, 5), sharex=True, sharey=True, constrained_layout=True
)

cs = axes[0].pcolormesh(label_data["vz"], label_data["z"], label_data["label"], **vlim)

cs = axes[1].pcolormesh(
    label_data["vz"],
    label_data["z"],
    label_model.label(z=label_data["z"], vz=label_data["vz"], params=label_params0),
    **vlim
)
fig.colorbar(cs, ax=axes[:2])

axes[0].set_title("data")
axes[1].set_title("initial model")

In [None]:
label_model_bounds = {}

_dens0 = [0.01, 2] * u.Msun / u.pc**3
label_model_bounds["ln_Omega"] = np.log(
    np.sqrt(_dens0 * 4 * np.pi * G).to_value(1 / u.Myr)
)
label_model_bounds["z0"] = (-0.05, 0.05)
label_model_bounds["vz0"] = (-0.02, 0.02)

label_model_bounds["e_params"] = e_bounds
label_model_bounds["label_params"] = label_bounds

In [None]:
label_model.objective(params=label_params0, **label_data)

In [None]:
clean_mask = np.isfinite(label_data["label"]) & np.isfinite(label_data["label_err"])
clean_label_data = {k: v[clean_mask] for k, v in label_data.items()}

In [None]:
label_res = label_model.optimize(
    params0=label_params0, bounds=label_model_bounds, **clean_label_data
)
label_res.state

In [None]:
plot_data_models_label_residual(
    label_data, label_model, label_params0, label_res.params
);