In [None]:
import pathlib

import astropy.table as at
import astropy.units as u
import gala.potential as gp
import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline
import numpy as np
from gala.units import galactic
from scipy.stats import binned_statistic_2d

Make test data in a Harmonic oscillator:

In [None]:
Omega = 0.05 * u.rad / u.Myr

scale_vz = 50 * u.km / u.s
sz = (scale_vz / np.sqrt(Omega)).decompose(galactic)

In [None]:
N = 200_000

rng = np.random.default_rng(42)
with u.set_enabled_equivalencies(u.dimensionless_angles()):
    Jzs = (rng.exponential(scale=sz.value**2, size=N) * sz.unit**2).to(
        galactic["length"] ** 2 / galactic["time"]
    )
    thzs = rng.uniform(0, 2 * np.pi, size=N) * u.rad

with u.set_enabled_equivalencies(u.dimensionless_angles()):
    pdata = {
        "z": (np.sqrt(2 * Jzs / Omega) * np.sin(thzs)).to(galactic["length"]),
        "vz": (np.sqrt(2 * Jzs * Omega) * np.cos(thzs)).to(
            galactic["length"] / galactic["time"]
        ),
        "Jz": Jzs,
        "thetaz": thzs,
    }
    pdata["r_e"] = np.sqrt(pdata["z"] ** 2 * Omega + pdata["vz"] ** 2 / Omega)
    pdata["label"] = rng.normal(np.sqrt(0.15) * pdata["Jz"].value ** 0.5 + 0.025, 0.04)

Bin the particle data to make 2D arrays to save for tests:

In [None]:
vzlim = (-100, 100)
zlim = (-3, 3)
Nbins = 128
bins = (np.linspace(*vzlim, Nbins), np.linspace(*zlim, Nbins))

In [None]:
H_dens_testdata, *_ = np.histogram2d(
    pdata["vz"].to_value(u.km / u.s),
    pdata["z"].to_value(u.kpc),
    bins=bins,
)

H_abun_testdata, *_ = binned_statistic_2d(
    pdata["vz"].to_value(u.km / u.s),
    pdata["z"].to_value(u.kpc),
    pdata["label"],
    bins=bins,
)

In [None]:
xc = 0.5 * (bins[0][:-1] + bins[0][1:])
yc = 0.5 * (bins[1][:-1] + bins[1][1:])
xc, yc = np.meshgrid(xc, yc)

test_data = {
    "z": (np.array(yc) * u.kpc).decompose(galactic).value,
    "vz": (np.array(xc) * u.km / u.s).decompose(galactic).value,
    "H_dens": np.array(H_dens_testdata.T),
    "H_label": np.array(H_abun_testdata.T),
}
np.savez("test-data.npz", **test_data)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5), constrained_layout=True)

axes[0].pcolormesh(
    test_data["vz"],
    test_data["z"],
    test_data["H_dens"],
    cmap="Greys",
    norm=mpl.colors.LogNorm(vmin=0.5),
)
axes[1].pcolormesh(
    test_data["vz"], test_data["z"], test_data["H_label"], cmap="magma_r"
)

---

# Test prototyping

Should be moved into `test_model.py`

In [None]:
from functools import partial

import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jaxopt
import torusimaging as oti

In [None]:
test_data = np.load("test-data.npz", allow_pickle=True)

In [None]:
Omega = 0.05 * u.rad / u.Myr
max_re = 3.0 * np.sqrt(Omega.value)
max_re

## Density

In [None]:
vzlim = (-100, 100)
zlim = (-3, 3)
Nbins = 128
bins = {
    "vel": np.linspace(*vzlim, Nbins) * u.km / u.s,
    "pos": np.linspace(*zlim, Nbins) * u.kpc,
}
data = oti.data.OTIData(pdata["z"], pdata["vz"])
bdata = data.get_binned_counts(bins=bins)

In [None]:
n_dens_knots = 10

# Knot locations, spaced equally in r_z
ln_dens_knots = jnp.linspace(0, max_re, n_dens_knots)


def ln_dens_func(rz, ln_dens_vals):
    return oti.model_helpers.monotonic_quadratic_spline(
        ln_dens_knots**2, ln_dens_vals, rz**2
    )


ln_dens_bounds = {
    "ln_dens_vals": (
        np.concatenate(([-5], jnp.full(n_dens_knots - 1, -25.0))),
        np.concatenate(([25], jnp.full(n_dens_knots - 1, 0.0))),
    )
}

In [None]:
# NOTE: ln_Omega is exactly degenerate with e2 value at small rz, so we only use a few knots

n_knots = {2: 5}  # , 4: 9}
e_knots = {m: jnp.linspace(0, ln_dens_knots.max(), n_knots[m]) for m in n_knots}
e_signs = {2: 1.0, 4: -1.0}


def e_func_base(rzp, vals, m):
    return e_signs[m] * oti.model_helpers.monotonic_quadratic_spline(
        e_knots[m], jnp.concatenate((jnp.array([0.0]), vals)), rzp
    )


e_funcs = {m: partial(e_func_base, m=m) for m in e_knots}

e_params0 = {m: {"vals": np.zeros(n_knots[m] - 1)} for m in e_funcs}
e_bounds = {
    m: {"vals": (np.full(n_knots[m] - 1, 0), np.full(n_knots[m] - 1, 1))}
    for m in e_funcs
}

In [None]:
def reg_func(params):
    p = 0.0

    # L1 norm
    for k in params["e_params"]:
        p += jnp.sum(jnp.abs(params["e_params"][k]["vals"]) / 0.1)

    #     p += jnp.sum(jnp.abs(params['ln_dens_params']['ln_dens_vals']) / 1.)

    return p


model = oti.DensityOrbitModel(
    ln_dens_func=ln_dens_func,
    e_funcs=e_funcs,
    regularization_func=reg_func,
    unit_sys=galactic,
)

In [None]:
params0 = {}

params0["ln_Omega0"] = np.log(Omega.value) * 1.03
params0["pos0"] = 0.0
params0["vel0"] = 0.0

params0["e_params"] = e_params0
params0["ln_dens_params"] = {"ln_dens_vals": np.zeros(n_dens_knots)}

In [None]:
bounds = {}

_dens0 = [0.01, 2] * u.Msun / u.pc**3
bounds["ln_Omega0"] = (np.log(1e-4), np.log(1e0))
bounds["pos0"] = (-0.5, 0.5)
bounds["vel0"] = (-0.05, 0.05)

bounds["e_params"] = e_bounds
bounds["ln_dens_params"] = ln_dens_bounds

In [None]:
model.objective(params0, bdata["pos"].value, bdata["vel"].value, bdata["counts"])

In [None]:
res = model.optimize(
    params0=params0,
    bounds=bounds,
    jaxopt_kwargs={"tol": 1e-10},
    pos=bdata["pos"].value,
    vel=bdata["vel"].value,
    dens=bdata["counts"],
)
res.state.success, res.state.iter_num

In [None]:
res.params

In [None]:
fig, axes = oti.plot.plot_data_models_residual(bdata, model, params0, res.params)

## Label

In [None]:
vzlim = (-100, 100)
zlim = (-3, 3)
Nbins = 128
bins = {
    "vel": np.linspace(*vzlim, Nbins) * u.km / u.s,
    "pos": np.linspace(*zlim, Nbins) * u.kpc,
}
label_data = oti.data.OTIData(pdata["z"], pdata["vz"], label=pdata["label"])

In [None]:
label_bdata = label_data.get_binned_label(bins=bins)
# label_count_bdata = label_data.get_binned_counts(bins=bins)
label_err_bdata = label_data.get_binned_label(
    # bins=bins, statistic=lambda x: 1.5 * np.nanmedian(np.abs(x - np.nanmedian(x)))
    bins=bins,
    statistic=lambda x: np.sqrt(0.02**2 + np.var(x)) / np.sqrt(len(x)),
)

# HACK:
# label_bdata["label_err"] = np.sqrt(0.02**2 + (0.05 * label_bdata["label"])**2)
# label_bdata["label_err"] = np.sqrt(0.02**2 + label_err_bdata["label"] ** 2) / np.sqrt(label_count_bdata['counts'])
label_bdata["label_err"] = label_err_bdata["label"]

In [None]:
fig, axes = plt.subplots(
    1, 2, figsize=(10, 6), constrained_layout=True, sharex=True, sharey=True
)
cs = axes[0].pcolormesh(
    label_bdata["vel"].to_value(u.km / u.s),
    label_bdata["pos"].to_value(u.kpc),
    label_bdata["label"],
)
cb = fig.colorbar(cs, ax=axes[0], orientation="horizontal")

cs = axes[1].pcolormesh(
    label_bdata["vel"].to_value(u.km / u.s),
    label_bdata["pos"].to_value(u.kpc),
    label_bdata["label_err"],
)
cb = fig.colorbar(cs, ax=axes[1], orientation="horizontal")

In [None]:
n_label_knots = 5

# Knot locations, spaced equally in r_z
label_knots = jnp.linspace(0, max_rzp, n_label_knots)


def label_func(rz, label_vals):
    return oti.model_helpers.monotonic_quadratic_spline(label_knots, label_vals, rz)


label_func_bounds = {
    "label_vals": (
        np.concatenate(([-1.0], jnp.full(n_label_knots - 1, 0.0))),
        np.concatenate(([1.0], jnp.full(n_label_knots - 1, 5.0))),
    )
}

In [None]:
# NOTE: ln_Omega is exactly degenerate with e2 value at small rz, so we only use a few knots

n_knots = {2: 5}  # , 4: 9}
e_knots = {m: jnp.linspace(0, label_knots.max(), n_knots[m]) for m in n_knots}
e_signs = {2: 1.0, 4: -1.0}


def e_func_base(rzp, vals, m):
    return e_signs[m] * oti.model_helpers.monotonic_quadratic_spline(
        e_knots[m], jnp.concatenate((jnp.array([0.0]), vals)), rzp
    )


e_funcs = {m: partial(e_func_base, m=m) for m in e_knots}

e_params0 = {m: {"vals": np.zeros(n_knots[m] - 1)} for m in e_funcs}
e_bounds = {
    m: {"vals": (np.full(n_knots[m] - 1, 0), np.full(n_knots[m] - 1, 1))}
    for m in e_funcs
}

In [None]:
label_model = oti.LabelOrbitModel(
    label_func=label_func,
    e_funcs=e_funcs,
    # regularization_func=reg_func,
    unit_sys=galactic,
)

In [None]:
rz0, _ = label_model.z_vz_to_rz_theta_prime(
    pdata["z"].value, pdata["vz"].value, label_params0
)
# plt.plot(rz0, pdata['label'], marker='o', ls='none')

In [None]:
label_params0 = {}

label_params0["ln_Omega"] = np.log(Omega.value)  # * 1.03
label_params0["z0"] = 0.0
label_params0["vz0"] = 0.0

label_params0["e_params"] = e_params0
label_params0["label_params"] = {
    "label_vals": np.concatenate(
        ([np.nanmean(pdata["label"][rz0 < 0.05])], np.full(n_label_knots - 1, 0.3))
    )
}

In [None]:
label_bounds = {}

_dens0 = [0.01, 2] * u.Msun / u.pc**3
label_bounds["ln_Omega"] = (np.log(1e-4), np.log(1e0))
label_bounds["z0"] = (-0.5, 0.5)
label_bounds["vz0"] = (-0.05, 0.05)

label_bounds["e_params"] = e_bounds
label_bounds["label_params"] = label_func_bounds

In [None]:
label_model.objective(
    label_params0,
    z=label_bdata["pos"].decompose(galactic).value,
    vz=label_bdata["vel"].decompose(galactic).value,
    label=label_bdata["label"],
    label_err=label_bdata["label_err"],
)

In [None]:
mask = np.isfinite(label_bdata["label"])

label_res = label_model.optimize(
    params0=label_params0,
    bounds=label_bounds,
    jaxopt_kwargs={"tol": 1e-14},
    z=label_bdata["pos"].decompose(galactic).value[mask],
    vz=label_bdata["vel"].decompose(galactic).value[mask],
    label=label_bdata["label"][mask],
    label_err=label_bdata["label_err"][mask],
)
label_res.state.success, label_res.state.iter_num

In [None]:
label_res.params

In [None]:
tmp = {"z": test_data["z"], "vz": test_data["vz"], "label": label_bdata["label"].T}
fig, axes = oti.plot.plot_data_models_label_residual(
    tmp, label_model, label_params0, label_res.params
)