In [None]:
import pathlib

import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from scipy.spatial import cKDTree
from scipy.stats import binned_statistic
from scipy.interpolate import interp1d
from tqdm import tqdm

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.integrate as gi
import gala.potential as gp
from gala.units import galactic

from totoro.config import galcen_frame, elem_names, plot_config as pc
from totoro.abundances import elem_to_label

In [None]:
tbls = {}
for elem in elem_names:
    path = pathlib.Path(f"../../cache/optimize-results-{elem}.csv")
    if not path.exists():
        continue
    tbls[elem] = at.Table.read(path)

In [None]:
fiducials = {
    'mdisk_f': 1.,
    'zsun': 20.8,
    'vzsun': 7.78
}

colcols = [
    ('mdisk_f', 'zsun'), 
    ('mdisk_f', 'vzsun'), 
    ('zsun', 'vzsun')
]

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for elem in tbls:
    for i, (col1, col2) in enumerate(colcols):
        ax = axes[i]
        ax.plot(tbls[elem][col1], tbls[elem][col2],
                ls='none', marker='o', mew=0, ms=4, 
                label=elem_to_label(elem))
axes[0].legend()

axes[0].set_xlabel(r'${\rm M}_{\rm disk} / {\rm M}_{\rm disk}^\star$')
axes[1].set_xlabel(r'${\rm M}_{\rm disk} / {\rm M}_{\rm disk}^\star$')
axes[2].set_xlabel(r'$z_\odot$ [pc]')

axes[0].set_ylabel(r'$z_\odot$ [pc]')
axes[1].set_ylabel(r'$v_{z,\odot}$ ' + f'[{u.km/u.s:latex_inline}]')
axes[2].set_ylabel(r'$v_{z,\odot}$ ' + f'[{u.km/u.s:latex_inline}]')

for ax, (col1, col2) in zip(axes, colcols):
    ax.axvline(fiducials[col1], zorder=-10, color='#aaaaaa', linestyle='--')
    ax.axhline(fiducials[col2], zorder=-10, color='#aaaaaa', linestyle='--')

fig.set_facecolor('w')
    
fig.tight_layout()

In [None]:
np.mean(tbls['MG_FE']['mdisk_f']), np.std(tbls['MG_FE']['mdisk_f'])

In [None]:
np.mean(tbls['MG_FE']['zsun']), np.std(tbls['MG_FE']['zsun'])

In [None]:
np.mean(tbls['MG_FE']['vzsun']), np.std(tbls['MG_FE']['vzsun'])

---

Plot zsun, vzsun over element contours:

In [None]:
from scipy.stats import binned_statistic_2d
from totoro.data import load_apogee_sample

In [None]:
t, c = load_apogee_sample('../../data/apogee-parent-sample.fits')

In [None]:
vsun = galcen_frame.galcen_v_sun.d_xyz.copy()
vsun[2] = 0
tmp_galcen_frame = coord.Galactocentric(galcen_v_sun=vsun,
                                        z_sun=0*u.pc)

galcen = c.transform_to(tmp_galcen_frame)
z = galcen.z.to_value(pc['zunit'])
vz = galcen.v_z.to_value(pc['vunit'])

In [None]:
vstep = 2.
zstep = 50 / 1e3
vzz_bins = (np.arange(-pc['vlim'], pc['vlim']+1e-3, vstep),
            np.arange(-pc['zlim'], pc['zlim']+1e-3, zstep))

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(14, 8.7),
                         sharex=True, sharey=True, 
                         constrained_layout=True)

for i, (ax, elem_name) in enumerate(zip(axes.flat, elem_names)):
    e = t[elem_name]
    mask = (e > -3) & (e < 3)
    
    stat = binned_statistic_2d(vz[mask], z[mask], e[mask], 
                               statistic='median',
                               bins=vzz_bins)
    vmin, vmax = np.nanpercentile(e[mask], [16, 84])
    
    counts, *_ = np.histogram2d(vz[mask], z[mask], 
                                bins=vzz_bins)

    im = stat.statistic.copy()
    im[counts < 2] = np.nan
    
    cs = ax.pcolormesh(stat.x_edge, stat.y_edge, im.T, 
                       cmap='magma', vmin=vmin, vmax=vmax,
                       rasterized=True)
    
    cb = fig.colorbar(cs, ax=ax, location='top', aspect=10)
    elem_label = elem_to_label(elem_name, dollar=False)
    cb.set_label(rf'$\langle {elem_label} \rangle$', labelpad=12)
    
    zsun, zsun_err = (-np.mean(tbls[elem_name]['zsun']) / 1e3, 
                      np.std(tbls[elem_name]['zsun']) / 1e3)
    vzsun, vzsun_err = (-np.mean(tbls[elem_name]['vzsun']), 
                        np.std(tbls[elem_name]['vzsun']))
    ax.errorbar(vzsun, zsun, 
                xerr=vzsun_err, yerr=zsun_err,
                marker='o', ls='none',
                ms=6, mew=0, zorder=100, color='tab:green', alpha=0.75)
    
for i in range(axes.shape[0]):
    axes[i, 0].set_ylabel(f'$z$ [{pc["zunit"]:latex_inline}]')
    
for j in range(axes.shape[1]):
    axes[-1, j].set_xlabel(f'$v_z$ [{pc["vunit"]:latex_inline}]')

ax.set_xticks(np.arange(-pc['vlim'], pc['vlim']+1e-3, 40))
ax.set_yticks(np.arange(-pc['zlim'], pc['zlim']+1e-3, 0.75))

ax.set_xlim(-pc['vlim'], pc['vlim'])
ax.set_ylim(-pc['zlim'], pc['zlim'])
    
# fig.tight_layout()
fig.set_facecolor('w')