In [None]:
import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from scipy.spatial import cKDTree
from scipy.stats import binned_statistic
from scipy.interpolate import interp1d

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.integrate as gi
import gala.potential as gp
from gala.units import galactic

from totoro.config import galcen_frame
from totoro.data import load_apogee_sample
from totoro.potentials import potentials, galpy_potentials

In [None]:
t, c = load_apogee_sample('../data/apogee-parent-sample.fits')
t = t[np.argsort(t['APOGEE_ID'])]

In [None]:
galcen = c.transform_to(galcen_frame)
w0s = gd.PhaseSpacePosition(galcen.data)

In [None]:
from galpy.actionAngle import estimateDeltaStaeckel, actionAngleStaeckel
from totoro.config import rsun as ro, vcirc as vo
from totoro.galpy_helpers import gala_to_galpy_orbit

### Compute Staeckel delta on a grid:

In [None]:
Rz_grids = (np.arange(8-2.5, 8+2.5 + 1e-3, 0.05),
            np.arange(-2.5, 2.5 + 1e-3, 0.05))
Rz_grid = np.stack(list(map(np.ravel, np.meshgrid(*Rz_grids)))).T

In [None]:
for pot_name in ['0.4', '1.0', '1.6']:
    pot = galpy_potentials[pot_name]

    delta_staeckels = []
    for i in range(Rz_grid.shape[0]):
        R = (Rz_grid[i, 0] * u.kpc).to_value(ro)
        z = (Rz_grid[i, 1] * u.kpc).to_value(ro)
        delta_staeckels.append(estimateDeltaStaeckel(
            pot, R, z))

    plt.figure()
    plt.title(pot_name)
    plt.scatter(Rz_grid[:, 0], Rz_grid[:, 1], 
                c=delta_staeckels,
                vmin=2, vmax=6, s=8, marker='s')

In [None]:
from scipy.interpolate import NearestNDInterpolator

In [None]:
pot = galpy_potentials['1.0']

delta_staeckels = []
for i in range(Rz_grid.shape[0]):
    R = (Rz_grid[i, 0] * u.kpc).to_value(ro)
    z = (Rz_grid[i, 1] * u.kpc).to_value(ro)
    delta_staeckels.append(estimateDeltaStaeckel(
        pot, R, z))
    
delta_interp = NearestNDInterpolator(Rz_grid, 
                                     delta_staeckels)

In [None]:
# def fast_actions():

deltas = delta_interp(w0s.cylindrical.rho.to_value(u.kpc),
                      w0s.z.to_value(u.kpc))

o = gala_to_galpy_orbit(w0s)
aAS = actionAngleStaeckel(pot=pot, delta=deltas)
actions = np.squeeze(aAS(o)).T * ro * vo

### Compare to Sanders & Binney actions

In [None]:
sb_aaf = at.Table.read('../cache_new_zsun/aaf-1.0.fits')
sb_aaf = sb_aaf[np.isin(sb_aaf['APOGEE_ID'], t['APOGEE_ID'])]

assert len(sb_aaf) == len(t)
sb_aaf = sb_aaf[np.argsort(sb_aaf['APOGEE_ID'])] 
assert np.all(t['APOGEE_ID'] == sb_aaf['APOGEE_ID'])

In [None]:
sb_actions = sb_aaf['actions']

In [None]:
actions.shape, sb_actions.shape

In [None]:
from scipy.stats import binned_statistic
from astropy.stats import median_absolute_deviation

In [None]:
k = 0
for k in [0, 2]:
    sb_J = sb_actions[:, k]
    J = actions[:, k]
    mask = np.isfinite(sb_J) & np.isfinite(J)
    sb_J = sb_J[mask]
    J = J[mask]
    stat = binned_statistic(np.log10(sb_J), 
                            (J - sb_J) / sb_J,
                            statistic=lambda x: 1.5 * median_absolute_deviation(x),
                            bins=np.arange(-1, 3, 0.1))
    bincen = 0.5 * (10 ** stat.bin_edges[:-1] + 10 ** stat.bin_edges[1:])
    
    fig = plt.figure()
    plt.plot(sb_J, (J - sb_J) / sb_J,
             alpha=0.1, ls='none', ms=2, mew=0)

    plt.plot(bincen, stat.statistic)

    plt.xscale('log')
    plt.xlim(0.1, 2000)
    plt.ylim(-1, 1)
    
    fig.set_facecolor('w')

---