In [None]:
import os

# Third-party
from astropy.constants import G
import astropy.coordinates as coord
from astropy.io import ascii, fits
import astropy.table as atbl
import astropy.units as u
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
plt.style.use('notebook.mplstyle')
%matplotlib inline
import corner

import gala.dynamics as gd
import gala.coordinates as gc
import gala.potential as gp
from gala.units import galactic

In [None]:
tgas_rave = fits.getdata('../data/tgas-rave.fits', 1)

SN_cut = 5

clean_idx = ((np.abs(tgas_rave['parallax'] / tgas_rave['parallax_error']) > SN_cut) & 
             np.isfinite(tgas_rave['parallax'] / tgas_rave['parallax_error']) &
             (tgas_rave['eHRV'] < 10.) &
             np.isfinite(tgas_rave['HRV'] / tgas_rave['eHRV']))

clean_tgas_rave = tgas_rave[clean_idx]
len(clean_tgas_rave)

In [None]:
rv = clean_tgas_rave['HRV'] * u.km/u.s
c = coord.ICRS(ra=clean_tgas_rave['ra']*u.deg, dec=clean_tgas_rave['dec']*u.deg,
               distance=(clean_tgas_rave['parallax']*u.mas).to(u.pc, equivalencies=u.parallax()))
pm = np.vstack((clean_tgas_rave['pmra'],clean_tgas_rave['pmdec'])) * u.mas/u.yr

xyz = c.transform_to(coord.Galactocentric).cartesian.xyz
vxyz = gc.vhel_to_gal(c, pm=pm, rv=rv, vlsr=[0,0,0]*u.km/u.s)

_ix = np.abs(xyz[2]) < 100*u.pc
w0 = gd.CartesianPhaseSpacePosition(pos=xyz[:,_ix], vel=vxyz[:,_ix])
print(w0.shape)

xyz = w0.pos.to(u.kpc)
vxyz = w0.vel.to(u.km/u.s) - np.median(w0.vel.to(u.km/u.s), axis=1)[:,None]

In [None]:
vcyl = gc.cartesian_to_cylindrical(xyz, vxyz)

In [None]:
fig,axes = plt.subplots(1,3,figsize=(15,5))

axes[0].hist(vcyl[0], bins=np.linspace(-150,150,64))
axes[1].hist(vcyl[1], bins=np.linspace(-150,150,64));
axes[2].hist(vcyl[2], bins=np.linspace(-150,150,64));

axes[0].set_xlabel('$v_R$ [km s$^{-1}$]')
axes[1].set_xlabel(r'$v_\phi$ [km s$^{-1}$]')
axes[2].set_xlabel(r'$v_z$ [km s$^{-1}$]')

In [None]:
fig,axes = plt.subplots(1,3,figsize=(15,5), sharex=True, sharey=True)

style = dict(marker='.', alpha=0.5, cmap='RdBu', s=50, vmin=-50, vmax=50)
axes[0].scatter(xyz[0], xyz[1], c=vcyl[0], #-np.median(vcyl[0]),
                **style)

axes[1].scatter(xyz[0], xyz[1], c=vcyl[1], #-np.median(vxyz[1]), #vcyl[1]-np.median(vcyl[1]),
                **style)

c = axes[2].scatter(xyz[0], xyz[1], c=vcyl[2], #-np.median(vcyl[2]),
                    **style)

axes[0].set_xlim(-8.3-0.25, -8.3+0.25)
axes[0].set_ylim(-0.25,0.25)

axes[0].set_xlabel("$x$ [kpc]")
axes[0].set_ylabel("$y$ [kpc]")
axes[1].set_xlabel("$x$ [kpc]")
axes[2].set_xlabel("$x$ [kpc]")

axes[0].set_title("color: $v_R$")
axes[1].set_title(r"color: $v_\phi$")
axes[2].set_title("color: $v_z$")

# fig.colorbar(c)