In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import os
import astropy.coordinates as coord
import astropy.units as u
from astropy.table import Table, QTable, hstack
from myspace import MySpace
from sklearn.mixture import GaussianMixture
from zero_point import zpt

In [2]:
import sklearn
import jax
import numpy
import scipy

print('scikit-learn', sklearn.__version__)
print('jax', jax.__version__)
print('numpy', numpy.__version__)
print('scipy', scipy.__version__)

#Output:
#scikit-learn 0.23.2
#jax 0.2.5
#numpy 1.19.1
#scipy 1.5.0

scikit-learn 0.23.2
jax 0.2.7
numpy 1.19.4
scipy 1.5.4


In [3]:
gaia = QTable.read('../data/RV-all-result.fits', format='fits')

In [4]:
zpt.load_tables()

gmag = gaia['phot_g_mean_mag'].value
nueffused = gaia['nu_eff_used_in_astrometry'].value
psc = gaia['pseudocolour'].value
sinbeta = np.sin(np.deg2rad(gaia['ecl_lat'].value))
soltype = gaia['astrometric_params_solved']

zpvals = zpt.get_zpt(gmag, nueffused, psc, sinbeta, soltype)
cparallax=gaia['parallax'].value-zpvals

                Outside this range, there is no further interpolation, thus the values at 6 or 21 are returned.
                mag). Outside this range, the zero-point calculated can be seriously wrong.
                 The maximum corrections are reached already at 1.24 and 1.72


In [5]:
qindx=(gaia['parallax_over_error']>4.)


In [6]:
c = coord.SkyCoord(ra=gaia['ra'][qindx],dec=gaia['dec'][qindx],distance=1./cparallax[qindx]*u.kpc,pm_ra_cosdec=gaia['pmra'][qindx],pm_dec=gaia['pmdec'][qindx],radial_velocity=gaia['radial_velocity'][qindx])



In [7]:
gal = c.galactic
gal.set_representation_cls('cartesian')

In [8]:
xyz = np.vstack((gal.u.to(u.kpc).value, 
                 gal.v.to(u.kpc).value,
                 gal.w.to(u.kpc).value)).T

UVW = np.vstack((gal.U.to(u.km/u.s).value, 
                 gal.V.to(u.km/u.s).value, 
                 gal.W.to(u.km/u.s).value)).T

disk_vmask = np.linalg.norm(UVW, axis=1) < 150.

In [9]:
XX=xyz
VV=UVW
dist2=np.sqrt(XX[:,0]**2+XX[:,1]**2)

_cyl = gal.represent_as('cylindrical')
mask_r100 = (_cyl.rho < 100*u.pc) & (np.abs(_cyl.z) < 150*u.pc)
mask_r300 = (_cyl.rho < 300*u.pc) & (np.abs(_cyl.z) < 500*u.pc)
mask_r500 = (_cyl.rho < 500*u.pc) & (np.abs(_cyl.z) < 500*u.pc)
mask_r100.sum(), mask_r500.sum()
#local_mask=(dist2<0.2)*(np.fabs(XX[:,2])<0.2)
#train_mask=(dist2>0.2)*(dist2<0.5)*(np.fabs(XX[:,2])<0.5)

local_v = UVW[disk_vmask & mask_r100]
local_x = xyz[disk_vmask & mask_r100]

print(local_v.shape)

(98166, 3)


In [10]:
local_gmm = GaussianMixture(n_components=64)
local_gmm.fit(local_v)

GaussianMixture(n_components=64)

In [11]:
myspacexv = MySpace(local_gmm, terms=['x','xv'])

In [12]:

train_v = UVW[disk_vmask & mask_r300]
train_x = xyz[disk_vmask & mask_r300]

test_v = UVW[disk_vmask & mask_r500]
test_x = xyz[disk_vmask & mask_r500]

local_v.shape, train_v.shape, test_v.shape

((98166, 3), (829927, 3), (1499237, 3))

In [13]:
# subsample so Hogg doesn't die of old age
n, p = train_v.shape
I = np.random.randint(n, size=128)
train_x, train_v = train_x[I], train_v[I]
print(train_x.shape, train_v.shape)
n2, p = test_v.shape
I2 = np.random.randint(n2, size=42)
test_x, test_v = test_x[I2], test_v[I2]
print(test_x.shape, test_v.shape)

(128, 3) (128, 3)
(42, 3) (42, 3)


In [14]:
resxv, tensorsxv = myspacexv.fit(train_x, train_v)
# resxvx, tensorsxvx = myspacexvx.fit(train_x, train_v)



In [15]:
print(resxv, tensorsxv)

      fun: 15.185102574747603
 hess_inv: array([[ 1.81415145e+04, -3.76789628e+03, -3.56375081e+03, ...,
         1.49057370e+02,  1.03810335e+02, -2.73388336e+01],
       [-3.76789628e+03,  9.61178062e+03,  4.86757778e+02, ...,
        -9.82029401e+00,  2.05600194e+01, -2.49979420e+01],
       [-3.56375081e+03,  4.86757778e+02,  1.89925171e+03, ...,
        -3.14574392e+01, -6.73757909e+01,  4.77579014e+00],
       ...,
       [ 1.49057370e+02, -9.82029401e+00, -3.14574392e+01, ...,
         2.76701754e+01,  6.63240226e+00, -6.79163448e+00],
       [ 1.03810335e+02,  2.05600194e+01, -6.73757909e+01, ...,
         6.63240226e+00,  3.39315993e+01, -4.30562997e+00],
       [-2.73388336e+01, -2.49979420e+01,  4.77579014e+00, ...,
        -6.79163448e+00, -4.30562997e+00,  1.75140194e+01]])
      jac: array([ 1.21765763e-07, -1.44716630e-06, -3.17738749e-06,  3.74579550e-06,
        5.68710324e-09,  2.07309914e-06,  1.84731020e-06,  1.41224386e-06,
       -5.53900419e-07, -9.09996005e-07, 

In [16]:
fixxv=myspacexv.get_model_v(test_v,test_x)

In [17]:
print(fixxv)

[[  22.39441858  -40.55778467  -13.96741999]
 [  16.45274479   12.7960342    -5.82388435]
 [ -23.44240401   42.82291755  -14.83189223]
 [ -39.21526686  -53.44511096  -28.42970739]
 [-128.74297459  -33.19079446  -45.76426549]
 [  35.26947261   10.91073151    2.43606833]
 [ -42.04719767  -11.70941256   23.6777174 ]
 [ -63.76725511 -126.355443     -4.24693589]
 [ 156.44588111  -30.38454268  -11.69736289]
 [ -33.84381354 -200.70709137 -137.9661494 ]
 [  31.82703284  -44.67838105  -44.67088357]
 [  -4.84455988   -6.90419264   -1.12695983]
 [ -25.20859417  -25.99180988   -2.11580747]
 [  82.5474887   -44.38199011  -72.53525802]
 [  59.81999968   -5.45605187  -28.78602813]
 [ -83.57106642   95.69270732   47.18294665]
 [  42.04286818    6.53555879    0.26533847]
 [ -55.1140662   -42.95424221  -42.50883195]
 [ -24.56379223  -14.35488969  -25.60341897]
 [ -74.19444075  -58.43354829   13.5230567 ]
 [  19.18889701   -1.47984024    0.24688972]
 [   6.63285164  -27.58209466  -19.47929607]
 [ -34.035