In [None]:
import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.integrate as gi
import gala.potential as gp
from gala.units import galactic
from cmastro import cmaps

from totoro.actions import get_staeckel_aaf
from tqdm.notebook import tqdm, trange

In [None]:
cg2020 = at.QTable.read('/Users/apricewhelan/data/GaiaDR2/Cantat-Gaudin2020.fit')
len(cg2020)

In [None]:
cg2020

In [None]:
c = coord.SkyCoord(
    ra=tbl['RA_ICRS'], 
    dec=tbl['DE_ICRS'],
    distance=tbl['plx'].to(u.pc, u.parallax()),
    pm_ra_cosdec=tbl['pmRA_'],
    pm_dec=tbl['pmDE_'],
    # DAMMIT! Don't have RV's...
    radial_velocity=)

In [None]:
galcen = c.transform_to(coord.Galactocentric())
galcen_mask = np.isfinite(galcen.x) & np.isfinite(galcen.v_x)

In [None]:
w0 = gd.PhaseSpacePosition(galcen.data[galcen_mask])
w0.shape

In [None]:
# Parameters from Price-Whelan et al. 2021
mw = gp.MilkyWayPotential(disk=dict(m=6.98e10*u.Msun),
                          halo=dict(m=4.82e+11*u.Msun))
galpy_mw = gp.gala_to_galpy_potential(mw)

In [None]:
orbits = mw.integrate_orbit(w0, dt=-0.5, t1=0, t2=-6*u.Gyr)

In [None]:
zmax = orbits.zmax().to(u.kpc)
zmax

In [None]:
fig = orbits.plot(alpha=0.4, marker='', lw=0.5)
for ax in fig.axes:
    ax.set_xlim(-30, 30)
    ax.set_ylim(-30, 30)
fig.set_facecolor('w')

In [None]:
fig, axes = plt.subplots(figsize=(8, 8))
fig = orbits.cylindrical.plot(['rho', 'z'], alpha=0.4, marker='', lw=0.5,
                              axes=[axes])
for ax in fig.axes:
    ax.set_xlim(0, 30)
    ax.set_ylim(-15, 15)
fig.set_facecolor('w')

In [None]:
tbl[orbits.zmax().argmax()]

In [None]:
tbl[np.max(orbits.cylindrical.rho, axis=0).argmax()]

In [None]:
P = np.abs(orbits.estimate_period())
ecc = orbits.eccentricity(approximate=True)

In [None]:
aafs = {
    'actions': [],
    'angles': [],
    'freqs': []
}
for n in trange(w0.shape[0]):
    orbit_dop = mw.integrate_orbit(w0[n], dt=1., t1=0, t2=10 * P[n],
                                   Integrator=gi.DOPRI853Integrator)
    
    Delta = np.median(gd.get_staeckel_fudge_delta(mw, orbit_dop))
    aaf = at.QTable(get_staeckel_aaf(mw, orbit_dop, delta=Delta))
    
    for k in aaf.colnames:
        aafs[k].append(np.mean(aaf[k], axis=0))
        
for k in aafs:
    aafs[k] = u.Quantity(aafs[k])

In [None]:
Jphi_unit = -229*u.km/u.s * 8.1*u.kpc
JR_unit = 25 * u.km/u.s * 1*u.kpc
Jz_unit = 15 * u.km/u.s * 0.5*u.kpc
J_units = [JR_unit, Jphi_unit, Jz_unit]
J_names = [r'J_R', r'J_\phi', r'J_z']

Rg = np.abs(aafs['actions'][:, 1] / mw.circular_velocity(w0)).to(u.kpc)