In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import os
import astropy.coordinates as coord
import astropy.units as u
from astropy.table import Table, QTable, hstack
from myspace import MySpace
from sklearn.mixture import GaussianMixture
from zero_point import zpt

In [None]:
import sklearn
import jax
import numpy
import scipy

print('scikit-learn', sklearn.__version__)
print('jax', jax.__version__)
print('numpy', numpy.__version__)
print('scipy', scipy.__version__)

#Output:
#scikit-learn 0.23.2
#jax 0.2.5
#numpy 1.19.1
#scipy 1.5.0

In [None]:
gaia = QTable.read('../data/RV-all-result.fits', format='fits')

In [None]:
zpt.load_tables()

gmag = gaia['phot_g_mean_mag'].value
nueffused = gaia['nu_eff_used_in_astrometry'].value
psc = gaia['pseudocolour'].value
sinbeta = np.sin(np.deg2rad(gaia['ecl_lat'].value))
soltype = gaia['astrometric_params_solved']

zpvals = zpt.get_zpt(gmag, nueffused, psc, sinbeta, soltype)
cparallax=gaia['parallax'].value-zpvals

In [None]:
qindx=(gaia['parallax_over_error']>4.)


In [None]:
c = coord.SkyCoord(ra=gaia['ra'][qindx],dec=gaia['dec'][qindx],distance=1./cparallax[qindx]*u.kpc,pm_ra_cosdec=gaia['pmra'][qindx],pm_dec=gaia['pmdec'][qindx],radial_velocity=gaia['radial_velocity'][qindx])



In [None]:
gal = c.galactic
gal.set_representation_cls('cartesian')

In [None]:
xyz = np.vstack((gal.u.to(u.kpc).value, 
                 gal.v.to(u.kpc).value,
                 gal.w.to(u.kpc).value)).T

UVW = np.vstack((gal.U.to(u.km/u.s).value, 
                 gal.V.to(u.km/u.s).value, 
                 gal.W.to(u.km/u.s).value)).T

disk_vmask = np.linalg.norm(UVW, axis=1) < 150.

In [None]:
XX=xyz
VV=UVW
dist2=np.sqrt(XX[:,0]**2+XX[:,1]**2)

_cyl = gal.represent_as('cylindrical')
mask_r100 = (_cyl.rho < 100*u.pc) & (np.abs(_cyl.z) < 150*u.pc)
mask_r300 = (_cyl.rho < 300*u.pc) & (np.abs(_cyl.z) < 500*u.pc)
mask_r500 = (_cyl.rho < 500*u.pc) & (np.abs(_cyl.z) < 500*u.pc)
mask_r100.sum(), mask_r500.sum()
#local_mask=(dist2<0.2)*(np.fabs(XX[:,2])<0.2)
#train_mask=(dist2>0.2)*(dist2<0.5)*(np.fabs(XX[:,2])<0.5)

local_v = UVW[disk_vmask & mask_r100]
local_x = xyz[disk_vmask & mask_r100]

print(local_v.shape)

In [None]:
local_gmm = GaussianMixture(n_components=64)
local_gmm.fit(local_v)

In [None]:
myspacexv = MySpace(local_gmm, terms=['x','xv'])

In [None]:

train_v = UVW[disk_vmask & mask_r300]
train_x = xyz[disk_vmask & mask_r300]

test_v = UVW[disk_vmask & mask_r500]
test_x = xyz[disk_vmask & mask_r500]

local_v.shape, train_v.shape, test_v.shape

In [None]:
# subsample so Hogg doesn't die of old age
n, p = train_v.shape
I = np.random.randint(n, size=128)
train_x, train_v = train_x[I], train_v[I]
print(train_x.shape, train_v.shape)
n2, p = test_v.shape
I2 = np.random.randint(n2, size=42)
test_x, test_v = test_x[I2], test_v[I2]
print(test_x.shape, test_v.shape)

In [None]:
resxv, tensorsxv = myspacexv.fit(train_x, train_v)
# resxvx, tensorsxvx = myspacexvx.fit(train_x, train_v)

In [None]:
print(resxv, tensorsxv)

In [None]:
fixxv=myspacexv.get_model_v(test_v,test_x)

In [None]:
print(fixxv)