In [None]:
import glob
import pathlib

import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

from pyia import GaiaData

In [None]:
main_filename = pathlib.Path('../../data/TESS-Gaia-main.fits')

if not main_filename.exists():
    all_tbls = []
    for filename in glob.glob('../../TESS-Gaia/gaiatess*'):
        all_tbls.append(at.Table.read(filename))

    main_tbl = at.vstack(all_tbls)
    main_tbl.write(main_filename)
    
else:
    main_tbl = at.Table.read(main_filename)

In [None]:
# dr16 = at.Table.read('/Users/apricewhelan/data/APOGEE_DR16/allStarLite-r12-l33.fits')

In [None]:
g = GaiaData(main_tbl)
g = g[(g.get_ruwe() < 1.4) &
      (g.parallax > 1.*u.mas) &
      (g.phot_bp_mean_mag != 0*u.mag) &
      (g.phot_rp_mean_mag != 0*u.mag)]
g = g[np.unique(g.source_id, return_index=True)[1]]
g = g[(g.get_ebv() < 0.1)]
# g = g[np.isin(g.source_id, dr16['GAIA_SOURCE_ID'])]

In [None]:
# sub_dr16 = dr16[np.isin(dr16['GAIA_SOURCE_ID'], g.source_id)]
# plt.hist(sub_dr16['M_H'], bins=np.linspace(-2.5, 0.7, 64));

In [None]:
# mg = g.phot_g_mean_mag - g.distmod
# bprp = g.phot_bp_mean_mag - g.phot_rp_mean_mag
mg = g.get_G0() - g.distmod
bprp = g.get_BP0() - g.get_RP0()

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))

H, xe, ye = np.histogram2d(bprp.value, mg.value,
                           bins=(np.arange(-1, 5, 0.02),
                                 np.arange(-6, 15, 0.04)))
# ax.scatter(bprp, mg, s=1, alpha=0.5)
ax.pcolormesh(xe, ye, H.T, 
              norm=mpl.colors.LogNorm(),
              cmap='cividis')
ax.set_xlim(-1, 5)
ax.set_ylim()
ax.set_ylim(15, -6)

ax.set_xlabel(r'$G_{\rm BP} - G_{\rm RP}$')
ax.set_ylabel(r'$M_G$')

# from Mamajek: http://www.pas.rochester.edu/~emamajek/EEM_dwarf_UBVIJHK_colors_Teff.txt
sptypes = {
    'B': ['tab:purple', (-1, -0.087)],
    'A': ['tab:blue', (-0.037, 0.327)],
    'F': ['tab:green', (0.377, 0.767)],
    'G': ['y', (0.782, 0.950)],
    'K': ['tab:orange', (0.98, 1.78)],
    'M': ['tab:red', (1.84, 5)]
}
for sptype, (color, lims) in sptypes.items():
    ax.axvspan(lims[0], lims[1], lw=0, color=color, alpha=0.25, zorder=-100)
    
    if sptype == 'B':
        ax.text(lims[1]-0.02, -5+0.2, 
                sptype, fontsize=16, color=color, alpha=1, 
                zorder=10, ha='right', va='bottom')
    else:
        ax.text(lims[0]+0.01, -5+0.2, 
                sptype, fontsize=16, color=color, alpha=1, 
                zorder=10, ha='left', va='bottom')

fig.set_facecolor('w')