In [None]:
import os

import astropy.table as at
import astropy.coordinates as coord
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.potential as gp
import gala.integrate as gi
from gala.units import galactic

import jax
jax.config.update('jax_enable_x64', True)
import jax.numpy as jnp

from empaf import DensityOrbitModel
from empaf.plot import plot_data_models_residual

In [None]:
import sys
sys.path.append('/mnt/home/apricewhelan/downloads/Agama-zone/')
import agama
agama.setUnits(mass=u.Msun, length=u.kpc, time=u.Myr)

### Use Agama to sample z,vz:

In [None]:
gala_pot = gp.MilkyWayPotential()
agama_pot = agama.Potential(
    dict(type='miyamotonagai', mass=6.8e10, scaleradius=3.0, scaleheight=0.28),
    dict(type='dehnen', mass=5.00e9, scaleradius=1.0),
    dict(type='dehnen', mass=1.71e9, scaleradius=0.07),
    dict(type='nfw',    mass=5.4e11, scaleradius=15.62),
)

In [None]:
vcirc = 229 * u.km/u.s
Rsun = 8.275 * u.kpc

Jphi0 = (vcirc * Rsun).decompose(galactic).value
dJphi = 0.22 * 8. * 0.04
dJr   = 0.05 * 1. * 0.04
dJz   = 0.04 * 0.5

N     = 50_000_000
def df(J):
    Jr, Jz, Jphi = J.T
    return np.exp(-0.5*Jr**2/dJr**2 - 0.5*((Jphi-Jphi0)/dJphi)**2 - np.abs(Jz)/dJz)

if not os.path.exists('../test-data/agama-galaxymodel-df.npy'):
    gm = agama.GalaxyModel(agama_pot, df)
    xv = gm.sample(N)[0]
    np.save('../test-data/agama-galaxymodel-df.npy', xv)
else:
    xv = np.load('../test-data/agama-galaxymodel-df.npy')
    
xv = xv[:1_000_000]

In [None]:
act_finder = agama.ActionFinder(agama_pot)
agama_act, agama_ang, agama_freq = act_finder(xv, angles=True)
agama_aaf = at.Table({
    "J_z": agama_act[:, 1],
    "theta_z": agama_ang[:, 1],
    "Omega_z": agama_freq[:, 1],
    "T_z": 2*np.pi / agama_freq[:, 1]
})

In [None]:
from astropy.constants import G
nu = (0.075/u.Myr)
(nu**2 / G).to(u.Msun/u.pc**3) / (4*np.pi)

In [None]:
gala_pot.density([8., 0, 0]).to(u.Msun/u.pc**3)

In [None]:
# w = gd.PhaseSpacePosition.from_w(xv.T, units=galactic)
# w.plot();

In [None]:
# for shift in [0, 3]:
#     fig, axes = plt.subplots(
#         1, 3, figsize=(15, 5), sharex=True, sharey=True, constrained_layout=True
#     )
#     for i, ax in enumerate(axes):
#         ax.hist(xv[:, i + shift], bins=151)

In [None]:
# plt.hist2d(xv[:, 5], xv[:, 2], bins=(np.linspace(-0.08, 0.08, 101), np.linspace(-2, 2, 101)));
# plt.xlabel('$v_z$')
# plt.ylabel('$z$')

In [None]:
# tbl = at.Table.read('../scripts/zvz-random.fits')
# tbl = tbl.filled()
# tbl = tbl[np.abs(tbl['vz']) < 0.08]

# tbl['z'].unit = u.kpc
# tbl['vz'].unit = u.kpc/u.Myr
# tbl = at.QTable(tbl)

tbl = at.QTable()
tbl['z'] = xv[:, 2] * u.kpc
tbl['vz'] = xv[:, 5] * u.kpc/u.Myr

bins = (np.linspace(-0.1, 0.1, 151), np.linspace(-2.5, 2.5, 151))
plt.hist2d(
    tbl['vz'].value,
    tbl['z'].value,
    bins=bins,
    norm=mpl.colors.LogNorm(),
)
plt.xlim(bins[0].min(), bins[0].max())
plt.ylim(bins[1].min(), bins[1].max())
plt.xlabel("$v_z$")
plt.ylabel("$z$")

In [None]:
init_model = DensityOrbitModel(
    ln_dens_knots=jnp.linspace(0, np.sqrt(1.5), 21) ** 2,
    e_knots={
        2: jnp.linspace(0, np.sqrt(3), 15) ** 2,
        4: jnp.linspace(0, np.sqrt(3), 9) ** 2,
        # 6: jnp.linspace(0, np.sqrt(3), 5) ** 2,
        # 8: jnp.linspace(0, np.sqrt(3), 5) ** 2,
    },
    e_signs={2: 1.0, 4: -1.0}, # , 6: -1.0, 8: -1.0},
    unit_sys=galactic,
)

In [None]:
model0 = init_model.get_params_init(tbl['z'], tbl['vz'])
model = model0.copy()
params0 = model0.get_params()
model0.state

In [None]:
im_bins = {'z': np.linspace(-2, 2, 151)}
im_bins['vz'] = im_bins['z'] * model0.state['nu']
data_H = model0.get_data_im(tbl['z'], tbl['vz'], im_bins)

In [None]:
model0.get_aaf(
    tbl["z"].astype(np.float64)[:10],
    tbl["vz"].astype(np.float64)[:10],
    101
)

In [None]:
bounds_l = {
    "vz0": -0.1,
    "z0": -0.5,
    "ln_dens_vals": np.full_like(model0.state["ln_dens_vals"], -5.0),
    "ln_nu": -5.0,
    "e_vals": {},
}

bounds_r = {
    "vz0": 0.1,
    "z0": 0.5,
    "ln_dens_vals": np.full_like(model0.state["ln_dens_vals"], 15.0),
    "ln_nu": 0.0,
    "e_vals": {},
}

for m in model0.e_knots:
    # bounds_l["e_vals"][m] = np.full_like(model0.state["e_vals"][m], -0.4)
    bounds_l["e_vals"][m] = np.full_like(model0.state["e_vals"][m], 0.)
    bounds_r["e_vals"][m] = np.full_like(model0.state["e_vals"][m], 0.4)

In [None]:
vlim = dict(
    norm=mpl.colors.LogNorm(vmax=3e4, vmin=1e-1), shading="auto"
)  # vmin=0, vmax=30)

fig, axes = plt.subplots(
    1, 2, figsize=(11, 5), sharex=True, sharey=True, constrained_layout=True
)

cs = axes[0].pcolormesh(data_H["vz"], data_H["z"], data_H["H"], **vlim)

cs = axes[1].pcolormesh(
    data_H["vz"],
    data_H["z"],
    np.exp(model0.ln_density(z=data_H["z"], vz=data_H["vz"])),
    **vlim
)
fig.colorbar(cs, ax=axes[:2])

axes[0].set_title("data")
axes[1].set_title("initial model")

In [None]:
model.objective(params0, **data_H)

In [None]:
res = model.optimize(
    **data_H, 
    params0=params0,
    bounds=(bounds_l, bounds_r), 
    jaxopt_kwargs=dict(options=dict(maxls=1000, disp=False))
)
res.state

In [None]:
res.params

In [None]:
fig, axes = plot_data_models_residual(data_H, model0, model)

In [None]:
model.get_aaf([0.]*u.kpc, [0]*u.km/u.s, 101)

In [None]:
plot_rz = np.linspace(0, 4, 51)
es = model.get_es(plot_rz)
dens = np.exp(model.get_ln_dens(plot_rz))
tmp_aaf = model.get_aaf(plot_rz * u.kpc, np.zeros_like(plot_rz) * u.km / u.s, 101)
sqrtJz = np.sqrt(tmp_aaf["J_z"].value)

fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True, sharex=True)

ax = axes[0]
ax.plot(sqrtJz, tmp_aaf["Omega_z"].value, zorder=100)
ax.plot(
    np.sqrt(agama_aaf["J_z"].value),
    agama_aaf["Omega_z"].value,
    ls="none",
    marker="o",
    mew=0,
    alpha=0.2,
    ms=1.0,
    zorder=1,
)
ax.set_ylabel(r"$\Omega_z$ " + f"[{tmp_aaf['Omega_z'].unit:latex_inline}]")
ax.axhline(model.state['nu'], color='tab:green', ls='--', alpha=0.4)

for n in es:
    axes[1].plot(sqrtJz, es[n], label=f"$e_{n}$")
axes[2].plot(sqrtJz, dens)
axes[2].set_yscale("log")

axes[1].legend()

for ax in axes:
    ax.set_xlabel(r"$\sqrt{J_z}'$")

In [None]:
thp = np.linspace(0, 2*np.pi, 256)
for rzp in np.linspace(0, 3.0, 16):
    plt.plot(thp, model.get_rz(rzp, thp), marker='')

In [None]:
# thp = np.linspace(0, 2*np.pi, 256)
# for rz in np.linspace(0, 1.0, 16):
#     plt.plot(thp, model.get_rz_prime(rz, thp), marker='')

In [None]:
model_aaf = model.get_aaf(
    tbl["z"].astype(np.float64),
    tbl["vz"].astype(np.float64),
    101,
)
model_aaf[:3]

In [None]:
agamas = [agama_aaf["J_z"], agama_aaf["Omega_z"], np.cos(agama_aaf["theta_z"])]
models = [
    model_aaf["J_z"].value,
    model_aaf["Omega_z"].value,
    np.cos(model_aaf["theta_z"]),
]
labels = ["$J_z$", r"$\Omega_z$", r"$\cos\theta_z$"]

fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True)
lims = [(0, 0.15), (0, 0.1), (-1, 1)]
for ax, lim, x1, x2, label in zip(axes, lims, agamas, models, labels):
    ax.hist2d(
        x1,
        x2,
        bins=np.linspace(*lim, 128),
        cmap="Greys",
        norm=mpl.colors.LogNorm(vmin=1e1),
    )

    xx = np.linspace(*lim, 10)
    ax.plot(xx, xx, marker="", color="tab:green", ls="--", alpha=0.3)
    ax.set_xlim(*lim)
    ax.set_ylim(*lim)

    ax.set_xlabel(f"agama {label}")
    ax.set_ylabel(f"empaf {label}")

In [None]:
agama_pot2 = agama.Potential(
    dict(type='miyamotonagai', mass=6.8e10 * 1.1, scaleradius=3.0, scaleheight=0.28 * 0.9),
    dict(type='dehnen', mass=5.00e9, scaleradius=1.0),
    dict(type='dehnen', mass=1.71e9, scaleradius=0.07),
    dict(type='nfw',    mass=5.4e11, scaleradius=15.62),
)

act_finder2 = agama.ActionFinder(agama_pot2)
agama_act2, agama_ang2, agama_freq2 = act_finder2(xv, angles=True)
agama_aaf2 = at.Table({
    "J_z": agama_act2[:, 1],
    "theta_z": agama_ang2[:, 1],
    "Omega_z": agama_freq2[:, 1],
    "T_z": 2*np.pi / agama_freq2[:, 1]
})

In [None]:
agamas = [agama_aaf['J_z'], agama_aaf['Omega_z'], np.cos(agama_aaf['theta_z'])]
models = [agama_aaf2['J_z'], agama_aaf2['Omega_z'], np.cos(agama_aaf2['theta_z'])]
labels = ["$J_z$", r"$\Omega_z$", r"$\cos\theta_z$"]

fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True)
lims = [(0, 0.15), (0, 0.1), (-1, 1)]
for ax, lim, x1, x2, label in zip(axes, lims, agamas, models, labels):
    ax.hist2d(x1, x2, bins=np.linspace(*lim, 128), 
              cmap='Greys', norm=mpl.colors.LogNorm(vmin=1e1))
    
    xx = np.linspace(*lim, 10)
    ax.plot(xx, xx, marker='', color='tab:green', ls='--', alpha=0.3)
    ax.set_xlim(*lim)
    ax.set_ylim(*lim)
    
    ax.set_xlabel(f"agama {label}")
    ax.set_ylabel(f"agama2 {label}")