TODO:
- compare inferred acceleration trends - which selection recovers local acceleration the best?
- also do different spatial/velocity/R selection and rerun

In this notebook, we will go over how to use the `DensityOrbitModel` to fit the vertical phase-space density of stars with a flexible model for the vertical orbit structure.

In [None]:
import copy
import os

from astropy.constants import G
import astropy.table as at
import astropy.coordinates as coord
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.potential as gp
import gala.integrate as gi
from gala.units import galactic

import jax
jax.config.update('jax_enable_x64', True)
import jax.numpy as jnp
from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline

from empaf import DensityOrbitModel
from empaf.plot import plot_data_models_residual
from empaf.model_helpers import generalized_logistic_func_alt

# Load test data

Load some particle data generated in an equilibrium galaxy model:

In [None]:
# particle_data = at.QTable.read('../test-data/agama-galaxymodel-particles.fits')
tbl = at.QTable.read('../test-data/agama-galaxymodel-particles-qIso.fits')

In [None]:
Jphi0 = 229*u.km/u.s * 8.3*u.kpc
R = np.sqrt(tbl['xyz'][:, 0]**2 + tbl['xyz'][:, 1]**2)
v_R = (tbl['xyz'][:, 0]*tbl['v_xyz'][:, 0] + tbl['xyz'][:, 1]*tbl['v_xyz'][:, 1]) / R
mask = (
    (np.abs(tbl['J_phi'] - Jphi0) < (1*u.kpc * 229*u.km/u.s)) 
    & (np.abs(R - 8.3*u.kpc) < 0.5*u.kpc)
    & (np.abs(v_R) < 15*u.km/u.s)
)
print(mask.sum())

particle_data = tbl[mask]

Bin the data and return bin locations and number counts:

In [None]:
data = DensityOrbitModel.get_data_im(
    z=particle_data["xyz"][:, 2].decompose(galactic).value,
    vz=particle_data["v_xyz"][:, 2].decompose(galactic).value,
    bins={"z": np.linspace(-2.5, 2.5, 155), "vz": np.linspace(-0.1, 0.1, 155)},
)

In [None]:
plt.pcolormesh(
    data["vz"], data["z"], data["H"], cmap="magma", norm=mpl.colors.LogNorm()
)
plt.xlabel("$v_z$")
plt.ylabel("$z$")

# Set up the internal model functions

To recap the math behind this method, our model of the vertical kinematics will fit the phase-space density with a function $n(r_z)$, or the statistics of a stellar label with a function that specifies the variation of a stellar label over the vertical kinematic phase space $f(r_z)$. The argument $r_z$ is an invariant along a density contour and serves as a proxy for the vertical action (it is closely related to the square-root of the vertical action $r_z \sim \sqrt{J_z}$). This "proxy action" radius $r_z$ is a latent parameter for each star or pixel used when fitting the model to data, but it is computed internally in the model using the elliptical (polar) coordinates ($r_z', \theta_z'$), which we can compute from vertical position $z$ and velocity $v_z$ given an axis ratio $\Omega_0$:

$$
\begin{align}
r_z' &= \sqrt{z^2\,\Omega_0 + v_z^2 / \Omega_0} \\
\tan\theta_z' &= \frac{z}{v_z} \, \Omega_0
\end{align}
$$

The "proxy action" radius $r_z$ is then assumed to be a Fourier distortion away from the elliptical polar radius as:

$$
r_z = r_z' \, (1 + \sum_m e_m(r_z') \, \cos\left(m \, \theta_z'\right)
$$

where $\Omega_0$ and the parameters of the functions $e_m(r_z')$ have to be determined from the data.

Both the density function $n(r_z)$ (or label function $f(r_z)$) and the functions $e_m(r_z')$ are flexible and left up to the user of `empaf`: These must be specified as a first step when constructing an orbit model instance. For the density function, we could put in a rigid (few parameter) model like an exponential or power law, but here we will take a more flexible approach and represent the (log) density function as a quadratic spline function with fixed knot locations:

In [None]:
n_dens_knots = 15
def ln_dens_func(rz, ln_dens_vals):
    # Knot locations, spaced equally in sqrt(r_z)
    xs = jnp.linspace(0, 1.0, n_dens_knots) ** 2
    
    spl = InterpolatedUnivariateSpline(xs, ln_dens_vals, k=2)
    return spl(rz)

So this model for the density function has 15 (`n_dens_knots=15`) parameters that specify the spline function values at the locations of the (fixed) knots in $r_z$.

For optimizing later on, we will need to specify parameter bounds for the knot values -- here, I allow the (ln) density knot values to be between -10 and 20.

In [None]:
ln_dens_bounds = {
    "ln_dens_vals": (
        jnp.full(n_dens_knots, -5.0),
        jnp.full(n_dens_knots, 25.0)
    )
}

We now need to specify functions to control the dependence of the distortion coefficients $e_m(r_z')$. We have found that allowing $m$ to be either $m=\{2, 4\}$ can produce good representations of realistic phase-space densities, so we will work with just these two terms. In general, both of the distortion coefficient functions for these terms should go to zero at $r_z'=0$, so $e_2(0) = e_4(0) = 0$. For disk galaxy vertical kinematics, the $m=2$ distortion coefficient function should increase with $r_z'$ and the $m=4$ coefficient function should decrease (to negative values) for larger $r_z'$. 

Below we will specify Python functions for the $e_2(r_z')$ and $e_4(r_z')$ functions, and we will set initial parameter values and bounds for the parameters of these functions. We will use splines for these functions as well, but we will use a custom implementation of a monotonic quadratic spline:

In [None]:
from empaf.model_helpers import monotonic_quadratic_spline

n_e2_knots = 11
n_e4_knots = 5


def e2_func(rzp, e2_vals):
    e2_knots = jnp.linspace(0, 1.0, n_e2_knots) ** 2
    vals = monotonic_quadratic_spline(
        e2_knots, jnp.concatenate((jnp.array([0.0]), e2_vals)), rzp
    )
    return vals


def e4_func(rzp, e4_vals):
    e4_knots = jnp.linspace(0, 1.0, n_e4_knots) ** 2
    vals = monotonic_quadratic_spline(
        e4_knots, jnp.concatenate((jnp.array([0.0]), e4_vals)), rzp
    )
    return -vals

In [None]:
e_params0 = {}
e_bounds = {}
# e_params0[2] = {"e2_vals": np.full(n_e2_knots - 1, 0.2)}
# e_params0[2] = {"e2_vals": np.linspace(1.5, 0.2, n_e2_knots - 1) / 0.6 * 0.2}
# e_params0[4] = {"e4_vals": np.full(n_e4_knots - 1, 0.08)}
e_params0[2] = {"e2_vals": np.zeros(n_e2_knots - 1)}
e_params0[4] = {"e4_vals": np.zeros(n_e4_knots - 1)}
e_bounds[2] = {"e2_vals": (np.full(n_e2_knots-1, 0), np.full(n_e2_knots-1, 10))}
e_bounds[4] = {"e4_vals": (np.full(n_e4_knots-1, 0), np.full(n_e4_knots-1, 10))}

Let's visualize the functions at the initial parameter values:

In [None]:
grid = np.linspace(0, 1, 128)
plt.plot(grid, e2_func(grid, **e_params0[2]))
plt.plot(grid, e4_func(grid, **e_params0[4]))
plt.xlabel("$r_z'$")
plt.ylabel("initial $e_m(r_z')$")

# Define the model

With functions specified for the log-density and $e_2$, $e_4$ coefficients, we are now ready to initialize a `DensityOrbitModel` instance:

In [None]:
model = DensityOrbitModel(
    ln_dens_func=ln_dens_func,
    e_funcs={2: e2_func, 4: e4_func},  # the keys are the "m" values
    unit_sys=galactic,
)

With functions defined and a unit system specified, we can use the model instance to estimate initial values for the density function parameters and other nuisance parameters of the model:

In [None]:
params0 = model.get_params_init(
    particle_data['xyz'][:, 2], particle_data['v_xyz'][:, 2], ln_dens_params0={'ln_dens_vals': np.zeros(15)}
)
params0

As noted in the warning above, when passing in custom $e_funcs$, you must define your own initial parameter values for the function parameters -- we did that above, so now let's store those in the `params0` dictionary:

In [None]:
params0['e_params'] = e_params0

The other parameter values returned by the `get_params_init()` method -- `z0`, `vz0`, `ln_Omega` -- control other aspects of the model:
- `z0`, `vz0` is zero-point location of the peak phase-space density (interpretable as the solar height and vertical velocity)
- `ln_Omega` is the log of the asymptotic frequency at the midplane (i.e. $z=0$)

Let's now visualize our initial model compared to the data: 

In [None]:
vlim = dict(
    norm=mpl.colors.LogNorm(vmax=3e4, vmin=1e-1), shading="auto"
)  # vmin=0, vmax=30)

fig, axes = plt.subplots(
    1, 2, figsize=(11, 5), sharex=True, sharey=True, constrained_layout=True
)

cs = axes[0].pcolormesh(data["vz"], data["z"], data["H"], **vlim)

cs = axes[1].pcolormesh(
    data["vz"],
    data["z"],
    np.exp(model.ln_density(z=data["z"], vz=data["vz"], params=params0)),
    **vlim
)
fig.colorbar(cs, ax=axes[:2])

axes[0].set_title("data")
axes[1].set_title("initial model")

Not a terrible initial guess, but clearly there are differences in the shapes of these distributions! Let's optimize the model. To do that, we have to specify bounds for all parameters in the model. We already specified bounds for the log-density function parameters and the $e_m$ function parameters, so below we will specify bounds for the other nuisance parameters (zero-point location and midplane frequency):

In [None]:
bounds = {}

_dens0 = [0.01, 2] * u.Msun / u.pc**3
bounds["ln_Omega"] = np.log(np.sqrt(_dens0 * 4 * np.pi * G).to_value(1 / u.Myr))
bounds["z0"] = (-0.05, 0.05)
bounds["vz0"] = (-0.02, 0.02)

bounds["e_params"] = e_bounds
bounds["ln_dens_params"] = ln_dens_bounds

Let's make sure the model evaluates to a finite value at our initial parameter guess:

In [None]:
model.objective(params0, data['z'], data['vz'], data['H'])

Now we are ready to optimize! 

In [None]:
res = model.optimize(
    params0=params0, bounds=bounds, jaxopt_kwargs={"tol": 1e-10}, **data
)
res.state

In [None]:
res.params

In [None]:
opt = optax.adam(1e-3)
solver = jaxopt.OptaxSolver(opt=opt, fun=model.objective, maxiter=8192)
res_adam = solver.run(params0, )

It looks like the optimizer succeeded! But does the fitted model look like a better representation of the phase-space density? Let's plot the data, initial model, fitted model, and residuals using a built-in convenience function:

In [None]:
fig, axes = plot_data_models_residual(data, model, params0, res.params)

What do the $e_m$ functions look like after fitting?

In [None]:
plot_rz = np.linspace(0, 1, 301)
es = model.get_es(plot_rz, res.params['e_params'])
for m, ee in es.items():
    plt.plot(plot_rz, ee, marker='', label=f"$m={m}$")

    plt.legend(fontsize=16)
plt.ylabel("$e_m(r_z')$ for $m=2,4$")

In [None]:
pars = res.params

Finally, let's make one more diagnostic plot:

In [None]:
# Compute model predicted density:
plot_rz = np.linspace(1e-3, 0.55, 101)
model_dens = np.exp(model.get_ln_dens(plot_rz, pars))

# Compute rz values at image pixel locations:
tmp_rzp, tmp_tzp = model.z_vz_to_rz_theta_prime(
    data["z"].astype(np.float64), data["vz"].astype(np.float64), pars
)
im_rz = model.get_rz(tmp_rzp, tmp_tzp, pars["e_params"])

# Compute model implicit Omega_z vs. r_z function:
tmp_z = np.array(
    [model.get_z(plot_rz[n], np.pi / 2, pars) for n in range(len(plot_rz))]
)
tmp_rzp = np.array(
    [model.get_rz_prime(plot_rz[n], 0.0, pars["e_params"]) for n in range(len(plot_rz))]
)

tmp_aaf = model.compute_action_angle(
    tmp_z * u.kpc, np.zeros_like(tmp_z) * u.km / u.s, pars, 101
)
model_Omega_z = tmp_aaf['Omega_z']
model_J_z = tmp_aaf['J_z']

# Compute Omega_z at image pixel locations:
tmp_aaf = model.compute_action_angle(
    data["z"].ravel() * u.kpc, data["vz"].ravel() * u.kpc / u.Myr, pars, 25
)
im_Omega_z = tmp_aaf['Omega_z']

In [None]:
# compute rz values for all particles:
tmp_rzp, tmp_tzp = model.z_vz_to_rz_theta_prime(
    particle_data["xyz"][:, 2].astype(np.float64),
    particle_data["v_xyz"][:, 2].astype(np.float64),
    pars,
)
particle_rz = model.get_rz(tmp_rzp, tmp_tzp, pars["e_params"])

particle_aaf = model.compute_action_angle(
    particle_data["xyz"].astype(np.float64)[:10_000, 2],
    particle_data["v_xyz"].astype(np.float64)[:10_000, 2],
    res.params,
    21,
)
particle_aaf[:3]

In [None]:
dens = np.exp(model.ln_density(tmp_z, np.zeros_like(tmp_z), res.params))

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5), constrained_layout=True, sharex=True)

ax = axes[0]
ax.plot(plot_rz, model_Omega_z.value, zorder=100, marker='')
ax.plot(
    particle_rz[:len(particle_aaf)],
    particle_aaf['Omega_z'].value,
    ls="none",
    marker="o",
    mew=0,
    alpha=0.2,
    ms=3.0,
    zorder=1000,
)
# ax.plot(
#     im_rz.ravel(),
#     im_Omega_z.value,
#     ls="none",
#     marker="o"
# )
ax.set_ylabel(r"$\Omega_z$ " + f"[{tmp_aaf['Omega_z'].unit:latex_inline}]")

axes[1].plot(plot_rz, dens, zorder=100)
axes[1].plot(im_rz.ravel(), data['H'].ravel())
axes[1].set_yscale("log")

for ax in axes:
    ax.set_xlabel(r"$r_z$", fontsize=18)

Finally, we can plot a map of the inferred orbit shapes over the phase-space distribution:

In [None]:
grid_aaf = model.compute_action_angle(
    data["z"].ravel() * model.unit_sys["length"],
    data["vz"].ravel() * model.unit_sys["length"] / model.unit_sys["time"],
    params=res.params,
    N_grid=25,
)

_rzp, _tzp = model.z_vz_to_rz_theta_prime(
    data["z"].ravel(), data["vz"].ravel(), res.params
)
grid_rz = model.get_rz(_rzp, _tzp, res.params["e_params"])

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5), sharex=True, sharey=True, constrained_layout=True)

for ax in axes:
    ax.pcolormesh(
        data['vz'],
        data['z'],
        data['H'],
        cmap='Blues',
        norm=mpl.colors.LogNorm()
    )

cs = axes[0].contour(
    data["vz"], data["z"], grid_rz.reshape(data["z"].shape),
    colors='k',
    levels=np.linspace(0, 0.5, 11)
)
axes[0].clabel(cs, cs.levels, inline=True, fontsize=10)
axes[0].set_title('$r_z$')

cs = axes[1].contour(
    data["vz"], data["z"], np.sqrt(grid_aaf["J_z"].value).reshape(data["z"].shape),
    colors='k',
    levels=np.linspace(0, 0.5, 11)
)
axes[1].clabel(cs, cs.levels, inline=True, fontsize=10)
axes[1].set_title(r'$\sqrt{J_z}$')

# Compute Actions, Angles, Frequencies with the fitted model

With the model fitted to the orbital phase-space distribution, we can now use the model to compute empirical actions, angles, and frequencies for all (or a subset, for speed) stars that went into the initial histogram that we fit to:

In [None]:
model_aaf = model.compute_action_angle(
    particle_data["xyz"].astype(np.float64)[:100_000, 2],
    particle_data["v_xyz"].astype(np.float64)[:100_000, 2],
    res.params,
    21,
)
model_aaf[:3]

Let's compare our empirically-derived values to the "truth" from Agama:

In [None]:
agamas = [
    particle_data["J_z"][:len(model_aaf)].value, 
    particle_data["Omega_z"][:len(model_aaf)].value, 
    np.cos(particle_data["theta_z"][:len(model_aaf)])
]
models = [
    model_aaf["J_z"].value,
    model_aaf["Omega_z"].value,
    np.cos(model_aaf["theta_z"]),
]
labels = ["$J_z$", r"$\Omega_z$", r"$\cos\theta_z$"]

fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True)
lims = [(0, 0.15), (0, 0.1), (-1, 1)]
for ax, lim, x1, x2, label in zip(axes, lims, agamas, models, labels):
    ax.hist2d(
        x1,
        x2,
        bins=np.linspace(*lim, 128),
        cmap="Greys",
        norm=mpl.colors.LogNorm(vmin=5e-1),
    )

    xx = np.linspace(*lim, 10)
    ax.plot(xx, xx, marker="", color="tab:green", ls="--", alpha=0.3)
    ax.set_xlim(*lim)
    ax.set_ylim(*lim)

    ax.set_xlabel(f"Agama {label}")
    ax.set_ylabel(f"empaf {label}")

Nice - those look great! 

In [None]:
agamas = [
    particle_data["J_z"][:len(model_aaf)].value, 
    particle_data["Omega_z"][:len(model_aaf)].value, 
    np.cos(particle_data["theta_z"][:len(model_aaf)])
#     particle_data["theta_z"][:len(model_aaf)].value
]
models = [
    model_aaf["J_z"].value,
    model_aaf["Omega_z"].value,
    np.cos(model_aaf["theta_z"]),
#     model_aaf["theta_z"].value % (2*np.pi)
]
labels = ["$J_z$", r"$\Omega_z$", r"$\cos\theta_z$"]

fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True)
lims = [(0, 0.15), (0, 0.1), (-1, 1)]
for ax, lim, x1, x2, label in zip(axes, lims, agamas, models, labels):
    ax.hist2d(
        x1,
        (x2 - x1),
        bins=(
            np.linspace(*lim, 128),
            np.linspace(-0.5*lim[1], 0.5*lim[1], 128)
        ),
        cmap="Greys",
        norm=mpl.colors.LogNorm(vmin=5e-1),
    )

    ax.axhline(0, marker="", color="tab:green", ls="--", alpha=0.3)
    ax.set_xlim(*lim)
    # ax.set_ylim(-0.5, 0.5)
    ax.set_ylim(-0.3*lim[1], 0.3*lim[1])

    ax.set_xlabel(f"Agama {label}")
    ax.set_ylabel(f"(empaf - Agama) {label}")
    
axes[2].set_ylim(-0.3, 0.3)

In [None]:
JRs = particle_data["J_R"][:len(model_aaf)].value

labels = ["$J_z$", r"$\Omega_z$", r"$\cos\theta_z$"]

fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True)
lims = [(0, 0.15), (0, 0.1), (-1, 1)]
for ax, lim, x1, x2, label in zip(axes, lims, agamas, models, labels):
    ax.hist2d(
        JRs,
        (x2 - x1),
        bins=(
            np.linspace(*lim, 128),
            np.linspace(-0.5*lim[1], 0.5*lim[1], 128)
        ),
        cmap="Greys",
        norm=mpl.colors.LogNorm(vmin=5e-1),
    )

    ax.axhline(0, marker="", color="tab:green", ls="--", alpha=0.3)
    ax.set_xlim(*lim)
    # ax.set_ylim(-0.5, 0.5)
    ax.set_ylim(-0.3*lim[1], 0.3*lim[1])

    ax.set_xlabel(f"Agama $J_R$")
    ax.set_ylabel(f"(empaf - Agama) {label}")
    
axes[2].set_ylim(-0.3, 0.3)

There is some bias at large and small frequency: this is because the distribution function we used to generate the particle data is nearly flat at $J_z\sim 0$, so there is no constraining power to measure the shapes of density contours. In the opposite regime, the method is limited by particle (shot) noise at large $J_z$ or small $\Omega_z$.

# Fitting a Label Model

In [None]:
from empaf.model import LabelOrbitModel
from empaf.plot import plot_data_models_label_residual

In [None]:
label_data = LabelOrbitModel.get_data_im(
    z=particle_data["z"].decompose(galactic).value,
    vz=particle_data["v_z"].decompose(galactic).value,
    label=particle_data["MG_FE"],
    bins={"z": np.linspace(-2.5, 2.5, 155), "vz": np.linspace(-0.1, 0.1, 155)},
)

In [None]:
plt.figure(figsize=(6, 5))
plt.pcolormesh(
    label_data["vz"], label_data["z"], label_data["label"], cmap="magma_r", 
)
plt.xlabel("$v_z$")
plt.ylabel("$z$")
cb = plt.colorbar()

In [None]:
n_label_knots = 9
def label_func(rz, label_vals):
    # Knot locations, spaced equally in sqrt(r_z)
    xs = jnp.linspace(0, 1.0, n_label_knots) ** 2
    
    spl = InterpolatedUnivariateSpline(xs, label_vals, k=2)
    return spl(rz)

In [None]:
label_bounds = {
    "label_vals": (
        jnp.full(n_label_knots, -5.0),
        jnp.full(n_label_knots, 5.0)
    )
}

In [None]:
label_model = LabelOrbitModel(
    label_func=label_func,
    e_funcs={2: e2_func, 4: e4_func},
    unit_sys=galactic,
)

In [None]:
label_params0 = label_model.get_params_init(
    vz=label_data["vz"] * u.kpc/u.Myr, z=label_data["z"] * u.kpc, label=label_data['label'],
    label_params0={"label_vals": np.zeros(n_label_knots)}
)

label_params0['e_params'] = params0['e_params']

In [None]:
vlim = dict(
    vmin=0, vmax=0.25
)

fig, axes = plt.subplots(
    1, 2, figsize=(11, 5), sharex=True, sharey=True, constrained_layout=True
)

cs = axes[0].pcolormesh(label_data["vz"], label_data["z"], label_data["label"], **vlim)

cs = axes[1].pcolormesh(
    label_data["vz"],
    label_data["z"],
    label_model.label(z=label_data["z"], vz=label_data["vz"], params=label_params0),
    **vlim
)
fig.colorbar(cs, ax=axes[:2])

axes[0].set_title("data")
axes[1].set_title("initial model")

In [None]:
label_model_bounds = {}

_dens0 = [0.01, 2] * u.Msun / u.pc**3
label_model_bounds["ln_Omega"] = np.log(np.sqrt(_dens0 * 4 * np.pi * G).to_value(1 / u.Myr))
label_model_bounds["z0"] = (-0.05, 0.05)
label_model_bounds["vz0"] = (-0.02, 0.02)

label_model_bounds["e_params"] = e_bounds
label_model_bounds["label_params"] = label_bounds

In [None]:
label_model.objective(params=label_params0, **label_data)

In [None]:
clean_mask = np.isfinite(label_data['label']) & np.isfinite(label_data['label_err'])
clean_label_data = {k: v[clean_mask] for k, v in label_data.items()}

In [None]:
label_res = label_model.optimize(
    params0=label_params0,
    bounds=label_model_bounds,
    **clean_label_data
)
label_res.state

In [None]:
plot_data_models_label_residual(label_data, label_model, label_params0, label_res.params);