In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import os
import astropy.coordinates as coord
import astropy.units as u
from astropy.table import Table, hstack
from myspace import MySpace
from sklearn.mixture import GaussianMixture

In [None]:
import astropy
import sklearn
import jax
import numpy
import scipy

print('astropy', astropy.__version__)
print('scikit-learn', sklearn.__version__)
print('jax', jax.__version__)
print('numpy', numpy.__version__)
print('scipy', scipy.__version__)

In [None]:
gaia = Table.read('../data/RV-all-result.fits', format='fits')

In [None]:
# qindx=(gaia['bp_rp']<1.5)*(gaia['phot_g_mean_mag']<14.5)*(gaia['parallax']/gaia['parallax_error']>4.)*(gaia['parallax_error']<0.1)*(gaia['visibility_periods_used']>5.)*(gaia['phot_bp_rp_excess_factor']<1.3)*(gaia['phot_bp_rp_excess_factor']>1.172)
qindx = gaia['parallax'] > 0.5

In [None]:
def make_anim2(XX,VV,tensorsx,myspacex,tensorsxv,myspacexv,gs=150):
    
    
    _cyl = gal.represent_as('cylindrical')
    mask2_r500 = (_cyl.rho < 500*u.pc) & (np.abs(_cyl.z) < 500*u.pc) & (_cyl.rho > 200*u.pc)
    disk_vmask2=(np.sqrt(VV[:,0]**2+VV[:,1]**2+VV[:,2]**2)<200.)
    
    rindx=(disk_vmask2)*(mask2_r500)
    for i in range(0,36):
        wedgedex=rindx*(_cyl.phi+np.pi*u.rad>(i*np.pi/18.)*u.rad)*(_cyl.phi+np.pi*u.rad<((i+3)*np.pi/18.)*u.rad)
        if i==34:
            wedgedex=rindx*(_cyl.phi+np.pi*u.rad>(i*np.pi/18.)*u.rad)*(_cyl.phi+np.pi*u.rad<((i+3)*np.pi/18.)*u.rad)+rindx*(_cyl.phi+np.pi*u.rad>0.)*(_cyl.phi+np.pi*u.rad<((1)*np.pi/18.)*u.rad)
        if i==35:
            wedgedex=rindx*(_cyl.phi+np.pi*u.rad>(i*np.pi/18.)*u.rad)*(_cyl.phi+np.pi*u.rad<((i+3)*np.pi/18.)*u.rad)+rindx*(_cyl.phi+np.pi*u.rad>0.)*(_cyl.phi+np.pi*u.rad<((2)*np.pi/18.)*u.rad)
        print(wedgedex.sum(),'stars in wedge',i)

        fixx=myspacex.get_model_v(tensorsx,VV[wedgedex],XX[wedgedex])
        fixxv=myspacexv.get_model_v(tensorsxv,VV[wedgedex],XX[wedgedex])

        f, ((ax1,ax2),(ax3,ax4)) = plt.subplots(2, 2, figsize=(15,15))
        ax1.hist2d(XX[:,0][wedgedex],XX[:,1][wedgedex],range=[[-500,500],[-500.,500.]],bins=gs,cmin=1.0e-50,rasterized=True,density=True)
        ax1.set_xlabel(r'$X\ (\mathrm{kpc})$',fontsize=20)
        ax1.set_ylabel(r'$Y\ (\mathrm{kpc})$',fontsize=20)
        ax1.set_xlim(-500.,500.)
        ax1.set_ylim(-500.,500.)
        ax1.set_title(r'$\mathrm{Selected\ area}$',fontsize=20)
        ax2.hist2d(VV[:,0][wedgedex],VV[:,1][wedgedex],range=[[-125,125],[-125,125]],bins=gs,cmin=1.0e-50,rasterized=True,density=True)
        ax2.set_xlabel(r'$v_X\ (\mathrm{km\ s}^{-1})$',fontsize=20)
        ax2.set_ylabel(r'$v_Y\ (\mathrm{km\ s}^{-1})$',fontsize=20)
        ax2.set_xlim(-125,125)
        ax2.set_ylim(-125,125)
        ax2.set_title(r'$\mathrm{No\ correction}$',fontsize=20)
        ax3.hist2d(fixx[:,0],fixx[:,1],range=[[-125,125],[-125,125]],bins=gs,cmin=1.0e-50,rasterized=True,density=True)
        ax3.set_xlabel(r'$v_X\ (\mathrm{km\ s}^{-1})$',fontsize=20)
        ax3.set_ylabel(r'$v_Y\ (\mathrm{km\ s}^{-1})$',fontsize=20)
        ax3.set_xlim(-125,125)
        ax3.set_ylim(-125,125)
        ax3.set_title(r'$\mathrm{x\ correction}$',fontsize=20)
        ax4.hist2d(fixxv[:,0],fixxv[:,1],range=[[-125,125],[-125,125]],bins=gs,cmin=1.0e-50,rasterized=True,density=True)
        ax4.set_xlabel(r'$v_X\ (\mathrm{km\ s}^{-1})$',fontsize=20)
        ax4.set_ylabel(r'$v_Y\ (\mathrm{km\ s}^{-1})$',fontsize=20)
        ax4.set_xlim(-125,125)
        ax4.set_ylim(-125,125)
        ax4.set_title(r'$\mathrm{xv\ correction}$',fontsize=20)
        ax1.tick_params(axis='both', which='major', labelsize=15)
        ax2.tick_params(axis='both', which='major', labelsize=15)
        ax3.tick_params(axis='both', which='major', labelsize=15)
        ax4.tick_params(axis='both', which='major', labelsize=15)
        plt.savefig('gaiao'+str("{:02d}".format(i))+'.pdf',bbox_inches='tight')
        plt.close()

    os.system('convert -delay 5 -loop 0 gaiao*.pdf orders.gif')

In [None]:
print(gaia['dec'][0]*u.deg)
c = coord.SkyCoord(ra=gaia['ra'][qindx]*u.deg,dec=gaia['dec'][qindx]*u.deg,distance=1./gaia['parallax'][qindx]*u.kpc,pm_ra_cosdec=gaia['pmra'][qindx]*u.mas/u.yr,pm_dec=gaia['pmdec'][qindx]*u.mas/u.yr,radial_velocity=gaia['radial_velocity'][qindx]*u.km/u.s)
# galcen = c.transform_to(coord.Galactocentric())

In [None]:
gal = c.galactic
gal.set_representation_cls('cartesian')

In [None]:
xyz = np.vstack((gal.u.to(u.pc).value, 
                 gal.v.to(u.pc).value,
                 gal.w.to(u.pc).value)).T

UVW = np.vstack((gal.U.to(u.km/u.s).value, 
                 gal.V.to(u.km/u.s).value, 
                 gal.W.to(u.km/u.s).value)).T

disk_vmask = np.linalg.norm(UVW, axis=1) < 200.

In [None]:
XX=xyz
VV=UVW
dist2=np.sqrt(XX[:,0]**2+XX[:,1]**2)

_cyl = gal.represent_as('cylindrical')
mask_r100 = (_cyl.rho < 100*u.pc) & (np.abs(_cyl.z) < 150*u.pc)
mask_r300 = (_cyl.rho < 300*u.pc) & (np.abs(_cyl.z) < 500*u.pc)
mask_r500 = (_cyl.rho < 500*u.pc) & (np.abs(_cyl.z) < 500*u.pc)
mask_r100.sum(), mask_r500.sum()
#local_mask=(dist2<0.2)*(np.fabs(XX[:,2])<0.2)
#train_mask=(dist2>0.2)*(dist2<0.5)*(np.fabs(XX[:,2])<0.5)

local_v = UVW[disk_vmask & mask_r100]
local_x = xyz[disk_vmask & mask_r100]

print(local_v.shape)

In [None]:
local_gmm = GaussianMixture(n_components=64)
local_gmm.fit(local_v)

In [None]:
local_gmm_samples, _ = local_gmm.sample(100_000)

fig, (ax1,ax2) = plt.subplots(1, 2, figsize=(10,5))
ax1.hist2d(local_gmm_samples[:, 0],
           local_gmm_samples[:, 1],
           bins=np.linspace(-150, 150, 256))
ax2.hist2d(local_gmm_samples[:, 0],
           local_gmm_samples[:, 2],
           bins=np.linspace(-150, 150, 256))
fig.tight_layout()

In [None]:
myspace = MySpace(local_gmm, terms=['x'])
myspacexv = MySpace(local_gmm, terms=['x','xv'])

In [None]:
train_v = UVW[disk_vmask & mask_r300]
train_x = xyz[disk_vmask & mask_r300]

test_v = UVW[disk_vmask & mask_r500]
test_x = xyz[disk_vmask & mask_r500]

local_v.shape, train_v.shape, test_v.shape

In [None]:
res, tensors = myspace.fit(train_x, train_v)
resxv, tensorsxv = myspacexv.fit(train_x, train_v)
print(tensors)
print(tensorsxv)

In [None]:
fixx=myspace.get_model_v(tensors,test_v,test_x)
fixxv=myspacexv.get_model_v(tensorsxv,test_v,test_x)

In [None]:
f, ((ax1,ax2,ax3)) = plt.subplots(1, 3, figsize=(15,5))
gs=200
ax1.hexbin(test_v[:,0],test_v[:,1],extent=[-125,125,-125,125],mincnt=1,rasterized=True,gridsize=gs)
ax1.set_title('Uncorrected',fontsize=20)
ax1.set_xlabel('vx (km/s)',fontsize=20)
ax1.set_ylabel('vy (km/s)',fontsize=20)
ax1.set_xlim(-125,125)
ax1.set_ylim(-125,125)
ax2.hexbin(fixx[:,0],fixx[:,1],extent=[-125,125,-125,125],mincnt=1,rasterized=True,gridsize=gs)
ax2.set_title('x corrected',fontsize=20)
ax2.set_xlabel('vx (km/s)',fontsize=20)
ax2.set_xlim(-125,125)
ax2.set_ylim(-125,125)
ax3.hexbin(fixxv[:,0],fixxv[:,1],extent=[-125,125,-125,125],mincnt=1,rasterized=True,gridsize=gs)
ax3.set_title('xv corrected',fontsize=20)
ax3.set_xlabel('vx (km/s)',fontsize=20)
ax3.set_xlim(-125,125)
ax3.set_ylim(-125,125)
plt.show()

In [None]:
make_anim2(XX,VV,tensors,myspace,tensorsxv,myspacexv)