In [None]:
import pathlib
import pickle

import astropy.table as at
import astropy.coordinates as coord
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from scipy.stats import binned_statistic_2d

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.potential as gp
import gala.integrate as gi
from gala.units import galactic

import agama
agama.setUnits(mass=u.Msun, length=u.kpc, time=u.Myr)

### Use Agama to sample positions and velocities:

In [None]:
gala_pot = gp.MilkyWayPotential(disk={"m": 6.565e10, "a": 3.0, "b": 0.25})

agama_pot = agama.Potential(
    dict(
        type="miyamotonagai",
        mass=gala_pot["disk"].parameters["m"].value,
        scaleradius=gala_pot["disk"].parameters["a"].value,
        scaleheight=gala_pot["disk"].parameters["b"].value,
    ),
    dict(
        type="dehnen",
        mass=gala_pot["nucleus"].parameters["m"].value,
        scaleradius=gala_pot["nucleus"].parameters["c"].value,
    ),
    dict(
        type="dehnen",
        mass=gala_pot["bulge"].parameters["m"].value,
        scaleradius=gala_pot["bulge"].parameters["c"].value,
    ),
    dict(
        type="nfw",
        mass=gala_pot["halo"].parameters["m"].value,
        scaleradius=gala_pot["halo"].parameters["r_s"].value,
    ),
)

In [None]:
test_xyz = np.array([-8.3, 0, 0.208])
gala_pot.acceleration(test_xyz)[:, 0].value - agama_pot.force(*test_xyz)

In [None]:
vcirc = gala_pot.circular_velocity(test_xyz)[0]
print(vcirc)
Rsun = 8.3 * u.kpc
Jphi0 = (vcirc * Rsun).decompose(galactic).value

## Sample from DF:

In [None]:
dJphi = Jphi0 * 0.05  # 5% solar
dJr   = 0.05 * 1. * 0.05  # 5% solar
dJz   = 0.04 * 0.5

N     = 20_000_000
def df(J):
    Jr, Jz, Jphi = J.T
    return np.exp(-0.5*Jr**2/dJr**2 - 0.5*((Jphi-Jphi0)/dJphi)**2 - np.abs(Jz)/dJz)

gm = agama.GalaxyModel(agama_pot, df)
xv = gm.sample(N)[0]

In [None]:
bins = (np.linspace(-0.1, 0.1, 151), np.linspace(-2.5, 2.5, 151))
plt.hist2d(
    xv[:, 5],
    xv[:, 2],
    bins=bins,
    norm=mpl.colors.LogNorm(),
)
plt.xlim(bins[0].min(), bins[0].max())
plt.ylim(bins[1].min(), bins[1].max())
plt.xlabel("$v_z$")
plt.ylabel("$z$")

In [None]:
act_finder = agama.ActionFinder(agama_pot)
agama_act = act_finder(xv)  # JR, Jz, Jphi

In [None]:
mask = (
    (agama_act[:, 0] < 1e-3)  # 2%
    & (np.abs(agama_act[:, 2] - Jphi0) < 0.02*Jphi0)  # 2%
)
sub_xv = xv[mask]
mask.sum()

In [None]:
Norbits = 12
orbits = []

Nt = 1024
zeros = np.zeros(Nt)

Jzs = np.linspace(1.5e-2, np.sqrt(0.12), Norbits) ** 2 # u.kpc**2 / u.Myr
Omzs = []
for Jz in Jzs:
    act = u.Quantity([0, Jz, Jphi0])
    torus_mapper = agama.ActionMapper(agama_pot, act)

    t_grid = np.linspace(0, 2 * np.pi / torus_mapper.Omegaz, Nt)
    thz = torus_mapper.Omegaz * t_grid
    Omzs.append(torus_mapper.Omegaz)
    angles = np.stack((zeros, thz, zeros)).T
    z, vz = torus_mapper(angles)[:, [2, 5]].T

    # vz = (vz * u.kpc / u.Myr).to_value(u.km / u.s)
    
    orbits.append((z, vz))

## Make binned representation:

In [None]:
bins = (np.linspace(-0.1, 0.1, 151), np.linspace(-2.5, 2.5, 151))

In [None]:
H, xe, ye = np.histogram2d(
    sub_xv[:, 5],
    sub_xv[:, 2],
    bins=bins,
)

In [None]:
fig, axes = plt.subplots(
    1, 2, figsize=(10, 5), sharex=True, sharey=True, constrained_layout=True
)

axes[0].set(
    xlim=(bins[0].min(), bins[0].max()),
    ylim=(bins[1].min(), bins[1].max()),
    xlabel="$v_z$",
    ylabel="$z$"
)

for ax in axes:
    ax.pcolormesh(
        xe, ye, H.T,
        norm=mpl.colors.LogNorm(),
    )
    
for (z, vz), Jz in zip(orbits, Jzs):
    axes[1].plot(
        vz, z, marker="", ls="-", lw=1, color='c'
    )

## Fake element abundances:

In [None]:
# Mg/Fe
rng = np.random.default_rng(seed=42)

# Trend and dispersion eyeballed from APOGEE:
mgfe = rng.normal(
    np.sqrt(0.15) * agama_act[mask, 1]**0.5,
    0.04
)

In [None]:
mgfe_stat = binned_statistic_2d(sub_xv[:, 5], sub_xv[:, 2], mgfe, bins=bins)

fig, ax = plt.subplots(1, 1, figsize=(6, 5), constrained_layout=True)

cs = ax.pcolormesh(
    mgfe_stat.x_edge,
    mgfe_stat.y_edge,
    mgfe_stat.statistic.T,
    cmap="cividis_r",
    vmin=0,
    vmax=0.15,
)
cb = fig.colorbar(cs)

## Save test data: 

Both particle data, and binned data

In [None]:
test_data_path = pathlib.Path('../test-data').absolute()
test_data_path.mkdir(exist_ok=True)

In [None]:
sub_act, sub_ang, sub_freq = act_finder(sub_xv, angles=True)

In [None]:
test_data = at.QTable()

In [None]:
for i, name in enumerate(["x", "y", "z"]):
    test_data[name] = sub_xv[:, i] * u.kpc
    test_data[f"v_{name}"] = sub_xv[:, i + 3] * u.kpc/u.Myr

In [None]:
for i, name in enumerate(["R", "z", "phi"]):
    test_data[f"J_{name}"] = sub_act[:, i] * u.kpc**2/u.Myr
    test_data[f"theta_{name}"] = sub_ang[:, i] * u.rad
    test_data[f"Omega_{name}"] = sub_freq[:, i] * u.rad/u.Myr

In [None]:
test_data['MG_FE'] = mgfe

In [None]:
test_data.write(test_data_path / 'agama-galaxymodel-particles.fits', overwrite=True)
gp.save(gala_pot, test_data_path / 'agama-galaxymodel-gala_pot.yml')

In [None]:
ps_density_data = {}
ps_density_data['vz_bins'] = bins[0]
ps_density_data['z_bins'] = bins[1]
ps_density_data['H'] = H.T
np.savez(test_data_path / "binned-density.npz", **ps_density_data)

In [None]:
label_data = {}
label_data['vz'] = mgfe_stat.x_edge
label_data['z'] = mgfe_stat.y_edge
label_data['label_H'] = mgfe_stat.statistic.T
np.savez(test_data_path / "binned-label.npz", **label_data)