TODO: 
* Do all abundance ratios (or a curated subset)

In [None]:
import pickle
import os

import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.integrate as gi
import gala.potential as gp
from gala.units import galactic

from thriftshop.potentials import potentials
from thriftshop.config import vcirc, rsun

In [None]:
os.makedirs('../plots', exist_ok=True)

In [None]:
parent = at.Table.read('../data/apogee-parent-sample.fits')
parent = parent[parent['GAIA_PARALLAX'] > 0.5]

aafs = {}
for name in potentials:
    filename = f'../cache/aaf-{name}.fits'
    aafs[name] = at.join(at.QTable.read(filename), 
                         parent, 
                         keys='APOGEE_ID')

In [None]:
c = coord.SkyCoord(ra=parent['RA']*u.deg,
                   dec=parent['DEC']*u.deg,
                   distance=1000 / parent['GAIA_PARALLAX'] * u.pc,
                   pm_ra_cosdec=parent['GAIA_PMRA']*u.mas/u.yr,
                   pm_dec=parent['GAIA_PMDEC']*u.mas/u.yr,
                   radial_velocity=parent['VHELIO_AVG']*u.km/u.s)
galcen = c.transform_to(coord.Galactocentric)

z = galcen.z.to_value(u.kpc)
vz = galcen.v_z.to_value(u.km/u.s)

In [None]:
with open('../cache/w0s.pkl', 'rb') as f:
    w0s = pickle.load(f)
    
with open('../cache/w0s-actions.pkl', 'rb') as f:
    w0s_actions = pickle.load(f)
    
orbits = {}
for k, w0 in w0s.items():
    orbits[k] = potentials[k].integrate_orbit(
        w0, dt=0.5*u.Myr, t1=0, t2=6*u.Gyr
    )

### Plot of 3 actions of all stars in each potential, over-plotred with the values for these orbits:

In [None]:
for n in range(w0s['fiducial'].shape[0]):
    for name in potentials.keys():
        print(n, name)
        act = aafs[name]['actions'].to(1*u.kpc * 30*u.km/u.s)
        w0_act = w0s_actions[name][n].to(act.unit)

        fig, axes = plt.subplots(1, 3, figsize=(15, 5), 
                                 constrained_layout=True)
        not_in = [2, 1, 0]
        lims = [(0, 7.5), (-90, -30), (0, 3)]
        labels = ['$J_R$', r'$J_\phi$', '$J_z$']
        for k, (i, j) in enumerate(zip([0,0,1], [1,2,2])):
            ax = axes[k]

            ax.scatter(w0_act[i], w0_act[j], color='tab:red', zorder=100)

            mask = np.abs((act[:, not_in[k]] - w0_act[not_in[k]]) / w0_act[not_in[k]]) < 0.2
            ax.plot(act[mask, i], act[mask, j], 
                    marker='o', ls='none', ms=1.5, mew=0, alpha=0.4)
            
            ax.set_xlim(lims[i])
            ax.set_ylim(lims[j])
            ax.set_xlabel(labels[i])
            ax.set_ylabel(labels[j])
        
        fig.suptitle(f"potential: {name},    orbit: {n}", fontsize=20)

In [None]:
def get_action_box(tbl, orbit_actions):
    actions = tbl['actions']
    
    masks = []
    for n in range(len(orbit_actions)):
        # JR_mask = np.abs(actions[:, 0] / orbit_actions[n][0] - 1).decompose() < 0.4
        JR_mask = actions[:, 0] < 2 * orbit_actions[n][0]
        Jp_mask = np.abs(actions[:, 1] / orbit_actions[n][1] - 1).decompose() < 0.2
        Jz_mask = np.abs(actions[:, 2] / orbit_actions[n][2] - 1).decompose() < 0.2
        masks.append(JR_mask & Jp_mask & Jz_mask)
        
    return masks

In [None]:
sorted_keys = sorted(
    orbits.keys(), 
    key=lambda k: potentials[k]['disk'].parameters['m'])

In [None]:
zlim = 1.75 # kpc
vlim = 75. # pc/Myr

fig, axes = plt.subplots(1, 3, figsize=(15, 5.5), 
                         sharex=True, sharey=True)

for i, name in enumerate(sorted_keys):
    ax = axes[i]
    masks = get_action_box(aafs[name], w0s_actions[name])
    print(masks[0].sum())
    for n, mask in enumerate(masks):
        l, = ax.plot(vz[mask], z[mask], 
                     marker='o', mew=0, ls='none', ms=3, alpha=0.5)
        ax.plot(orbits[name][:, n].v_z.to_value(u.km/u.s),
                orbits[name][:, n].z.to_value(u.kpc), marker='',
                color='#aaaaaa', alpha=0.2, zorder=-100)
    
    ax.set_xlabel(f'$v_z$ [{u.km/u.s:latex_inline}]')
axes[0].set_ylabel(f'$z$ [{u.kpc:latex_inline}]')

ax.set_xlim(-vlim, vlim)
ax.set_ylim(-zlim, zlim)

fig.tight_layout()

In [None]:
from scipy.stats import binned_statistic

In [None]:
angz_bins = np.arange(0, 360+1e-3, 30)

fig, axes = plt.subplots(len(aafs), 2, figsize=(12, 4*len(aafs)), 
                         sharex=True, sharey=True)

for i, name in enumerate(aafs):
    t = aafs[name]
    masks = get_action_box(aafs[name], w0s_actions[name])
    
    angz = coord.Angle(t['angles'][:, 2]).wrap_at(360*u.deg).degree
    
    for n in range(2):
        axes[i, n].plot(angz[masks[n]], 
                        t['MG_FE'][masks[n]], 
                        marker='o', mew=0, ls='none', 
                        ms=4, alpha=0.4)
        
        angz_bins = np.arange(0, 360+1e-3, 30)
        stat = binned_statistic(angz[masks[n]], t['MG_FE'][masks[n]], 
                                bins=angz_bins, statistic='mean')
        ctr = 0.5 * (stat.bin_edges[:-1] + stat.bin_edges[1:])
        axes[i, n].plot(ctr, stat.statistic, drawstyle='steps-mid', marker='')
    
        axes[i, n].set_title(name)
    
    axes[i, 0].set_ylabel("[Mg/Fe]")
    
axes[0, 0].set_ylim(-0.1, 0.15)
axes[-1, 0].set_xlabel(r'$\theta_z$')
axes[-1, 1].set_xlabel(r'$\theta_z$')
fig.set_facecolor('w')

In [None]:
from scipy.spatial import cKDTree
from scipy.stats import binned_statistic

In [None]:
action_unit = 30*u.km/u.s * 1*u.kpc
angz_bins = np.arange(0, 2*np.pi+1e-4, np.radians(5))

angzs = {}
d_elems = {}
stats = {}
for name in aafs.keys():
    X = aafs[name]['actions'].to(action_unit)
    finite_mask = np.all(np.isfinite(X), axis=1)
    X = X[finite_mask]
    safe_aaf = aafs[name][finite_mask]
    print(X.shape)

    tree = cKDTree(X)

    dists, idx = tree.query(X, k=32+1)
    d_elems[name] = safe_aaf['MG_FE'] - np.mean(aafs[name]['MG_FE'][idx[:, 1:]], axis=1)

    # ---
    fig, ax = plt.subplots(1, 1, figsize=(6, 5))
    angzs[name] = coord.Angle(safe_aaf['angles'][:, 2]).wrap_at(360*u.deg).radian
    ax.plot(angzs[name], d_elems[name], 
            marker='o', ls='none', mew=0, ms=2, alpha=0.3)
    
    stats[name] = binned_statistic(angzs[name], d_elems[name], 
                                   bins=angz_bins)
    ctr = 0.5 * (stats[name].bin_edges[:-1] + stats[name].bin_edges[1:])
    ax.plot(ctr, stats[name].statistic, 
            marker='', drawstyle='steps-mid', 
            zorder=10, color='tab:red', alpha=1)
    
    ax.set_xlim(0, 2*np.pi)
    ax.set_ylim(-0.025, 0.025)
    ax.set_title(name)

In [None]:
def get_M(x, N=2):
    M = np.full((len(x), 1 + 2*N), np.nan)
    M[:, 0] = 1.

    for n in range(N):
        M[:, 1 + 2*n] = np.cos((n+1) * x)
        M[:, 2 + 2*n] = np.sin((n+1) * x)

    return M

In [None]:
N_trials = 32

fig, axes = plt.subplots(3, 1, figsize=(8, 12), 
                         sharex=True, sharey=True)
plot_x = np.linspace(0, 2*np.pi, 1024)
all_coeffs = {}
for i, name in enumerate(sorted_keys):
    ax = axes[i]
    
    stat = stats[name]
    angz = angzs[name]
    d_elem = np.array(d_elems[name])
    
    bin_x = 0.5 * (stat.bin_edges[:-1] + stat.bin_edges[1:])
    bin_y = binned_statistic(angz, d_elem, bins=angz_bins).statistic
    
    all_coeffs[name] = []
    all_tmps = []
    np.random.seed(42)
    for trial in range(N_trials):
        idx = np.random.choice(len(angz), size=len(angz))
        
        x = angz[idx]
        y = d_elem[idx]
    
        coeffs, *_ = np.linalg.lstsq(get_M(x), y, rcond=None)
        tmp = np.zeros_like(coeffs)
        tmp[[0, 3]] = coeffs[[0, 3]]
        all_tmps.append(tmp)
        all_coeffs[name].append(coeffs)
    
    ax.plot(x, y, marker='o', mew=0, ls='none', alpha=0.4, ms=2.)
    ax.plot(bin_x, bin_y, marker='', drawstyle='steps-mid', color='tab:red')
    
    for coeffs in all_tmps:
        plot_y = get_M(plot_x) @ coeffs
        ax.plot(plot_x, plot_y, alpha=0.4, color='tab:blue', marker='')
    ax.set_title(name)

ax.set_xlim(0, 2*np.pi)
ax.set_ylim(-0.01, 0.01)

ax.set_xlabel(r'vertical conjugate angle, $\theta_z$ [rad]')
axes[1].set_ylabel('action-local\n[Mg/Fe] anomaly')

fig.tight_layout()
fig.set_facecolor('w')
fig.savefig('../plots/anomaly-panels.png', dpi=256)

In [None]:
name_map = {'fiducial': 1.}

xs = []
ys = []
for name in all_coeffs.keys():
    y = np.array(all_coeffs[name])[:, 3]
    
    if name in name_map:
        f = name_map[name]
    else:
        f = float(name)
    x = np.ones_like(y) * f
    xs.append(x)
    ys.append(y)
xs = np.array(xs)
ys = np.array(ys)

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
for k in range(xs.shape[1]):
    fuck = np.argsort(xs[:, k])
    ax.plot(xs[fuck, k], ys[fuck, k])
    
ax.axhline(0, zorder=-10, color='#aaaaaa', alpha=0.3)
# ax.axvline(1)
ax.set_xlim(0.3, 1.7)
ax.set_ylim(-3e-3, 3e-3)

ax.set_xlabel(r'factor times disk mass (at constant $v_{\rm circ}$)',
              fontsize=18)
ax.set_ylabel(r'amplitude of projection onto $\cos(2\theta_z)$',
              fontsize=18)

fig.set_facecolor('w')
fig.savefig('../plots/cos2theta-amp.png', dpi=250)