In [None]:
import pathlib

import astropy.coordinates as coord
from astropy.io import fits
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from scipy.stats import binned_statistic, binned_statistic_2d
from gala.mpl_style import hesperia, laguna

from totoro.config import fig_path, cache_path
from totoro.config import elem_names, galcen_frame
from totoro.data import datasets
from totoro.potentials import potentials
from totoro.abundance_helpers import elem_to_label
from totoro.objective import TorusImagingObjective

In [None]:
data_name = 'apogee-rgb-loalpha'
d = datasets[data_name]
this_cache_path = cache_path / data_name

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(14, 7.),
                         sharex=True, 
                         constrained_layout=True)

for i, (ax, elem_name) in enumerate(zip(axes.flat, elem_names)):
    path = pathlib.Path(this_cache_path / f"optimize-results-{elem_name}.csv")
    tbl = at.Table.read(path)

    xs = []
    for col in tbl.colnames:
        xs.append(np.nanmean(tbl[col]))

    obj = TorusImagingObjective(d, elem_name)
    atm = obj.get_atm(*xs)

    Jz = atm.aaf['actions'][:, 2].to_value(u.kpc*u.km/u.s)
    elem = atm.aaf[elem_name]  # nan-safe

    mask = (elem > -3) & (elem < 3)
    elem = elem - np.nanmedian(elem[mask])

    elem = elem[mask]
    Jz = Jz[mask]

    pcl = np.nanpercentile(elem, [1, 99])
    lim = 2 * np.max(np.abs(pcl))

    ax.hist2d(Jz, elem, 
              bins=(10 ** np.arange(-2, 2.5, 0.02),
                    np.linspace(-lim, lim, 128)),
              norm=mpl.colors.LogNorm(vmin=5e-1),
              cmap='Greys', rasterized=True)
    
    stat = binned_statistic(Jz, elem, 
                            bins=10 ** np.arange(-3, 3, 0.1))
    ctr = 0.5 * (stat.bin_edges[:-1] + stat.bin_edges[1:])
    ax.plot(ctr, stat.statistic, 
            marker='', drawstyle='steps-mid', 
            lw=1.5, color='tab:blue')
    
    for val in [16, 84]:
        stat = binned_statistic(Jz, elem, 
                                bins=stat.bin_edges, 
                                statistic=lambda x: np.percentile(x, val))
        ax.plot(ctr, stat.statistic, 
                marker='', drawstyle='steps-mid', 
                lw=1., color='tab:purple')
    
    ax.set_xscale('log')
    
    if elem_name == 'FE_H':
        lim = 0.6
        ax.set_yticks(np.arange(-0.6, 0.6+1e-3, 0.2))
    else:
        lim = 0.3
        if elem_name not in ['C_FE', 'MG_FE']:
            ax.set_yticklabels([])
    
    ax.set_yticks(np.arange(-0.6, 0.6+1e-3, 0.1), minor=True)
    ax.set_ylim(-lim, lim)
    
    ax.errorbar(1e2, lim-0.1*lim, 
                np.median(d.g.data[elem_name + '_ERR']),
                marker='o', ms=4, color='tab:red', alpha=0.75)

    lbl = elem_to_label(elem_name)
    ax.text(1.5e-2, lim-0.05*lim, lbl, 
            ha='left', va='top', fontsize=18)
    
for i in range(axes.shape[0]):
    axes[i, 0].set_ylabel(rf'$X - \langle X \rangle$')
    
for j in range(axes.shape[1]):
    axes[-1, j].set_xlabel(f'$J_z$ [{u.kpc*u.km/u.s:latex_inline}]')
    
fig.set_facecolor('w')
fig.savefig(fig_path / 'elem-Jz-gradients.pdf', dpi=250)

---

In [None]:
polys = {}

fig, axes = plt.subplots(2, 4, figsize=(14, 7.),
                         sharex=True, 
                         constrained_layout=True)

for i, (ax, elem_name) in enumerate(zip(axes.flat, elem_names)):
    path = pathlib.Path(this_cache_path / f"optimize-results-{elem_name}.csv")
    tbl = at.Table.read(path)

    xs = []
    for col in tbl.colnames:
        xs.append(np.nanmean(tbl[col]))

    obj = TorusImagingObjective(d, elem_name)
    atm = obj.get_atm(*xs)

    Jz = atm.aaf['actions'][:, 2].to_value(u.kpc*u.km/u.s)
    elem = atm.aaf[elem_name]  # nan-safe
    elem_err = atm.aaf[elem_name + '_ERR']

    mask = (elem > -3) & (elem < 3)
    elem = elem - np.nanmedian(elem[mask])

    elem = elem[mask]
    Jz = Jz[mask]
    elem_err = elem_err[mask]

    pcl = np.nanpercentile(elem, [1, 99])
    lim = 2 * np.max(np.abs(pcl))

    ax.hist2d(Jz, elem, 
              bins=(10 ** np.arange(-2, 2.5, 0.02),
                    np.linspace(-lim, lim, 128)),
              norm=mpl.colors.LogNorm(vmin=5e-1),
              cmap='Greys', rasterized=True)
    
#     stat = binned_statistic(Jz, elem, 
#                             bins=10 ** np.arange(-3, 3, 0.1))
#     y = stat.statistic

    stat1 = binned_statistic(Jz, elem / elem_err**2, statistic='sum', 
                             bins=10 ** np.arange(-3, 3, 0.1))
    stat2 = binned_statistic(Jz, 1 / elem_err**2, statistic='sum', 
                             bins=10 ** np.arange(-3, 3, 0.1))
    y = stat1.statistic / stat2.statistic

    ctr = 0.5 * (stat1.bin_edges[:-1] + stat1.bin_edges[1:])
    ax.plot(ctr, y, 
            marker='', drawstyle='steps-mid', 
            lw=2, color='tab:blue')
    
    logJz = np.log10(Jz)
    coeffs = np.polyfit(logJz, elem, w=1/elem_err**2, deg=7)
    poly = np.poly1d(coeffs)
    polys[elem_name] = poly
    
    x = np.linspace(logJz.min(), logJz.max(), 1024)
    ax.plot(10**x, poly(x))
    
    ax.set_xscale('log')
    ax.set_ylim(-lim, lim)

    lbl = elem_to_label(elem_name)
    ax.text(1.5e-2, lim-0.05*lim, lbl, 
            ha='left', va='top', fontsize=18)
    
for i in range(axes.shape[0]):
    axes[i, 0].set_ylabel(rf'$X - \langle X \rangle$')
    
for j in range(axes.shape[1]):
    axes[-1, j].set_xlabel(f'$J_z$ [{u.kpc*u.km/u.s:latex_inline}]')
    
fig.set_facecolor('w')

In [None]:
len(Jz), len(mask), len(elem)

In [None]:
polys = {}

fig, axes = plt.subplots(2, 4, figsize=(14, 7.),
                         sharex=True, 
                         constrained_layout=True)

for i, (ax, elem_name) in enumerate(zip(axes.flat, elem_names)):
    path = pathlib.Path(this_cache_path / f"optimize-results-{elem_name}.csv")
    tbl = at.Table.read(path)

    xs = []
    for col in tbl.colnames:
        xs.append(np.nanmean(tbl[col]))
    
    # only used to get actions
    tmpd = d[np.isfinite(d.t[elem_name]) &
             (d.t[elem_name] > -3) & 
             (d.t[elem_name] < 3)]
    obj = TorusImagingObjective(tmpd, elem_name)
    atm = obj.get_atm(*xs)
    Jz = atm.aaf['actions'][:, 2].to_value(u.kpc*u.km/u.s)
        
    if elem_name != 'FE_H':
        elem_name = '_'.join([elem_name.split('_')[0], 'H'])

    elem = tmpd.get_elem_ratio(elem_name)
    elem = elem - np.nanmedian(elem)

    pcl = np.nanpercentile(elem, [1, 99])
    lim = 2 * np.max(np.abs(pcl))

    ax.hist2d(Jz, elem, 
              bins=(10 ** np.arange(-2, 2.5, 0.02),
                    np.linspace(-lim, lim, 128)),
              norm=mpl.colors.LogNorm(vmin=5e-1),
              cmap='Greys', rasterized=True)
    
    logJz = np.log10(Jz)
    coeffs = np.polyfit(logJz, elem, deg=5)
    poly = np.poly1d(coeffs)
    polys[elem_name] = poly
    
    x = np.linspace(logJz.min(), logJz.max(), 1024)
    ax.plot(10**x, poly(x))
    
    ax.set_xscale('log')
    ax.set_ylim(-lim, lim)

    lbl = elem_to_label(elem_name)
    ax.text(1.5e-2, lim-0.05*lim, lbl, 
            ha='left', va='top', fontsize=18)
    
for i in range(axes.shape[0]):
    axes[i, 0].set_ylabel(rf'$X - \langle X \rangle$')
    
for j in range(axes.shape[1]):
    axes[-1, j].set_xlabel(f'$J_z$ [{u.kpc*u.km/u.s:latex_inline}]')
    
fig.set_facecolor('w')

In [None]:
fig = plt.figure(figsize=(6, 6))

grid = np.linspace(0, 2.2, 256)
for i, elem_name in enumerate(polys.keys()):
    poly = polys[elem_name]
    deriv = poly.deriv()(grid)
    
    plt.plot(grid, np.abs(deriv), marker='', label=elem_name)

plt.legend()

plt.xlim(0, 2.2)
plt.ylim(-0.05, 0.5)

plt.xlabel(r'$\log J_z$')
plt.ylabel(r'$\left| \frac{{\rm d} [{\rm X}/{\rm H}]}{{\rm d} \log J_z} \right|$')

fig.set_facecolor('w')