In [None]:
import pickle
import os

import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.integrate as gi
import gala.potential as gp
from gala.units import galactic

from thriftshop.potentials import potentials
from thriftshop.config import vcirc, rsun

In [None]:
parent = at.Table.read('../data/apogee-parent-sample.fits')
parent = parent[parent['GAIA_PARALLAX'] > 0.5]

aafs = {}
for name in potentials:
    filename = f'../cache/aaf-{name}.fits'
    aafs[name] = at.join(at.QTable.read(filename), parent, keys='APOGEE_ID')

In [None]:
c = coord.SkyCoord(ra=parent['RA']*u.deg,
                   dec=parent['DEC']*u.deg,
                   distance=1000 / parent['GAIA_PARALLAX'] * u.pc,
                   pm_ra_cosdec=parent['GAIA_PMRA']*u.mas/u.yr,
                   pm_dec=parent['GAIA_PMDEC']*u.mas/u.yr,
                   radial_velocity=parent['VHELIO_AVG']*u.km/u.s)
galcen = c.transform_to(coord.Galactocentric)

z = galcen.z.to_value(u.kpc)
vz = galcen.v_z.to_value(u.km/u.s)

In [None]:
with open('../cache/w0s.pkl', 'rb') as f:
    w0s = pickle.load(f)
    
with open('../cache/w0s-actions.pkl', 'rb') as f:
    w0s_actions = pickle.load(f)
    
orbits = {}
for k, w0 in w0s.items():
    orbits[k] = potentials[k].integrate_orbit(
        w0, dt=0.5*u.Myr, t1=0, t2=6*u.Gyr
    )

### Plot of 3 actions of all stars in each potential, over-plotred with the values for these orbits:

In [None]:
for n in range(w0s['fiducial'].shape[0]):
    for name in potentials.keys():
        print(n, name)
        act = aafs[name]['actions'].to(1*u.kpc * 30*u.km/u.s)
        w0_act = w0s_actions[name][n].to(act.unit)

        fig, axes = plt.subplots(1, 3, figsize=(15, 5), 
                                 constrained_layout=True)
        not_in = [2, 1, 0]
        lims = [(0, 7.5), (-90, -30), (0, 3)]
        labels = ['$J_R$', r'$J_\phi$', '$J_z$']
        for k, (i, j) in enumerate(zip([0,0,1], [1,2,2])):
            ax = axes[k]

            ax.scatter(w0_act[i], w0_act[j], color='tab:red', zorder=100)

            mask = np.abs((act[:, not_in[k]] - w0_act[not_in[k]]) / w0_act[not_in[k]]) < 0.2
            ax.plot(act[mask, i], act[mask, j], 
                    marker='o', ls='none', ms=1.5, mew=0, alpha=0.4)
            
            ax.set_xlim(lims[i])
            ax.set_ylim(lims[j])
            ax.set_xlabel(labels[i])
            ax.set_ylabel(labels[j])
        
        fig.suptitle(f"potential: {name},    orbit: {n}", fontsize=20)

In [None]:
def get_action_box(tbl, orbit_actions):
    actions = tbl['actions']
    
    masks = []
    for n in range(len(orbit_actions)):
        JR_mask = np.abs(actions[:, 0] / orbit_actions[n][0] - 1).decompose() < 0.4
        Jp_mask = np.abs(actions[:, 1] / orbit_actions[n][1] - 1).decompose() < 0.2
        Jz_mask = np.abs(actions[:, 2] / orbit_actions[n][2] - 1).decompose() < 0.2
        masks.append(JR_mask & Jp_mask & Jz_mask)
        
    return masks

In [None]:
sorted_keys = sorted(
    orbits.keys(), 
    key=lambda k: potentials[k]['disk'].parameters['m'])

In [None]:
zlim = 1.75 # kpc
vlim = 75. # pc/Myr

fig, axes = plt.subplots(1, 3, figsize=(15, 5.5), 
                         sharex=True, sharey=True)

for i, name in enumerate(sorted_keys):
    ax = axes[i]
    masks = get_action_box(aafs[name], w0s_actions[name])
    print(masks[0].sum())
    for n, mask in enumerate(masks):
        l, = ax.plot(vz[mask], z[mask], 
                     marker='o', mew=0, ls='none', ms=3, alpha=0.5)
        ax.plot(orbits[name][:, n].v_z.to_value(u.km/u.s),
                orbits[name][:, n].z.to_value(u.kpc), marker='',
                color='#aaaaaa', alpha=0.2, zorder=-100)
    
    ax.set_xlabel(f'$v_z$ [{u.km/u.s:latex_inline}]')
axes[0].set_ylabel(f'$z$ [{u.kpc:latex_inline}]')

ax.set_xlim(-vlim, vlim)
ax.set_ylim(-zlim, zlim)

fig.tight_layout()