- Data tables from [Reid & Brunthaler 2020](https://ui.adsabs.harvard.edu/abs/2020ApJ...892...39R/abstract)
- Fiducial coordinates from note in Table 1

In [None]:
import pathlib

import astropy.coordinates as coord
import astropy.table as at
from astropy.time import Time
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

import arviz as az
import pymc3 as pm
import pymc3_ext as pmx

In [None]:
fiducial_c = coord.SkyCoord(
    "17:45:40.0409", 
    "-29:00:28.118",
    unit=(u.hourangle, u.degree)
)

In [None]:
data_path = pathlib.Path('../data/').resolve()

In [None]:
data = {}
for filename in data_path.glob('J*'):
    name = filename.parts[-1]
    tbl = at.QTable.read(filename, format='ascii.csv')
    tbl['Date'] = Time(tbl['Date'], format='jyear')
    for colname in tbl.colnames[1:]:
        tbl[colname] *= u.mas
    tbl['dEast'] = -tbl['dEast']
    tbl['dNorth'] = -tbl['dNorth']
    data[name] = tbl

In [None]:
style = dict(ls='none', marker='o', ms=2)

fig, axes = plt.subplots(
    1, 2, 
    figsize=(12, 5), 
    sharex=True, 
    constrained_layout=True
)

for name, tbl in data.items():
    print(np.min(tbl['dEast_err']), np.min(tbl['dNorth_err']))
    axes[0].errorbar(
        tbl['Date'].jyear,
        tbl['dEast'].value,
        tbl['dEast_err'].value,
        **style
    )
    
    axes[1].errorbar(
        tbl['Date'].jyear,
        tbl['dNorth'].value,
        tbl['dNorth_err'].value,
        **style
    )

for ax in axes:
    ax.set_xlabel('year')
axes[0].set_ylabel(r'$\Delta\alpha$ [mas]')
axes[1].set_ylabel(r'$\Delta\delta$ [mas]')

In [None]:
EPOCH = 2000.
import theano.tensor as tt

def make_model(t_jyear, dx, dx_err):
    with pm.Model() as model:
        acc = pm.Uniform('acc', -10, 10)  # acceleration in mas/yr**2
        pm_ = pm.Uniform('pm', -10, 10)  # proper motion in mas/yr
        x0 = pm.Uniform('x0', -1000, 1000)
        logs = pm.Uniform('logs', -12, 2)
        s = tt.exp(logs)
        err = tt.sqrt(s**2 + dx_err**2)
        true_dx = acc * (t_jyear - EPOCH)**2 + pm_ * (t_jyear - EPOCH) + x0
        pm.Normal('like', true_dx, err, observed=dx)
    
    return model

In [None]:
all_samples = {}
for name, tbl in data.items():
    for dir_ in ['East', 'North']:
        with make_model(tbl['Date'].jyear, tbl[f'd{dir_}'].value, tbl[f'd{dir_}_err'].value) as model:
            res = pmx.optimize(start={'pm': -3, 'x0': 0})
            print(res)
            all_samples[name + dir_] = pmx.sample(tune=1000, draws=10000, chains=2,
                                                  start=res, return_inferencedata=True)

# Joint fit:
tbl = at.vstack((data['J1745-283'], data['J1748-291']))
for dir_ in ['East', 'North']:
    with make_model(tbl['Date'].jyear, tbl[f'd{dir_}'].value, tbl[f'd{dir_}_err'].value) as model:
        res = pmx.optimize(start={'pm': -3, 'x0': 0})
        print(res)
        all_samples['joint' + dir_] = pmx.sample(
            tune=1000, draws=10000, chains=2,
            start=res, return_inferencedata=True
        )

In [None]:
pm_east = np.mean(all_samples['jointEast'].posterior.pm.values.ravel())
pm_east_err = np.std(all_samples['jointEast'].posterior.pm.values.ravel())

pm_north = np.mean(all_samples['jointNorth'].posterior.pm.values.ravel())
pm_north_err = np.std(all_samples['jointNorth'].posterior.pm.values.ravel())

In [None]:
print(f"pm_E = {pm_east:.3f} +/- {pm_east_err:.3f}")
print(f"pm_N = {pm_north:.3f} +/- {pm_north_err:.3f}")

In [None]:
pos_east_2016 = all_samples['jointEast'].posterior.pm * (2016 - EPOCH) + all_samples['jointEast'].posterior.x0
pos_north_2016 = all_samples['jointNorth'].posterior.pm * (2016 - EPOCH) + all_samples['jointNorth'].posterior.x0

In [None]:
np.mean(pos_east_2016).values, np.mean(pos_north_2016).values

In [None]:
np.std(pos_east_2016).values, np.std(pos_north_2016).values

In [None]:
sgr_ra_2016 = fiducial_c.ra + np.mean(pos_east_2016).values * u.mas
sgr_dec_2016 = fiducial_c.dec + np.mean(pos_north_2016).values * u.mas

In [None]:
Rsun = 8.275 * u.kpc
cc = coord.SkyCoord(
    sgr_ra_2016, 
    sgr_dec_2016,
    distance=Rsun,
    pm_ra_cosdec=pm_east * u.mas/u.yr,
    pm_dec=pm_north * u.mas/u.yr,
    radial_velocity=0*u.km/u.s
)

galcen_frame = coord.Galactocentric(
    galcen_v_sun=[0, 0, 0]*u.km/u.s,
    galcen_distance=Rsun,
    z_sun=20.8 * u.pc
)

In [None]:
-cc.transform_to(galcen_frame).velocity.d_xyz

---

In [None]:
import corner

In [None]:
fig = corner.corner(all_samples['jointEast'].posterior)

In [None]:
lw = 1.

ller = az.labels.MapLabeller(
    var_name_map={
        "acc": r"acceleration $a$ [mas/yr^2]",
        "pm": r"proper motion $\mu$ [mas/yr]",
        "x0": "J2000 epoch position [mas]"
    }
)

var_names = ['x0', 'pm', 'acc']
limits = [
    (10, 15),
    (-5, -2),
    (-1e-2, 1e-2)
]

for dir_ in ['East', 'North']:
    
    axes = None
    for name, color in zip(data.keys(), ['tab:blue', 'tab:orange']):
        ret_axes = az.plot_pair(
            all_samples[name + dir_], 
            var_names=var_names,
            figsize=(8, 8),
            kind="kde",
            kde_kwargs={
                "hdi_probs": [0.68], 
                "contour_kwargs": {"colors": color, "alpha": 0.85, "fill_last": False, "linewidths": lw},
                "contourf_kwargs": {"alpha": 0},
                "label": name
            },
            marginal_kwargs={
                "color": color,
                "plot_kwargs": {"marker": "", "linewidth": lw}
            },
            marginals=True,
#             labeller=ller
        )
    
#     for i in range(3):
#         for j in range(3):
#             axes[i, j].set_xlim(limits[j])

#         axes[0,0].yaxis.set_visible(False)
#         axes[1,1].xaxis.set_visible(False)
    
    break