In [None]:
import pathlib

import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from scipy.spatial import cKDTree
from scipy.stats import binned_statistic
from scipy.interpolate import interp1d
from tqdm import tqdm

from totoro.config import galcen_frame, cache_path, plot_path
from totoro.data import datasets, elem_names
from totoro.potentials import potentials, galpy_potentials
from totoro.objective import TorusImagingObjective
from totoro.abundance_helpers import elem_to_label

In [None]:
def find_nearest_square(num):
    xs = np.arange(2, 8+1)
    squares = xs**2
    return xs[squares >= num][0]

In [None]:
# data_name = 'apogee-rgb-loalpha'
# d = datasets[data_name]
# for data_name, d in datasets.items():
for data_name in ['galah-ms-loalpha']:
    d = datasets[data_name]
    this_cache_path = cache_path / data_name

    # ---

    # plot all data in one figure:
    sq = find_nearest_square(len(elem_names[data_name]))
    fig, axes = plt.subplots(sq, sq, figsize=(16, 12),
                             sharex=True, sharey=False,
                             constrained_layout=True)
    n_panels = axes.size

    for i, elem_name in enumerate(elem_names[data_name]):
        ax = axes.flat[i]

        path = pathlib.Path(this_cache_path / f"optimize-results-{elem_name}.csv")
        tbl = at.Table.read(path)

        xs = []
        for col in tbl.colnames:
            xs.append(np.nanmean(tbl[col]))

        obj = TorusImagingObjective(d, elem_name)
        atm = obj.get_atm(*xs)

        Jz = atm.aaf['actions'][:, 2].value
        elem = atm.aaf[elem_name]  # nan-safe

        mask = (elem > -3) & (elem < 3)
        elem = elem - np.nanmedian(elem[mask])

        elem = elem[mask]
        Jz = Jz[mask]

        pcl = np.nanpercentile(elem, [1, 99])
        lim = 2 * np.max(np.abs(pcl))

        ax.hist2d(np.log10(Jz), elem, 
                  bins=(np.arange(-3, 3, 0.02),
                        np.linspace(-lim, lim, 128)),
                  norm=mpl.colors.LogNorm(vmin=5e-1),
                  cmap='Greys')

        stat = binned_statistic(np.log10(Jz), elem, 
                                bins=np.arange(-3, 3, 0.1))
        ctr = 0.5 * (stat.bin_edges[:-1] + stat.bin_edges[1:])
        ax.plot(ctr, stat.statistic, 
                marker='', drawstyle='steps-mid', 
                lw=2, color='tab:blue')

        ax.set_ylim(-lim, lim)

        lbl = elem_to_label(elem_name)
        ax.text(-1.9, lim-0.05*lim, lbl, 
                ha='left', va='top', fontsize=18)

    ax.set_xlim(-2, 2)    

    if i < n_panels-1:
        for j in range(i+1, n_panels):
            axes.flat[j].set_visible(False)

    if all([ax.get_visible() == False for ax in axes[-1]]):
        j = -2
    else:
        j = -1
    for ax in axes[j]:
        if not ax.get_visible():
            continue
        ax.set_xlabel(r'$\log_{10} J_z$')

    for ax in axes[:, 0]:
        ax.set_ylabel(r'$X - \langle X \rangle$')

    fig.savefig(plot_path / data_name / 'elem-gradients-Jz.png', dpi=250)
    plt.close(fig)