In [None]:
import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline
import numpy as np
import scipy.interpolate as sci

from gala.units import galactic

import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from empaf import LabelOrbitModel
from empaf.plot import plot_data_models_label_residual

In [None]:
tbl = at.QTable()
tbl["z"] = np.load("../test-data/ztest.npy") * u.kpc
tbl["vz"] = np.load("../test-data/vztest.npy") * u.kpc / u.Gyr
tbl["label"] = np.load("../test-data/mgtest.npy")

In [None]:
init_model = LabelOrbitModel(
    label_knots=jnp.linspace(0, np.sqrt(1.5), 5) ** 2,
    e_knots={
        2: jnp.linspace(0, np.sqrt(3), 11) ** 2,
        4: jnp.linspace(0, np.sqrt(3), 5) ** 2,
        # 6: jnp.array([0., 1.]),
    },
    e_signs={4: -1.0},
    units=galactic,
)

In [None]:
im_bins = {"z": np.linspace(-2.0, 2.0, 91), "vz": np.linspace(-0.075, 0.075, 91)}
data_H = init_model.get_data_im(tbl["z"], tbl["vz"], tbl["label"], im_bins)

In [None]:
fig, axes = plt.subplots(
    1, 2, figsize=(13, 5), sharex=True, sharey=True, constrained_layout=True
)
cs = axes[0].pcolormesh(data_H["vz"], data_H["z"], data_H["label_stat"])
axes[0].axvline(0.0, alpha=0.1)
axes[0].axhline(0.0, alpha=0.1)
fig.colorbar(cs, ax=axes[0])

cs = axes[1].pcolormesh(data_H["vz"], data_H["z"], data_H["label_stat_err"])
fig.colorbar(cs, ax=axes[1])

In [None]:
model0 = init_model.get_params_init(
    data_H["z"] * u.kpc, data_H["vz"] * u.kpc / u.Myr, data_H["label_stat"]
)
model = model0.copy()

In [None]:
bounds_l = {
    "vz0": -0.1,
    "z0": -0.5,
    "label_vals": np.full_like(model0.state["label_vals"], -5.0),
    "ln_Omega": -5.0,
    "e_vals": {},
}

bounds_r = {
    "vz0": 0.1,
    "z0": 0.5,
    "label_vals": np.full_like(model0.state["label_vals"], 5.0),
    "ln_Omega": 0.0,
    "e_vals": {},
}

for m in model0.e_knots:
    bounds_l["e_vals"][m] = np.full_like(model0.state["e_vals"][m], 0.0)
    bounds_r["e_vals"][m] = np.full_like(model0.state["e_vals"][m], 0.3)

In [None]:
clean_mask = np.isfinite(data_H["label_stat"]) & (data_H["label_stat_err"] > 0)
clean_data = {k: v[clean_mask] for k, v in data_H.items()}

In [None]:
res = model.optimize(
    **clean_data,
    bounds=(bounds_l, bounds_r),
    jaxopt_kwargs=dict(options=dict(maxls=1000, disp=False))
)
res.state

In [None]:
res.params

In [None]:
plot_data_models_label_residual(data_H, model0, model);

In [None]:
model.label(0.0, 0.0)

In [None]:
plot_rz = np.linspace(0, 1, 51)
es = model.get_es(plot_rz)
model_label = model.get_label(plot_rz)
tmp_aaf = model.get_aaf(plot_rz * u.kpc, np.zeros_like(plot_rz) * u.km / u.s, 101)
sqrtJz = np.sqrt(tmp_aaf["J_z"].value)

fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True, sharex=True)

ax = axes[0]
ax.plot(sqrtJz, tmp_aaf["Omega_z"].value, zorder=100)
ax.set_ylabel(r"$\Omega_z$ " + f"[{tmp_aaf['Omega_z'].unit:latex_inline}]")
ax.axhline(model.state["nu"], color="tab:green", ls="--", alpha=0.4)

for n in es:
    axes[1].plot(sqrtJz, es[n], label=f"$e_{n}$")
axes[2].plot(sqrtJz, model_label)

axes[1].legend()

for ax in axes:
    ax.set_xlabel(r"$\sqrt{J_z}'$")

In [None]:
from astropy.constants import G

nu = model.state["nu"] / model.units["time"]
(nu**2 / G).to(u.Msun / u.pc**3) / (4 * np.pi)

### Check that constraint is met, drz/drzp > 0??

In [None]:
thp = np.linspace(0, 2 * np.pi, 256)
for rzp in np.linspace(0, 2.0, 16):
    plt.plot(thp, model.get_rz(rzp, thp), marker="")

In [None]:
drz_drzp = jax.vmap(jax.grad(model.get_rz, argnums=0))

In [None]:
th = np.linspace(0, np.pi / 2, 128)

for rzp in np.linspace(0, 1.0, 16):
    plt.plot(th, drz_drzp(np.full_like(th, rzp), th), color="k", marker="")

for rzp in np.linspace(1, 2.0, 8):
    plt.plot(th, drz_drzp(np.full_like(th, rzp), th), color="#aaaaaa", marker="")

plt.axhline(0)
plt.ylim(-2, None)

plt.xlabel("$r_z'$")
plt.ylabel(r"$\frac{\mathrm{d}r_z}{\mathrm{d}r_z'}$")