In [None]:
import pathlib
import itertools
import pickle

import astropy.coordinates as coord
from astropy.convolution import convolve, Gaussian2DKernel
from astropy.io import fits
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from scipy.stats import binned_statistic_2d
import corner

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.integrate as gi
import gala.potential as gp
from gala.units import galactic
from gala.mpl_style import hesperia, laguna

from thriftshop.config import vcirc, rsun, plot_path, fig_path, cache_path, fiducial_mdisk
from thriftshop.config import plot_config as pc
from thriftshop.config import elem_names
from thriftshop.data import load_apogee_sample
from thriftshop.potentials import potentials
from thriftshop.actions import safe_get_actions, get_w0s_with_same_actions
from thriftshop.abundances import get_elem_names, elem_to_label

coord.galactocentric_frame_defaults.set('v4.0');

In [None]:
galcen = coord.Galactocentric()
galcen

In [None]:
galcen.frame_attribute_references

In [None]:
t, c = load_apogee_sample('../../data/apogee-parent-sample.fits')

In [None]:
dr16 = at.Table.read('/Users/apricewhelan/data/APOGEE_DR16/allStarLite-r12-l33.fits')
in_dr16 = np.isin(t['APOGEE_ID'], dr16['APOGEE_ID'])
len(in_dr16), in_dr16.sum(), len(in_dr16) / in_dr16.sum()

In [None]:
galcen = c.transform_to(coord.Galactocentric)
z = galcen.z.to_value(pc['zunit'])
vz = galcen.v_z.to_value(pc['vunit'])

In [None]:
# zlim = 2 # kpc
# vlim = 100. # km/s
# vstep = 4
# zstep = 75 / 1e3
# vzz_bins = (np.arange(-vlim, vlim+1e-3, vstep),
#             np.arange(-zlim, zlim+1e-3, zstep))

# fig, axes = plt.subplots(1, 2, figsize=(12, 5),
#                          constrained_layout=True)

# elem = t['MG_FE']
# stat = binned_statistic_2d(vz, z, elem, statistic='mean',
#                            bins=vzz_bins)
# vmin, vmax = np.percentile(elem, [15, 85])

# ax = axes[0]
# cs = ax.pcolormesh(stat.x_edge, stat.y_edge, stat.statistic.T, 
#                    cmap='cividis', vmin=vmin, vmax=vmax)
# cb = fig.colorbar(cs, ax=ax, aspect=40)

# ax.set_xlabel('v_z')
# ax.set_ylabel('z')

# ax = axes[1]
# H, *_ = np.histogram2d(vz, z, bins=vzz_bins)
# cs = ax.pcolormesh(stat.x_edge, stat.y_edge, H.T, 
#                    cmap='cividis', 
#                    norm=mpl.colors.LogNorm(1, 3e2))
# cb = fig.colorbar(cs, ax=ax, aspect=40)

# ax.set_xlabel('v_z')
# # ax.set_ylabel('z')

# fig.set_facecolor('w')

In [None]:
all_elem_names = get_elem_names(t)
len(all_elem_names), len(elem_names)

In [None]:
tmp = []
for elem_name in all_elem_names:
    mask = (t[elem_name] > -3) & (t[elem_name] < 3)
    tmp.append(np.percentile(t[f'{elem_name}_ERR'][mask], 50))
sorted_names = np.array(all_elem_names)[np.argsort(tmp)]
sorted_tmp = np.sort(tmp)
for elem_name, err in zip(sorted_names, sorted_tmp):
    print(elem_name, f'{err:.3f}')

In [None]:
median_elem_errs = dict()
for elem_name in elem_names:
    mask = (t[elem_name] > -3) & (t[elem_name] < 3)
    median_elem_errs[elem_name] = round(np.percentile(t[f'{elem_name}_ERR'][mask], 50), 2)

In [None]:
vstep = 2.
zstep = 50 / 1e3
vzz_bins = (np.arange(-pc['vlim'], pc['vlim']+1e-3, vstep),
            np.arange(-pc['zlim'], pc['zlim']+1e-3, zstep))

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(14, 8.5),
                         sharex=True, sharey=True, 
                         constrained_layout=True)

for i, (ax, elem_name) in enumerate(zip(axes.flat, elem_names)):
    e = t[elem_name]
    mask = (e > -3) & (e < 3)
    
    stat = binned_statistic_2d(vz[mask], z[mask], e[mask], 
                               statistic='median',
                               bins=vzz_bins)
    vmin, vmax = np.nanpercentile(e[mask], [16, 84])
    
    counts, *_ = np.histogram2d(vz[mask], z[mask], 
                                bins=vzz_bins)

    im = stat.statistic.copy()
    im[counts < 2] = np.nan
    
    cs = ax.pcolormesh(stat.x_edge, stat.y_edge, im.T, 
                       cmap='magma', vmin=vmin, vmax=vmax,
                       rasterized=True)
    
    cb = fig.colorbar(cs, ax=ax, location='top', aspect=10)
    elem_label = elem_to_label(elem_name, dollar=False)
    cb.set_label(rf'$\langle {elem_label} \rangle$', labelpad=12)
    
for i in range(axes.shape[0]):
    axes[i, 0].set_ylabel(f'$z$ [{pc["zunit"]:latex_inline}]')
    
for j in range(axes.shape[1]):
    axes[-1, j].set_xlabel(f'$v_z$ [{pc["vunit"]:latex_inline}]')

ax.set_xticks(np.arange(-pc['vlim'], pc['vlim']+1e-3, 50))
ax.set_yticks(np.arange(-pc['zlim'], pc['zlim']+1e-3, 1))

ax.set_xlim(-pc['vlim'], pc['vlim'])
ax.set_ylim(-pc['zlim'], pc['zlim'])
    
# fig.tight_layout()
fig.set_facecolor('w')

fig.savefig(fig_path / 'abundance-zvz-grid.pdf', dpi=250)

---

In [None]:
# fig, axes = plt.subplots(4, 5, figsize=(18, 13),
#                          sharex=True, sharey=True)

# for ax, elem_name in zip(axes.flat, all_elem_names):
#     e = t[elem_name]
#     mask = (e > -3) & (e < 3)
    
#     stat = binned_statistic_2d(vz[mask], z[mask], e[mask], 
#                                statistic='median',
#                                bins=vzz_bins)
#     vmin, vmax = np.nanpercentile(e[mask], [15, 85])

#     cs = ax.pcolormesh(stat.x_edge, stat.y_edge, stat.statistic.T, 
#                        cmap='Greys', vmin=vmin, vmax=vmax)
    
#     ax.text(-90, 1.7, elem_to_label(elem_name), 
#             fontsize=18, ha='left', va='top', 
#             color='tab:blue', bbox=dict(facecolor=(1.,1.,1.,0.8)))
    
# for i in range(axes.shape[0]):
#     axes[i, 0].set_ylabel(f'$z$ [{u.kpc:latex_inline}]')
    
# for j in range(axes.shape[1]):
#     axes[-1, j].set_xlabel(f'$v_z$ [{u.km/u.s:latex_inline}]')
    
# fig.tight_layout()
# fig.set_facecolor('w')
# fig.savefig(plot_path / 'abundance-z-vz-grid-all.png', dpi=250)

### Initial conditions for our comparison orbits

In [None]:
fiducial_w0 = gd.PhaseSpacePosition(
    pos=([[-rsun.to_value(u.kpc), 0, 0],
          [-rsun.to_value(u.kpc), 0, 0]]*u.kpc).T,
    vel=([[15, vcirc.to_value(u.km/u.s), 20.],
          [20, vcirc.to_value(u.km/u.s), 45.]]*u.km/u.s).T)

In [None]:
w0s_cache = cache_path / 'w0s.pkl'

if not w0s_cache.exists():
    w0s = get_w0s_with_same_actions(fiducial_w0, staeckel=True)

    with open(w0s_cache, 'wb') as f:
        pickle.dump(w0s, f)

with open(w0s_cache, 'rb') as f:
    w0s = pickle.load(f)

In [None]:
orbits = {}
for k, w0 in w0s.items():
    # print(k)
    # print(w0.v_xyz.T)
    orbits[k] = potentials[k].integrate_orbit(
        w0, dt=0.5*u.Myr, t1=0, t2=6*u.Gyr
    )

Confirming that the velocities (at constant action value) are smooth functions of the disk mass:

In [None]:
tmp = ['$v_x$', '$v_z$']
for j in [0, 1]: # orbit
    fig, axes = plt.subplots(2, 1, figsize=(6, 8), sharex=True)
    
    for m, i in enumerate([0, 2]): # vx, vz
        vxs = []
        mdisks = []
        for k, w0 in w0s.items():
            mdisks.append(float(k))
            vxs.append(w0.v_xyz.T[j, i].value)

        mdisks = np.array(mdisks)
        vxs = np.array(vxs)

        axes[m].plot(mdisks[np.argsort(mdisks)], vxs[np.argsort(mdisks)])
        axes[m].set_ylabel(tmp[m])
    axes[m].set_xlabel(r'$M_{\rm disk}$')

### Compute actions for these orbits with the Sanders & Binney method to compare:

In [None]:
w0s_actions_cache = cache_path / 'w0s-actions.pkl'

if not w0s_actions_cache.exists():
    sanders_actions = {}
    for name in potentials.keys():
        sanders_actions[name] = []
        for n in range(w0s[name].shape[0]):
            actions = safe_get_actions(potentials[name], w0s[name][n], N_max=8)['actions']
            sanders_actions[name].append(actions)
        sanders_actions[name] = u.Quantity(sanders_actions[name])

    with open(w0s_actions_cache, 'wb') as f:
        pickle.dump(sanders_actions, f)
        
with open(w0s_actions_cache, 'rb') as f:
    sanders_actions = pickle.load(f)

In [None]:
orb1_actions = u.Quantity([v[0] for x, v in sanders_actions.items()])
orb2_actions = u.Quantity([v[1] for x, v in sanders_actions.items()])

In [None]:
np.std(orb1_actions, axis=0) / np.mean(orb1_actions, axis=0)

In [None]:
np.std(orb2_actions, axis=0) / np.mean(orb2_actions, axis=0)

### Visualize the orbits:

In [None]:
sorted_keys = ['0.4', '1.0', '1.6']

In [None]:
# -----
# vz, z
fig, axes = plt.subplots(1, 3, figsize=(15, 5), 
                         sharex=True, sharey=True)

for k, ax in zip(sorted_keys, axes):
    _ = orbits[k].plot(['v_z', 'z'], axes=[ax], 
                       auto_aspect=False, units=[u.km/u.s, u.kpc])
    
    if k == '1.0':
        ax.set_title('fiducial')
    else:
        ax.set_title(f'${float(k):.1f}' + r' \, {\rm M}_{\rm disk}$')

axes[1].set_ylabel('')
axes[2].set_ylabel('')

axes[0].set_xlim(-pc['vlim'], pc['vlim'])
axes[0].set_ylim(-pc['zlim'], pc['zlim'])

fig.tight_layout()

# -----
# R, z

fig, axes = plt.subplots(1, 3, figsize=(15, 5), 
                         sharex=True, sharey=True)

for k, ax in zip(sorted_keys, axes):
    _ = orbits[k].cylindrical.plot(
        ['rho', 'z'], axes=[ax], 
        auto_aspect=False, units=[u.kpc, u.kpc])
    
    if k == '1.0':
        ax.set_title('fiducial')
    else:
        ax.set_title(f'${float(k):.1f}' + r' \, {\rm M}_{\rm disk}$')

axes[1].set_ylabel('')
axes[2].set_ylabel('')

axes[0].set_xlim(pc['Rlim'])
axes[0].set_ylim(-pc['zlim'], pc['zlim'])

fig.tight_layout()

### Plot orbits over element contours

In [None]:
def plot_four_panel_zvz(z, vz, elem, vzz_bins, elem_name,
                        symmetrize=False, statistic='mean',
                        min_counts=2, stretch_q=(16, 84),
                        figsize=(17, 4.75)):
    from thriftshop.config import plot_config as pc
    
    if symmetrize:
        quad_vz = np.array([])
        quad_z = np.array([])
        quad_elem = np.array([])
        for i, j in itertools.product([-1,1], [-1, 1]):
            quad_vz = np.concatenate((quad_vz, i * vz))
            quad_z = np.concatenate((quad_z, j * z))
            quad_elem = np.concatenate((quad_elem, elem))
        z = quad_z
        vz = quad_vz
        elem = quad_elem
    
    stat = binned_statistic_2d(vz, z, elem, 
                               statistic=statistic,
                               bins=vzz_bins)
    counts, *_ = np.histogram2d(vz, z, bins=vzz_bins)
    
    H = stat.statistic.copy()
    H[counts < min_counts] = np.nan
    
    fig, axes = plt.subplots(1, 4, figsize=figsize, 
                             constrained_layout=True,
                             sharex=True, sharey=True)

    vmin, vmax = np.percentile(elem, stretch_q)

    for ax in axes:
        cs = ax.pcolormesh(stat.x_edge, stat.y_edge, 
                           H.T, cmap='magma', vmin=vmin, vmax=vmax,
                           rasterized=True)
        ax.set_xlabel(f'$v_z$ [{u.km/u.s:latex_inline}]')

    cb = fig.colorbar(cs, ax=axes, aspect=40)
    elem_label = elem_to_label(elem_name, dollar=False)
    cb.set_label(rf'$\langle {elem_label} \rangle$')

    for k, color, ax in zip(sorted_keys,
                            ['w', 'w', 'w'],
                            axes[1:]):
        o = orbits[k]
        o_z = o.z.to_value(u.kpc)
        o_vz = o.v_z.to_value(u.km/u.s)
        ax.plot(o_vz, o_z, marker='', color=color, alpha=0.6)

        if k == '1.0':
            ax.set_title(r'${\rm M}_{\rm disk}^\star = ' + 
                         f'{fiducial_mdisk.value / 1e10:.3f}' + 
                         r'\times 10^{10}\,{\rm M}_{\odot}$', 
                         pad=11, fontsize=22)
        else:
            ax.set_title(r'${\rm M}_{\rm disk} =' + f' {float(k):.1f}' + 
                         r' \, {\rm M}_{\rm disk}^\star$', 
                         pad=11, fontsize=22)

    ax = axes[0]
    ax.set_xlim(-pc['vlim'], pc['vlim'])
    ax.set_xticks(pc['vticks'])
    ax.set_xticks(pc['vminorticks'], minor=True)
    ax.set_ylim(-pc['zlim'], pc['zlim'])
    ax.set_yticks(pc['zticks'])
    ax.set_yticks(pc['zminorticks'], minor=True)

    axes[0].set_ylabel(f'$z$ [{u.kpc:latex_inline}]')

    fig.set_facecolor('w')
    
    return fig, axes

In [None]:
vstep = 2.
zstep = 50 / 1e3
vzz_bins = (np.arange(-pc['vlim'], pc['vlim']+1e-3, vstep),
            np.arange(-pc['zlim'], pc['zlim']+1e-3, zstep))

In [None]:
elem_name = 'MG_FE'
elem_plot_path = plot_path / elem_name
elem_plot_path.mkdir(exist_ok=True)

elem = t[elem_name]
elem_mask = (elem > -3) & (elem < 3)
elem = elem[elem_mask]
elem_z = z[elem_mask]
elem_vz = vz[elem_mask]

fig, axes = plot_four_panel_zvz(z[elem_mask], 
                                vz[elem_mask], 
                                elem,
                                vzz_bins=vzz_bins,
                                elem_name=elem_name,
                                statistic='mean', 
                                figsize=(18, 5))

fig.suptitle("Mean [Mg/Fe] Abundance of APOGEE Stars in Bins of Vertical Kinematics", 
             fontsize=28)

fig.savefig(fig_path / f'zvz-mean-{elem_name}.pdf')

---

In [None]:
for elem_name in elem_names:
    elem_plot_path = plot_path / elem_name
    elem_plot_path.mkdir(exist_ok=True)
    
    elem = t[elem_name]
    elem_mask = (elem > -3) & (elem < 3)
    elem = elem[elem_mask]
    elem_z = z[elem_mask]
    elem_vz = vz[elem_mask]
    
    for bit, name in enumerate(['', '-sym']):
        fig, axes = plot_four_panel_zvz(z[elem_mask], 
                                        vz[elem_mask], 
                                        elem,
                                        vzz_bins=vzz_bins,
                                        elem_name=elem_name,
                                        symmetrize=bool(bit),
                                        figsize=(18, 4.5))
        fig.savefig(elem_plot_path / f'orbits-mean-abun-z-vz{name}.pdf', dpi=250)
        plt.close(fig)