In [None]:
import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from scipy.spatial import cKDTree
from scipy.stats import binned_statistic_2d
from scipy.interpolate import interp1d
from tqdm import tqdm

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.integrate as gi
import gala.potential as gp
from gala.units import galactic
from gala.mpl_style import hesperia_r

from totoro.config import galcen_frame, elem_names
from totoro.data import datasets
from totoro.potentials import potentials, galpy_potentials
from totoro.objective import TorusImagingObjective

In [None]:
data_name = 'apogee-rgb-loalpha'
d = datasets[data_name]

In [None]:
galcen = d.c.transform_to(galcen_frame)

In [None]:
fig, ax = plt.subplots(figsize=(6, 5))
ax.plot(d.t['TEFF'], d.get_elem_ratio('SI_FE'),
        marker='o', ls='none', mew=0, ms=1.5, alpha=0.4)

In [None]:
teff = d.t['TEFF']
logg = d.t['LOGG']
feh = d.t['FE_H']

teff_ref = -382.5 * feh + 4607
tmp = 0.0018 * (teff - teff_ref) + 2.4
rc_mask = (logg >= 1.9) & (logg <= tmp)
rc_mask.sum(), len(rc_mask)

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))

stat = binned_statistic_2d(
    d.t['TEFF'], 
    d.t['LOGG'],
    d.t['M_H'],
    statistic='mean',
    bins=(np.arange(4200, 5200, 25),
          np.arange(2, 3., 0.01)))
ax.pcolormesh(stat.x_edge, stat.y_edge,
              stat.statistic.T)

ax.invert_xaxis()
ax.invert_yaxis()

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))

stat = binned_statistic_2d(
    d.t['TEFF'], 
    d.t['LOGG'],
    d.t['M_H'],
    statistic='count',
    bins=(np.arange(4200, 5200, 25),
          np.arange(1.8, 3.5, 0.01)))
ax.pcolormesh(stat.x_edge, stat.y_edge,
              stat.statistic.T)

ax.invert_xaxis()
ax.invert_yaxis()

# ---

fig, ax = plt.subplots(figsize=(5, 5))

stat = binned_statistic_2d(
    d.t['TEFF'][rc_mask], 
    d.t['LOGG'][rc_mask],
    d.t['M_H'][rc_mask],
    statistic='count',
    bins=(np.arange(4200, 5200, 25),
          np.arange(1.8, 3.5, 0.01)))
ax.pcolormesh(stat.x_edge, stat.y_edge,
              stat.statistic.T)

ax.invert_xaxis()
ax.invert_yaxis()

In [None]:
zvz_bins = (np.arange(-90, 90+1e-3, 1.5),
            np.arange(-1.75, 1.75+1e-3, 0.05))

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 6))

feh = d.get_elem_ratio('MG_FE')
mask = ((feh > -3) & np.isfinite(feh) & (feh != 0)) & rc_mask
feh = feh - np.mean(feh[mask])

try:
    std = np.nanstd(feh[mask].filled(np.nan))
except:
    std = np.nanstd(feh[mask])

stat = binned_statistic_2d(
    galcen.v_z.to_value(u.km/u.s)[mask], 
    galcen.z.to_value(u.kpc)[mask],
    feh[mask],
    statistic='mean',
    bins=zvz_bins)
ax.pcolormesh(stat.x_edge, stat.y_edge,
              stat.statistic.T, 
              vmin=-std, vmax=std,
              cmap=hesperia_r, rasterized=True)

ax.set_xlim(zvz_bins[0].min(), zvz_bins[0].max())
ax.set_ylim(zvz_bins[1].min(), zvz_bins[1].max())

fig.tight_layout()
fig.set_facecolor('w')