In [None]:
import pathlib

import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
from astropy.stats import median_absolute_deviation as mad
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from scipy.spatial import cKDTree
from scipy.stats import binned_statistic_2d
from scipy.interpolate import interp1d
from tqdm import tqdm

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.integrate as gi
import gala.potential as gp
from gala.mpl_style import center_deemph
from gala.units import galactic

from totoro.data import elem_names, datasets
from totoro.config import galcen_frame, plot_config as pc, cache_path
from totoro.abundance_helpers import elem_to_label
from totoro.objective import TorusImagingObjective

In [None]:
data_name = 'apogee-rgb-loalpha'
this_cache_path = cache_path / data_name
elem_name = 'MG_FE'
d = datasets[data_name]

In [None]:
path = pathlib.Path(this_cache_path / f"optimize-results-{elem_name}.csv")
tbl = at.Table.read(path)

In [None]:
means = dict()
for k in tbl.colnames:
    means[k] = np.mean(tbl[k])

In [None]:
obj = TorusImagingObjective(d, elem_name=elem_name, tree_K=20)

In [None]:
atm = obj.get_atm(**means)

In [None]:
%%timeit
angz, d_elem, d_elem_errs = atm.get_theta_z_anomaly(elem_name)

In [None]:
galcen = d.c.transform_to(galcen_frame)

zvz_bins = (np.arange(-60, 60+1e-3, 1.5),
            np.arange(-1.75, 1.75+1e-3, 0.05))
stat = binned_statistic_2d(
    galcen.v_z.to_value(u.km/u.s), 
    galcen.z.to_value(u.kpc),
    d_elem,
    statistic='mean',
    bins=zvz_bins)

std = 1.5 * mad(d_elem)
std

In [None]:
from scipy.spatial import cKDTree

def get_theta_z_anomaly_funny(self, elem_name, action_unit=30*u.km/u.s*u.kpc):
    action_unit = u.Quantity(action_unit)

    # Actions without units:
    X = self.aaf['actions'].to_value(action_unit)
    angz = coord.Angle(self.aaf['angles'][:, 2]).wrap_at(360*u.deg).radian

    # element abundance
    elem = self.aaf[elem_name]
    elem_errs = self.aaf[f"{elem_name}_ERR"]
    ivar = 1 / elem_errs**2

    tree = cKDTree(X)
    dists, idx = tree.query(X, k=self.tree_K+1)
    
    xhat = np.mean(X[idx[:, 1:]], axis=1) - X
    dx = X[idx[:, 1:]] - X[:, None]
    x = np.einsum('nij,nj->ni', dx, xhat)
    y = elem[idx[:, 1:]]

    w = np.sum(x**2, axis=1)[:, None] - x * np.sum(x, axis=1)[:, None]
    means = np.sum(y * w, axis=1) / np.sum(w, axis=1)

    d_elem = elem - means

    return angz, d_elem, None

In [None]:
%%timeit
_, funky_d_elem, _ = get_theta_z_anomaly_funny(atm, elem_name)

In [None]:
funky_stat = binned_statistic_2d(
    galcen.v_z.to_value(u.km/u.s), 
    galcen.z.to_value(u.kpc),
    funky_d_elem,
    statistic='mean',
    bins=zvz_bins)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5.5),
                         sharex=True, sharey=True,
                         constrained_layout=True)

axes[0].pcolormesh(
    stat.x_edge, stat.y_edge,
    stat.statistic.T,
    vmin=-std, vmax=std,
    cmap=center_deemph, rasterized=True)

cs = axes[1].pcolormesh(
    funky_stat.x_edge, funky_stat.y_edge,
    funky_stat.statistic.T,
    vmin=-std, vmax=std,
    cmap=center_deemph, rasterized=True)

ax = axes[0]
ax.set_xlim(zvz_bins[0].min(), zvz_bins[0].max())
ax.set_ylim(zvz_bins[1].min(), zvz_bins[1].max())

# ax.axhline(0, ls='--', color='w')
# ax.axvline(0, ls='--', color='w')

ax.set_xlabel(f'$v_z$ [{u.km/u.s:latex_inline}]')
ax.set_ylabel(f'$z$ [{u.kpc:latex_inline}]')

cb = fig.colorbar(cs, ax=axes, aspect=30)

fig.set_facecolor('w')

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5.5),
                         sharex=True, sharey=True,
                         constrained_layout=True)

_bins = np.linspace(-0.1, 0.1, 64)
axes[0].hist(d_elem, bins=_bins)
axes[1].hist(funky_d_elem, bins=_bins);

axes[0].axvline(np.mean(d_elem))
axes[0].axvline(np.mean(d_elem) + np.std(d_elem), ls='--')
axes[0].axvline(np.mean(d_elem) - np.std(d_elem), ls='--')

axes[1].axvline(np.mean(funky_d_elem))
axes[1].axvline(np.mean(funky_d_elem) + np.std(funky_d_elem), ls='--')
axes[1].axvline(np.mean(funky_d_elem) - np.std(funky_d_elem), ls='--')

# ax.set_xlabel(f'$v_z$ [{u.km/u.s:latex_inline}]')
# ax.set_ylabel(f'$z$ [{u.kpc:latex_inline}]')