In [None]:
import os

# Third-party
from astropy.constants import G
import astropy.coordinates as coord
from astropy.io import ascii, fits
import astropy.table as atbl
import astropy.units as u
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
plt.style.use('notebook.mplstyle')
%matplotlib inline
import corner

import gala.dynamics as gd
import gala.coordinates as gc
import gala.potential as gp
from gala.units import galactic

In [None]:
tgas_path = os.path.abspath("../data/tgas.hdf5")

In [None]:
with h5py.File(tgas_path, 'r') as f:
    tgas = f['tgas'][:]
    
all_tgas_c = coord.SkyCoord(ra=tgas['ra']*u.degree, 
                            dec=tgas['dec']*u.degree,
                            distance=(tgas['parallax']*u.mas).to(u.kpc,u.parallax()))

In [None]:
all_xyz = all_tgas_c.transform_to(coord.Galactocentric).cartesian.xyz

In [None]:
plt.figure(figsize=(10,10))
plt.plot(all_xyz[0][np.abs(all_xyz[1]) < 0.1*u.kpc],
         all_xyz[2][np.abs(all_xyz[1]) < 0.1*u.kpc], 
         ls='none', marker=',', alpha=0.4)
plt.xlim(-8.3-2, -8.3+2)
plt.ylim(-1,1)
plt.xlabel('$x$ [kpc]')
plt.ylabel('$z$ [kpc]')

In [None]:
Gmag = tgas['phot_g_mean_mag']

parallax = tgas['parallax'] * u.mas
parallax_err = tgas['parallax_error']*u.mas

pm = np.vstack((tgas['pmra'], tgas['pmdec'])) * u.mas/u.yr
pm_err = np.vstack((tgas['pmra_error'], tgas['pmdec_error'])) * u.mas/u.yr

In [None]:
pm_mag = np.sqrt(np.sum(pm**2, axis=0))

In [None]:
plt.figure(figsize=(10,10))

H,xedg,yedge = np.histogram2d(np.log(parallax.to(u.mas).value), np.log(pm_mag.to(u.mas/u.yr).value), 
                              bins=(np.linspace(-5,7,128), np.linspace(-3,9,128)))

plt.imshow(np.log(H.T), cmap='Blues')

# plt.plot(np.log(parallax.to(u.mas).value), 
#          np.log(pm_mag.to(u.mas/u.yr).value), 
#          linestyle='none', alpha=0.1, marker=',')

# plt.xlim(-5, 7.)
# plt.ylim(-3, 9.)

plt.xlabel(r'$\log\,\varpi$')
plt.ylabel(r'$\log\,\mu$')

---

In [None]:
tgas_rave = fits.getdata('../data/tgas-rave.fits', 1)

clean_idx = ((np.abs(tgas_rave['parallax'] / tgas_rave['parallax_error']) > 2) & 
             (np.abs(tgas_rave['HRV'] / tgas_rave['eHRV']) > 2) & 
             (tgas_rave['Fe_H'] > -0.5))

clean_tgas_rave = tgas_rave[clean_idx]
len(clean_tgas_rave)

In [None]:
plt.hist(clean_tgas_rave['Fe_H'], bins=np.linspace(-0.5, 1., 32));

In [None]:
tgas_rave_c = coord.SkyCoord(ra=clean_tgas_rave['ra']*u.deg, 
                             dec=clean_tgas_rave['dec']*u.deg, 
                             distance=(clean_tgas_rave['parallax']*u.mas).to(u.kpc, equivalencies=u.parallax()))
xyz = tgas_rave_c.transform_to(coord.Galactocentric).represent_as(coord.CartesianRepresentation).xyz
cyl = tgas_rave_c.transform_to(coord.Galactocentric).represent_as(coord.CylindricalRepresentation)

In [None]:
pm = np.vstack((clean_tgas_rave['pmra'], clean_tgas_rave['pmdec'])) * u.mas/u.yr
rv = clean_tgas_rave['HRV'] * u.km/u.s

In [None]:
UVW = gc.vhel_to_gal(tgas_rave_c, pm=pm, rv=rv, vlsr=[0,0,0]*u.km/u.s)

In [None]:
fig,ax = plt.subplots(1,1,figsize=(10,10))
ax.scatter(xyz[0], xyz[1], c=UVW[2], s=4, marker=',',
           cmap='plasma', vmin=-50, vmax=50)
ax.set_xlim(-9.2, -7.4)
ax.set_ylim(-0.9, 0.9)

In [None]:
fig,axes = plt.subplots(1,3,figsize=(15,5), sharex=True, sharey=True)

axes[0].scatter(xyz[0][np.abs(xyz[1])<0.1*u.kpc], 
                xyz[2][np.abs(xyz[1])<0.1*u.kpc], 
                c=UVW[0][np.abs(xyz[1])<0.1*u.kpc], s=2, marker=',',alpha=0.5, 
                cmap='RdBu', vmin=-50, vmax=50)
axes[0].set_xlim(-9.2, -7.4)
axes[0].set_ylim(-0.9, 0.9)
axes[0].set_title("$v_x$")

axes[1].scatter(xyz[0][np.abs(xyz[1])<0.1*u.kpc], 
                xyz[2][np.abs(xyz[1])<0.1*u.kpc], 
                c=UVW[1][np.abs(xyz[1])<0.1*u.kpc]-220*u.km/u.s, s=2, marker=',', alpha=0.5, 
                cmap='RdBu', vmin=-50, vmax=50)
axes[1].set_title("$v_y$")

axes[2].scatter(xyz[0][np.abs(xyz[1])<0.1*u.kpc], 
                xyz[2][np.abs(xyz[1])<0.1*u.kpc], 
                c=UVW[2][np.abs(xyz[1])<0.1*u.kpc], s=2, marker=',',alpha=0.5, 
                cmap='RdBu', vmin=-50, vmax=50)
axes[2].set_title("$v_z$")

for ax in axes:
    ax.set_xlabel("$x$ [kpc]")
axes[0].set_ylabel("$z$ [kpc]")

fig.tight_layout()