In [None]:
import pathlib

from astropy.constants import G
import astropy.table as at
from astropy.convolution import convolve, Gaussian2DKernel
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline
import numpy as np
import gala.dynamics as gd
from gala.units import galactic
from scipy.ndimage import gaussian_filter
from scipy.stats import binned_statistic_2d
from scipy.interpolate import interp1d
import cmasher as cmr

from pyia import GaiaData

from config import galcen_frame, R0, vc0

In [None]:
figure_path = pathlib.Path("../tex/figures/").resolve()

In [None]:
tmp1 = at.Table.read("/mnt/home/apricewhelan/data/APOGEE/DR17/allStar-dr17-synspec-gaiadr3-gaiasourcelite.fits")
tmp2 = at.Table.read("/mnt/home/apricewhelan/data/APOGEE/DR17/allStarLite-dr17-synspec_rev1.fits")
apogee = at.join(tmp2, tmp1, keys="APOGEE_ID")
apogee = at.unique(apogee, keys="APOGEE_ID")
apogee = GaiaData(apogee)

In [None]:
apogee = apogee.filter(
    # LOGG=(0.5, 3),
    LOGG=(1.5, 3),
    TEFF=(3000, 5500),
    FE_H=(-0.8, 0.5),
    MG_FE=(-0.2, 0.22),
    parallax=(0.*u.mas, None)
)

In [None]:
# plt.hist2d(
#     apogee['FE_H'],
#     apogee['MG_FE'],
#     bins=128,
#     norm=mpl.colors.LogNorm()
# );
# plt.axvline(-0.8)
# plt.axhline(0.22)

plt.hist2d(
    apogee['MG_FE'] + apogee['FE_H'],
    -apogee['MG_FE'],
    bins=128,
    norm=mpl.colors.LogNorm()
);
plt.axvline(-0.8)
plt.axhline(0.22)

In [None]:
c_ap = apogee.get_skycoord(radial_velocity=apogee.VHELIO_AVG*u.km/u.s)
galcen_ap = c_ap.transform_to(galcen_frame)
w_ap = gd.PhaseSpacePosition(galcen_ap.data)

In [None]:
# Lz_sun = R0*vc0
Lz_sun = R0 * 229*u.km/u.s
Lz = np.abs(w_ap.angular_momentum()[2]).to(u.kpc*u.km/u.s)
Lz_mask = np.abs(Lz - Lz_sun) < 0.15 * Lz_sun
Lz_mask.sum()

In [None]:
bins = (
    np.linspace(-100, 100, 128),
    np.linspace(-2.5, 2.5, 128)
)

stat = binned_statistic_2d(
    galcen_ap.v_z.to_value(u.km/u.s)[Lz_mask],
    galcen_ap.z.to_value(u.kpc)[Lz_mask],
    apogee.MG_FE[Lz_mask],
    bins=bins,
    statistic='mean'
)

counts, *_ = np.histogram2d(
    galcen_ap.v_z.to_value(u.km/u.s)[Lz_mask],
    galcen_ap.z.to_value(u.kpc)[Lz_mask],
    bins=bins
)

In [None]:
zmax_mask = Lz_mask & (np.abs(galcen_ap.v_z) < 10 * u.km / u.s)
print(zmax_mask.sum())

xlim = (0.05, 2.5)
ylim = (-0.15, 0.22)

hline_val = np.nanmedian(apogee.MG_FE[zmax_mask & (np.abs(galcen_ap.z) < 200 * u.pc)])

fig, axes = plt.subplots(1, 2, figsize=(10.5, 6.5), constrained_layout=True)

ax = axes[0]

cs = ax.pcolormesh(
    stat.x_edge,
    stat.y_edge,
    stat.statistic.T,
    cmap='magma',
    vmin=-0.05,
    vmax=ylim[1],
    rasterized=True
)
# for val in [-10, 10.]:
#     ax.axvline(val, color='#aaaaaa', alpha=0.75, ls='-')

ax.set_xlabel(f"vertical velocity, $v_z$ [{u.km/u.s:latex_inline}]")
ax.set_ylabel(f"vertical position, $z$ [{u.kpc:latex_inline}]")
    
cb = fig.colorbar(cs, ax=ax, orientation='horizontal')
cb.set_label(r"mean $[{\rm Mg}/{\rm Fe}]$")

# ---

ax = axes[1]

H, xe, ye = np.histogram2d(
    np.abs(galcen_ap.z.to_value(u.kpc)[zmax_mask]),
    apogee.MG_FE[zmax_mask],
    bins=(np.geomspace(*xlim, 51), np.linspace(*ylim, 51)),
)
ax.pcolormesh(xe, ye, H.T / H.T.sum(axis=0)[None], vmax=0.08, cmap="Greys", rasterized=True)

grid = np.linspace(*xlim, 256)
slope = (0.17 - hline_val) / (2.5 - 0.15)
offset = hline_val - slope * 0.15
ax.plot(
    grid,
    slope * grid + offset,
    ls="--",
    marker='',
    alpha=0.7,
    color="tab:orange",
    lw=1.5,
    label=r"$[{\rm Mg}/{\rm Fe}] = " + f"{slope:0.3f}\," + r"z_{\rm max}" + f" + {offset:0.3f}$"
)

ax.legend(loc='lower right', fontsize=15)
ax.set_xlim(xlim)
ax.set_xscale("log")
ax.set_ylim(ylim)

ax.set_xlabel(r"proxy for $z_{\rm max}$ " + f"[{u.kpc:latex_inline}]")
ax.set_ylabel(r"[Mg/Fe]")

fig.suptitle("Example Element Abundance Gradient\nin the Vertical Phase Space", fontsize=24)

fig.savefig(figure_path / "mgfe-zvz.pdf", dpi=300)

---

In [None]:
# plt.pcolormesh(
#     stat.x_edge,
#     stat.y_edge,
#     convolve(stat.statistic.T, Gaussian2DKernel(3.0), boundary="extend"),
#     vmin=0,
#     vmax=0.15,
# )

nsigma = 3.2
print(nsigma * np.diff(stat.x_edge)[0], nsigma * np.diff(stat.y_edge)[0])

H = stat.statistic.T.copy()
H[counts.T < 8] = np.nan

plt.contourf(
    0.5 * (stat.x_edge[:-1] + stat.x_edge[1:]),
    0.5 * (stat.y_edge[:-1] + stat.y_edge[1:]),
    convolve(H, Gaussian2DKernel(nsigma), boundary="extend"),
#     levels=np.linspace(0, 0.15, 16)
    levels=np.geomspace(0.015, 0.1, 12),
    cmap='Blues_r'
)
plt.colorbar()

---

In [None]:
# xiang = GaiaData("/mnt/home/apricewhelan/data/misc/XiangRix2022-subgiants-joined-gaiadr3.fits")
xiang = at.Table.read("/mnt/home/apricewhelan/data/misc/XiangRix2022-subgiants-fullparam.fits")
for col in xiang.colnames:
    xiang.rename_column(col, col.lower())
xiang = GaiaData(xiang)

In [None]:
# Gaia RVs
# xiang = xiang.filter(parallax=(0*u.mas, None), radial_velocity=(-900, 900)*u.km/u.s)

# LAMOST RVs
xiang = xiang.filter(
    parallax=(0 * u.mas, None), 
    vlos=(-900, 900),
    feh=(-0.8, 0.5),
    alpha_fe=(-0.1, 0.15)
)

In [None]:
plt.hist2d(
    xiang.feh,
    xiang.alpha_fe,
    bins=128,
    norm=mpl.colors.LogNorm()
);
plt.axvline(-0.8)
plt.axhline(0.2)

In [None]:
c_xiang = xiang.get_skycoord(radial_velocity=xiang.vlos * u.km/u.s)

In [None]:
galcen_xiang = c_xiang.transform_to(galcen_frame)
w_xiang = gd.PhaseSpacePosition(galcen_xiang.data)

In [None]:
Lz = np.abs(w_xiang.angular_momentum()[2]).to(u.kpc*u.km/u.s)
Lz_mask_xiang = np.abs(Lz - Lz_sun) < 0.2 * Lz_sun
Lz_mask_xiang.sum()

In [None]:
bins = (np.linspace(-100, 100, 128), np.linspace(-2.5, 2.5, 128))

stat_x_age = binned_statistic_2d(
    galcen_xiang.v_z.to_value(u.km / u.s)[Lz_mask_xiang],
    galcen_xiang.z.to_value(u.kpc)[Lz_mask_xiang],
    xiang.age[Lz_mask_xiang],
    bins=bins,
)

stat_x_alpha = binned_statistic_2d(
    galcen_xiang.v_z.to_value(u.km / u.s)[Lz_mask_xiang],
    galcen_xiang.z.to_value(u.kpc)[Lz_mask_xiang],
    xiang.alpha_fe[Lz_mask_xiang],
    bins=bins,
)

In [None]:
plt.pcolormesh(
    stat_x_age.x_edge,
    stat_x_age.y_edge,
    convolve(stat_x_age.statistic.T, Gaussian2DKernel(1.5), boundary="extend"),
    vmin=0,
    vmax=14,
)

# plt.contourf(
#     0.5 * (stat_x_age.x_edge[:-1] + stat_x_age.x_edge[1:]),
#     0.5 * (stat_x_age.y_edge[:-1] + stat_x_age.y_edge[1:]),
#     convolve(stat_x_age.statistic.T, Gaussian2DKernel(2.0), boundary="extend"),
#     levels=np.linspace(0, 14, 12)
# )

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4.6), constrained_layout=True)

cs = axes[0].pcolormesh(stat_x.x_edge, stat_x.y_edge, stat_x_alpha.statistic.T, vmin=0, vmax=0.15)
cb = fig.colorbar(cs, ax=axes[0])
cb.set_label(r"[$\alpha/{\rm Fe}$]")

cs = axes[1].pcolormesh(stat_x.x_edge, stat_x.y_edge, stat_x_age.statistic.T, vmin=0, vmax=12)
cb = fig.colorbar(cs, ax=axes[1])
cb.set_label("age [Gyr]")

axes[0].set_ylabel(f"$z$ [{u.kpc:latex_inline}]")
for ax in axes:
    ax.set_xlabel(f"$v_z$ [{u.km/u.s:latex_inline}]")