In [None]:
from os import path
import astropy.coordinates as coord
from astropy.io import fits
import astropy.units as u
import numpy as np
import matplotlib.pyplot as plt
from cycler import cycler
plt.style.use('apw-notebook')
%matplotlib inline
from sqlalchemy import func
from scipy.ndimage import gaussian_filter

from comoving_rv.db import Session, Base, db_connect
from comoving_rv.db.model import (Run, Observation, TGASSource, SimbadInfo,
                                  SpectralLineInfo, SpectralLineMeasurement)

In [None]:
base_path = '/Volumes/ProjectData/gaia-comoving-followup/'
db_path = path.join(base_path, 'db.sqlite')
engine = db_connect(db_path)
session = Session()

In [None]:
def get_abs_mag(mag, parallax, parallax_error):
    # parallax in mas
    SNR = parallax / parallax_error
    dist = coord.Distance(1000. * (parallax/2 * (1 + np.sqrt(1 - 16/SNR**2)))**(-1) * u.pc)
    mu = dist.distmod
    M = mag - mu.value
    return M

In [None]:
tmass = fits.getdata('../../data/tgas_2mass_partial_j.fits.gz')

In [None]:
group_ids = session.query(Observation.group_id).join(Run).filter(Run.name == 'mdm-spring-2017')\
                   .filter((Observation.group_id != None) & 
                           (Observation.group_id != 0) & 
                           (Observation.group_id != 10))\
                   .filter(TGASSource.J != None)\
                   .group_by(Observation.group_id)\
                   .having(func.length(Observation.id) > 1)\
                   .distinct().all()
group_ids = [x[0] for x in group_ids]
len(group_ids)

In [None]:
color_mag = dict()
for gid in group_ids:
    group = session.query(Observation).join(Run).filter(Run.name == 'mdm-spring-2017')\
                   .filter(Observation.group_id == gid).all()
    
    color_mag[gid] = {'G-J': [], 
                      'M_G': []}
    for member in group:
        src = member.tgas_source
        G = src.phot_g_mean_mag
        J = src.J
        
        if G is None or J is None:
            del color_mag[gid]
            break
        
        M_G = get_abs_mag(G, src.parallax, src.parallax_error)
        
        color_mag[gid]['G-J'].append(G - J)
        color_mag[gid]['M_G'].append(M_G)
    
    if gid in color_mag and len(color_mag[gid]['G-J']) < 2:
        del color_mag[gid]
        
for gid in group_ids:
    if gid not in color_mag: continue
        
    for k in color_mag[gid].keys():
        color_mag[gid][k] = np.array(color_mag[gid][k])

In [None]:
tmass.dtype

In [None]:
M_G_all = get_abs_mag(tmass['phot_g_mean_mag'], tmass['parallax'], tmass['parallax_error'])
G_J_all = tmass['phot_g_mean_mag'] - tmass['j_m']

xbins = np.arange(-0.1, 2.3+0.01, 0.02)
ybins = np.arange(-0.5, 8.5+0.01, 0.02)
H,xedges,yedges = np.histogram2d(G_J_all, M_G_all, bins=(xbins, ybins))

In [None]:
colors = ['#fdae61', '#7f5abf', '#1a9641', '#d7191c']

In [None]:
fig,ax = plt.subplots(1, 1, figsize=(6,6))

H = gaussian_filter(H, 1.)
ax.pcolormesh(xedges, yedges, np.log(H.T+1.), cmap='Blues')

ax.set_prop_cycle(cycler('color', colors))
for gid, d in color_mag.items():
    ax.plot(d['G-J'], d['M_G'], marker='', 
            linestyle='-', alpha=0.65, zorder=1) # color='#e34a33',
    ax.plot(d['G-J'], d['M_G'], marker='.', 
            linestyle='', alpha=1., color='k', zorder=10, markersize=3)

ax.set_xlim(-0.1, 2.3)
ax.set_ylim(8.5, -0.5)

ax.set_xlabel('$G-J$ [mag]')
# ax.set_ylabel('$G - 5(\log\hat{d}-1)$ [mag]')
ax.set_ylabel('$M_G$ [mag]')

fig.savefig('sample_cmd.pdf')