In [None]:
import pickle
import pathlib

import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from scipy.spatial import cKDTree
from scipy.stats import binned_statistic
from scipy.interpolate import interp1d

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.integrate as gi
import gala.potential as gp
from gala.units import galactic

from thriftshop.config import vcirc, rsun, plot_path, fig_path, cache_path
from thriftshop.data import load_apogee_sample
from thriftshop.potentials import potentials
from thriftshop.abundances import get_elem_names, elem_to_label

coord.galactocentric_frame_defaults.set('v4.0');

In [None]:
t, c = load_apogee_sample('../data/apogee-parent-sample.fits')

TEST:

In [None]:
from thriftshop.galpy_helpers import get_staeckel_aaf
from thriftshop.potentials import galpy_potentials
from thriftshop.config import galcen_frame
from thriftshop.actions_multiproc import action_worker, compute_actions_multiproc

from schwimmbad import MultiPool
from schwimmbad.utils import batch_tasks

In [None]:
comp

In [None]:
potential_name = '1.0'

with MultiPool(processes=4) as pool:
    aaf = compute_actions_multiproc(t,
                                    potential_name='1.0', 
                                    pool=pool)

In [None]:
scaling = np.array([
    [1000, 8],
    [10000, 19.9],
    [75000, 113]])

plt.plot(scaling[:, 0], scaling[:, 1])

In [None]:
c = coord.SkyCoord(ra=t['RA'] * u.deg,
                   dec=t['DEC'] * u.deg,
                   distance=1000 / t['GAIA_PARALLAX'] * u.pc,
                   pm_ra_cosdec=t['GAIA_PMRA'] * u.mas/u.yr,
                   pm_dec=t['GAIA_PMDEC'] * u.mas/u.yr,
                   radial_velocity=t['VHELIO_AVG'] * u.km/u.s)
galcen = c.transform_to(galcen_frame)
w0s = gd.PhaseSpacePosition(galcen.data)

In [None]:
from galpy.actionAngle import estimateDeltaStaeckel, actionAngleStaeckel
from thriftshop.config import rsun as ro, vcirc as vo
from thriftshop.galpy_helpers import gala_to_galpy_orbit

In [None]:
%load_ext line_profiler

In [None]:
%lprun -f estimateDeltaStaeckel get_staeckel_aaf(w0s[0], galpy_potentials['1.0'])

In [None]:
pot = galpy_potentials['1.0']

delta_staeckels = []
for i in range(w0s.shape[0]):
    R = w0s[i].cylindrical.rho.to_value(ro)
    z = w0s[i].z.to_value(ro)
    delta_staeckels.append(estimateDeltaStaeckel(
        pot, R, z))

In [None]:
from scipy.stats import binned_statistic_2d

stat = binned_statistic_2d(w0s.cylindrical.rho.to_value(u.kpc),
                           w0s.z.to_value(u.kpc),
                           delta_staeckels,
                           statistic='mean',
                           bins=(np.arange(8-2, 8+2, 0.05),
                                 np.arange(-1.5, 1.5, 0.05)))

plt.pcolormesh(stat.x_edge, stat.y_edge, 
               stat.statistic.T,
               vmin=2.5, vmax=6)

In [None]:
from scipy.interpolate import NearestNDInterpolator

In [None]:
xcen = 0.5 * (stat.x_edge[:-1] + stat.x_edge[1:])
ycen = 0.5 * (stat.y_edge[:-1] + stat.y_edge[1:])
xycens = np.stack(map(np.ravel, np.meshgrid(xcen, ycen))).T

In [None]:
plt.scatter(xycens[:, 0], xycens[:, 1], c=stat.statistic.T.ravel())

In [None]:
delta_interp = NearestNDInterpolator(xycens, stat.statistic.T.ravel())

In [None]:
delta_interp(w0s[0].cylindrical.rho.to_value(u.kpc),
             w0s[0].z.to_value(u.kpc))

In [None]:
derp = w0s[np.abs(w0s.z) > 1*u.kpc][0]

In [None]:
o = gala_to_galpy_orbit(derp)
aAS = actionAngleStaeckel(pot=pot, delta=np.mean(delta_staeckels))
np.squeeze(aAS(o)) * ro * vo

In [None]:
o = gala_to_galpy_orbit(derp)
aAS = actionAngleStaeckel(pot=pot, delta=2.8)
np.squeeze(aAS(o)) * ro * vo

In [None]:
potential_name = '1.0'

with MultiPool(processes=8) as pool:
    tasks = batch_tasks(n_batches=max(1, pool.size - 1),
                        arr=t[:1000],
                        args=(potential_name, ))

    all_data = []
    for data in pool.map(action_worker, tasks):
        all_data.append(data)

In [None]:
tasks = batch_tasks(n_batches=max(1, pool.size - 1),
                    n_tasks=1000)
tasks

In [None]:
(3.73*u.millisecond * 4000 / 4).to(u.second)

In [None]:
%timeit get_staeckel_aaf(w0s[0], galpy_potentials['1.0'])

---

In [None]:
aafs = {}
for name in potentials:
    filename = cache_path / f'aaf-{name}.fits'
    if not filename.exists():
        continue
    aafs[name] = at.join(at.QTable.read(filename), 
                         t, 
                         keys='APOGEE_ID')

In [None]:
galcen = c.transform_to(galcen_frame)

z = galcen.z.to_value(u.kpc)
vz = galcen.v_z.to_value(u.km/u.s)

In [None]:
with open('../cache/w0s.pkl', 'rb') as f:
    w0s = pickle.load(f)
    
with open('../cache/w0s-actions.pkl', 'rb') as f:
    w0s_actions = pickle.load(f)
    
orbits = {}
for k, w0 in w0s.items():
    orbits[k] = potentials[k].integrate_orbit(
        w0, dt=0.5*u.Myr, t1=0, t2=6*u.Gyr
    )

In [None]:
sorted_keys = sorted(
    orbits.keys(), 
    key=lambda k: potentials[k]['disk'].parameters['m'])
sorted_keys = [x for x in sorted_keys if x in aafs]

In [None]:
all_finite_mask = None
for name in sorted_keys:
    X = aafs[name]['actions']
    finite_mask = np.all(np.isfinite(X), axis=1)
    
    if all_finite_mask is None:
        all_finite_mask = finite_mask
    else:
        all_finite_mask &= finite_mask
        
print(all_finite_mask.sum())

In [None]:
def get_stat(actions, theta_z, elem, tree_K=64,
             action_unit=30*u.km/u.s * 1*u.kpc,
             angz_bins=np.arange(0, 2*np.pi+1e-4, np.radians(5))):
    
    # Actions without units:
    X = actions.to_value(action_unit)
    angz = coord.Angle(theta_z).wrap_at(360*u.deg).radian
    
    tree = cKDTree(X)
    dists, idx = tree.query(X, k=tree_K+1)
    
    d_elem = elem - np.mean(elem[idx[:, 1:]], axis=1)
    stat = binned_statistic(angz, d_elem, bins=angz_bins)
    
    return stat, angz, d_elem


def get_boostrap_stats(actions, theta_z, elem,
                       statistic='mean',
                       bootstrap_K=64,
                       tree_K=64,
                       action_unit=30*u.km/u.s * 1*u.kpc,
                       angz_bins=np.arange(0, 2*np.pi+1e-4, np.radians(5)),
                       seed=42):
    
    # Actions without units:
    X = actions.to_value(action_unit)
    angz = coord.Angle(theta_z).wrap_at(360*u.deg).radian
    
    tree = cKDTree(X)
    dists, idx = tree.query(X, k=tree_K+1)
    
    d_elem = elem - np.mean(elem[idx[:, 1:]], axis=1)
    
    if seed:
        np.random.seed(seed)
    
    stats = []
    counts = []
    for k in range(bootstrap_K + 1):
        if k > 0:
            idx = np.random.choice(len(angz), size=len(angz))
            x = angz[idx]
            y = d_elem[idx]
        else:
            x = angz
            y = d_elem
        stat = binned_statistic(x, y, bins=angz_bins, 
                                statistic=statistic)
        count, *_ = np.histogram(x, bins=angz_bins)
        
        stats.append(stat)
        counts.append(count)
    
    return stats, counts

In [None]:
angzs = {}
d_elems = {}
stats = {}
for name in sorted_keys:
    aaf = aafs[name][all_finite_mask]
    stats[name], angzs[name], d_elems[name] = get_stat(
        aaf['actions'], aaf['angles'][:, 2], aaf['MG_FE'],
        tree_K=64)

    # ---
    fig, ax = plt.subplots(1, 1, figsize=(6, 5))
    ax.plot(angzs[name], d_elems[name], 
            marker='o', ls='none', mew=0, ms=2, alpha=0.3)
    
    ctr = 0.5 * (stats[name].bin_edges[:-1] + stats[name].bin_edges[1:])
    ax.plot(ctr, stats[name].statistic, 
            marker='', drawstyle='steps-mid', 
            zorder=10, color='tab:red', alpha=1)
    
    ax.set_xlim(0, 2*np.pi)
    ax.set_ylim(-0.025, 0.025)
    ax.set_title(name)

In [None]:
mdisks = []
vars_ = []
for name in sorted_keys:
    aaf = aafs[name][all_finite_mask]
    *_, d_elems = get_stat(
        aaf['actions'], aaf['angles'][:, 2], aaf['MG_FE'],
        tree_K=64)
    
    mdisks.append(float(name))
    vars_.append(np.var(d_elems))
    
plt.plot(mdisks, vars_)

i = np.argmin(vars_)
coeffs = np.polyfit(mdisks[i-1:i+2], vars_[i-1:i+2], deg=2)
grid = np.linspace(0.7, 1.5, 1024)
shit = np.poly1d(coeffs)(grid)
plt.plot(grid, shit)
grid[shit.argmin()]

In [None]:
def get_M(x, N=2):
    M = np.full((len(x), 1 + 2*N), np.nan)
    M[:, 0] = 1.

    for n in range(N):
        M[:, 1 + 2*n] = np.cos((n+1) * x)
        M[:, 2 + 2*n] = np.sin((n+1) * x)

    return M

In [None]:
def get_stats_coeffs(aaf, elem_name, **kwargs):
    stats, counts = get_boostrap_stats(aaf['actions'], 
                                       aaf['angles'][:, 2],
                                       aaf[elem_name], 
                                       **kwargs)

    all_coeffs = []
    for stat, count in zip(stats, counts):
        x = 0.5 * (stat.bin_edges[:-1] + stat.bin_edges[1:])
        y = stat.statistic
        coeffs, *_ = np.linalg.lstsq(get_M(x), y / count, rcond=None)
        all_coeffs.append(coeffs)

    return stats, counts, np.array(all_coeffs)

In [None]:
all_coeffs = {}
for elem_name in elem_names:
# for elem_name in ['MG_FE']:
    os.makedirs(f'../plots/{elem_name}', exist_ok=True)
    all_coeffs[elem_name] = {}
    # for statistic in ['mean', 'median', 'std']:
    for statistic in ['mean']:
        plot_x = np.linspace(0, 2*np.pi, 1024)
        for i, name in enumerate(sorted_keys):
            stats, counts, coeffs = get_stats_coeffs(
                aafs[name][all_finite_mask], 
                elem_name,
                statistic=statistic)
            
            if statistic == 'mean':
                all_coeffs[elem_name][name] = coeffs

            fig, ax = plt.subplots(1, 1, figsize=(8, 5))

            # ax.plot(x, y, marker='o', mew=0, ls='none', alpha=0.4, ms=2.)
            bin_x = 0.5 * (stats[0].bin_edges[:-1] + stats[0].bin_edges[1:])
            bin_y = stats[0].statistic / counts[0]
            ax.plot(bin_x, bin_y, marker='', drawstyle='steps-mid', color='tab:red')

            for c in coeffs:
                tmp = np.zeros_like(c)
                tmp[[0,3]] = c[[0,3]]
                plot_y = get_M(plot_x) @ c
                ax.plot(plot_x, plot_y, alpha=0.4, color='tab:blue', marker='')
            ax.set_title(name)

            ax.set_xlim(0, 2*np.pi)
            
            # auto-set ylim
            init_ylim = (np.nanmin(bin_y), np.nanmax(bin_y))
            yspan = init_ylim[1] - init_ylim[0]
            ax.set_ylim(init_ylim[0] - 0.25*yspan,
                        init_ylim[1] + 0.25*yspan)
            ax.axhline(0, zorder=-100, color='#aaaaaa')

            ax.set_xlabel(r'vertical conjugate angle, $\theta_z$ [rad]')
            ax.set_ylabel(f'action-local\n{elem_to_label(elem_name)} anomaly')

            fig.tight_layout()
            fig.set_facecolor('w')
            fig.savefig(f'../plots/{elem_name}/anomaly-panels-{statistic}-{name}.png', dpi=256)
            plt.close(fig)

In [None]:
statistic = 'mean'

summary = {}
for elem_name in all_coeffs.keys():
# for elem_name in ['MG_FE']:
    fig, axes = plt.subplots(3, 1, figsize=(10, 10), sharex=True)
    
    for ax, coeff_idx in zip(axes, [1, 2, 3]):
        N_m_grid = len(all_coeffs[elem_name])
        shape = (N_m_grid, len(all_coeffs[elem_name]['1.0']))
        xs = np.full(shape, np.nan)
        ys = np.full(shape, np.nan)
        for i, potential_name in enumerate(sorted(all_coeffs[elem_name].keys())):
            ys[i] = np.array(all_coeffs[elem_name][potential_name])[:, coeff_idx]
            xs[i] = float(potential_name)
        
        ax.plot(xs[:, 0], ys[:, 0], alpha=1., lw=3, zorder=100)
        ax.plot(xs[:, 1:], ys[:, 1:], alpha=0.4)
        
        if coeff_idx == 3:  # only for cos(2ø) term
            zero_cross = np.array([interp1d(ys[:, k], xs[:, k], fill_value="extrapolate")(0.) 
                                   for k in range(ys.shape[1])])
            summary[elem_name] = [np.mean(zero_cross), np.std(zero_cross)]

        ax.axhline(0, zorder=-10, color='#aaaaaa', alpha=0.3)
        ylim = (ys.min(), ys.max())
        ylim = (ylim[0] - 0.2 * (ylim[1]-ylim[0]),
                ylim[1] + 0.2 * (ylim[1]-ylim[0]))
        ax.set_ylim(ylim)
        
    ax.set_xlim(0.4, 1.6)

    axes[-1].set_xlabel(r'factor times disk mass (at constant $v_{\rm circ}$)',
                        fontsize=18)
    
    
    axes[0].set_ylabel('amplitude of projection\n' + r'onto $\cos(\theta_z)$',
                       fontsize=18)
    axes[1].set_ylabel('amplitude of projection\n' + r'onto $\sin(\theta_z)$',
                       fontsize=18)
    axes[2].set_ylabel('amplitude of projection\n' + r'onto $\cos(2\theta_z)$',
                       fontsize=18)

    fig.set_facecolor('w')
    fig.tight_layout()
    
    fig.savefig(f'../plots/{elem_name}/cos2theta-amp-{statistic}.png', dpi=250)
    plt.close(fig)

In [None]:
names = []
idxs = []
vals = []
errs = []
for i, (name, (val, err)) in enumerate(summary.items()):
    names.append(name)
    vals.append(val)
    errs.append(err)
    idxs.append(i)

TODO: we can't combine all as independent, but can do M/H and A/M, and all /Fe

In [None]:
fig, ax = plt.subplots(figsize=(12, 6))
ax.errorbar(idxs, vals, errs, 
            marker='o', ls='none', ecolor='#aaaaaa')
ax.set_xticks(idxs)
ax.set_xticklabels([elem_to_label(x) for x in names])
ax.set_xlim(min(idxs)-0.5, max(idxs)+0.5)
ax.set_ylim(0.4, 2.)
ax.axhline(1., zorder=-100, color='tab:green', alpha=0.5)
ax.set_ylabel(r'$M_{\rm disk}$ relative to fiducial')
fig.set_facecolor('w')
fig.tight_layout()
fig.savefig(plot_path / 'mdisk-vs-elems.png', dpi=250)

# Verification / test plots

## Plot of 3 actions of all stars in each potential, over-plotred with the values for these orbits:

In [None]:
for n in range(w0s['fiducial'].shape[0]):
    for name in ['0.4', 'fiducial', '1.6']:
        print(n, name)
        act = aafs[name]['actions'].to(1*u.kpc * 30*u.km/u.s)
        w0_act = w0s_actions[name][n].to(act.unit)

        fig, axes = plt.subplots(1, 3, figsize=(15, 5), 
                                 constrained_layout=True)
        not_in = [2, 1, 0]
        lims = [(0, 7.5), (-90, -30), (0, 3)]
        labels = ['$J_R$', r'$J_\phi$', '$J_z$']
        for k, (i, j) in enumerate(zip([0,0,1], [1,2,2])):
            ax = axes[k]

            ax.scatter(w0_act[i], w0_act[j], color='tab:red', zorder=100)

            mask = np.abs((act[:, not_in[k]] - w0_act[not_in[k]]) / w0_act[not_in[k]]) < 0.2
            ax.plot(act[mask, i], act[mask, j], 
                    marker='o', ls='none', ms=1.5, mew=0, alpha=0.4)
            
            ax.set_xlim(lims[i])
            ax.set_ylim(lims[j])
            ax.set_xlabel(labels[i])
            ax.set_ylabel(labels[j])
        
        fig.suptitle(f"potential: {name},    orbit: {n}", fontsize=20)

## Make sure action-selected stars fall near orbits

In [None]:
def get_action_box(tbl, orbit_actions):
    actions = tbl['actions']
    
    masks = []
    for n in range(len(orbit_actions)):
        # JR_mask = np.abs(actions[:, 0] / orbit_actions[n][0] - 1).decompose() < 0.4
        JR_mask = actions[:, 0] < 2 * orbit_actions[n][0]
        Jp_mask = np.abs(actions[:, 1] / orbit_actions[n][1] - 1).decompose() < 0.2
        Jz_mask = np.abs(actions[:, 2] / orbit_actions[n][2] - 1).decompose() < 0.2
        masks.append(JR_mask & Jp_mask & Jz_mask)
        
    return masks

In [None]:
zlim = 1.75 # kpc
vlim = 75. # pc/Myr

for i, name in enumerate(sorted_keys):
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    
    masks = get_action_box(aafs[name], w0s_actions[name])
    print(masks[0].sum())
    for n, mask in enumerate(masks):
        l, = ax.plot(vz[mask], z[mask], 
                     marker='o', mew=0, ls='none', ms=3, alpha=0.5)
        ax.plot(orbits[name][:, n].v_z.to_value(u.km/u.s),
                orbits[name][:, n].z.to_value(u.kpc), marker='',
                color='#aaaaaa', alpha=0.2, zorder=-100)
    
    ax.set_xlabel(f'$v_z$ [{u.km/u.s:latex_inline}]')
    ax.set_ylabel(f'$z$ [{u.kpc:latex_inline}]')

    ax.set_xlim(-vlim, vlim)
    ax.set_ylim(-zlim, zlim)

    fig.tight_layout()

## Abundance vs angles in action boxes:

In [None]:
angz_bins = np.arange(0, 360+1e-3, 30)

fig, axes = plt.subplots(len(aafs), 2, figsize=(12, 4*len(aafs)), 
                         sharex=True, sharey=True)

for i, name in enumerate(aafs):
    t = aafs[name]
    masks = get_action_box(aafs[name], w0s_actions[name])
    
    angz = coord.Angle(t['angles'][:, 2]).wrap_at(360*u.deg).degree
    
    for n in range(2):
        axes[i, n].plot(angz[masks[n]], 
                        t['MG_FE'][masks[n]], 
                        marker='o', mew=0, ls='none', 
                        ms=4, alpha=0.4)
        
        angz_bins = np.arange(0, 360+1e-3, 30)
        stat = binned_statistic(angz[masks[n]], t['MG_FE'][masks[n]], 
                                bins=angz_bins, statistic='mean')
        ctr = 0.5 * (stat.bin_edges[:-1] + stat.bin_edges[1:])
        axes[i, n].plot(ctr, stat.statistic, drawstyle='steps-mid', marker='')
    
        axes[i, n].set_title(name)
    
    axes[i, 0].set_ylabel("[Mg/Fe]")
    
axes[0, 0].set_ylim(-0.1, 0.15)
axes[-1, 0].set_xlabel(r'$\theta_z$')
axes[-1, 1].set_xlabel(r'$\theta_z$')
fig.set_facecolor('w')

OLD!

In [None]:
plot_x = np.linspace(0, 2*np.pi, 1024)
all_coeffs = {}
for i, name in enumerate(sorted_keys):
    fig, ax = plt.subplots(1, 1, figsize=(8, 5))
    
    stat = stats[name]
    angz = angzs[name]
    d_elem = d_elems[name]
    
    bin_x = 0.5 * (stat.bin_edges[:-1] + stat.bin_edges[1:])
    bin_y = stat.statistic
    
    all_coeffs[name] = []
    all_tmps = []
    np.random.seed(42)
    for trial in range(N_trials):
        if trial > 0:
            idx = np.random.choice(len(angz), size=len(angz))
            x = angz[idx]
            y = d_elem[idx]
        else:
            x = angz
            y = d_elem
    
        coeffs, *_ = np.linalg.lstsq(get_M(x), y, rcond=None)
        tmp = np.zeros_like(coeffs)
        tmp[[0, 3]] = coeffs[[0, 3]]
        all_tmps.append(tmp)
        all_coeffs[name].append(coeffs)
    
    ax.plot(x, y, marker='o', mew=0, ls='none', alpha=0.4, ms=2.)
    ax.plot(bin_x, bin_y, marker='', drawstyle='steps-mid', color='tab:red')
    
    for coeffs in all_tmps:
        plot_y = get_M(plot_x) @ coeffs
        ax.plot(plot_x, plot_y, alpha=0.4, color='tab:blue', marker='')
    ax.set_title(name)

    ax.set_xlim(0, 2*np.pi)
    ax.set_ylim(-0.01, 0.01)

    ax.set_xlabel(r'vertical conjugate angle, $\theta_z$ [rad]')
    ax.set_ylabel('action-local\n[Mg/Fe] anomaly')

    fig.tight_layout()
    fig.set_facecolor('w')
#     fig.savefig(f'../plots/anomaly-panels-{name}.png', dpi=256)