In [None]:
import copy
import os

from astropy.constants import G
import astropy.table as at
import astropy.coordinates as coord
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.potential as gp
import gala.integrate as gi
from gala.units import galactic

import jax
jax.config.update('jax_enable_x64', True)
import jax.numpy as jnp

from empaf import DensityOrbitModel
from empaf.plot import plot_data_models_residual

In [None]:
import sys
# sys.path.append('/mnt/home/apricewhelan/downloads/Agama-zone/')
sys.path.append('/Users/apricewhelan/projects/others/Agama-zone/')
import agama
agama.setUnits(mass=u.Msun, length=u.kpc, time=u.Myr)

In [None]:
gala_pot = gp.MilkyWayPotential()
agama_pot = agama.Potential(
    dict(type='miyamotonagai', mass=6.8e10, scaleradius=3.0, scaleheight=0.28),
    dict(type='dehnen', mass=5.00e9, scaleradius=1.0),
    dict(type='dehnen', mass=1.71e9, scaleradius=0.07),
    dict(type='nfw',    mass=5.4e11, scaleradius=15.62),
)

In [None]:
xv = np.load('../test-data/agama-galaxymodel-df-small.npy')

In [None]:
act_finder = agama.ActionFinder(agama_pot)
agama_act, agama_ang, agama_freq = act_finder(xv, angles=True)
agama_aaf = at.Table({
    "J_z": agama_act[:, 1],
    "theta_z": agama_ang[:, 1],
    "Omega_z": agama_freq[:, 1],
    "T_z": 2*np.pi / agama_freq[:, 1]
})

In [None]:
tbl = at.QTable()
tbl['z'] = xv[:, 2] * u.kpc
tbl['vz'] = xv[:, 5] * u.kpc/u.Myr

bins = (np.linspace(-0.1, 0.1, 151), np.linspace(-2.5, 2.5, 151))
plt.hist2d(
    tbl['vz'].value,
    tbl['z'].value,
    bins=bins,
    norm=mpl.colors.LogNorm(),
)
plt.xlim(bins[0].min(), bins[0].max())
plt.ylim(bins[1].min(), bins[1].max())
plt.xlabel("$v_z$")
plt.ylabel("$z$")

In [None]:
init_model = DensityOrbitModel(
    e_signs={2: 1.0, 4: -1.0}, # , 6: -1.0, 8: -1.0},
    unit_sys=galactic,
)

In [None]:
params0 = {}
params0["e_params"] = {m: {} for m in [2, 4]}

# params0["e_params"][2]["A"] = 0.1
params0["e_params"][2]["f1"] = 0.1
params0["e_params"][2]["alpha"] = 0.33
params0["e_params"][2]["x0"] = 3.0

# params0["e_params"][4]["A"] = 0.05
params0["e_params"][4]["f1"] = 0.02
params0["e_params"][4]["alpha"] = 0.45
params0["e_params"][4]["x0"] = 3.0

params0["ln_dens_params"] = {"f0": 8.0, "f1": -10.0, "alpha": 0.6, "x0": 5.0}

params0["ln_Omega"] = np.log(0.06)
params0["z0"] = 0.0
params0["vz0"] = 0.0

In [None]:
# model0 = init_model.get_params_init(tbl['z'], tbl['vz'])
# model0 = init_model.copy()
# model0.set_state(params0)

# model0.state = params0.copy()
# model0.state['Omega'] = np.exp(model0.state['ln_Omega'])

In [None]:
im_bins = {'z': np.linspace(-2, 2, 211)}
im_bins['vz'] = im_bins['z'] * np.exp(params0['ln_Omega'])
data_H = init_model.get_data_im(tbl['z'], tbl['vz'], im_bins)

In [None]:
vlim = dict(
    norm=mpl.colors.LogNorm(vmax=3e4, vmin=1e-1), shading="auto"
)  # vmin=0, vmax=30)

fig, axes = plt.subplots(
    1, 2, figsize=(11, 5), sharex=True, sharey=True, constrained_layout=True
)

cs = axes[0].pcolormesh(data_H["vz"], data_H["z"], data_H["H"], **vlim)

cs = axes[1].pcolormesh(
    data_H["vz"],
    data_H["z"],
    np.exp(init_model.ln_density(z=data_H["z"], vz=data_H["vz"], params=params0)),
    **vlim
)
fig.colorbar(cs, ax=axes[:2])

axes[0].set_title("data")
axes[1].set_title("initial model")

In [None]:
init_model.get_aaf(
    tbl["z"].astype(np.float64)[:3],
    tbl["vz"].astype(np.float64)[:3],
    params0,
    101
)

In [None]:
bounds_l = {}
bounds_l["ln_Omega"] = np.log(
    np.sqrt(0.01 * u.Msun / u.pc**3 * 4 * np.pi * G).to_value(1 / u.Myr)
)

bounds_l["z0"] = -0.05
bounds_l["vz0"] = -0.02

bounds_l["e_params"] = {m: {} for m in [2, 4]}

# bounds_l["e_params"][2]["A"] = 0.0
bounds_l["e_params"][2]["f1"] = 0.0
bounds_l["e_params"][2]["alpha"] = 0.25
bounds_l["e_params"][2]["x0"] = 2.0

# bounds_l["e_params"][4]["A"] = 0.0
bounds_l["e_params"][4]["f1"] = 0.0
bounds_l["e_params"][4]["alpha"] = 0.25
bounds_l["e_params"][4]["x0"] = 2.0

bounds_l["ln_dens_params"] = {"f0": 0, "f1": -40.0, "alpha": 0.3, "x0": 3.0}


bounds_r = {}
bounds_r["ln_Omega"] = np.log(
    np.sqrt(2 * u.Msun / u.pc**3 * 4 * np.pi * G).to_value(1 / u.Myr)
)
bounds_r["z0"] = abs(bounds_l["z0"])
bounds_r["vz0"] = abs(bounds_l["vz0"])

bounds_r["e_params"] = {m: {} for m in [2, 4]}

# bounds_r["e_params"][2]["A"] = 0.7
bounds_r["e_params"][2]["f1"] = 0.7
bounds_r["e_params"][2]["alpha"] = 0.5
bounds_r["e_params"][2]["x0"] = 30.0

# bounds_r["e_params"][4]["A"] = 0.7
bounds_r["e_params"][4]["f1"] = 0.3
bounds_r["e_params"][4]["alpha"] = 0.5
bounds_r["e_params"][4]["x0"] = 30.0

bounds_r["ln_dens_params"] = {"f0": 12.0, "f1": -5.0, "alpha": 0.7, "x0": 30.0}

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(13, 5), sharex=True, sharey=True)

for ax, p in zip(axes, [bounds_l, bounds_r]):    
    cs = ax.pcolormesh(
        data_H["vz"],
        data_H["z"],
        np.exp(init_model.ln_density(z=data_H["z"], vz=data_H["vz"], params=p)),
        norm=mpl.colors.LogNorm()
    )
    fig.colorbar(cs, ax=ax)

In [None]:
init_model.objective(bounds_l, **data_H)

In [None]:
init_model.objective(bounds_r, **data_H)

In [None]:
init_model.objective(params0, **data_H)

In [None]:
res = init_model.optimize(
    **data_H, 
    params0=params0,
    bounds=(bounds_l, bounds_r), 
    jaxopt_kwargs=dict(options=dict(maxls=1000, disp=False))
)
res.state

In [None]:
fig, axes = plot_data_models_residual(data_H, init_model, params0, res.params)

In [None]:
# vals, treedef = jax.tree_util.tree_flatten(res.params)
# vals = [float(x) for x in vals]
# pars = jax.tree_util.tree_unflatten(treedef, vals)
# fig, axes = plot_data_models_residual(data_H, init_model, params0, pars)

In [None]:
init_model.get_aaf([0.]*u.kpc, [0]*u.km/u.s, res.params, 101)

In [None]:
plot_rz = np.linspace(0, 1, 301)
es = init_model.get_es(plot_rz, res.params['e_params'])
for n, ee in es.items():
    plt.plot(plot_rz, ee, marker='')

In [None]:
# vals, treedef = jax.tree_util.tree_flatten(res.params)
# vals = [float(x) for x in vals]
# pars = jax.tree_util.tree_unflatten(treedef, vals)
pars = res.params

# pars['ln_Omega'] = np.log(0.078)
# pars['e_params'][2]['alpha'] = 0.5

sqrtOm = np.sqrt(np.exp(pars["ln_Omega"]))
plot_rzp = np.linspace(0, 3, 101) * sqrtOm
es = init_model.get_es(plot_rzp, pars["e_params"])

dens = np.exp(init_model.get_ln_dens(plot_rzp, pars))
tmp_aaf = init_model.get_aaf(
    plot_rzp / sqrtOm * u.kpc, np.zeros_like(plot_rzp) * u.km / u.s, pars, 101
)
sqrtJz = np.sqrt(tmp_aaf["J_z"].value)

fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True, sharex=True)

ax = axes[0]
ax.plot(sqrtJz, tmp_aaf["Omega_z"].value, zorder=100)
ax.plot(
    np.sqrt(agama_aaf["J_z"].value),
    agama_aaf["Omega_z"].value,
    ls="none",
    marker="o",
    mew=0,
    alpha=0.2,
    ms=1.0,
    zorder=1,
)
ax.set_ylabel(r"$\Omega_z$ " + f"[{tmp_aaf['Omega_z'].unit:latex_inline}]")
ax.axhline(np.exp(pars["ln_Omega"]), color="tab:green", ls="--", alpha=0.4)

for n in es:
    axes[1].plot(sqrtJz, es[n], label=f"$e_{n}$")
axes[2].plot(sqrtJz, dens)
axes[2].set_yscale("log")

axes[1].legend()

for ax in axes:
    ax.set_xlabel(r"$\sqrt{J_z}'$")

In [None]:
# model_H = np.exp(init_model.ln_density(z=data_H["z"], vz=data_H["vz"], params=res.params))

In [None]:
# H, xe = np.histogram(np.sqrt(agama_aaf['J_z'].value), bins=np.linspace(0, 0.5, 101))
# xc = 0.5 * (xe[:-1] + xe[1:])

# huh = np.diff(data_H['vz'][0])[0] * np.diff(data_H['z'][:, 0])[0]

# plt.plot(xc, H / np.diff(xe) * huh / np.sqrt(xc))
# plt.yscale('log')

In [None]:
# thp = np.linspace(0, 2*np.pi, 256)
# for rzp in np.linspace(0, 3.0, 16):
#     plt.plot(thp, model.get_rz(rzp, thp), marker='')

### Compute AAF

In [None]:
model_aaf = init_model.get_aaf(
    # tbl["z"].astype(np.float64)[:1_000_000],
    # tbl["vz"].astype(np.float64)[:1_000_000],
    tbl["z"].astype(np.float64)[:100_000],
    tbl["vz"].astype(np.float64)[:100_000],
    res.params,
    101,
)
model_aaf[:3]

In [None]:
agamas = [
    agama_aaf["J_z"][:len(model_aaf)], 
    agama_aaf["Omega_z"][:len(model_aaf)], 
    np.cos(agama_aaf["theta_z"][:len(model_aaf)])
]
models = [
    model_aaf["J_z"].value,
    model_aaf["Omega_z"].value,
    np.cos(model_aaf["theta_z"]),
]
labels = ["$J_z$", r"$\Omega_z$", r"$\cos\theta_z$"]

fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True)
lims = [(0, 0.15), (0, 0.1), (-1, 1)]
for ax, lim, x1, x2, label in zip(axes, lims, agamas, models, labels):
    ax.hist2d(
        x1,
        x2,
        bins=np.linspace(*lim, 128),
        cmap="Greys",
        norm=mpl.colors.LogNorm(vmin=1e0),
    )

    xx = np.linspace(*lim, 10)
    ax.plot(xx, xx, marker="", color="tab:green", ls="--", alpha=0.3)
    ax.set_xlim(*lim)
    ax.set_ylim(*lim)

    ax.set_xlabel(f"Agama {label}")
    ax.set_ylabel(f"empaf {label}")

---

Compare a 10% different Agama potential to truth:

In [None]:
agama_pot2 = agama.Potential(
    dict(type='miyamotonagai', mass=6.8e10 * 1.1, scaleradius=3.0, scaleheight=0.28 * 0.9),
    dict(type='dehnen', mass=5.00e9, scaleradius=1.0),
    dict(type='dehnen', mass=1.71e9, scaleradius=0.07),
    dict(type='nfw',    mass=5.4e11, scaleradius=15.62),
)

act_finder2 = agama.ActionFinder(agama_pot2)
agama_act2, agama_ang2, agama_freq2 = act_finder2(xv, angles=True)
agama_aaf2 = at.Table({
    "J_z": agama_act2[:, 1],
    "theta_z": agama_ang2[:, 1],
    "Omega_z": agama_freq2[:, 1],
    "T_z": 2*np.pi / agama_freq2[:, 1]
})

In [None]:
agamas = [agama_aaf['J_z'], agama_aaf['Omega_z'], np.cos(agama_aaf['theta_z'])]
models = [agama_aaf2['J_z'], agama_aaf2['Omega_z'], np.cos(agama_aaf2['theta_z'])]
labels = ["$J_z$", r"$\Omega_z$", r"$\cos\theta_z$"]

fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True)
lims = [(0, 0.15), (0, 0.1), (-1, 1)]
for ax, lim, x1, x2, label in zip(axes, lims, agamas, models, labels):
    ax.hist2d(x1, x2, bins=np.linspace(*lim, 128), 
              cmap='Greys', norm=mpl.colors.LogNorm(vmin=1e1))
    
    xx = np.linspace(*lim, 10)
    ax.plot(xx, xx, marker='', color='tab:green', ls='--', alpha=0.3)
    ax.set_xlim(*lim)
    ax.set_ylim(*lim)
    
    ax.set_xlabel(f"agama {label}")
    ax.set_ylabel(f"agama2 {label}")