In [None]:
import sys
import os
from os import path
_tmp = os.path.abspath('../../pal5s-biggest-fan/notebooks')
if _tmp not in sys.path:
    sys.path.append(_tmp)

import pickle

import emcee
import corner
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline

import astropy.coordinates as coord
from astropy.coordinates.matrix_utilities import rotation_matrix
from astropy.table import Table, vstack, join
from astropy.io import fits, ascii
import astropy.units as u
from astroML.utils import log_multivariate_gaussian
from scipy.optimize import minimize
from scipy.special import logsumexp
from xdgmm import XDGMM
from schwimmbad import MultiPool
from tqdm import trange

import gala.coordinates as gc
import gala.dynamics as gd
import gala.potential as gp
import gala.mpl_style

from pyia import GaiaData

galcen_frame = coord.Galactocentric(galcen_distance=8.1*u.kpc)

In [None]:
ra_lim = (215, 255) * u.deg
dec_lim = (-15, 10) * u.deg
dist_lim = 35. * u.kpc

In [None]:
globs = Table.read('/Users/apricewhelan/data/Misc/Vasiliev-globclust.txt', format='ascii.fixed_width')
skip_globs = globs[(globs['RA'] > ra_lim[0].value) & (globs['RA'] < ra_lim[1].value) & 
                   (globs['DEC'] > dec_lim[0].value) & (globs['DEC'] < dec_lim[1].value) &
                   (globs['Name'] != 'Pal 5')]
vasiliev_pal5 = globs[globs['Name'] == 'Pal 5']

In [None]:
vasiliev_pal5

In [None]:
vasiliev_pal5_c = coord.SkyCoord(ra=vasiliev_pal5['RA']*u.deg,
                                 dec=vasiliev_pal5['DEC']*u.deg,
                                 pm_ra_cosdec=vasiliev_pal5['PMRA']*u.mas/u.yr,
                                 pm_dec=vasiliev_pal5['PMDEC']*u.mas/u.yr)
vasiliev_pal5_c = vasiliev_pal5_c.transform_to(gc.Pal5)[0]

In [None]:
def skip_mask(ra, dec):
    c1 = coord.SkyCoord(ra, dec)
    c2 = coord.SkyCoord(skip_globs['RA']*u.deg, skip_globs['DEC']*u.deg)
    
    mask = np.ones(len(ra), dtype=bool)
    for c in c2:
        mask &= c1.separation(c) > 0.5*u.deg
    
    return mask

In [None]:
sky_path = Table.read('../../pal5s-biggest-fan/data/pal5_skypath_icrs.txt', 
                      format='ascii.basic', names=['ra', 'dec'])
_path = mpl.path.Path(np.stack((sky_path['ra'], sky_path['dec'])).T)

### Load data

In [None]:
_tbl = Table.read('../data/PS1+SOS.gdr2.newDps1.C10QUESTIDs.csv.gz',
                  format='ascii.csv')

_tbl.rename_column('ra_1', 'ra')
_tbl.rename_column('dec_1', 'dec')

In [None]:
g_stream = GaiaData('../data/pal5_rrls_inside_canonical_footprint.D18_25_pm_cuts.csv')
g_stream = g_stream[(g_stream.pmra.value != 0.) & (g_stream.pmdec.value != 0.)]

g_all = GaiaData(_tbl)
g_all = g_all[(g_all.ra > ra_lim[0]) & (g_all.ra < ra_lim[1]) &
              (g_all.dec > dec_lim[0]) & (g_all.dec < dec_lim[1]) &
              skip_mask(g_all.ra, g_all.dec) & 
              np.isfinite(g_all.pmra) & np.isfinite(g_all.pmdec) &
              (g_all.pmra.value != 0.) & (g_all.pmdec.value != 0.) & 
              (g_all.D_ps1 > 0) & 
              (g_all.D_kpc < dist_lim.value)] 
c_all = g_all.get_skycoord(distance=g_all.D_ps1*u.kpc)

g_nostream = GaiaData('../data/rrls_in_pal5_bkg.m5_ngc5634_removed.csv')
g_nostream = g_nostream[skip_mask(g_nostream.ra, g_nostream.dec) & 
                        (g_nostream.pmra.value != 0.) & (g_nostream.pmdec.value != 0.) & 
                        (g_nostream.D_ps1 > 0) & 
                        (g_nostream.D_kpc < dist_lim.value)]
c_nostream = g_nostream.get_skycoord(distance=g_nostream.D_ps1*u.kpc)
c_nostream_pal5 = c_nostream.transform_to(gc.Pal5)

stream_sky_mask = _path.contains_points(np.stack((g_all.ra.value, 
                                                  g_all.dec.value)).T)
g_stream_track = g_all[stream_sky_mask]
c_stream_track = g_stream_track.get_skycoord(distance=g_stream_track.D_ps1*u.kpc)

len(g_all), len(g_stream_track), len(g_nostream)

In [None]:
plt.figure()
plt.plot(g_nostream.ra, g_nostream.pmdec, 
         ls='none', marker='o', mew=0, ms=1.5)
plt.ylim(-25, 10)

plt.figure()
plt.plot(g_nostream.ra, g_nostream.pmra, 
         ls='none', marker='o', mew=0, ms=1.5)
plt.ylim(-25, 10)

plt.figure()
plt.plot(g_nostream.ra, g_nostream.D_ps1, 
         ls='none', marker='o', mew=0, ms=1.5)
plt.ylim(0, 35)

In [None]:
t1 = Table.read('/Users/apricewhelan/data/Streams/Pal5/Odenkirchen2009_gaia.csv')
t2 = Table.read('/Users/apricewhelan/data/Streams/Pal5/Odenkirchen2002_gaia.csv')
t1.remove_columns(['epoch_photometry_url', 'astrometric_primary_flag', 'duplicated_source'])
t2.remove_columns(['epoch_photometry_url', 'astrometric_primary_flag', 'duplicated_source'])
g_oden = GaiaData(vstack((t1, t2)))
g_oden = g_oden[(g_oden.vr_a > -80) & (g_oden.vr_a < -20) & (g_oden.pmra > -5*u.mas/u.yr)]
g_oden = g_oden[(g_oden.vr_a > -66) & (g_oden.vr_a < -50) & (g_oden.ra < 232*u.deg)]
g_oden.data['D_kpc'] = 21. # HACK
g_oden.data['D_ps1'] = 21. # HACK
# TODO: could xmatch to starhorse to get distances for the RGB stars...or fit my own

In [None]:
g_merged = GaiaData(vstack((g_all.data, g_oden.data)))

merged_sky_mask = _path.contains_points(np.stack((g_merged.ra.value, 
                                                  g_merged.dec.value)).T)

In [None]:
stream_sky_mask = _path.contains_points(np.stack((g_merged.ra.value, 
                                                  g_merged.dec.value)).T)
stream_sky_mask.sum()

In [None]:
fig, axes = plt.subplots(4, 1, figsize=(8, 10), sharex=True)

axes[0].errorbar(g_stream.ra.value, g_stream.pmra.value,
                 yerr=g_stream.pmra_error.value, 
                 ls='none', marker='o')
axes[0].errorbar(g_oden.ra.value, g_oden.pmra.value,
                 yerr=g_oden.pmra_error.value, 
                 ls='none', marker='o')

axes[1].errorbar(g_stream.ra.value, g_stream.pmdec.value,
                 yerr=g_stream.pmdec_error.value, 
                 ls='none', marker='o')
axes[1].errorbar(g_oden.ra.value, g_oden.pmdec.value,
                 yerr=g_oden.pmdec_error.value, 
                 ls='none', marker='o')

# axes[2].errorbar(g_stream.ra.value, g_stream.pmdec.value,
#                  yerr=g_stream.pmdec_error.value, 
#                  ls='none', marker='o')
axes[2].errorbar(g_oden.ra.value, g_oden.vr_a,
                 yerr=g_oden.vr_a_err, 
                 ls='none', marker='o', 
                 color='tab:blue')

axes[3].errorbar(g_stream.ra.value, g_stream.D_kpc,
                 yerr=0.03 * g_stream.D_kpc, 
                 ls='none', marker='o')
# axes[3].errorbar(g_oden.ra.value, g_oden.D_kpc,
#                  yerr=0.08 * g_oden.D_kpc, 
#                  ls='none', marker='o')

axes[0].set_xlim(250, 215)

axes[0].set_ylabel(r'$\mu_\alpha$')
axes[1].set_ylabel(r'$\mu_\delta$')
axes[2].set_ylabel(r'$v_r$')
axes[3].set_ylabel(r'$D$ [kpc]')
axes[-1].set_xlabel('RA [deg]')

fig.set_facecolor('w')
fig.tight_layout()

In [None]:
c_stream = g_stream.get_skycoord(distance=g_stream.D_kpc*u.kpc)
c_oden = g_oden.get_skycoord(distance=False)

c_stream = c_stream.transform_to(gc.Pal5)
c_oden = c_oden.transform_to(gc.Pal5)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5), 
                         sharex=True, sharey=True)

axes[0].scatter(g_stream.pmra.value, g_stream.pmdec.value,
                c=g_stream.ra.value)
axes[0].errorbar(g_stream.pmra.value, g_stream.pmdec.value,
                 xerr=g_stream.pmra_error.value,
                 yerr=g_stream.pmdec_error.value,
                 marker='', ls='', zorder=-100, ecolor='#aaaaaa')

axes[1].scatter(g_oden.pmra.value, g_oden.pmdec.value,
                c=g_oden.ra.value)
axes[1].errorbar(g_oden.pmra.value, g_oden.pmdec.value,
                 xerr=g_oden.pmra_error.value,
                 yerr=g_oden.pmdec_error.value,
                 marker='', ls='', zorder=-100, ecolor='#aaaaaa')

axes[0].set_xlim(-3.7, -2.0)
axes[0].set_ylim(-3.5, -1.5)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5.2), 
                         sharex=True, sharey=True)

ax = axes[0]
# _d_mask = np.abs(g_all.D_kpc - 22.) < 5
_d_mask = np.abs(g_all.D_ps1 - 22.) < 5
ax.plot(g_all.pmra.value[_d_mask], 
        g_all.pmdec.value[_d_mask],
        marker='o', ls='', ms=1.5, alpha=1.)
ax.set_title('all w/ dist cut')

ax = axes[1]
_d_mask = np.abs(g_nostream.D_kpc - 22.) < 5
ax.plot(g_nostream.pmra.value[_d_mask], 
        g_nostream.pmdec.value[_d_mask],
        marker='o', ls='', ms=1.5, alpha=1.)
ax.set_title('excluding pal5 footprint')

ax.set_xlim(-5, 5)
ax.set_ylim(-5, 5)

axes[0].set_xlabel(r'$\mu_\alpha$')
axes[0].set_ylabel(r'$\mu_\delta$')
axes[1].set_xlabel(r'$\mu_\alpha$')

fig.set_facecolor('w')
fig.tight_layout()

In [None]:
# huh = g_all[(np.abs(g_all.D_kpc - 22.) < 5) & 
#             (np.abs(g_all.pmra.value - 2) < 1) & 
#             (np.abs(g_all.pmdec.value - -1) < 1)]
# plt.scatter(huh.ra, huh.dec)
# plt.xlim(250, 215)
# plt.ylim(-10, 10)

In [None]:
X = np.stack((c_nostream_pal5.pm_phi1_cosphi2.value, 
              c_nostream_pal5.pm_phi2.value,
              g_nostream.D_ps1)).T
C_pm = gc.transform_pm_cov(c_nostream, g_nostream.get_cov()[:, 3:5, 3:5], gc.Pal5)

C = np.zeros((C_pm.shape[0], 3, 3))
C[:, :2, :2] = C_pm
C[:, 2, 2] = (0.08 * X[:, 2]) ** 2

In [None]:
plt.hist(X[:, 2], bins=64);

In [None]:
# xdgmm = XDGMM()

# param_range = np.arange(5, 10+1, 1)

# # Loop over component numbers, fitting XDGMM model and computing the BIC:
# bic, optimal_n_comp, lowest_bic = xdgmm.bic_test(X, C, param_range)

# fig, ax = plt.subplots(1, 1, figsize=(6, 6))
# ax.plot(param_range, bic, marker='', drawstyle='steps-mid')

# See BIC stuff above
optimal_n_comp = 6

### GMM / XD

In [None]:
gmm_cache_file = 'bg_gmm_pal5.pkl'
if not os.path.exists(gmm_cache_file): # or True:
    gmm = XDGMM(n_components=optimal_n_comp, tol=1e-8, n_iter=2048, method='Bovy')
    _ = gmm.fit(X, C)
    with open(gmm_cache_file, 'wb') as f:
        pickle.dump(gmm, f)
else:
    with open(gmm_cache_file, 'rb') as f:
        gmm = pickle.load(f)

In [None]:
faster_gmm = XDGMM(n_components=gmm.n_components, 
                   mu=gmm.mu, V=gmm.V, weights=gmm.weights, 
                   method='Bovy')

In [None]:
X_sample = gmm.sample(size=len(g_all))

In [None]:
g_merged = GaiaData(vstack((g_all.data, g_oden.data)))
merged_c = g_merged.get_skycoord(distance=g_merged.D_ps1*u.kpc)
merged_c_pal5 = merged_c.transform_to(gc.Pal5)

# X_all = np.stack((g_all.pmra.value, g_all.pmdec.value, g_all.D_ps1)).T
X_all = np.stack((merged_c_pal5.pm_phi1_cosphi2.value, 
                  merged_c_pal5.pm_phi2.value, 
                  merged_c_pal5.distance.value)).T
C_pm = g_merged.get_cov()[:, 3:5, 3:5]
C_all = np.zeros((C_pm.shape[0], 3, 3))
C_all[:, :2, :2] = gc.transform_pm_cov(merged_c, C_pm, gc.Pal5PriceWhelan18)
C_all[:len(g_all), 2, 2] = (0.03 * X_all[:len(g_all), 2]) ** 2
C_all[len(g_all):, 2, 2] = 10 ** 2  # RGB

# X_oden = np.stack((g_oden.pmra.value, g_oden.pmdec.value, np.full(len(g_oden), 22.5))).T
# _C = g_oden.get_cov()[:, 3:5, 3:5]
# C_oden = np.zeros((_C.shape[0], 3, 3))
# C_oden[:, :2, :2] = _C
# C_oden[:, 2, 2] = 8. ** 2

# XX = np.vstack((X_all, X_oden))
# CC = np.vstack((C_all, C_oden))

XX = X_all
CC = C_all

phi1_0 = 0.
dphi1 = merged_c_pal5.phi1.degree - phi1_0

In [None]:
_, bib, _ = plt.hist(X_sample[:, 2], bins=64);
plt.hist(X_all[:, 2], bins=bib, alpha=0.5);

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5), 
                         sharex=True, sharey=True)

ax = axes[0]
ax.plot(X[:, 0], X[:, 1],
        marker='o', ls='', ms=1.5, alpha=0.5)

ax = axes[1]
ax.plot(X_sample[:, 0], X_sample[:, 1],
        marker='o', ls='', ms=1.5, alpha=0.5)

ax = axes[2]
ax.plot(X_all[:, 0],
        X_all[:, 1],
        marker='o', ls='', ms=1.5, alpha=0.5)

ax.set_xlim(0, 10)
ax.set_ylim(-5, 5)

# --

fig, axes = plt.subplots(1, 3, figsize=(15, 5), 
                         sharex=True, sharey=True)

ax = axes[0]
ax.plot(X[np.abs(X[:,2]-22)<4, 0], 
        X[np.abs(X[:,2]-22)<4, 1],
        marker='o', ls='', ms=2.5, alpha=1)

ax = axes[1]
ax.plot(X_sample[np.abs(X_sample[:,2]-22)<4, 0], 
        X_sample[np.abs(X_sample[:,2]-22)<4, 1],
        marker='o', ls='', ms=2.5, alpha=1)

ax = axes[2]
ax.plot(X_all[:, 0][np.abs(X_all[:, 2]-22)<4], 
        X_all[:, 1][np.abs(X_all[:, 2]-22)<4],
        marker='o', ls='', ms=2.5, alpha=1)
ax.scatter(vasiliev_pal5_c.pm_phi1_cosphi2, 
           vasiliev_pal5_c.pm_phi2, 
           color='tab:red', zorder=100)

ax.set_xlim(0, 10)
ax.set_ylim(-5, 5)

### Model specification

In [None]:
# from fitting to model
model_pm1_coeffs = np.array([1.64332475e-03, 3.11068762e-02, 3.07719668e+00])
model_pm2_coeffs = np.array([0.00083942, 0.03776296, 0.64002269])
model_dist_coeffs = np.array([-1.01983328e-02, -1.91206970e-01,  2.24739655e+01])

In [None]:
def ln_normal(x, mu, var):
    return -0.5*np.log(2*np.pi) - 0.5*np.log(var) - 0.5 * (x-mu)**2 / var

# Quadratic stream track:
def ln_prior(p):
    (pm1, pm2, dist, 
     b_pm1, b_pm2, b_dist, 
     c_pm1, c_pm2, c_dist,
     x_pm1, x_pm2, x_dist,
     lnf) = p
    
    lp = 0
    
    for x in [pm1, pm2]:
        if not -5 < x < 5:
            return -np.inf
    
    # all have negative curvature
    # for c in [c_pm1, c_pm2, c_dist]:
    #     if c > 0:
    #         return -np.inf
    for b in [b_pm1, b_pm2, b_dist]:
        lp += ln_normal(b, 0, 1e-2)
    for c in [c_pm1, c_pm2, c_dist]:
        lp += ln_normal(c, 0, 1e-3)
    for x in [x_pm1, x_pm2, x_dist]:
        # lp += ln_normal(x, 0, 8)
        if not -20 < x < 5:
            return -np.inf
    
    lp += ln_normal(pm1, vasiliev_pal5_c.pm_phi1_cosphi2.value, 0.1) # Vasiliev
    lp += ln_normal(pm2, vasiliev_pal5_c.pm_phi2.value, 0.1) # Vasiliev
    lp += ln_normal(dist, 21.1, 1.) # me
    
    # uniform in f
    lp += lnf
    if lnf > 0:
        return -np.inf
    
    return lp

def ln_likelihood(p, gmm, X, Cov, dphi1, bg_prob):
    (pm1, pm2, dist, 
     b_pm1, b_pm2, b_dist, 
     c_pm1, c_pm2, c_dist,
     x_pm1, x_pm2, x_dist,
     lnf) = p
    f = np.exp(lnf)
    
    # mu = np.array([[pmra, pmdec, dist]])
    mu = np.zeros_like(X)
    mu[:, 0] = pm1 + b_pm1 * (dphi1 - x_pm1) + c_pm1 * (dphi1 - x_pm1)**2
    mu[:, 1] = pm2 + b_pm2 * (dphi1 - x_pm2) + c_pm2 * (dphi1 - x_pm2)**2
    mu[:, 2] = dist + b_dist * (dphi1 - x_dist) + c_dist * (dphi1 - x_dist)**2 
    V = np.diag([0.02, 0.02, 0.5]) ** 2 # HACK: MAGIC NUMBERs
    
    T = Cov + V
    logproba = log_multivariate_gaussian(X, mu, T)
    
    ll1 = bg_prob + np.log(1-f)
    ll2 = logproba + np.log(f)
    
    return np.logaddexp(ll1, ll2).sum(), (ll1, ll2)

def ln_posterior(p, gmm, X, Cov, dphi1, bg_prob, null_blob, blobs=True):
    lnp = ln_prior(p)
    if not np.isfinite(lnp):
        if blobs:
            return -np.inf, null_blob
        else:
            return -np.inf

    lnl, blob = ln_likelihood(p, gmm, X, Cov, dphi1, bg_prob)
    if not np.isfinite(lnl):
        if blobs:
            return -np.inf, null_blob
        else:
            return -np.inf
    
    if blobs:
        return lnp + lnl, blob
    else:
        return lnp + lnl

In [None]:
filename = 'membership_samples-quadrat-in-track-pm12.pkl'
continue_ = False

# QUADRATIC MODEL:
p0 = [vasiliev_pal5_c.pm_phi1_cosphi2.value, vasiliev_pal5_c.pm_phi2.value, 21.10097608, 
      model_pm1_coeffs[1], model_pm2_coeffs[1], model_dist_coeffs[1],
      model_pm1_coeffs[0], model_pm2_coeffs[0], model_dist_coeffs[0], 
      -0.1, -10, -0.1,
      -1.5]
#     res = minimize(lambda *args, **kwargs: -ln_likelihood(*args, **kwargs)[0],
#                    x0=p0, args=(faster_gmm, XX, CC, dphi1, bg_prob), method='L-BFGS-B')
#     x = res.x
x = p0
# print('done optimize', x)

null_blob = (np.full(len(XX[stream_sky_mask]), -np.inf), 
             np.full(len(XX[stream_sky_mask]), -np.inf))
bg_prob = logsumexp(gmm.logprob_a(XX, CC), axis=-1)

nwalkers = 8 * len(x)
if path.exists(filename):
    with open(filename, 'rb') as f:
        # sampler, blobs = pickle.load(f)
        sampler = pickle.load(f)
    # p0 = chain[:, -1]
    pre_loaded = True

else:
    p0 = emcee.utils.sample_ball(x, 1e-3 * np.abs(x), size=nwalkers)
    pre_loaded = False
    
    
if not path.exists(filename) or continue_:
    with MultiPool() as pool:
        sampler = emcee.EnsembleSampler(nwalkers, len(x), ln_posterior,
                                        args=(faster_gmm, XX[stream_sky_mask], CC[stream_sky_mask], dphi1[stream_sky_mask], 
                                              bg_prob[stream_sky_mask], null_blob, False),
                                        pool=pool)
        
        if pre_loaded:
            pos = p0
            
        else:
            pos, prob, state = sampler.run_mcmc(p0, 1024, progress=True)
            print('done initialization')

            pos0 = emcee.utils.sample_ball(np.median(pos, axis=0), 
                                           1e-3 * np.abs(x), 
                                           size=nwalkers)
            sampler.reset()
            pos, prob, state = sampler.run_mcmc(pos0, 1024, progress=True)
            print('done burn-in')
        
            sampler.reset()
            
        pos, prob, state = sampler.run_mcmc(pos, 131072, progress=True)
    
    sampler.lnprobfn = None
    sampler.pool = None
    with open(filename, 'wb') as f:
        # pickle.dump((sampler, blobs), f)
        # pickle.dump(sampler.chain, f)
        pickle.dump(sampler, f)
    
else:
    with open(filename, 'rb') as f:
        # sampler, blobs = pickle.load(f)
        # chain = pickle.load(f)
        sampler = pickle.load(f)

In [None]:
for k in range(sampler.ndim):
    fig = plt.figure()
    plt.plot(sampler.chain[..., k].T, 
             marker='', color='k', 
             drawstyle='steps-mid', alpha=0.1, lw=0.5)
    
    plt.axhline(np.median(sampler.chain[:, -1, k], axis=0))

In [None]:
# acor = sampler.acor

In [None]:
len(XX[stream_sky_mask]) * np.concatenate(np.exp(sampler.chain[:, ::8, -1]))

In [None]:
norm = 0.0
post_prob = np.zeros(XX.shape[0])

_chain = sampler.chain[:, -2:]
for i in range(_chain.shape[1]):
    for j in range(_chain.shape[0]):
        _, (ll_bg, ll_fg) = ln_likelihood(_chain[j, i], faster_gmm, XX, CC, dphi1, bg_prob)
        post_prob += np.exp(ll_fg - np.logaddexp(ll_fg, ll_bg))
        norm += 1

post_prob /= norm
post_prob.sum()

In [None]:
post_prob[stream_sky_mask & (~np.isin(g_merged.source_id, g_oden.source_id))].sum()

In [None]:
flatchain = np.vstack(sampler.chain[:, ::1024])
flatchain.shape

In [None]:
np.median(flatchain, axis=0)

In [None]:
_ = corner.corner(flatchain)

### Corner plot for paper:

In [None]:
flat_pm1_samples = flatchain[:, 0]
flat_pm2_samples = flatchain[:, 1]
flat_dist_samples = flatchain[:, 2]

flat_samples_c = coord.SkyCoord(phi1=0*u.deg, phi2=0*u.deg, 
                                distance=flat_dist_samples*u.kpc, frame=gc.Pal5,
                                pm_phi1_cosphi2=flat_pm1_samples*u.mas/u.yr,
                                pm_phi2=flat_pm2_samples*u.mas/u.yr).icrs

flat_samples_X = np.stack((flat_samples_c.distance.value,
                           flat_samples_c.pm_ra_cosdec.value,
                           flat_samples_c.pm_dec.value)).T

fac = 3
# era = fac * vasiliev_pal5['ePMRA']
# edec = fac * vasiliev_pal5['ePMDEC']
era = np.sqrt([0.1])
edec = np.sqrt([0.1])
_cross = era * edec * vasiliev_pal5['corrPM']
_cov = np.array([[era**2, _cross],
                 [_cross, edec**2]])[..., 0]
_mu = np.array([vasiliev_pal5['PMRA'], vasiliev_pal5['PMDEC']])[:, 0]
_pm_samples = np.random.multivariate_normal(_mu, _cov, size=1000000).T
_d_samples = np.random.normal(23.6, fac * 0.8, size=_pm_samples.shape[1])
flat_prior_samples = np.stack((_d_samples, _pm_samples[0], _pm_samples[1])).T

In [None]:
fig = corner.corner(flat_prior_samples, color='tab:green', 
                    plot_density=False, plot_datapoints=False,
                    hist_kwargs=dict(density=True),
                    levels=(1-np.exp(-0.5), 1-np.exp(-0.5*(2**2))),
                    bins=50)
fig = corner.corner(flat_samples_X, 
                    range=[[15, 32], [-3.7, -1.7], [-3.7, -1.7]],
                    plot_density=False, plot_datapoints=False,
                    fig=fig, hist_kwargs=dict(density=True),
                    levels=(1-np.exp(-0.5), 1-np.exp(-0.5*(2**2))), 
                    bins=32)

### Write prob table

In [None]:
prob_tbl = g_merged.data.copy()
prob_tbl['member_prob'] = post_prob
prob_tbl['inside_stream_track'] = merged_sky_mask
prob_tbl = prob_tbl[(post_prob > 0.01) & (~np.isin(g_merged.source_id, g_oden.source_id))]
# prob_tbl.write('../data/RRL-with-prob.csv', overwrite=True)
# prob_tbl.write('../data/RRL-with-prob.fits', overwrite=True)
len(prob_tbl)

In [None]:
((prob_tbl['member_prob'] > 0.5) & prob_tbl['inside_stream_track']).sum()

In [None]:
min_prob = 0.05

plt.figure(figsize=(8, 6))

plt.scatter(g_merged.ra[post_prob > min_prob], 
            g_merged.dec[post_prob > min_prob], 
            c=post_prob[post_prob > min_prob], 
            cmap='magma_r', alpha=0.45, linewidth=0)

plt.plot(sky_path['ra'], sky_path['dec'], marker='')

plt.xlim(250, 215)
plt.ylim(-10, 10)

plt.xlabel('RA')
plt.ylabel('Dec.')

# ---

plt.figure(figsize=(8, 6))

plt.scatter(merged_c_pal5.phi1.degree[post_prob > min_prob], 
            merged_c_pal5.phi2.degree[post_prob > min_prob],
            c=post_prob[post_prob > min_prob], 
            cmap='magma_r', alpha=0.75, linewidth=0)

plt.xlim(-20, 15)
plt.ylim(-10, 10)

plt.xlabel('RA')
plt.ylabel('Dec')

### Plot sky position over MS stars

In [None]:
ana = fits.getdata('../../pal5s-biggest-fan/data/pal5_ls_lite_grz.fits')
ana_masks = fits.getdata('../../pal5s-biggest-fan/data/cmd_masks_orig.fits')
ana = ana[ana_masks['gr_mask'] & ana_masks['grz_mask']]

In [None]:
def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=256):
    new_cmap = mpl.colors.LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)))
    return new_cmap

cmap = truncate_colormap(plt.get_cmap('GnBu'), minval=0, maxval=0.9)

In [None]:
min_prob = 0.1
prob_mask = (post_prob > min_prob)
rrl_mask =  ~np.isin(g_merged.source_id, g_oden.source_id)

fig, ax = plt.subplots(1, 1, figsize=(8, 4.5), 
                         sharex=True, sharey=True,
                         constrained_layout=True)

*_, hh = ax.hist2d(ana['ra'], ana['dec'],
                   bins=(np.arange(ra_lim.value[0], ra_lim.value[1]+1e-3, 0.2),
                         np.arange(dec_lim.value[0], dec_lim.value[1]+1e-3, 0.2)),
                   cmap='Greys', vmin=4, vmax=44, rasterized=True)
# hh.set_edgecolor('face')

cs = ax.scatter(g_merged.ra[prob_mask & rrl_mask], 
                g_merged.dec[prob_mask & rrl_mask], 
                c=post_prob[prob_mask & rrl_mask], 
                cmap=cmap, alpha=0.8, 
                linewidth=0.5, edgecolor='#777777', s=20,
                label='RRL', zorder=100, vmin=0, vmax=1)

cs = ax.scatter(g_merged.ra[~prob_mask & rrl_mask], 
                g_merged.dec[~prob_mask & rrl_mask], 
                c=post_prob[~prob_mask & rrl_mask], 
                cmap=cmap, alpha=0.9, 
                linewidth=0.2, edgecolor='#777777', s=4,
                label='RRL', zorder=50, vmin=0, vmax=1)

ax.plot(sky_path['ra'], sky_path['dec'],
        marker='', alpha=0.4, color='tab:purple')

ax.set_xlim(250, 215)
ax.set_ylim(-11, 10)

ax.set_xlabel('RA [deg]')
ax.set_ylabel('Dec. [deg]')

ax.xaxis.tick_bottom()
ax.yaxis.tick_left()

# inset plot
zoom_size = 0.4
zoom_xlim = [229.02-zoom_size, 229.02+zoom_size]
zoom_ylim = [-0.11-zoom_size, -0.11+zoom_size]
# ax.add_patch(mpl.patches.Rectangle((zoom_xlim[0], zoom_ylim[0]), 
#                                    width=zoom_size*2, height=zoom_size*2,
#                                    facecolor='none', edgecolor='tab:orange', linewidth=0.5))
axins = ax.inset_axes([0.63, 0.55, 0.42, 0.42], zorder=500)
axins.hist2d(ana['ra'], ana['dec'],
             bins=(np.arange(zoom_xlim[0], zoom_xlim[1]+1e-3, 0.03),
                   np.arange(zoom_ylim[0], zoom_ylim[1]+1e-3, 0.03)),
             cmap='Greys', rasterized=True, zorder=501, vmin=0, vmax=5)
axins.scatter(g_merged.ra[prob_mask & rrl_mask], 
              g_merged.dec[prob_mask & rrl_mask], 
              c=post_prob[prob_mask & rrl_mask], 
              cmap=cmap, alpha=0.7, 
              linewidth=0.5, edgecolor='#777777', s=30,
              label='RRL', vmin=0, vmax=1, zorder=522)
axins.add_patch(mpl.patches.Circle((229.02, -0.11), radius=(11.2*u.arcmin).to_value(u.deg), 
                                   zorder=510, facecolor='none', edgecolor='tab:red', alpha=0.5))
axins.text(228.84, -0.22, 'Jacobi\nradius', 
           zorder=509, va='top', color='tab:red', alpha=0.7)
axins.set_aspect('equal')
axins.set_xlim(zoom_xlim[::-1])
axins.set_ylim(zoom_ylim)
axins.xaxis.set_visible(False)
axins.yaxis.set_visible(False)
  
cb = fig.colorbar(cs, ax=fig.axes[0], aspect=30)
cb.ax.yaxis.set_ticks(np.arange(0, 1+1e-3, 0.25));
cs.set_clim(0, 1)
cb.set_label('membership probability')

fig.set_facecolor('w')

fig.savefig('../plots/members.pdf', dpi=300)

In [None]:
rrl_mask = ~np.isin(g_merged.source_id, g_oden.source_id)
rrl_mask.sum()

In [None]:
post_prob[stream_sky_mask & rrl_mask].sum()

### N in cluster, and jacobi radius

In [None]:
_mask = (post_prob > 0.5) & stream_sky_mask & (~np.isin(g_merged.source_id, g_oden.source_id))
in_track_members = g_merged[_mask]
in_track_members_c = merged_c[_mask]
in_track_members_pal5 = in_track_members_c.transform_to(gc.Pal5)
len(in_track_members)

In [None]:
pal5_c = coord.SkyCoord(ra=229.018*u.degree, dec=-0.124*u.degree,
                        distance=20.9*u.kpc)
xyz = pal5_c.transform_to(galcen_frame).data.xyz
Menc = gp.MilkyWayPotential().mass_enclosed(xyz)[0]
m_pal5 = 1.2e4*u.Msun

rjac = (m_pal5 / Menc)**(1/3) * pal5_c.transform_to(galcen_frame).data.norm()
rjac_arcmin = (rjac / pal5_c.distance).to(u.arcmin, u.dimensionless_angles())
rjac_arcmin

In [None]:
jacobi_rad_mask = in_track_members_c.separation(pal5_c) < rjac_arcmin
all_jacobi_mask = merged_c.separation(pal5_c) < rjac_arcmin

In [None]:
jacobi_rad_mask.sum(), (~jacobi_rad_mask).sum()

In [None]:
np.unique(in_track_members.best_Type, return_counts=True)

In [None]:
np.unique(in_track_members.best_Type[jacobi_rad_mask], return_counts=True)

In [None]:
np.unique(in_track_members.best_Type[~jacobi_rad_mask], return_counts=True)

In [None]:
(post_prob[rrl_mask] > 0.5).sum(), ((post_prob > 0.5) & rrl_mask & stream_sky_mask).sum()

In [None]:
rrl_types = g_merged.data[((post_prob > 0.5) & rrl_mask & stream_sky_mask)]['PS1_type']
np.unique(rrl_types, return_counts=True)

### Make track samples from mcmc samples

In [None]:
_grid = np.linspace(-20, 15, 512)

# percentiles = [16, 50, 84]
percentiles = [5, 50, 95]

x = _grid[None] - flatchain[:, 9:10]
pm1_samples = flatchain[:, 0:1] + flatchain[:, 3:4]*x + flatchain[:, 6:7]*x**2
pm1_pctl = np.percentile(pm1_samples, percentiles, axis=0)

x = _grid[None] - flatchain[:, 10:11]
pm2_samples = flatchain[:, 1:2] + flatchain[:, 4:5]*x + flatchain[:, 7:8]*x**2
pm2_pctl = np.percentile(pm2_samples, percentiles, axis=0)

x = _grid[None] - flatchain[:, 11:12]
dist_samples = flatchain[:, 2:3] + flatchain[:, 5:6]*x + flatchain[:, 8:9]*x**2
dist_pctl = np.percentile(dist_samples, percentiles, axis=0)

In [None]:
for i in range(3):
    x = 0 - flatchain[:, 9+i]
    x_samples = flatchain[:, 0+i] + flatchain[:, 3+i]*x + flatchain[:, 6+i]*x**2
    print(np.median(x_samples), 1.5 * np.median(np.abs(x_samples - np.median(x_samples))))

In [None]:
coord.SkyCoord(0*u.deg, 0*u.deg,
               pm_phi1_cosphi2=3.783*u.mas/u.yr,
               pm_phi2=0.715*u.mas/u.yr,
               frame=gc.Pal5).icrs

In [None]:
vasiliev_pal5

How different is the velocity:

In [None]:
old_v = (23.5 * u.kpc * np.sqrt(2.296**2 + 2.257**2)*u.mas/u.yr).to(u.km/u.s, u.dimensionless_angles())
new_v = (20.5 * u.kpc * np.sqrt(3.783**2 + 0.72**2)*u.mas/u.yr).to(u.km/u.s, u.dimensionless_angles())
old_v - new_v

In [None]:
plt.plot(_grid, dist_pctl[1], **med_style)
plt.fill_between(_grid, dist_pctl[0], dist_pctl[2],
                 color='tab:blue', alpha=0.5, zorder=-10, linewidth=0)

plt.errorbar(merged_c_pal5.phi1.degree[other_ax_mask], 
             g_merged.D_ps1[other_ax_mask],
             yerr=(0.03 * g_merged.D_ps1)[other_ax_mask],
             marker='o', ls='none', ecolor='#aaaaaa', 
             alpha=0.75, zorder=-10, elinewidth=0.75)

plt.xlim(-1, 1)
plt.ylim(17, 24)
plt.axvline(12/60)
plt.axvline(-12/60)

plt.xlabel(r'$\phi_1$')
plt.ylabel('$d$ [kpc]')

### Plot kinematics vs. phi1

In [None]:
merged_c_pal5_reflex = gc.reflex_correct(merged_c_pal5)

In [None]:
estyle = dict(marker='', ls='none', ecolor='#aaaaaa', 
              alpha=0.75, zorder=-10, elinewidth=0.75)
style = dict(vmin=0., vmax=1, cmap=cmap, s=15,
             marker='o', alpha=0.85, zorder=10,
             linewidth=0.5, edgecolor='#777777')
style_sm = style.copy()
style_sm['s'] = 4
style_sm['alpha'] = 0.4
style_sm['zorder'] = 1

fill_style = dict(color='tab:orange', alpha=0.2, zorder=-10, linewidth=0)
med_style = dict(color='tab:orange', alpha=0.8, marker='', zorder=-2)

fig, axes = plt.subplots(2, 2, figsize=(10, 6), 
                         sharex=True, constrained_layout=True)
axes = axes.flat

ax = axes[0]

cs = ax.scatter(merged_c_pal5.phi1.degree[prob_mask], 
           merged_c_pal5.phi2.degree[prob_mask],
           c=post_prob[prob_mask],
           **style)

ax.scatter(merged_c_pal5.phi1.degree[~prob_mask], 
           merged_c_pal5.phi2.degree[~prob_mask],
           c=post_prob[~prob_mask],
           **style_sm)

ax.quiver(merged_c_pal5_reflex.phi1[prob_mask], 
          merged_c_pal5_reflex.phi2[prob_mask],
          merged_c_pal5_reflex.pm_phi1_cosphi2[prob_mask].value, 
          merged_c_pal5_reflex.pm_phi2[prob_mask].value, 
          headwidth=4, headlength=6, width=0.1, scale=1.,
          units='xy', zorder=-10, color='#777777', alpha=0.5,
          rasterized=True)

ax.yaxis.set_ticks(np.arange(-2, 5+1, 2))
ax.set_xlim(-20, 15)
ax.set_ylim(-3, 5)
ax.set_ylabel(r'$\phi_2$ [${\rm deg}$]')

# ---

other_ax_mask = prob_mask & (merged_c_pal5.phi2.degree > -2.5) & (merged_c_pal5.phi2.degree < 5.)
other_ax_mask_sm = ~prob_mask & (merged_c_pal5.phi2.degree > -2.5) & (merged_c_pal5.phi2.degree < 5.)

ax = axes[1]
ax.errorbar(merged_c_pal5.phi1.degree[other_ax_mask], 
            g_merged.D_ps1[other_ax_mask],
            yerr=(0.03 * g_merged.D_ps1)[other_ax_mask],
            **estyle)

ax.scatter(merged_c_pal5.phi1.degree[other_ax_mask], 
           g_merged.D_ps1[other_ax_mask],
           c=post_prob[other_ax_mask],
           **style)

ax.scatter(merged_c_pal5.phi1.degree[other_ax_mask_sm], 
           g_merged.D_ps1[other_ax_mask_sm],
           c=post_prob[other_ax_mask_sm],
           **style_sm)

ax.fill_between(_grid, dist_pctl[0], dist_pctl[2],
                **fill_style)
ax.plot(_grid, dist_pctl[1], **med_style)

ax.set_ylim(13, 26)
ax.yaxis.set_ticks(np.arange(13, 25+1e-3, 2))
ax.set_ylabel('$d$ [{:latex_inline}]'.format(u.kpc))

# ---

ax = axes[2]
ax.errorbar(merged_c_pal5.phi1.degree[other_ax_mask], 
            merged_c_pal5.pm_phi1_cosphi2.value[other_ax_mask],
            yerr=g_merged.pmra_error.value[other_ax_mask],
            **estyle)

ax.scatter(merged_c_pal5.phi1.degree[other_ax_mask], 
           merged_c_pal5.pm_phi1_cosphi2.value[other_ax_mask],
           c=post_prob[other_ax_mask],
           **style)

ax.scatter(merged_c_pal5.phi1.degree[other_ax_mask_sm], 
           merged_c_pal5.pm_phi1_cosphi2.value[other_ax_mask_sm],
           c=post_prob[other_ax_mask_sm],
           **style_sm)

ax.fill_between(_grid, pm1_pctl[0], pm1_pctl[2],
                **fill_style)
ax.plot(_grid, np.full_like(_grid, pm1_pctl[1]), **med_style)

ax.set_ylim(1, 6)
ax.set_ylabel(r'$\mu_1$ [{:latex_inline}]'.format(u.mas/u.yr))

# ---

ax = axes[3]
ax.errorbar(merged_c_pal5.phi1.degree[other_ax_mask], 
            merged_c_pal5.pm_phi2.value[other_ax_mask],
            yerr=g_merged.pmdec_error.value[other_ax_mask],
            **estyle)

ax.scatter(merged_c_pal5.phi1.degree[other_ax_mask], 
           merged_c_pal5.pm_phi2.value[other_ax_mask],
           c=post_prob[other_ax_mask],
           **style)

ax.scatter(merged_c_pal5.phi1.degree[other_ax_mask_sm], 
           merged_c_pal5.pm_phi2.value[other_ax_mask_sm],
           c=post_prob[other_ax_mask_sm],
           **style_sm)

ax.fill_between(_grid, pm2_pctl[0], pm2_pctl[2],
                **fill_style)
ax.plot(_grid, pm2_pctl[1], **med_style)

ax.set_ylim(-2.5, 2.5)
ax.set_ylabel(r'$\mu_2$ [{:latex_inline}]'.format(u.mas/u.yr))

for ax in axes[2:]:
    ax.set_xlabel(r'$\phi_1$ [${\rm deg}$]')

cb = fig.colorbar(cs, ax=fig.axes, aspect=10, shrink=0.6,
                  location='bottom')
cb.ax.yaxis.set_ticks(np.arange(0, 1+1e-3, 0.25));
cs.set_clim(0, 1)
cb.set_label('membership probability', fontsize=18)

# from matplotlib.font_manager import FontProperties
# fig.suptitle('Pal 5 stream track from RR Lyrae', 
#              fontproperties=FontProperties(family='serif', weight='bold', size=18), 
#              ha='center')

fig.set_facecolor('w')
fig.savefig('../plots/tracks.pdf', dpi=300)

### Plot proper motion distribution

In [None]:
c_stream_track = g_stream_track.get_skycoord(distance=False)
stream_track_jacobi_mask = c_stream_track.separation(pal5_c) < rjac_arcmin

In [None]:
style = dict(marker='o', ls='', ms=4., alpha=0.9)
style_jac = dict(marker='s', ls='', ms=4., alpha=0.9)

fig, axes = plt.subplots(1, 3, figsize=(13, 5),
                         sharex=True, sharey=True)

ax = axes[0]
_D_mask = (g_all.D_ps1 > 18) & (g_all.D_ps1 < 25)
ax.plot(g_all.pmra.value[_D_mask], 
        g_all.pmdec.value[_D_mask],
        **style)

ax = axes[1]
ax.plot(g_stream_track.pmra.value[(np.abs(g_stream_track.D_ps1-23)<4) & stream_track_jacobi_mask], 
        g_stream_track.pmdec.value[(np.abs(g_stream_track.D_ps1-23)<4) & stream_track_jacobi_mask],
        color='tab:blue', **style_jac)
ax.plot(g_stream_track.pmra.value[(np.abs(g_stream_track.D_ps1-23)<4) & ~stream_track_jacobi_mask], 
        g_stream_track.pmdec.value[(np.abs(g_stream_track.D_ps1-23)<4) & ~stream_track_jacobi_mask],
        color='k', **style)

ax.scatter(vasiliev_pal5['PMRA'], vasiliev_pal5['PMDEC'], 
           label='Vasiliev19',
           color='tab:red', s=40, alpha=0.75, linewidth=0, zorder=-5)
ax.errorbar(vasiliev_pal5['PMRA'], vasiliev_pal5['PMDEC'], 
            xerr=vasiliev_pal5['ePMRA'], 
            yerr=vasiliev_pal5['ePMDEC'], 
            color='tab:red', marker='', ls='none', alpha=0.3)

ax.scatter(-2.296, -2.257, label='Fritz15',
           color='tab:green', s=40, alpha=0.75, linewidth=0, zorder=-5)
ax.errorbar(-2.296, -2.257,
            xerr=0.186, 
            yerr=0.181, 
            color='tab:green', marker='', ls='none', alpha=0.3)

ax.annotate('Pal 5', xy=(-2.7, -3.6), xytext=(-2.7, -4.5), 
            ha='center', va='top', color='k',
            arrowprops=dict(arrowstyle="->", color='k'),
            fontsize=17)

# for j in range(1024):
#     ax.plot(np.full_like(_grid, pmra_samples[j]), 
#             pmdec_samples[j], 
#             marker='', alpha=0.01, color='tab:blue', zorder=-100)

ax.text(-3.5, 0 + 0.25, 'median error', fontsize=16,
        ha='center', va='bottom')
ax.errorbar(-3.5, 0,
            xerr=0.25, yerr=0.2,
            marker='', ls='',
            alpha=1, color='#888888')
    
ax = axes[2]
ax.plot(g_nostream.pmra.value[np.abs(g_nostream.D_ps1-23)<4], 
        g_nostream.pmdec.value[np.abs(g_nostream.D_ps1-23)<4],
        **style)

ax.set_xlim(-5, 1)
ax.set_ylim(-5, 1)

for ax in axes:
    ax.set_xlabel(r'$\mu_\alpha$ [{:latex_inline}]'.format(u.mas/u.yr))
    ax.set_xticks(np.arange(-4, 1+1e-3, 1))
axes[0].set_ylabel(r'$\mu_\delta$ [{:latex_inline}]'.format(u.mas/u.yr))

axes[0].set_title(r'RRL: $18 < d < 25\,{\rm kpc}}$')
axes[1].set_title('near stream track')
axes[2].set_title('excluding stream')

axes[1].legend(loc='lower right', fontsize=14)

fig.set_facecolor('w')

fig.tight_layout()

fig.savefig('../plots/proper-motion.pdf')

### Plot Galactocentric properties:

In [None]:
d_samples = np.random.normal(in_track_members.distance.value,
                             0.03 * in_track_members.distance.value,
                             size=(256, len(in_track_members))).T * u.kpc

in_track_c_samples = coord.SkyCoord(in_track_members.ra[:, None],
                                    in_track_members.dec[:, None],
                                    distance=d_samples)

members_galcen = in_track_c_samples.transform_to(galcen_frame)

In [None]:
plt.plot(members_galcen.spherical.lon.wrap_at(180*u.deg).T,
         members_galcen.spherical.lat.T,
         marker='o', ls='none', mew=0);

Find the cluster mean distance, position:

In [None]:
cl_in_track = in_track_members[np.abs(in_track_members_pal5.phi1.wrap_at(180*u.deg).degree) < 0.2]
d = cl_in_track.distance
d_err = 0.03 * d
mean_d = np.sum(d / d_err**2) / np.sum(1 / d_err**2)
print('{:.1f} +/- {:.1f}'.format(mean_d, np.sqrt(1 / np.sum(1 / d_err**2))))

In [None]:
zeropt = coord.SkyCoord(np.median(cl_in_track.ra), 
                        np.median(cl_in_track.dec),
                        distance=mean_d)
zeropt = zeropt.transform_to(galcen_frame)

In [None]:
from astropy.coordinates.matrix_utilities import rotation_matrix, matrix_product

In [None]:
R1 = rotation_matrix(zeropt.spherical.lon, 'z')
R2 = rotation_matrix(-zeropt.spherical.lat, 'y')
R3 = rotation_matrix(-17*u.deg, 'x')
# R3 = np.eye(3)
R = matrix_product(R3, R2, R1)

In [None]:
rep = members_galcen.data.transform(R)
rep = rep.represent_as(coord.SphericalRepresentation)

In [None]:
plt.figure(figsize=(6, 6))
plt.plot(rep.lon.wrap_at(180*u.deg).degree.T,
         rep.lat.degree.T,
         marker='o', ls='none', mew=0,
         alpha=0.4, ms=3);
plt.axhline(0.)
plt.xlim(-15, 15)
plt.ylim(-15, 15)
plt.xlabel('galcen lon')
plt.ylabel('galcen lat')
plt.gcf().set_facecolor('w')
plt.tight_layout()

---

### Plot graveyard

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(6, 12), sharex=True)

ax = axes[2]
ax.errorbar(merged_c_pal5.phi1.degree[rrl_prob_sky_mask], 
            g_merged.D_kpc[rrl_prob_sky_mask],
             yerr=0.03 * g_merged.D_ps1[rrl_prob_sky_mask],
             marker='o', ls='none')
ax.errorbar(c_stream_pal5.phi1.degree,
             g_stream.D_kpc,
             yerr=0.03 * g_stream.D_kpc,
             marker='o', ls='none', mec='tab:red', 
             mew=1, mfc='none', ms=5, ecolor='tab:red',
             zorder=-10)

_grid = np.linspace(-20, 15, 128)
ax.fill_between(_grid, dist_pctl[0], dist_pctl[2],
                 color='tab:blue', alpha=0.5, zorder=-10)
ax.plot(_grid, dist_pctl[1], 
         color='tab:blue', zorder=10, marker='')

ax.set_xlim(-20, 15)
ax.set_ylim(17, 25)
ax.set_xlabel(r'$\phi_1$')
ax.set_ylabel('D_kpc')

# ---

ax = axes[0]
ax.errorbar(merged_c_pal5.phi1.degree[rrl_prob_sky_mask], 
             g_merged.pmra.value[rrl_prob_sky_mask],
             yerr=g_merged.pmra_error.value[rrl_prob_sky_mask],
             marker='o', ls='none')

ax.errorbar(c_stream_pal5.phi1.degree,
             g_stream.pmra.value,
             yerr=g_stream.pmra_error.value,
             marker='o', ls='none', mec='tab:red', 
             mew=1, mfc='none', ms=5, ecolor='tab:red',
             zorder=-10)


ax.set_xlim(-20, 15)
ax.set_ylim(-3.5, -1.5)
ax.set_ylabel(r'$\mu_\alpha$')

# ---

ax = axes[1]
ax.errorbar(merged_c_pal5.phi1.degree[rrl_prob_sky_mask], 
             g_merged.pmdec.value[rrl_prob_sky_mask],
             yerr=g_merged.pmdec_error.value[rrl_prob_sky_mask],
             marker='o', ls='none')

ax.errorbar(c_stream_pal5.phi1.degree,
             g_stream.pmdec.value,
             yerr=g_stream.pmdec_error.value,
             marker='o', ls='none', mec='tab:red', 
             mew=1, mfc='none', ms=5, ecolor='tab:red',
             zorder=-10)

ax.fill_between(_grid, pmdec_pctl[0], pmdec_pctl[2],
                 color='tab:blue', alpha=0.5, zorder=-10)
ax.plot(_grid, pmdec_pctl[1], 
         color='tab:blue', zorder=10, marker='')

ax.set_xlim(-1, 1)
ax.set_ylim(-3.5, -1.5)
ax.set_ylabel(r'$\mu_\delta$')

fig.tight_layout()

In [None]:
min_prob_mask = post_prob > 0.1

estyle = dict(marker='', ls='none', ecolor='#aaaaaa', 
              alpha=0.75, zorder=-10)
style = dict(vmin=0., vmax=1, cmap='magma_r',
             marker='o', linewidth=0, alpha=0.75, zorder=10)

# ---

fig, axes = plt.subplots(3, 1, figsize=(6, 12), sharex=True)

ax = axes[0]
ax.errorbar(merged_c_pal5.phi1.degree[min_prob_mask], 
            g_merged.pmra.value[min_prob_mask],
            yerr=g_merged.pmra_error.value[min_prob_mask],
            **estyle)

ax.scatter(merged_c_pal5.phi1.degree[min_prob_mask], 
           g_merged.pmra.value[min_prob_mask],
           c=post_prob[min_prob_mask],
           **style)

ax.fill_between(_grid, pmra_pctl[0], pmra_pctl[2],
                 color='tab:blue', alpha=0.5, zorder=-10, linewidth=0)
ax.plot(_grid, np.full_like(_grid, pmra_pctl[1]), 
         color='tab:blue', zorder=10, marker='')

ax.set_ylim(-3.5, -1.5)
ax.set_ylabel(r'$\mu_\alpha$')

# ---

ax = axes[1]
ax.errorbar(merged_c_pal5.phi1.degree[min_prob_mask], 
            g_merged.pmdec.value[min_prob_mask],
            yerr=g_merged.pmdec_error.value[min_prob_mask],
            **estyle)

ax.scatter(merged_c_pal5.phi1.degree[min_prob_mask], 
           g_merged.pmdec.value[min_prob_mask],
           c=post_prob[min_prob_mask],
           **style)

ax.fill_between(_grid, pmdec_pctl[0], pmdec_pctl[2],
                color='tab:blue', alpha=0.5, zorder=-10, linewidth=0)
ax.plot(_grid, pmdec_pctl[1], 
         color='tab:blue', zorder=10, marker='')

ax.set_ylim(-3.5, -1.5)
ax.set_ylabel(r'$\mu_\delta$')

# ---

ax = axes[2]
ax.errorbar(merged_c_pal5.phi1.degree[min_prob_mask], 
            g_merged.D_ps1[min_prob_mask],
            yerr=(0.03 * g_merged.D_ps1)[min_prob_mask],
            **estyle)

ax.scatter(merged_c_pal5.phi1.degree[min_prob_mask], 
           g_merged.D_ps1[min_prob_mask],
           c=post_prob[min_prob_mask],
           **style)

ax.fill_between(_grid, dist_pctl[0], dist_pctl[2],
                 color='tab:blue', alpha=0.5, zorder=-10, linewidth=0)
ax.plot(_grid, dist_pctl[1], 
         color='tab:blue', zorder=10, marker='')

ax.set_xlim(-20, 15)
ax.set_ylim(17, 25)
ax.set_xlabel(r'$\phi_1$')
ax.set_ylabel('D_kpc')

fig.tight_layout()
