In [None]:
from os import path

# Third-party
import astropy.coordinates as coord
from astropy.table import Table, vstack
from astropy.io import fits
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

from pyia import GaiaData

import gala.coordinates as gc
import gala.dynamics as gd
from scipy.stats import binned_statistic
from scipy.special import logsumexp

import emcee
import corner

In [None]:
g = GaiaData('../data/gd1-with-masks.fits')
stream = g[g.pm_mask & g.gi_cmd_mask]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(15, 4))

ax.plot(stream.phi1, stream.phi2, 
        marker='o', linewidth=0, ms=3,
        linestyle='none', alpha=0.6, c='k')

ax.set_xlim(-36, 0)
ax.set_ylim(-4, 4)

r = mpl.patches.Rectangle((-17.5, -0.7), 8, 1.2)
# r = mpl.patches.Rectangle((-30, -0.7), 8, 1.2)
ax.add_patch(r)

In [None]:
box = ((stream.phi1.value > r.xy[0]) &
       (stream.phi1.value < (r.xy[0]+r.get_width())) & 
       (stream.phi2.value > r.xy[1]) &
       (stream.phi2.value < (r.xy[1]+r.get_height())))
box.sum()

In [None]:
dense_part = stream[box]

In [None]:
C = dense_part.get_cov()
# y = np.vstack((dense_part.ra.value, dense_part.dec.value,
#                dense_part.pmra.value, dense_part.pmdec.value)).T
y = np.vstack((dense_part.pmra.value, dense_part.pmdec.value)).T

# cov = np.delete(np.delete(C, 2, axis=2), 2, axis=1)
# cov = np.delete(np.delete(cov, -1, axis=2), -1, axis=1)
cov = C[:, 3:5, 3:5]
ivar = np.array([np.linalg.inv(cov[i]) for i in range(len(dense_part))])

In [None]:
def lnlike(p, y, ivar):
    p_y = np.array(p) - y
    derp = np.einsum('ijk,ij->ik', ivar, p_y)
    derp = np.einsum('ij,ij->i', p_y, derp)
    return -0.5 * np.sum(derp)

In [None]:
nwalkers = 128
sampler = emcee.EnsembleSampler(nwalkers, y.shape[1], lnlike, 
                                args=(y, ivar))

In [None]:
# p0 = np.random.normal([177, 53.9, -7, -7.], 1e-2, size=(nwalkers, sampler.dim))
p0 = np.random.normal([-7, -7.], 1e-2, size=(nwalkers, sampler.dim))
pos, *_ = sampler.run_mcmc(p0, 1024)
sampler.reset()
_ = sampler.run_mcmc(pos, 2048)

In [None]:
fig, axes = plt.subplots(sampler.dim, 1, figsize=(6, 8),
                         sharex=True)

for k in range(sampler.dim):
    for walker in sampler.chain[..., k]:
        axes[k].plot(walker, marker='', 
                     drawstyle='steps-mid', color='k', alpha=0.2)

In [None]:
np.median(sampler.acceptance_fraction)

In [None]:
flatchain = np.vstack((sampler.chain[:, 256::16]))

In [None]:
_ = corner.corner(flatchain)

In [None]:
med_y = np.median(flatchain, axis=0)
med_y_cov = np.cov(flatchain.T)
med_y_std = 1.5 * np.median(np.abs(flatchain - med_y), axis=0)

In [None]:
med_y

In [None]:
plt.figure(figsize=(6, 6))
plt.scatter(y[:, 0], y[:, 1])
plt.scatter(med_y[0], med_y[1])

In [None]:
plt.scatter(dense_part.ra, dense_part.dec)
plt.scatter(np.mean(dense_part.ra), np.mean(dense_part.dec))

In [None]:
print('ra, dec = {:.3f}, {:.3f}'.format(np.mean(dense_part.ra), np.mean(dense_part.dec)))

In [None]:
print('pmra_cosdec = {:.2f} +/- {:.2f} '.format(med_y[0], med_y_std[0]))
print('pmdec = {:.2f} +/- {:.2f} '.format(med_y[1], med_y_std[1]))