In [None]:
import pathlib
import itertools
import pickle

import astropy.coordinates as coord
from astropy.convolution import convolve, Gaussian2DKernel
from astropy.io import fits
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from scipy.stats import binned_statistic
from scipy.spatial import cKDTree
from scipy.interpolate import interp1d
from tqdm.notebook import trange, tqdm

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.integrate as gi
import gala.potential as gp
from gala.units import galactic
from gala.mpl_style import hesperia, laguna

from thriftshop.config import vcirc, rsun, plot_path, fig_path, cache_path, fiducial_mdisk
from thriftshop.config import plot_config as pc
from thriftshop.data import load_apogee_sample
from thriftshop.potentials import potentials
from thriftshop.actions import safe_get_actions, get_w0s_with_same_actions
from thriftshop.abundances import get_elem_names, elem_to_label

coord.galactocentric_frame_defaults.set('v4.0');

In [None]:
main_elem = 'MG_FE'

In [None]:
t, c = load_apogee_sample('../../data/apogee-parent-sample.fits')

In [None]:
aafs = {}
for name in potentials:
    filename = cache_path / f'aaf-{name}.fits'
    if not filename.exists():
        continue
    aafs[name] = at.join(at.QTable.read(filename), 
                         t, 
                         keys='APOGEE_ID')

In [None]:
with open(cache_path / 'w0s.pkl', 'rb') as f:
    w0s = pickle.load(f)
    
orbits = {}
for k, w0 in w0s.items():
    orbits[k] = potentials[k].integrate_orbit(
        w0, dt=0.5*u.Myr, t1=0, t2=6*u.Gyr
    )

In [None]:
sorted_keys = sorted(
    orbits.keys(), 
    key=lambda k: potentials[k]['disk'].parameters['m'])
sorted_keys = [x for x in sorted_keys if x in aafs]

In [None]:
all_finite_mask = None
for name in sorted_keys:
    finite_mask = np.all(np.isfinite(aafs[name]['actions']), axis=1)
    if all_finite_mask is None:
        all_finite_mask = finite_mask
    else:
        all_finite_mask &= finite_mask
        
print(all_finite_mask.sum())

In [None]:
for name in aafs:
    aafs[name] = aafs[name][all_finite_mask]

In [None]:
class AbundanceTorusMaschine:
    
    def __init__(self, aaf, tree_K=64, sinusoid_K=2):
        """
        Parameters
        ----------
        aaf : `astropy.table.Table`
            Table of actions, angles, and abundances.
        tree_K : int (optional)
            The number of neighbors used to estimate the action-local 
            mean abundances.
        sinusoid_K : int (optional)
            The number of cos/sin terms in the sinusoid fit to the 
            abundance anomaly variations with angle.
        """
        
        self.aaf = aaf
        
        # config
        self.tree_K = int(tree_K)
        self.sinusoid_K = int(sinusoid_K)
        
    def get_theta_z_anomaly(self, elem_name, action_unit=30*u.km/u.s*u.kpc):
        action_unit = u.Quantity(action_unit)
        
        # Actions without units:
        X = self.aaf['actions'].to_value(action_unit)
        angz = coord.Angle(self.aaf['angles'][:, 2]).wrap_at(360*u.deg).radian
        
        # element abundance
        elem = self.aaf[elem_name]
        elem_errs = self.aaf[f"{elem_name}_ERR"]
        ivar = 1 / elem_errs**2

        tree = cKDTree(X)
        dists, idx = tree.query(X, k=self.tree_K+1)
        
        # compute action-local abundance anomaly
        errs = np.sqrt(1 / np.sum(ivar[idx[:, 1:]], axis=1))
        means = np.sum(elem[idx[:, 1:]] * ivar[idx[:, 1:]], axis=1) * errs**2
        
        d_elem = elem - means
        d_elem_errs = np.sqrt(elem_errs**2 + errs**2)
#         d_elem_errs = np.full_like(d_elem, 0.04)  # MAGIC NUMBER
        
        return angz, d_elem, d_elem_errs
    
    def get_M(self, x):
        M = np.full((len(x), 1 + 2*self.sinusoid_K), np.nan)
        M[:, 0] = 1.

        for n in range(self.sinusoid_K):
            M[:, 1 + 2*n] = np.cos((n+1) * x)
            M[:, 2 + 2*n] = np.sin((n+1) * x)

        return M
    
    def get_coeffs(self, M, y, yerr):
        Cinv_diag = 1 / yerr**2
        MT_Cinv = M.T * Cinv_diag[None]
        MT_Cinv_M = MT_Cinv @ M
        coeffs = np.linalg.solve(MT_Cinv_M, MT_Cinv @ y)
        coeffs_cov = np.linalg.inv(MT_Cinv_M)
        return coeffs, coeffs_cov
    
    def get_coeffs_for_elem(self, elem_name):
        tz, d_elem, d_elem_errs = self.get_theta_z_anomaly(elem_name)
        M = self.get_M(tz)
        return self.get_coeffs(M, d_elem, d_elem_errs)
    
    def get_binned_anomaly(self, elem_name, theta_z_step=5*u.deg):
        """
        theta_z_step : `astropy.units.Quantity` [angle] (optional)
            The bin step size for the vertical angle bins. This is only
            used in methods if `statistic != 'mean'`.
        """
        theta_z_step = coord.Angle(theta_z_step)
        angz_bins = np.arange(0, 2*np.pi+1e-4, 
                              theta_z_step.to_value(u.rad))
        theta_z, d_elem, d_elem_errs = self.get_theta_z_anomaly(elem_name)
        d_elem_ivar = 1 / d_elem_errs**2
#         d_elem_ivar = np.full_like(d_elem, 1 / 0.04**2)  # MAGIC NUMBER
        
        stat1 = binned_statistic(theta_z, d_elem * d_elem_ivar, 
                                 bins=angz_bins,
                                 statistic='sum')
        stat2 = binned_statistic(theta_z, d_elem_ivar, 
                                 bins=angz_bins,
                                 statistic='sum')
        
        binx = 0.5 * (angz_bins[:-1] + angz_bins[1:])
        means = stat1.statistic / stat2.statistic
        errs = np.sqrt(1 / stat2.statistic)

        return binx, means, errs

Visualize results from fitting the full sample for each potential model

In [None]:
fig, axes = plt.subplots(3, 5, figsize=(15, 10), 
                         sharex=True, sharey=True)
for i, name in enumerate(sorted_keys):
    atm = AbundanceTorusMaschine(aafs[name])
    
    coeffs, coeff_cov = atm.get_coeffs_for_elem(main_elem)
    angz, d_elem, d_elem_errs = atm.get_theta_z_anomaly(main_elem)
    binx, bin_anom, bin_anom_err = atm.get_binned_anomaly(main_elem)

    # ---
    
    ax = axes.flat[i]
    ax.plot(angz, d_elem, 
            marker='o', ls='none', mew=0, ms=2, alpha=0.3)

    ax.plot(binx, bin_anom, 
            marker='', drawstyle='steps-mid', 
            zorder=10, color='tab:red', alpha=1)
    ax.errorbar(binx, bin_anom, bin_anom_err,
                marker='o', ls='none', ecolor='tab:red', 
                zorder=5, alpha=0.6)
    
    plot_x = np.linspace(0, 2*np.pi, 1024)
    plot_M = atm.get_M(plot_x)
    plot_y = plot_M @ coeffs
    ax.plot(plot_x, plot_y, marker='', lw=2, color='tab:blue', zorder=100)
    
    ax.set_title(name, fontsize=18, pad=11)
    
ax.set_xlim(0, 2*np.pi)
ax.set_ylim(-0.025, 0.025)

In [None]:
def run_bootstrap_coeffs(aafs, elem_name, bootstrap_K=128, seed=42, overwrite=False):
    coeffs_cache = cache_path / f'coeffs-bootstrap{bootstrap_K}-{elem_name}.pkl'

    if not coeffs_cache.exists() or overwrite:
        all_bs_coeffs = {}
        for name in tqdm(sorted_keys):
            aaf = aafs[name]
            
            if seed is not None:
                np.random.seed(seed)
                
            bs_coeffs = []
            for k in range(bootstrap_K):
                bootstrap_idx = np.random.choice(len(aaf), size=len(aaf))
                atm = AbundanceTorusMaschine(aaf[bootstrap_idx])
                coeffs, _ = atm.get_coeffs_for_elem(elem_name)
                bs_coeffs.append(coeffs)

            all_bs_coeffs[name] = np.array(bs_coeffs)

        with open(coeffs_cache, 'wb') as f:
            pickle.dump(all_bs_coeffs, f)

    with open(coeffs_cache, 'rb') as f:
        all_bs_coeffs = pickle.load(f)
        
    return all_bs_coeffs

In [None]:
all_bs_coeffs = run_bootstrap_coeffs(aafs, main_elem)

In [None]:
angle_label = r'vertical angle, $\theta_z$ [rad]'

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharex=True, sharey=True)

plot_x = np.linspace(0, 2*np.pi, 1024)
for i, name in enumerate(['0.4', '1.0', '1.6']):
    atm = AbundanceTorusMaschine(aafs[name])
    bin_x, bin_anom, bin_anom_err = atm.get_binned_anomaly(main_elem)
    
    ax = axes[i]
    
    ax.plot(bin_x, bin_anom, 
            marker='', drawstyle='steps-mid', color='k', lw=2,
            label='mean abundance anomaly (data)')
    
    for j, c in enumerate(all_bs_coeffs[name]):
        plot_y = atm.get_M(plot_x) @ c
        if j == 0:
            label = 'bootstrap samples (model)'
        else:
            label = None
        ax.plot(plot_x, plot_y, 
                alpha=0.25, color='tab:blue', marker='', lw=1.,
                label=label)
        
    tmp = np.zeros_like(c)
    tmp[3] = c[3]
    plot_y = atm.get_M(plot_x) @ tmp
    ax.plot(plot_x, plot_y, 
            alpha=0.8, color='tab:red', marker='', lw=2, zorder=100, 
            label=r'$\cos(2\,\theta_z)$ term of mean model')
    
    if name == '1.0':
        title = r'${\rm M}_{\rm disk}$'
    else:
        title = f'${float(name):.1f}' + r' \, {\rm M}_{\rm disk}$'
    ax.set_title(title, pad=11, fontsize=22)

    ax.axhline(0, zorder=-100, color='#aaaaaa')

    ax.set_xlabel(angle_label)

ax = axes[0]
ax.set_ylim(-0.01, 0.01)
ax.set_ylabel(f'action-local\n{elem_to_label(main_elem)} anomaly')

ax.set_xlim(0, 2*np.pi)
ax.set_xticks(np.arange(0, 2+1e-3, 1) * np.pi)
ax.set_xticklabels(['0'] + [rf'${x:.0f}\,\pi$' for x in np.arange(1, 2+1e-3, 1)])
ax.set_xticks(np.arange(0, 2+1e-3, 0.5) * np.pi, minor=True)

axes[2].legend(loc='lower right', fontsize=14)

fig.tight_layout()
fig.set_facecolor('w')

fig.savefig(fig_path / 'sinusoid-fits.pdf')

---

In [None]:
from collections import defaultdict

In [None]:
def get_cos2th_zerocross(coeffs):
    summary = defaultdict(lambda *args: defaultdict(list))
    for i in range(5):
        for k in coeffs:
            summary[i]['mdisk'].append(float(k))
            summary[i]['y'].append(np.mean(coeffs[k][:, i]))
            summary[i]['y_err'].append(np.std(coeffs[k][:, i]))
            
    # cos2theta term:
    s = summary[3]
    zero_cross = interp1d(s['y'], s['mdisk'], fill_value="extrapolate")(0.) 
    zero_cross1 = interp1d(np.array(s['y']) + np.array(s['y_err']), 
                           s['mdisk'], fill_value="extrapolate")(0.)
    zero_cross2 = interp1d(np.array(s['y']) - np.array(s['y_err']), 
                           s['mdisk'], fill_value="extrapolate")(0.)

    zero_cross_err = 0.5 * (zero_cross2 - zero_cross1)
    return summary, zero_cross, zero_cross_err

In [None]:
summary, zero_cross, zero_cross_err = get_cos2th_zerocross(all_bs_coeffs)

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(12, 9), 
                         sharex=True, sharey='row')

for j, i in enumerate(range(1, 5)):
    ax = axes.flat[j]

    ax.errorbar(summary[i]['mdisk'], 
                summary[i]['y'], 
                summary[i]['y_err'],
                marker='o', ecolor='#666666', ls='')
    
    ax.plot(summary[i]['mdisk'], 
            summary[i]['y'], 
            marker='', ls='-', color='#aaaaaa', zorder=-10)
    
    ax.axhline(0, marker='', zorder=-100, color='tab:green', alpha=0.5)

axes[0, 0].set_ylabel(r'$\cos\,\theta_z$ amplitude')
axes[0, 1].set_ylabel(r'$\sin\,\theta_z$ amplitude')
axes[1, 0].set_ylabel(r'$\cos\,2\theta_z$ amplitude')
axes[1, 1].set_ylabel(r'$\sin\,2\theta_z$ amplitude')
axes[1, 0].set_xlabel(r'${\rm factor} \times {\rm M}_{\rm disk}$')
axes[1, 1].set_xlabel(r'${\rm factor} \times {\rm M}_{\rm disk}$')

ax = axes[1, 0]
xlim = (min(summary[0]['mdisk'])-0.1, 
        max(summary[0]['mdisk'])+0.1)
ax.set_xlim(xlim)
ax.set_xticks(np.arange(min(summary[0]['mdisk']), max(summary[0]['mdisk'])+1e-3, 0.2))
ax.set_xticks(np.arange(min(summary[0]['mdisk']), max(summary[0]['mdisk'])+1e-3, 0.1), 
              minor=True)
    
axes[0, 0].set_title('(sensitive to solar motion in $v_z$)', pad=11, fontsize=20)
axes[0, 1].set_title('(sensitive to solar position in $z$)', pad=11, fontsize=20)
axes[1, 0].set_title('(sensitive to local density)', pad=11, fontsize=20)
axes[1, 1].set_title('(verboten)', pad=11, fontsize=20)

fig.tight_layout()
fig.subplots_adjust(hspace=0.3)

fig.savefig(fig_path / 'coeff-vs-mdisk.pdf')

Run all elements:

In [None]:
elem_names = ['M_H', 'FE_H', 'NI_FE', 'MN_FE',
              'ALPHA_M', 'MG_FE', 'SI_FE', 'C_FE']

In [None]:
def worker(elem_name):
    bs_coeffs = run_bootstrap_coeffs(aafs, elem_name)
    s, zc, zc_err = get_cos2th_zerocross(bs_coeffs)
    return elem_name, [zc, zc_err]

In [None]:
results = []
for elem_name in elem_names:
    results.append(worker(elem_name))

In [None]:
# results = []
# with MultiPool() as pool:
#     for r in pool.map(worker, elem_names):
#         results.append(r)

---

### Note: below shows results from a test comparing bootstrap estimated coefficient errors to propagated uncertainty

Showing that the d_elem distribution has an intrinsic width of ~0.04 dex

In [None]:
np.mean(d_elem), np.median(d_elem), np.std(d_elem)

In [None]:
plt.hist(d_elem, bins=np.linspace(-0.2, 0.2, 64));

If we run the below, but replacing d_elem_errs with 0.04, the ratio goes to 1

In [None]:
np.mean(bs_coeffs[:, 3]), coeffs[3]

In [None]:
for i in range(5):
    print(np.std(bs_coeffs[:, i]), np.sqrt(coeffs_cov[i, i]),
          np.std(bs_coeffs[:, i]) / np.sqrt(coeffs_cov[i, i]))