In [None]:
import os
import itertools
import pickle

import astropy.coordinates as coord
from astropy.convolution import convolve, Gaussian2DKernel
from astropy.io import fits
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from scipy.stats import binned_statistic_2d
import corner

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.integrate as gi
import gala.potential as gp
from gala.units import galactic

from thriftshop.potentials import potentials
from thriftshop.config import vcirc, rsun
from thriftshop.actions import safe_get_actions, get_w0s_with_same_actions
from thriftshop.abundances import get_elem_names, elem_to_label

coord.galactocentric_frame_defaults.set('v4.0');

In [None]:
t = at.Table.read('../data/apogee-parent-sample.fits')
t = t[(t['GAIA_PARALLAX'] > 0.4) & 
      ((t['GAIA_PARALLAX'] / t['GAIA_PARALLAX_ERROR']) > 5)]
len(t)

In [None]:
c = coord.SkyCoord(ra=t['RA']*u.deg,
                   dec=t['DEC']*u.deg,
                   distance=1000 / t['GAIA_PARALLAX'] * u.pc,
                   pm_ra_cosdec=t['GAIA_PMRA']*u.mas/u.yr,
                   pm_dec=t['GAIA_PMDEC']*u.mas/u.yr,
                   radial_velocity=t['VHELIO_AVG']*u.km/u.s)
galcen = c.transform_to(coord.Galactocentric)

In [None]:
x = galcen.x.to_value(u.kpc) - (-8.122)
y = galcen.y.to_value(u.kpc)
z = galcen.z.to_value(u.kpc)
vz = galcen.v_z.to_value(u.km/u.s)
elem = t['MN_FE']

In [None]:
zlim = 2 # kpc
vlim = 100. # km/s
vstep = 4
zstep = 75 / 1e3
vzz_bins = (np.arange(-vlim, vlim+1e-3, vstep),
            np.arange(-zlim, zlim+1e-3, zstep))

fig, axes = plt.subplots(1, 2, figsize=(12, 5),
                         constrained_layout=True)

elem_mask = (elem > -3) & (elem < 3)
stat = binned_statistic_2d(vz[elem_mask], z[elem_mask], elem[elem_mask], 
                           statistic='mean',
                           bins=vzz_bins)

vmin, vmax = np.percentile(elem, [15, 85])

ax = axes[0]
cs = ax.pcolormesh(stat.x_edge, stat.y_edge, stat.statistic.T, 
                   cmap='cividis', vmin=vmin, vmax=vmax)
cb = fig.colorbar(cs, ax=ax, aspect=40)

ax.set_xlabel('v_z')
ax.set_ylabel('z')

ax = axes[1]
H, *_ = np.histogram2d(vz, z, bins=vzz_bins)
cs = ax.pcolormesh(stat.x_edge, stat.y_edge, H.T, 
                   cmap='cividis', 
                   norm=mpl.colors.LogNorm(1, 3e2))
cb = fig.colorbar(cs, ax=ax, aspect=40)

ax.set_xlabel('v_z')
# ax.set_ylabel('z')

fig.set_facecolor('w')

In [None]:
zlim = 1 # kpc
vlim = 75. # km/s
vstep = 1
zstep = 25 / 1e3
vzz_bins = (np.arange(-vlim, vlim+1e-3, vstep),
            np.arange(-zlim, zlim+1e-3, zstep))

fig, axes = plt.subplots(1, 2, figsize=(12, 5),
                         constrained_layout=True)

stat = binned_statistic_2d(vz[elem_mask], z[elem_mask], elem[elem_mask], 
                           statistic='mean',
                           bins=vzz_bins)

vmin, vmax = np.percentile(elem, [15, 85])

ax = axes[0]
cs = ax.pcolormesh(stat.x_edge, stat.y_edge, stat.statistic.T, 
                   cmap='cividis', vmin=vmin, vmax=vmax)
cb = fig.colorbar(cs, ax=ax, aspect=40)

ax.set_xlabel('v_z')
ax.set_ylabel('z')

ax = axes[1]
H, *_ = np.histogram2d(vz, z, bins=vzz_bins)
cs = ax.pcolormesh(stat.x_edge, stat.y_edge, H.T, 
                   cmap='cividis', 
                   norm=mpl.colors.LogNorm(1, 3e2))
cb = fig.colorbar(cs, ax=ax, aspect=40)

ax.set_xlabel('v_z')
# ax.set_ylabel('z')

fig.set_facecolor('w')

for ax in axes:
    ax.axhline(0.05)
    ax.axhline(0)
    ax.axvline(-15)
    ax.axvline(30)

In [None]:
mask = ((z < 0.05) & (z > 0) &
        (vz > -15) & (vz < 30))

mask2 = ((z < 0.25) & (z > 0.2) &
        (vz > -15) & (vz < 30))

# mask = ((z < 0.55) & (z > 0.5) &
#         (vz > -15) & (vz < 30))
mask.sum(), mask2.sum()

In [None]:
np.mean(elem[mask & elem_mask]), np.mean(elem[mask2 & elem_mask])

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

ax = axes[0]
ax.hist(elem[mask & elem_mask], bins=np.linspace(-0.3, 0.7, 64));
ax.axvline(np.mean(elem[mask & elem_mask]), color='r')
ax.axvline(np.median(elem[mask & elem_mask]), color='g')

ax = axes[1]
ax.hist(elem[mask2 & elem_mask], bins=np.linspace(-0.3, 0.7, 64));
ax.axvline(np.mean(elem[mask2 & elem_mask]), color='r')
ax.axvline(np.median(elem[mask2 & elem_mask]), color='g')

In [None]:
tmpmask = mask & elem_mask & (elem > 0.2)

In [None]:
plt.scatter(t['TEFF'][tmpmask], t['LOGG'][tmpmask])

In [None]:
unq, counts = np.unique(t[tmpmask]['FIELD'], return_counts=True)

In [None]:
unq[counts.argsort()[::-1][:10]]

In [None]:
plt.figure(figsize=(6, 6))
plt.scatter(x[tmpmask], y[tmpmask], alpha=0.2)

In [None]:
plt.figure(figsize=(6, 6))
plt.scatter(t['RA'][tmpmask], t['DEC'][tmpmask], alpha=0.2)

In [None]:
plt.hist((t['GAIA_PARALLAX'][tmpmask] / t['GAIA_PARALLAX_ERROR'][tmpmask]),
         bins=np.linspace(0, 10, 32));