In [None]:
import os
os.environ['APOGEE_CACHE_PATH'] = "/mnt/ceph/users/apricewhelan/apogee-test/"

import sys
import pathlib
_path = str(pathlib.Path('../').resolve())
if _path not in sys.path:
    sys.path.append(_path)

import corner
from astropy.io import fits
import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tqdm.auto import tqdm
from scipy.stats import binned_statistic
from astropy.stats import median_absolute_deviation as MAD

from gala.mpl_style import hesperia, laguna

from joaquin.data_helpers import get_parent_sample, phot_to_label
from joaquin.config import phot_names
from joaquin.plot import simple_corner

In [None]:
binaries = at.Table.read(
    '/mnt/home/apricewhelan/projects/apogee-dr16-vac/catalogs/lnK-0.2_logL3.0_metadata.fits')

In [None]:
# allstar = at.Table.read('/mnt/home/apricewhelan/data/APOGEE_DR17/allStar-dr17-turbo20-beta.fits', 
#                         hdu=1)

allstar = at.Table.read('/mnt/home/apricewhelan/data/APOGEE_DR16/allStar-r12-gaiaedr3.fits',
                        hdu=1)
for colname in ['phot_g_mean_mag', 'phot_bp_mean_mag', 'phot_rp_mean_mag',
                'parallax', 'parallax_error']:
    allstar[f'GAIAEDR3_{colname.upper()}'] = allstar[colname]

In [None]:
wise = at.Table.read('/mnt/home/apricewhelan/data/APOGEE_DR17/APOGEE-DR17-wise-result.fits.gz')
wise.rename_column('apogee_id', 'APOGEE_ID')

allstar['APOGEE_ID'] = allstar['APOGEE_ID'].astype(str)
wise['APOGEE_ID'] = [x.strip() for x in wise['APOGEE_ID'].astype(str)]

allstar = at.join(allstar, wise, keys='APOGEE_ID', join_type='left')
allstar = at.unique(allstar, keys='APOGEE_ID')

In [None]:
allstar['TELESCOPE'] = np.array([x.strip() for x in allstar['TELESCOPE']])
stars = allstar
len(stars)

In [None]:
phot_mask = np.isin(stars['TELESCOPE'], ['apo25m', 'lco25m'])

phot_mask = np.ones(len(stars), dtype=bool)
for name in phot_names:
    phot_mask &= (np.isfinite(stars[name]) & 
                  (stars[name] > 0) &
                  (stars[name] < 22))  # MAGIC NUMBER

# TODO: this assumes 2MASS photometry is in there...
for band in ['J', 'H', 'K']:
    phot_mask &= ((stars[f'{band}_ERR'] > 0) &
                  (stars[f'{band}_ERR'] < 0.1))

# TODO: this assumes WISE photometry is in there...
phot_mask &= np.char.startswith(stars['ph_qual'].astype(str), 'AA')

In [None]:
binaries_mask = ~np.isin(stars['APOGEE_ID'], binaries['APOGEE_ID'])

In [None]:
snr_mask = (stars['SNR'] > 40)

In [None]:
clean_stars = stars[phot_mask & snr_mask]  # & binaries_mask
len(stars), len(clean_stars)

### Color-color plot to prune outliers

In [None]:
G_J = clean_stars['GAIAEDR3_PHOT_G_MEAN_MAG'] - clean_stars['J']
J_K = clean_stars['J'] - clean_stars['K']

poly = np.poly1d(np.polyfit(G_J, J_K, deg=1))
xx = np.linspace(0, 10, 25)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

ax = axes[0]
ax.plot(xx, poly(xx), color='tab:blue', zorder=10, marker='')
ax.scatter(G_J, J_K, 
           c=clean_stars['J'],
           marker='o', alpha=0.4, lw=0, s=4, 
           cmap='cividis_r')

ax = axes[1]
dcolor = J_K - poly(G_J)
ax.scatter(clean_stars['J'], 
           dcolor, 
           marker='o', alpha=0.4, lw=0, s=4)

bins = np.linspace(5, 15, 25)
binc = 0.5 * (bins[:-1] + bins[1:])
stat = binned_statistic(clean_stars['J'], 
                        dcolor,
                        statistic='median',
                        bins=bins)
stat_std = binned_statistic(clean_stars['J'], 
                            dcolor,
                            statistic=lambda x: 1.5 * MAD(x),
                            bins=bins)

ax = axes[2]
ax.scatter(G_J, dcolor, 
           c=clean_stars['J'],
           marker='o', alpha=0.4, lw=0, s=4,
           cmap='cividis_r')

# MAGIC NUMBERS
ax.axhline(np.median(dcolor) - 6 * np.std(dcolor))
ax.axhline(np.median(dcolor) + 6 * np.std(dcolor))
# ax.axhline(np.median(dcolor) - 8 * 1.5 * MAD(dcolor), color='tab:green')
# ax.axhline(np.median(dcolor) + 8 * 1.5 * MAD(dcolor), color='tab:green')

fig.tight_layout()

In [None]:
dcolor_mask = np.abs(dcolor - np.median(dcolor)) < 6 * np.std(dcolor)
dcolor_mask &= (clean_stars['H'] - clean_stars['w2mpro']) > -0.5
dcolor_mask &= (clean_stars['w1mpro'] - clean_stars['w2mpro']) > -1
dcolor_mask.sum()

In [None]:
# Photometry / colors:
plot_X = []
labels = []

colors = [
    ('GAIAEDR3_PHOT_BP_MEAN_MAG', 'GAIAEDR3_PHOT_RP_MEAN_MAG'),
    ('J', 'K'),
    ('w1mpro', 'w2mpro'),
    ('GAIAEDR3_PHOT_G_MEAN_MAG', 'J'),
    ('H', 'w2mpro')
]
for i, (p1, p2) in enumerate(colors):
    vals = (clean_stars[p1] - clean_stars[p2])[dcolor_mask]
    plot_X.append(vals)
    
    lbl1 = p1
    if p1 in phot_to_label:
        lbl1 = phot_to_label[p1]
    
    lbl2 = p2
    if p2 in phot_to_label:
        lbl2 = phot_to_label[p2]
    
    lbl = f"{lbl1} $-$ {lbl2}"
    labels.append(lbl)
    
plot_X = np.array(plot_X).T

In [None]:
fig, axes = simple_corner(
    plot_X, 
    colorbar=True,
    labels=labels,
    alpha=0.75)

fig.set_facecolor('w')

In [None]:
G_J = clean_stars['GAIAEDR3_PHOT_G_MEAN_MAG'] - clean_stars['J']
H_W2 = clean_stars['H'] - clean_stars['w2mpro']

err1 = clean_stars['J_ERR']
err2 = np.sqrt(clean_stars['H_ERR']**2 + clean_stars['w2mpro_error']**2)

fig, ax = plt.subplots(1, 1, figsize=(7, 6))

cs = ax.scatter(G_J[dcolor_mask], H_W2[dcolor_mask], 
                c=clean_stars['SFD_EBV'][dcolor_mask],
                vmin=0, vmax=0.3,
                marker='o', alpha=0.4, lw=0, s=8, 
                cmap='cividis_r', zorder=100)
# ax.errorbar(G_J, H_W2, 
#             xerr=err1, yerr=err2,
#             marker='', ls='none', color='#666666', 
#             alpha=0.5, elinewidth=0.5,)

fig.colorbar(cs)

fig.tight_layout()