In [None]:
import pathlib
import itertools
import pickle

import astropy.coordinates as coord
from astropy.convolution import convolve, Gaussian2DKernel
from astropy.io import fits
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from scipy.stats import binned_statistic
from scipy.spatial import cKDTree
from scipy.interpolate import interp1d
from tqdm.notebook import trange, tqdm

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.integrate as gi
import gala.potential as gp
from gala.units import galactic
from gala.mpl_style import hesperia, laguna

from thriftshop.config import vcirc, rsun, plot_path, fig_path, cache_path, fiducial_mdisk
from thriftshop.config import plot_config as pc
from thriftshop.config import elem_names, galcen_frame
from thriftshop.data import load_apogee_sample
from thriftshop.potentials import potentials
from thriftshop.actions import safe_get_actions, get_w0s_with_same_actions
from thriftshop.abundances import get_elem_names, elem_to_label
from thriftshop.atm import (AbundanceTorusMaschine, 
                            run_bootstrap_coeffs, 
                            get_cos2th_zerocross)

In [None]:
main_elem = 'MG_FE'

In [None]:
t, c = load_apogee_sample('../../data/apogee-parent-sample.fits')

In [None]:
aafs = {}
for name in potentials:
    filename = cache_path / f'aaf-{name}.fits'
    if not filename.exists():
        continue
    aafs[name] = at.join(at.QTable.read(filename), 
                         t, 
                         keys='APOGEE_ID')

In [None]:
with open(cache_path / 'w0s.pkl', 'rb') as f:
    w0s = pickle.load(f)
    
orbits = {}
for k, w0 in w0s.items():
    orbits[k] = potentials[k].integrate_orbit(
        w0, dt=0.5*u.Myr, t1=0, t2=6*u.Gyr
    )

In [None]:
sorted_keys = sorted(
    orbits.keys(), 
    key=lambda k: potentials[k]['disk'].parameters['m'])
sorted_keys = [x for x in sorted_keys if x in aafs]

In [None]:
all_finite_mask = None
for name in sorted_keys:
    finite_mask = np.all(np.isfinite(aafs[name]['actions']), axis=1)
    if all_finite_mask is None:
        all_finite_mask = finite_mask
    else:
        all_finite_mask &= finite_mask
        
print(all_finite_mask.sum())

In [None]:
for elem_name in elem_names:
    print(elem_name, (aafs['1.0'][elem_name][all_finite_mask] > -3).sum())

In [None]:
for name in aafs:
    aafs[name] = aafs[name][all_finite_mask]

Visualize results from fitting the full sample for each potential model

In [None]:
fig, axes = plt.subplots(3, 5, figsize=(15, 10), 
                         sharex=True, sharey=True)
for i, name in enumerate(sorted_keys):
    atm = AbundanceTorusMaschine(aafs[name])
    
    coeffs, coeff_cov = atm.get_coeffs_for_elem(main_elem)
    angz, d_elem, d_elem_errs = atm.get_theta_z_anomaly(main_elem)
    binx, bin_anom, bin_anom_err = atm.get_binned_anomaly(main_elem)

    # ---
    
    ax = axes.flat[i]
    ax.plot(angz, d_elem, 
            marker='o', ls='none', mew=0, ms=2, alpha=0.3)

    ax.plot(binx, bin_anom, 
            marker='', drawstyle='steps-mid', 
            zorder=10, color='tab:red', alpha=1)
    ax.errorbar(binx, bin_anom, bin_anom_err,
                marker='o', ls='none', ecolor='tab:red', 
                zorder=5, alpha=0.6)
    
    plot_x = np.linspace(0, 2*np.pi, 1024)
    plot_M = atm.get_M(plot_x)
    plot_y = plot_M @ coeffs
    ax.plot(plot_x, plot_y, marker='', lw=2, color='tab:blue', zorder=100)
    
    ax.set_title(name, fontsize=18, pad=11)
    
ax.set_xlim(0, 2*np.pi)
ax.set_ylim(-0.025, 0.025)

In [None]:
all_bs_coeffs = run_bootstrap_coeffs(aafs, main_elem, bootstrap_K=128)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharex=True, sharey=True)

plot_x = np.linspace(0, 2*np.pi, 1024)
for i, name in enumerate(['0.4', '1.0', '1.6']):
    atm = AbundanceTorusMaschine(aafs[name])
    bin_x, bin_anom, bin_anom_err = atm.get_binned_anomaly(main_elem)
    
    ax = axes[i]
    
    ax.plot(bin_x, bin_anom, 
            marker='', drawstyle='steps-mid', color='k', lw=2,
            label='mean abundance deviation (data)')
    
    for j, c in enumerate(all_bs_coeffs[name]):
        plot_y = atm.get_M(plot_x) @ c
        if j == 0:
            label = 'bootstrap samples (model)'
        else:
            label = None
        ax.plot(plot_x, plot_y, 
                alpha=0.25, color='tab:blue', marker='', lw=1.,
                label=label)
        
    tmp = np.zeros_like(c)
    tmp[[0,3]] = c[[0,3]]
    plot_y = atm.get_M(plot_x) @ tmp
    ax.plot(plot_x, plot_y, 
            alpha=0.8, color='tab:red', marker='', lw=2, zorder=100, 
            label=r'$\cos(2\,\theta_z)$ term of mean model')
    
    title = (r'${\rm M}_{\rm disk} / {\rm M}_{\rm disk}^\star =' 
             + f' {float(name):.1f}$')
    ax.set_title(title, pad=11, fontsize=22)

    ax.axhline(0, zorder=-100, color='#aaaaaa')
    ax.set_xlabel(r'vertical angle, $\theta_z$ [rad]')

ax = axes[0]
ax.set_ylim(-0.01, 0.01)
ax.set_ylabel(f'mean {elem_to_label(main_elem)}\ndeviation, ' + 
              r'$\Delta^{[{\rm Mg}/{\rm Fe}]}$')

ax.set_xlim(0, 2*np.pi)
ax.set_xticks(np.arange(0, 2+1e-3, 1) * np.pi)
ax.set_xticklabels(['0'] + [rf'${x:.0f}\,\pi$' for x in np.arange(1, 2+1e-3, 1)])
ax.set_xticks(np.arange(0, 2+1e-3, 0.5) * np.pi, minor=True)

axes[2].legend(loc='lower right', fontsize=14)

fig.tight_layout()
fig.set_facecolor('w')

# fig.savefig(fig_path / 'sinusoid-fits.pdf')

---

In [None]:
summary, zero_cross, zero_cross_err = get_cos2th_zerocross(all_bs_coeffs)

In [None]:
zero_cross_err

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(12, 9), 
                         sharex=True, sharey=True)

for j, i in enumerate(range(1, 5)):
    ax = axes.flat[j]
    
    ax.errorbar(summary[i]['mdisk'], 
                summary[i]['y'], 
                yerr=summary[i]['y_err'],
                marker='o', ecolor='#666666', ls='')
    
    ax.plot(summary[i]['mdisk'], 
            summary[i]['y'], 
            marker='', ls='-', color='#aaaaaa', zorder=-10)
    
    ax.axhline(0, marker='', ls='--', zorder=-100, color='tab:green', alpha=0.5)

    
axes[1, 0].errorbar(zero_cross, [0], 
                    xerr=np.array(zero_cross_err)[None].T, 
                    ls=None, marker='s', color='tab:blue',
                    alpha=0.75, mew=0, ms=6)
    
axes[0, 0].set_ylabel(r'$\cos\,\theta_z$ amplitude, $a_1$')
axes[0, 1].set_ylabel(r'$\sin\,\theta_z$ amplitude, $b_1$')
axes[1, 0].set_ylabel(r'$\cos\,2\theta_z$ amplitude, $a_2$')
axes[1, 1].set_ylabel(r'$\sin\,2\theta_z$ amplitude, $b_2$')
axes[1, 0].set_xlabel(r'${\rm M}_{\rm disk} / {\rm M}_{\rm disk}^\star$')
axes[1, 1].set_xlabel(r'${\rm M}_{\rm disk} / {\rm M}_{\rm disk}^\star$')

ax = axes[1, 0]
xlim = (min(summary[0]['mdisk'])-0.1, 
        max(summary[0]['mdisk'])+0.1)
ax.set_xlim(xlim)
ax.set_xticks(np.arange(min(summary[0]['mdisk']), max(summary[0]['mdisk'])+1e-3, 0.2))
ax.set_xticks(np.arange(min(summary[0]['mdisk']), max(summary[0]['mdisk'])+1e-3, 0.1), 
              minor=True)
    
axes[0, 0].set_title('(sensitive to solar motion in $v_z$)', pad=11, fontsize=20)
axes[0, 1].set_title('(sensitive to solar position in $z$)', pad=11, fontsize=20)
axes[1, 0].set_title('(sensitive to local density)', pad=11, fontsize=20)
axes[1, 1].set_title('(verboten)', pad=11, fontsize=20)

fig.tight_layout()
fig.subplots_adjust(hspace=0.3)

# fig.savefig(fig_path / 'coeff-vs-mdisk.pdf')

## Run for all elements:

In [None]:
from thriftshop.atm import zerocross_worker
from schwimmbad import MultiPool, SerialPool

In [None]:
tasks = [(aafs, elem_name) for elem_name in elem_names]

results = []
# with MultiPool() as pool:
with SerialPool() as pool:
    for r in pool.map(zerocross_worker, tasks):
        results.append(r)

In [None]:
results

In [None]:
labels = []
xs = []
ys = []
yerrs = []

for i, r in enumerate(results):
    labels.append(r[0])
    xs.append(i)
    ys.append(r[1][0])
    yerrs.append(r[1][1])

ys = np.array(ys)
yerrs = np.array(yerrs)
    
fig, ax = plt.subplots(1, 1, figsize=(8.5, 6))
ax.errorbar(xs, ys, yerrs.T, marker='s', ms=6,
            ls='none', ecolor='#666666', zorder=10)
ax.errorbar(xs[4], ys[4], yerrs.T[:, 4:5], marker='s', ms=6,
            ls='none', color='tab:blue', zorder=100)
# ax.axhline(1., zorder=-10, color='#aaaaaa', alpha=0.5)

dumb_yerrs = np.mean(yerrs, axis=1)
ivars = 1 / dumb_yerrs**2
comb_y = np.sum(ys * ivars) / np.sum(ivars)
comb_err = np.sqrt(1 / np.sum(ivars))
ax.axhline(comb_y, color='tab:purple', zorder=-5, alpha=1)
ax.axhspan(comb_y - comb_err, comb_y + comb_err, 
           color='tab:purple', lw=0, alpha=0.5, zorder=-5)

ax.set_xlim(-0.75, len(results)-1 + 0.75)
ax.set_ylim(0, 2)

ax.set_ylabel(r'inferred ${\rm M}_{\rm disk} / {\rm M}_{\rm disk}^\star$')

ax.set_xticks([])
ax.set_xticklabels([])
# ax.set_xticklabels([elem_to_label(x) for x in labels])
for x, y, yerr, label in zip(xs, ys, yerrs, labels):
    ax.text(x, y - yerr[0] - 0.12, elem_to_label(label),
            ha='center', va='top', fontsize=18, color='#555555')

fig.tight_layout()
fig.set_facecolor('w')

fig.savefig(fig_path / 'mdisk-vs-elem.pdf')

In [None]:
for elem_name in elem_names:
    _all_bs_coeffs = run_bootstrap_coeffs(aafs, elem_name)
    elem_summary, *_ = get_cos2th_zerocross(_all_bs_coeffs)

    fig, axes = plt.subplots(2, 2, figsize=(12, 9), 
                             sharex=True, sharey='row')

    for j, i in enumerate(range(1, 5)):
        ax = axes.flat[j]

        ax.errorbar(elem_summary[i]['mdisk'], 
                    elem_summary[i]['y'], 
                    elem_summary[i]['y_err'],
                    marker='o', ecolor='#666666', ls='')
        
        _idx = np.argsort(elem_summary[i]['mdisk'])
        ax.plot(np.array(elem_summary[i]['mdisk'])[_idx], 
                np.array(elem_summary[i]['y'])[_idx], 
                marker='', ls='-', color='#aaaaaa', zorder=-10)

        ax.axhline(0, marker='', zorder=-100, color='tab:green', alpha=0.5)

    axes[0, 0].set_ylabel(r'$\cos\,\theta_z$ amplitude')
    axes[0, 1].set_ylabel(r'$\sin\,\theta_z$ amplitude')
    axes[1, 0].set_ylabel(r'$\cos\,2\theta_z$ amplitude')
    axes[1, 1].set_ylabel(r'$\sin\,2\theta_z$ amplitude')
    axes[1, 0].set_xlabel(r'${\rm factor} \times {\rm M}_{\rm disk}$')
    axes[1, 1].set_xlabel(r'${\rm factor} \times {\rm M}_{\rm disk}$')

    ax = axes[1, 0]
    xlim = (min(elem_summary[0]['mdisk'])-0.1, 
            max(elem_summary[0]['mdisk'])+0.1)
    ax.set_xlim(xlim)
    ax.set_xticks(np.arange(min(elem_summary[0]['mdisk']), max(elem_summary[0]['mdisk'])+1e-3, 0.2))
    ax.set_xticks(np.arange(min(elem_summary[0]['mdisk']), max(elem_summary[0]['mdisk'])+1e-3, 0.1), 
                  minor=True)

    axes[0, 0].set_title('(sensitive to solar motion in $v_z$)', pad=11, fontsize=20)
    axes[0, 1].set_title('(sensitive to solar position in $z$)', pad=11, fontsize=20)
    axes[1, 0].set_title('(sensitive to local density)', pad=11, fontsize=20)
    axes[1, 1].set_title('(verboten)', pad=11, fontsize=20)

    fig.tight_layout()
    fig.subplots_adjust(hspace=0.3)
    
    fig.savefig(plot_path / f'coeffs-vs-mdisk-{elem_name}.png', dpi=250)
    plt.close(fig)

---

### Note: below shows results from a test comparing bootstrap estimated coefficient errors to propagated uncertainty

Showing that the d_elem distribution has an intrinsic width of ~0.04 dex

In [None]:
np.mean(d_elem), np.median(d_elem), np.std(d_elem)

In [None]:
plt.hist(d_elem, bins=np.linspace(-0.2, 0.2, 64));

If we run the below, but replacing d_elem_errs with 0.04, the ratio goes to 1

In [None]:
np.mean(bs_coeffs[:, 3]), coeffs[3]

In [None]:
for i in range(5):
    print(np.std(bs_coeffs[:, i]), np.sqrt(coeffs_cov[i, i]),
          np.std(bs_coeffs[:, i]) / np.sqrt(coeffs_cov[i, i]))