In [None]:
import pickle

import astropy.coordinates as coord
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.potential as gp
from gala.units import galactic

from totoro.config import cache_path
from totoro.potentials import potentials, galpy_potentials
from totoro.actions_o2gf import get_o2gf_aaf
from totoro.actions_staeckel import get_staeckel_aaf

from totoro.data import datasets

In [None]:
w0s_cache = cache_path / 'w0s.pkl'
with open(w0s_cache, 'rb') as f:
    w0s = pickle.load(f)
    
    
for k in w0s:
    w0s[k] = gd.combine((gd.PhaseSpacePosition(w0s[k][0].pos.xyz + 1e-5*u.pc,
                                               w0s[k][0].vel),
                         gd.PhaseSpacePosition(w0s[k][1].pos.xyz + 1e-5*u.pc,
                                               w0s[k][1].vel)))

In [None]:
k = '1.0'
orbit = potentials[k].integrate_orbit(w0s[k], dt=0.5, n_steps=6000)
np.abs(orbit.z).max(axis=0)

In [None]:
d = datasets['apogee-rgb-loalpha']
galcen_d = d.c.transform_to(coord.Galactocentric())
zmask = np.abs(galcen_d.z) > (3*280*u.pc)
zmask.sum() / len(zmask)

In [None]:
for j in range(2):
    o2gf_actions = []
    stae_actions = []
    for k, w0 in w0s.items():
        o2gf_aaf = get_o2gf_aaf(potentials[k], w0[j])
        stae_aaf = get_staeckel_aaf(galpy_potentials[k], w0[j], 
                                    gala_potential=potentials[k])
        
        o2gf_actions.append(o2gf_aaf['actions'])
        stae_actions.append(stae_aaf['actions'])
        
    o2gf_actions = u.Quantity(o2gf_actions)
    stae_actions = u.Quantity(stae_actions)
    
    print(np.abs((o2gf_actions - stae_actions) / o2gf_actions).max())

In [None]:
derp = [f'{x:.1f}' for x in np.arange(0.5, 1.5+1e-3, 0.1)]
mask = np.array([x in derp for x in w0s.keys()])

In [None]:
diff = (o2gf_actions - stae_actions) / o2gf_actions
for k in range(3):
    print(f"{100 * np.abs(diff[mask, k]).max():.2f}")