In [None]:
import copy
import os

from astropy.constants import G
import astropy.table as at
import astropy.coordinates as coord
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.potential as gp
import gala.integrate as gi
from gala.units import galactic

import jax
jax.config.update('jax_enable_x64', True)
import jax.numpy as jnp
from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline

from empaf import DensityOrbitModel
from empaf.plot import plot_data_models_residual
from empaf.model_helpers import generalized_logistic_func_alt

# Load test data

Load some particle data generated in an equilibrium galaxy model:

In [None]:
particle_data = at.QTable.read('../test-data/agama-galaxymodel-particles.fits')

Bin the data and return bin locations and number counts:

In [None]:
data = DensityOrbitModel.get_data_im(
    particle_data["z"].decompose(galactic).value,
    particle_data["v_z"].decompose(galactic).value,
    bins={"z": np.linspace(-2.5, 2.5, 151), "vz": np.linspace(-0.1, 0.1, 151)},
)

In [None]:
plt.pcolormesh(
    data["vz"], data["z"], data["H"], cmap="magma", norm=mpl.colors.LogNorm()
)
plt.xlabel("$v_z$")
plt.ylabel("$z$")

# Set up the internal model functions

To recap the math behind this method, our model of the vertical kinematics will fit the phase-space density with a function $n(r_z)$, or the statistics of a stellar label with a function that specifies the variation of a stellar label over the vertical kinematic phase space $f(r_z)$. The argument $r_z$ is an invariant along a density contour and serves as a proxy for the vertical action (it is closely related to the square-root of the vertical action $r_z \sim \sqrt{J_z}$). This "proxy action" radius $r_z$ is a latent parameter for each star or pixel used when fitting the model to data, but it is computed internally in the model using the elliptical (polar) coordinates ($r_z', \theta_z'$), which we can compute from vertical position $z$ and velocity $v_z$ given an axis ratio $\Omega_0$:

$$
\begin{align}
r_z' &= \sqrt{z^2\,\Omega_0 + v_z^2 / \Omega_0} \\
\tan\theta_z' &= \frac{z}{v_z} \, \Omega_0
\end{align}
$$

The "proxy action" radius $r_z$ is then assumed to be a Fourier distortion away from the elliptical polar radius as:

$$
r_z = r_z' \, (1 + \sum_m e_m(r_z') \, \cos\left(m \, \theta_z'\right)
$$

where $\Omega_0$ and the parameters of the functions $e_m(r_z')$ have to be determined from the data.

Both the density function $n(r_z)$ (or label function $f(r_z)$) and the functions $e_m(r_z')$ are flexible and left up to the user of `empaf`: These must be specified as a first step when constructing an orbit model instance. For the density function, we could put in a rigid (few parameter) model like an exponential or power law, but here we will take a more flexible approach and represent the (log) density function as a quadratic spline function with fixed knot locations:

In [None]:
n_dens_knots = 15
def ln_dens_func(rz, ln_dens_vals):
    # Knot locations, spaced equally in sqrt(r_z)
    xs = jnp.linspace(0, 1.0, n_dens_knots) ** 2
    
    spl = InterpolatedUnivariateSpline(xs, ln_dens_vals, k=2)
    return spl(rz)

So this model for the density function has 15 (`n_dens_knots=15`) parameters that specify the spline function values at the locations of the (fixed) knots in $r_z$.

For optimizing later on, we will need to specify parameter bounds for the knot values -- here, I allow the (ln) density knot values to be between -10 and 20.

In [None]:
ln_dens_bounds = {
    "ln_dens_vals": (
        jnp.full(n_dens_knots, -5.0),
        jnp.full(n_dens_knots, 25.0)
    )
}

We now need to specify functions to control the dependence of the distortion coefficients $e_m(r_z')$. We have found that allowing $m$ to be either $m=\{2, 4\}$ can produce good representations of realistic phase-space densities, so we will work with just these two terms. In general, both of the distortion coefficient functions for these terms should go to zero at $r_z'=0$, so $e_2(0) = e_4(0) = 0$. For disk galaxy vertical kinematics, the $m=2$ distortion coefficient function should increase with $r_z'$ and the $m=4$ coefficient function should decrease (to negative values) for larger $r_z'$. 

Below we will specify Python functions for the $e_2(r_z')$ and $e_4(r_z')$ functions, and we will set initial parameter values and bounds for the parameters of these functions:

$$
A \, \left[\tanh\left(\left[\frac{x-x_0}{h}\right]^{1/\alpha}\right) \right]^\alpha
$$

In [None]:
from empaf.model_helpers import monotonic_poly_func_alt

# def e2_func(rzp, f1, alpha, ln_x0):
# #     return custom_tanh_func_alt(rzp, f_xval=f1, alpha=alpha, x0=np.exp(ln_x0), xval=1.)
#     return monotonic_poly_func_alt(rzp, f0=0.0, fx=f1, alpha=alpha, x0=jnp.exp(ln_x0), xval=1.0)

# def e2_func(rzp, rz0, f1, B, ln_C, ln_nu):
#     return generalized_logistic_func_alt(
#         rzp, t0=rz0, F1=f1, B=B, C=jnp.exp(ln_C), nu=jnp.exp(ln_nu)
#     )

n_e2_knots = 11
def e2_func(rzp, e2_vals):
    e2_knots = jnp.linspace(0, 1.0, n_e2_knots) ** 2
    spl = InterpolatedUnivariateSpline(
        e2_knots, jnp.cumsum(jnp.concatenate((jnp.array([0.]), e2_vals))), k=1
    )
    return spl(rzp)


n_e4_knots = 5
def e4_func(rzp, e4_vals):
    e4_knots = jnp.linspace(0, 1.0, n_e4_knots) ** 2
    spl = InterpolatedUnivariateSpline(
        e4_knots, -jnp.cumsum(jnp.concatenate((jnp.array([0.]), e4_vals))), k=1
    )
    return spl(rzp)

In [None]:
e_params0 = {}
e_bounds = {}

# e_params0[2] = {"f1": 0.1, "alpha": 0.33, "ln_x0": 1.0}
# e_bounds[2] = {"f1": (0, 0.8), "alpha": (0.2, 0.5), "ln_x0": (1, 8.0)}
# e_params0[2] = {"rz0": 0.5, "f1": 0.1, "B": -2, "ln_C": -4, "ln_nu": np.log(0.4)}
# e_bounds[2] = {"rz0": (-5, 5), "f1": (0, 1.), "B": (-10, 10), "ln_C": (-10, 4), "ln_nu": (-15, 15)}
e_params0[2] = {"e2_vals": np.full(n_e2_knots - 1, 0.2 / n_e2_knots)}
e_bounds[2] = {"e2_vals": (np.full(n_e2_knots-1, 0), np.full(n_e2_knots-1, 0.2))}

e_params0[4] = {"e4_vals": np.full(n_e4_knots - 1, 0.08 / n_e4_knots)}
e_bounds[4] = {"e4_vals": (np.full(n_e4_knots-1, 0), np.full(n_e4_knots-1, 0.2))}

# e_params0[4] = {"f1": -0.06}
# e_bounds[4] = {"f1": (-0.5, 0)}

In [None]:
grid = np.linspace(0, 1, 128)
plt.plot(grid, e2_func(grid, **e_params0[2]))
plt.plot(grid, e4_func(grid, **e_params0[4]))
plt.xlabel("$r_z'$")
plt.ylabel("initial $e_m(r_z')$")

We won't

# Define the model

In [None]:
model = DensityOrbitModel(
    ln_dens_func=ln_dens_func,
    e_funcs={2: e2_func, 4: e4_func},
    unit_sys=galactic,
)

In [None]:
params0 = model.get_params_init(
    particle_data['z'], particle_data['v_z'], ln_dens_params0={'ln_dens_vals': np.zeros(15)}
)
params0['ln_dens_params']['ln_dens_vals'] = jnp.abs(params0['ln_dens_params']['ln_dens_vals'])
params0['e_params'] = e_params0

In [None]:
check, _ = model.check_e_funcs(params0["e_params"])
assert check

In [None]:
vlim = dict(
    norm=mpl.colors.LogNorm(vmax=3e4, vmin=1e-1), shading="auto"
)  # vmin=0, vmax=30)

fig, axes = plt.subplots(
    1, 2, figsize=(11, 5), sharex=True, sharey=True, constrained_layout=True
)

cs = axes[0].pcolormesh(data["vz"], data["z"], data["H"], **vlim)

cs = axes[1].pcolormesh(
    data["vz"],
    data["z"],
    np.exp(model.ln_density(z=data["z"], vz=data["vz"], params=params0)),
    **vlim
)
fig.colorbar(cs, ax=axes[:2])

axes[0].set_title("data")
axes[1].set_title("initial model")

In [None]:
bounds = {}

_dens0 = [0.01, 2] * u.Msun / u.pc**3
bounds["ln_Omega"] = np.log(np.sqrt(_dens0 * 4 * np.pi * G).to_value(1 / u.Myr))
bounds["z0"] = (-0.05, 0.05)
bounds["vz0"] = (-0.02, 0.02)

bounds["e_params"] = e_bounds
bounds["ln_dens_params"] = ln_dens_bounds

In [None]:
bounds_l, bounds_r = model.unpack_bounds(bounds)

In [None]:
model.objective(bounds_l, data['z'], data['vz'], data['H'])

In [None]:
model.objective(bounds_r, data['z'], data['vz'], data['H'])

In [None]:
model.objective(params0, data['z'], data['vz'], data['H'])

In [None]:
res = model.optimize(
    params0=params0,
    bounds=bounds, 
    jaxopt_kwargs=dict(options=dict(maxls=1000, disp=False), tol=1e-8),
    **data
)
res.state

In [None]:
res.params

In [None]:
fig, axes = plot_data_models_residual(data, model, params0, res.params, smooth_residual=1.)

In [None]:
# model.compute_action_angle([0.]*u.kpc, [0]*u.km/u.s, res.params, 101)

In [None]:
# model.compute_action_angle(np.zeros(1000)*u.kpc, np.zeros(1000)*u.km/u.s, res.params, 101)

In [None]:
plot_rz = np.linspace(0, 1, 301)
es = model.get_es(plot_rz, res.params['e_params'])
for n, ee in es.items():
    plt.plot(plot_rz, ee, marker='')

In [None]:
pars = res.params

plot_rz = np.linspace(1e-3, 0.55, 101)
dens = np.exp(model.get_ln_dens(plot_rz, pars))
tmp_z = np.array(
    [model.get_z(plot_rz[n], np.pi / 2, pars) for n in range(len(plot_rz))]
)
tmp_rzp = np.array(
    [
        model.get_rz_prime(plot_rz[n], 0.0, pars["e_params"])
        for n in range(len(plot_rz))
    ]
)

tmp_aaf = model.compute_action_angle(
    tmp_z * u.kpc, np.zeros_like(tmp_z) * u.km / u.s, pars, 101
)
es = model.get_es(tmp_rzp, pars["e_params"])

tmp_rzp, tmp_tzp = model.z_vz_to_rz_theta_prime(
    particle_data['z'].astype(np.float64), 
    particle_data['v_z'].astype(np.float64), 
    pars
)
agama_rz = model.get_rz(tmp_rzp, tmp_tzp, pars["e_params"])

fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True, sharex=True)

ax = axes[0]
ax.plot(plot_rz, tmp_aaf["Omega_z"].value, zorder=100)
# ax.plot(
#     agama_rz,
#     agama_aaf["Omega_z"].value,
#     ls="none",
#     marker="o",
#     mew=0,
#     alpha=0.2,
#     ms=1.0,
#     zorder=1,
# )
ax.set_ylabel(r"$\Omega_z$ " + f"[{tmp_aaf['Omega_z'].unit:latex_inline}]")
# ax.axhline(np.exp(pars["ln_Omega"]), color="tab:orange", ls="--", alpha=0.4)
# ax.axhline(asym_freq.value, color="tab:green", ls="--", alpha=0.4)

for n in es:
    axes[1].plot(plot_rz, es[n], label=f"$e_{n}$")
axes[1].set_ylabel("$e_m(r_z')$")

axes[2].plot(plot_rz, dens)
axes[2].set_yscale("log")

axes[1].legend()

for ax in axes:
    ax.set_xlabel(r"$r_z$")

In [None]:
# plt.scatter(es[2], tmp_aaf['Omega_z'])

In [None]:
pars = res.params

sqrtOm = np.sqrt(np.exp(pars["ln_Omega"]))
plot_rzp = np.linspace(0, 3, 101) * sqrtOm
es = model.get_es(plot_rzp, pars["e_params"])

tmp_rz = model.get_rz(plot_rzp, np.zeros_like(plot_rzp), pars["e_params"])
dens = np.exp(model.get_ln_dens(tmp_rz, pars))
tmp_aaf = model.compute_action_angle(
    plot_rzp / sqrtOm * u.kpc, np.zeros_like(plot_rzp) * u.km / u.s, pars, 11
)
sqrtJz = np.sqrt(tmp_aaf["J_z"].value)

fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True, sharex=True)

ax = axes[0]
ax.plot(sqrtJz, tmp_aaf["Omega_z"].value, zorder=100)
ax.plot(
    np.sqrt(agama_aaf["J_z"].value),
    agama_aaf["Omega_z"].value,
    ls="none",
    marker="o",
    mew=0,
    alpha=0.2,
    ms=1.0,
    zorder=1,
)
ax.set_ylabel(r"$\Omega_z$ " + f"[{tmp_aaf['Omega_z'].unit:latex_inline}]")
ax.axhline(np.exp(pars["ln_Omega"]), color="tab:orange", ls="--", alpha=0.4)
ax.axhline(asym_freq.value, color="tab:green", ls="--", alpha=0.4)

for n in es:
    axes[1].plot(sqrtJz, es[n], label=f"$e_{n}$")
axes[2].plot(sqrtJz, dens)
axes[2].set_yscale("log")

axes[1].legend()

for ax in axes:
    ax.set_xlabel(r"$\sqrt{J_z}'$")

In [None]:
grid_aaf = model.compute_action_angle(
    data["z"].ravel() * model.unit_sys["length"],
    data["vz"].ravel() * model.unit_sys["length"] / model.unit_sys["time"],
    params=res.params,
    N_grid=11,
)

_rzp, _tzp = model.z_vz_to_rz_theta_prime(
    data["z"].ravel(), data["vz"].ravel(), res.params
)
grid_rz = model.get_rz(_rzp, _tzp, res.params["e_params"])

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5), sharex=True, sharey=True, constrained_layout=True)

for ax in axes:
    ax.pcolormesh(
        data['vz'],
        data['z'],
        data['H'],
        cmap='Blues',
        norm=mpl.colors.LogNorm()
    )

cs = axes[0].contour(
    data["vz"], data["z"], grid_rz.reshape(data["z"].shape),
    # cmap='turbo',
    colors='k',
    levels=np.linspace(0, 0.5, 11)
)
axes[0].clabel(cs, cs.levels, inline=True, fontsize=10)
axes[0].set_title('$r_z$')

cs = axes[1].contour(
    data["vz"], data["z"], np.sqrt(grid_aaf["J_z"].value).reshape(data["z"].shape),
    # cmap='turbo',
    colors='k',
    levels=np.linspace(0, 0.5, 11)
)
axes[1].clabel(cs, cs.levels, inline=True, fontsize=10)
axes[1].set_title(r'$\sqrt{J_z}$')

Zoom in:

In [None]:
zgrid, vzgrid = np.meshgrid(np.linspace(-0.2, 0.2, 128), np.linspace(-0.014, 0.014, 128))

zoom_grid_aaf = model.compute_action_angle(
    zgrid.ravel() * u.kpc,
    vzgrid.ravel() * u.kpc/u.Myr,
    params=res.params,
    N_grid=21,
)

_rzp, _tzp = model.z_vz_to_rz_theta_prime(
    zgrid.ravel(), vzgrid.ravel(), res.params
)
zoom_grid_rz = model.get_rz(_rzp, _tzp, res.params["e_params"])

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5), sharex=True, sharey=True, constrained_layout=True)

for ax in axes:
    ax.pcolormesh(
        data['vz'],
        data['z'],
        data['H'],
        cmap='Blues',
        norm=mpl.colors.LogNorm(),
        zorder=-100
    )

cs = axes[0].contour(
    vzgrid, zgrid, zoom_grid_rz.reshape(zgrid.shape),
    # cmap='turbo',
    colors='k',
    levels=np.linspace(0, 0.07, 15),
    zorder=100
)
axes[0].clabel(cs, cs.levels, inline=True, fontsize=10)
axes[0].set_title('$r_z$')

cs = axes[1].contour(
    vzgrid, zgrid, np.sqrt(zoom_grid_aaf["J_z"].value).reshape(zgrid.shape),
    # cmap='turbo',
    colors='k',
    levels=np.linspace(0, 0.07, 15), 
    zorder=100
)
axes[1].clabel(cs, cs.levels, inline=True, fontsize=10, zorder=100)
axes[1].set_title(r'$\sqrt{J_z}$')

axes[0].set_xlim(vzgrid.min(), vzgrid.max())
axes[0].set_ylim(zgrid.min(), zgrid.max())

In [None]:
# plot_rz = model.get_rz(plot_rzp, np.zeros_like(plot_rzp), res.params['e_params'])
# # plt.plot(plot_rzp, np.sqrt(tmp_aaf['J_z']))
# plt.plot(plot_rz, np.sqrt(tmp_aaf['J_z']))
# plt.xlabel('$r_z$')
# plt.ylabel(r'$\sqrt{J_z}$')

In [None]:
# plt.plot(np.sqrt(grid_Jzs), grid_Omega_zs)
# plt.plot(sqrtJz, tmp_aaf["Omega_z"].value)

In [None]:
# model_H = np.exp(model.ln_density(z=data["z"], vz=data["vz"], params=res.params))

In [None]:
# H, xe = np.histogram(np.sqrt(agama_aaf['J_z'].value), bins=np.linspace(0, 0.5, 101))
# xc = 0.5 * (xe[:-1] + xe[1:])

# huh = np.diff(data['vz'][0])[0] * np.diff(data['z'][:, 0])[0]

# plt.plot(xc, H / np.diff(xe) * huh / np.sqrt(xc))
# plt.yscale('log')

In [None]:
# thp = np.linspace(0, 2*np.pi, 256)
# for rzp in np.linspace(0, 3.0, 16):
#     plt.plot(thp, model.get_rz(rzp, thp), marker='')

### Compute AAF

In [None]:
model_aaf = model.compute_action_angle(
    # tbl["z"].astype(np.float64)[:1_000_000],
    # tbl["vz"].astype(np.float64)[:1_000_000],
    particle_data["z"].astype(np.float64)[:100_000],
    particle_data["v_z"].astype(np.float64)[:100_000],
    res.params,
    21,
)
model_aaf[:3]

In [None]:
agamas = [
    agama_aaf["J_z"][:len(model_aaf)], 
    agama_aaf["Omega_z"][:len(model_aaf)], 
    np.cos(agama_aaf["theta_z"][:len(model_aaf)])
]
models = [
    model_aaf["J_z"].value,
    model_aaf["Omega_z"].value,
    np.cos(model_aaf["theta_z"]),
]
labels = ["$J_z$", r"$\Omega_z$", r"$\cos\theta_z$"]

fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True)
lims = [(0, 0.15), (0, 0.1), (-1, 1)]
for ax, lim, x1, x2, label in zip(axes, lims, agamas, models, labels):
    ax.hist2d(
        x1,
        x2,
        bins=np.linspace(*lim, 128),
        cmap="Greys",
        norm=mpl.colors.LogNorm(vmin=5e-1),
    )

    xx = np.linspace(*lim, 10)
    ax.plot(xx, xx, marker="", color="tab:green", ls="--", alpha=0.3)
    ax.set_xlim(*lim)
    ax.set_ylim(*lim)

    ax.set_xlabel(f"Agama {label}")
    ax.set_ylabel(f"empaf {label}")

In [None]:
agamas = [
    agama_aaf["J_z"][:len(model_aaf)], 
    agama_aaf["Omega_z"][:len(model_aaf)], 
    # np.cos(agama_aaf["theta_z"][:len(model_aaf)])
    agama_aaf["theta_z"][:len(model_aaf)]
]
models = [
    model_aaf["J_z"].value,
    model_aaf["Omega_z"].value,
    # np.cos(model_aaf["theta_z"]),
    model_aaf["theta_z"].value % (2*np.pi)
]
labels = ["$J_z$", r"$\Omega_z$", r"$\cos\theta_z$"]

fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True)
# lims = [(0, 0.15), (0, 0.1), (-1, 1)]
lims = [(0, 0.15), (0, 0.1), (0, 2*np.pi)]
for ax, lim, x1, x2, label in zip(axes, lims, agamas, models, labels):
    ax.hist2d(
        x1,
        (x2 - x1),
        bins=(
            np.linspace(*lim, 128),
            np.linspace(-0.5*lim[1], 0.5*lim[1], 128)
        ),
        cmap="Greys",
        norm=mpl.colors.LogNorm(vmin=5e-1),
    )

    ax.axhline(0, marker="", color="tab:green", ls="--", alpha=0.3)
    ax.set_xlim(*lim)
    # ax.set_ylim(-0.5, 0.5)
    ax.set_ylim(-0.3*lim[1], 0.3*lim[1])

    ax.set_xlabel(f"Agama {label}")
    ax.set_ylabel(f"empaf {label}")
    
axes[2].set_ylim(-0.3, 0.3)

---

Compare a 10% different Agama potential to truth:

In [None]:
agama_pot2 = agama.Potential(
    dict(type='miyamotonagai', mass=6.8e10 * 1.1, scaleradius=3.0, scaleheight=0.28 * 0.9),
    dict(type='dehnen', mass=5.00e9, scaleradius=1.0),
    dict(type='dehnen', mass=1.71e9, scaleradius=0.07),
    dict(type='nfw',    mass=5.4e11, scaleradius=15.62),
)

act_finder2 = agama.ActionFinder(agama_pot2)
agama_act2, agama_ang2, agama_freq2 = act_finder2(xv, angles=True)
agama_aaf2 = at.Table({
    "J_z": agama_act2[:, 1],
    "theta_z": agama_ang2[:, 1],
    "Omega_z": agama_freq2[:, 1],
    "T_z": 2*np.pi / agama_freq2[:, 1]
})

In [None]:
agamas = [agama_aaf['J_z'], agama_aaf['Omega_z'], np.cos(agama_aaf['theta_z'])]
models = [agama_aaf2['J_z'], agama_aaf2['Omega_z'], np.cos(agama_aaf2['theta_z'])]
labels = ["$J_z$", r"$\Omega_z$", r"$\cos\theta_z$"]

fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True)
lims = [(0, 0.15), (0, 0.1), (-1, 1)]
for ax, lim, x1, x2, label in zip(axes, lims, agamas, models, labels):
    ax.hist2d(x1, x2, bins=np.linspace(*lim, 128), 
              cmap='Greys', norm=mpl.colors.LogNorm(vmin=1e1))
    
    xx = np.linspace(*lim, 10)
    ax.plot(xx, xx, marker='', color='tab:green', ls='--', alpha=0.3)
    ax.set_xlim(*lim)
    ax.set_ylim(*lim)
    
    ax.set_xlabel(f"agama {label}")
    ax.set_ylabel(f"agama2 {label}")