In [180]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import astropy.units as u
import astropy.coordinates as coord
from astropy.table import Table
from astropy.io import ascii

In [351]:
#Initiation of data

def load_data(file,flagset):
    
    data_raw = ascii.read(file, format='fast_csv')
    
    data = Table(data_raw, copy=True)
    
    if flagset == 'flag_dup':
        
        flag_list=['flag_dup']
        
    elif flagset == 'flag_any':
        
        flag_list=['flag_any']

    elif flagset == 'flag_any-lowlogg':
        
        flag_list=['flag_dup','flag_N','flag_outlier','flag_pole']
        
    elif flagset == 'None':
        
        flag_list = []
        
    elif flagset == 'custom':
        
        flag_list=str(input('Tell me which flags to remove: ')).split(',')

    else:
        raise Exception('Not a valid flagset')
    
    bad_rows = np.array([])
    
    for i in flag_list:
        
        bad = np.nonzero(data[i])
        
        bad_rows = np.concatenate((bad_rows,bad[0]))
    
    bad_rows = np.unique(bad_rows)
    
    data.remove_rows(bad_rows.astype(int))
    
    return data,data_raw


flagset = 'flag_any'

try:
    data
except NameError:
    
    data,data_raw = load_data('Distances_PJM2017.csv',flagset)

data = data_raw[0:10000]
    
RA = data['RAdeg']*u.degree
DEC = data['DEdeg']*u.degree
pm_RA = data['pmRA_TGAS']*u.mas/u.yr
pm_DEC = data['pmDE_TGAS']*u.mas/u.yr

sample = coord.ICRS(ra = RA, dec = DEC, pm_ra_cosdec = pm_RA, pm_dec = pm_DEC)

sample = sample.transform_to(coord.Galactic)

A = 15.3
B = -11.9

mul_obs = sample.pm_l_cosb.value
mub_obs = sample.pm_b.value

#bvals = np.array([10.,10.,10.])
#lvals = np.array([30.,30.,30.])
bvals = sample.b.value
lvals = sample.l.value

b = np.deg2rad(bvals) # just a test
l = np.deg2rad(lvals)
cosl = np.cos(l)
cosb = np.cos(b)
sinl = np.sin(l)
sinb = np.sin(b)
s = 0.1

mul = mul_obs - A*np.cos(2*l)-B
mub = mub_obs + A*np.sin(2*l)*cosb*sinb

#S = np.array([[-sinl,cosl,0],
#              [-sinb*cosl,-sinb*sinl,cosb]])

#mul = np.array([1,1,1])
#mub = np.array([1,1,1])

pvals = s*np.array([-sinl*cosb*mul - cosl*sinb*mub,
                 cosl*cosb*mul - sinl*sinb*mub,
                 cosb*mub])
    
rhatvals = np.array([cosb*cosl, cosb*sinl, sinb]).T
pvals = pvals.T

#Some test-values

vmin = np.array([-200,-200,-200])

n = np.array([8,8,8])

dv = np.array([50,50,50])

#pk = np.array([0,0,0])

#rhat = np.array([1,1,1]) / np.sqrt(3)

#rhat = np.stack([rhat]*len(pvals))

#vmin = np.array([0,0,0])
#dv = np.array([1,1,1])
#n = np.array([10,10,10])

In [343]:
def calc_K(pk,rhat,vmin,dv,n):
    '''Calculate the values of K simultaneously for all bins'''
    
    vxmin, vymin, vzmin = vmin
    dvx, dvy, dvz = dv
    nx, ny, nz = n
    pkx, pky, pkz = pk
    rhatx, rhaty, rhatz = rhat
    
    K = np.zeros((nx,ny,nz))
    
    vxmax, vymax, vzmax = vxmin+nx*dvx,vymin+ny*dvy,vzmin+nz*dvz
    # Find intersections along one dimension (say x)
    
    vx_bins = np.arange(vxmin, vxmax+dvx, dvx)
    vy_bins = np.arange(vymin, vymax+dvy, dvy)
    vz_bins = np.arange(vzmin, vzmax+dvz, dvz)
    
    if np.round(np.linalg.norm(rhat))!=1:
        raise ValueError('rhat must be a unit vector')
    
    vrx = (vx_bins-pkx)/rhatx
    vry = (vy_bins-pky)/rhaty
    vrz = (vz_bins-pkz)/rhatz
    
    vrmax = min(max(vrx),max(vry),max(vrz))
    vrmin = max(min(vrx),min(vry),min(vrz))
    
    vrx = vrx[(vrx<=vrmax) & (vrx>=vrmin)]
    vry = vry[(vry<=vrmax) & (vry>=vrmin)]
    vrz = vrz[(vrz<=vrmax) & (vrz>=vrmin)]
    vr = np.concatenate((vrx,vry,vrz))
    vr.sort()
    
    vr_prime =(vr[:-1] + vr[1:]) / 2
    line_bins = np.zeros((len(vr_prime),3))

    pk = np.stack([pk]*len(vr_prime))
    rhat = np.stack([rhat]*len(vr_prime))
    vmin = np.stack([vmin]*len(vr_prime))
    vr_primestack = np.stack([vr_prime]*3,axis=1)

    v_prime = pk + vr_primestack*rhat
    line_bins += np.floor((v_prime-vmin)/ dv)
    
#    for i in range(len(vr_prime)):
#        v_prime = pk + vr_prime[i]*rhat
#        line_bins[i] += np.floor((v_prime-vmin)/ dv)
      
    line_bins = line_bins.astype(int)
    #line_bins, unique_ind = np.unique(line_bins,axis=0,return_index=True)
    
    line_len = vr[1:]-vr[:-1]
    non_zero = np.nonzero(line_len)
    line_len = line_len[non_zero]
    line_bins = line_bins[non_zero]
    
    K[line_bins[:,0],line_bins[:,1],line_bins[:,2]] = line_len/(dvx*dvy*dvz)
    
    return K

In [344]:
def calc_sigma2(rhat,pvals):
    
    pmean = np.mean(pvals, axis=0)
    
    rhat_outer = rhat[:,:,None]*rhat[:,None,:]

    iden = np.identity(3)
    
    A = np.stack([iden]*len(rhat_outer))-rhat_outer

    """Look into 'LinAlgError: Singular matrix' for some data points"""
    
    Ainv = np.linalg.inv(A)
    
    Ainv_mean = np.mean(Ainv,axis=0)
    v_mean = np.dot(Ainv_mean, pmean)
    
    pp = pvals - np.dot(A,v_mean)
    
    ppx, ppy, ppz = pp[:,0], pp[:,1], pp[:,2]
    
    ppx2mean = np.mean(ppx*ppx,axis=0)
    ppy2mean = np.mean(ppy*ppy,axis=0)
    ppz2mean = np.mean(ppz*ppz,axis=0)
    
    pp2mean = np.array([ppx2mean,ppy2mean,ppz2mean])
    
    A = np.array([[9,-1,-1],[-1,9,-1],[-1,-1,9]])
    
    sigma2 = (3/14)*np.dot(A,pp2mean)
    
    return sigma2

In [345]:
def nl_delta(n,l):
    
    """Checks if our given vector n is within one unit vector e_i of the cell l"""
    
    e_x = np.array([1,0,0])
    e_y = np.array([0,1,0])
    e_z = np.array([0,0,1])
    
    rules = [np.array_equal(n,l+e_x),
            np.array_equal(n,l-e_x),
            np.array_equal(n,l+e_y),
            np.array_equal(n,l-e_y),
            np.array_equal(n,l+e_z),
            np.array_equal(n,l-e_z)]
    
    if np.array_equal(n,l):
        delta = -2
    elif any(rules):
        delta = 1
    else:
        delta = 0
    
    return delta
        
def calc_xhi(line_bins,sigma2,hx,hy,hz,nx,ny,nz):

    #Given a vector l, find the estimate of the second derivative. Compare l with possible adjacent n values.    
    
    h2 = np.array([hx**2,hy**2,hz**2])
    
    xhi = np.zeros((len(line_bins),7))
    
    n_bins = np.array([nx,ny,nz])
    
    e_x = np.array([1,0,0])
    e_y = np.array([0,1,0])
    e_z = np.array([0,0,1])
    
    for i in range(len(line_bins)):
        
        l = line_bins[i]
        
        n_list = [l,
                 l+e_x,l-e_x,
                 l+e_y,l-e_y,
                 l+e_z,l-e_z]
        
        print(n_list)
        
        for j in range(7):
            
            n = n_list[j]
            
            if (all(n>=0)) and (all(n<=n_bins)):
                xhi[i][j] += np.sum((sigma2/h2) *  nl_delta(n,l))

    return xhi

In [346]:
def sec_der(phi,sigma2,dv):
    
    """Estimates the second deriative for ln(f(v_l)) given a sample of stars."""
    
    nx, ny, nz = phi.shape
    hx, hy, hz = h
    
    nxx, nyy, nzz = nx+2, ny+2, nz+2
    
    h2 = np.array([hx*hx,hy*hy,hz*hz])

    phip = np.zeros((nxx,nyy,nzz))

    phip[1:-1,1:-1,1:-1] = phi
    
    kappa = sigma2/h2
    
    kappa_sum = -2*sum(sigma2/h2)
    
    phi_fac = np.array([phip[0:nxx-2,1:-1,1:-1]+phip[2:nxx,1:-1,1:-1],
                           phip[1:-1,0:nyy-2,1:-1]+phip[1:-1,2:nyy,1:-1],
                           phip[1:-1,1:-1,0:nzz-2]+phip[1:-1,1:-1,2:nzz]])

    phi_arrx = (sigma2[0]/h2[0])*phi_fac[0]
    phi_arry = (sigma2[1]/h2[1])*phi_fac[1]
    phi_arrz = (sigma2[2]/h2[2])*phi_fac[2]
    
    phi_arr = phi_arrx+phi_arry+phi_arrz+kappa_sum*phi

    return phi_arr

In [347]:
def phi_guess(vmin,dv,n):
    
    vxmin, vymin, vzmin = vmin
    dvx, dvy, dvz = dv
    nx, ny, nz = n
    
    vxmax, vymax, vzmax = vxmin+nx*dvx,vymin+ny*dvy,vzmin+nz*dvz
    
    vx_bins = np.arange(vxmin, vxmax+dvx, dvx)
    vy_bins = np.arange(vymin, vymax+dvy, dvy)
    vz_bins = np.arange(vzmin, vzmax+dvz, dvz)
    
    vxc = (vx_bins[1:]+vx_bins[:-1])/2
    vyc = (vy_bins[1:]+vy_bins[:-1])/2
    vzc = (vz_bins[1:]+vz_bins[:-1])/2
        
#    wthin = 0.91
#    wthick= 0.08
#    wthin = 0.70
#    wthick= 0.20
#    whalo = 1.-wthin-wthick

    thin0 = np.array([0,215,0])
#    thick0 = np.array([0,180,0])
#    halo0 = np.array([0,0,0])

    thin_disp  = np.array([30,20,17])
#    thick_disp  = np.array([80,60,55])
#    halo_disp = np.array([160,100,100])
    
    v0 = np.stack([thin0]*len(vxc))
    v0x = v0[:,0:1]
    v0y = v0[:,1:2]
    v0z = v0[:,2:3]
    
    disp = np.stack([thin_disp]*len(vxc))
    dispx = disp[:,0:1]
    dispy = disp[:,1:2]
    dispz = disp[:,2:3]
    
    gx = np.exp(-((vxc-v0x)**2)/(2*dispx**2)) / (dispx*np.sqrt(2*np.pi))
    gy = np.exp(-((vyc-v0y)**2)/(2*dispy**2)) / (dispy*np.sqrt(2*np.pi))
    gz = np.exp(-((vzc-v0z)**2)/(2*dispz**2)) / (dispz*np.sqrt(2*np.pi))
    gx, gy, gz = gx[:,0], gy[:,0], gz[:,0]
    
    phix, phiy, phiz = np.meshgrid(gx,gy,gz)
    
    phi = np.array([phix,phiy,phiz])
    
    phi = np.sum(phi,axis=0)
    
    return phi.T

In [348]:
def get_L(alpha,pvals,rhat,vmn,dv,n):
    
    """Test to obtain our L_tilde function. Call each function for a given sample -> get L."""
    
#    pvals, rhat, vmin, dv, n = args
    
    dvx, dvy, dvz = dv
    
    N = len(pvals)
    
    phi = phi_guess(vmin,dv,n)
    
    exphi = np.exp(phi)
    
    sigma2 = calc_sigma2(rhat,pvals)
    
    phixhi = sec_der(phi,sigma2,dv)
    
    K_sum = 0
    
    for i in range(N):
        rhatval = rhat[i]
        
        K = calc_K(pvals[i],rhatval,vmin,dv,n)
        Kphi = exphi*K
        K_sum += np.sum(Kphi)
        
    L = np.log(K_sum)/N + np.sum(exphi)-(alpha/(2*dvx*dvy*dvz))*np.sum(phixhi**2)
    
    return L

In [352]:
#%timeit get_L(1,pvals,rhatvals,vmin,dv,n)

2.62 s ± 9.56 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [189]:
#def max_L(alpha,*args):
    
    
    
    #make initial guess of alpha and phi (write new phi func)
    #maximize for these two parameters given a sample with stars
    #prosper