In [1]:
import numpy as np
import numba as nb
import twopoint_counter as tpc
from matplotlib import pyplot as plt
from numba import njit, prange
import math

In [2]:
from scipy.interpolate import interp1d

def get_nbar(z, zbins, cosmo, fsky=1):

    
    rbins = cosmo.comoving_distance(zbins) * cosmo.h
    vol_shell = fsky * 4*np.pi/3 * (rbins[1:]**3-rbins[:-1]**3)
    counts, _ = np.histogram(z, bins=zbins)
    density = counts/vol_shell
    
    nbar = {'zbins': zbins, 'rbins': rbins, 'vol_shell': vol_shell, 'counts': counts, 'density': density, 'fsky': fsky}
    return nbar

def assign_nbar(z, zbins, density): 
    dens = interp1d(zbins[:-1], density, kind='zero', fill_value='extrapolate')
    return dens(z)

In [2]:
### PV catalog
NPv = int(100000)
z = np.random.rand(NPv)*0.10
pos = np.random.random((NPv,3)) * 2000 -1000
v = np.random.randn(NPv)*400

dat_vpos = pos[np.random.choice(range(pos.shape[0]), size=100, replace=False)]
ran_vpos = pos[np.random.choice(range(pos.shape[0]), size=500, replace=False)]

dat_v = np.random.choice(v,100)
ran_v = np.random.choice(v,500)

dat_vn = np.ones(100)*1e-5
ran_vn = np.ones(500)*1e-5

dat_alpha = np.ones(100)
ran_alpha = np.ones(500)
dat_vsig = np.zeros(100)
ran_vsig = np.zeros(500)
dat_vamp = ran_vamp = 100000

In [3]:
### Den catalog
Ngal = int(100000)
pos = np.random.random((Ngal,3)) * 2000 -1000

dat_dpos = pos[np.random.choice(range(pos.shape[0]), size=10000, replace=False)]
ran_dpos = pos[np.random.choice(range(pos.shape[0]), size=50000, replace=False)]

dat_dn = np.ones(10000)*1e-5
ran_dn = np.ones(10000)*1e-5
dat_damp = ran_damp = 100000

In [3]:
s_edges=np.linspace(10,120,20)
bins_s = (np.array(s_edges[1:]) + np.array(s_edges[:-1])) / 2
mu_edges = np.linspace(-1,1,100)
bins_mu = (np.array(mu_edges[1:]) + np.array(mu_edges[:-1])) / 2
edges = (s_edges,mu_edges)

In [13]:
engine = tpc.CorrelationFunctionPV(s_edges,
                                   ran_vpos,ran_vn,ran_v,ran_vsig,ran_alpha,ran_vamp,
                                   ran_vpos,ran_vn,ran_v,ran_vsig,ran_alpha,ran_vamp
                                   )


In [15]:
engine.RR(mode='psi1')

IndexError: index 1 is out of bounds for axis 0 with size 1

In [37]:
#@njit(parallel=False,cache=True)
def psi1_auto(X,nbar,v,v_sig,alpha,v_amp,bin_s_min,bin_s_max,n_s_bin):
    # return psi1_n,psi1_n,psi2_d,psi2_d

    num = np.zeros((n_s_bin))
    den = np.zeros((n_s_bin))
    eta = v/alpha
    sum_w = []
    for i in prange(X.shape[0]):
        sum_w +=  1 / (alpha[i]*v_amp*nbar[i] + v_sig[i]**2/alpha[i])
        for j in prange(i+1,X.shape[0]):
            
            deltax = X[i][0] - X[j][0]
            deltay = X[i][1] - X[j][1]
            deltaz = X[i][2] - X[j][2]
            #Norms of r, ra, rb vectors
            norm_r = math.sqrt((deltax)**2 + (deltay)**2 + (deltaz)**2)
            norm_ri = math.sqrt((X[i][0])**2 + (X[i][1])**2 + (X[i][2])**2)
            norm_rj = math.sqrt((X[j][0])**2 + (X[j][1])**2 + (X[j][2])**2) 
            #Geometry calculations
            cosAB = (X[i][0]/norm_ri * X[j][0]/norm_rj) + (X[i][1]/norm_ri * X[j][1]/norm_rj) + (X[i][2]/norm_ri * X[j][2]/norm_rj)
            #weight
            w_i = 1 / (alpha[i]*v_amp*nbar[i] + v_sig[i]**2/alpha[i])
            w_j = 1 / (alpha[j]*v_amp*nbar[j] + v_sig[j]**2/alpha[j])

            bin_s_index = (norm_r-bin_s_min)/(bin_s_max-bin_s_min)*n_s_bin
            if n_s_bin>bin_s_index>=0 :
                num[int(bin_s_index)] += w_i*w_j * cosAB * eta[i]*eta[j]
                den[int(bin_s_index)] += w_i*w_j * cosAB** 2 

    return num,den,sum_w
def psi1_cross(X1,X2,nbar1,nbar2,v1,v2,v1_sig,v2_sig,alpha1,alpha2,v1_amp,v2_amp,bin_s_min,bin_s_max,n_s_bin):
    # return psi1_n,psi1_n,psi2_d,psi2_d

    num = np.zeros((n_s_bin))
    den = np.zeros((n_s_bin))
    eta1 = v1/alpha1
    eta2 = v2/alpha2
    sum_w = []
    for i in prange(len(X1)):
        sum_w +=  1 / (alpha1[i]*v1_amp*nbar1[i] + v1_sig[i]**2/alpha1[i])
        for j in prange(len(X2)):
            
            deltax = X1[i][0] - X2[j][0]
            deltay = X1[i][1] - X2[j][1]
            deltaz = X1[i][2] - X2[j][2]
            #Norms of r, ra, rb vectors
            norm_r = math.sqrt((deltax)**2 + (deltay)**2 + (deltaz)**2)
            norm_ri = math.sqrt((X1[i][0])**2 + (X1[i][1])**2 + (X1[i][2])**2)
            norm_rj = math.sqrt((X2[j][0])**2 + (X2[j][1])**2 + (X2[j][2])**2) 
            #Geometry calculations
            cosAB = (X1[i][0]/norm_ri * X2[j][0]/norm_rj) + (X1[i][1]/norm_ri * X1[j][1]/norm_rj) + (X1[i][2]/norm_ri * X2[j][2]/norm_rj)
            #weight
            w_i = 1 / (alpha1[i]*v1_amp*nbar1[i] + v1_sig[i]**2/alpha1[i])
            w_j = 1 / (alpha2[j]*v2_amp*nbar2[j] + v2_sig[j]**2/alpha2[j])

            bin_s_index = (norm_r-bin_s_min)/(bin_s_max-bin_s_min)*n_s_bin
            if n_s_bin>bin_s_index>=0 :
                num[int(bin_s_index)] += w_i*w_j * cosAB * eta1[i]*eta2[j]
                den[int(bin_s_index)] += w_i*w_j * cosAB** 2 

    return num,den,sum_w


In [30]:
bin_s_min = s_edges[0]
bin_s_max = s_edges[-1]
n_s_bin = len(s_edges)-1

In [31]:
count = psi1_auto(dat_vpos,dat_vn,dat_v,dat_vsig,dat_alpha,dat_vamp,bin_s_min,bin_s_max,n_s_bin)

In [38]:
#@njit(parallel=True,cache=True)
def psi1_counter(X,nbar,v,v_sig,alpha,v_amp,bins_s,n_cells,indices_cells,L_cell_pair,L_cells):
    """
    Make sure numba has been initialized with the correct number 
    of threads : numba.set_num_threads(Nthread) 
    Note : dim of bins depends on the mode
    """

    bin_s_min = bins_s[0]
    bin_s_max = bins_s[-1]
    n_s_bin = len(bins_s)-1

    num = np.zeros((n_s_bin))
    den = np.zeros((n_s_bin))
    sum_w = []
    #--Count pairs within the same cell:
    indices_cells_flat = indices_cells[:,0]*(n_cells[1]*n_cells[2]) + indices_cells[:,1]*(n_cells[2]) +indices_cells[:,2]
    
    for i_cell in prange(L_cells.shape[0]):
        w = (indices_cells_flat==i_cell)
        count = psi1_auto(X[w],nbar[w],v[w],v_sig[w],alpha[w],v_amp,bin_s_min,bin_s_max,n_s_bin)
        num += count[0]
        den += count[1]
        sum_w += count[2]
    #--Count pairs between pair of cells       
    L_cell_pair_flat_1 = L_cell_pair[:,0]*(n_cells[1]*n_cells[2]) + L_cell_pair[:,1]*(n_cells[2]) +L_cell_pair[:,2]      
    L_cell_pair_flat_2 = L_cell_pair[:,3]*(n_cells[1]*n_cells[2]) + L_cell_pair[:,4]*(n_cells[2]) +L_cell_pair[:,5]      
                             
    for i_pair in prange(L_cell_pair.shape[0]):
        #-- acces the positions in one cell-pair :
        w1 = (indices_cells_flat==L_cell_pair_flat_1[i_pair])
        w2 = (indices_cells_flat==L_cell_pair_flat_2[i_pair])

        count = psi1_cross(X[w1],X[w2],nbar[w1],nbar[w2],v[w1],v[w2],v_sig[w1],v_sig[w2],alpha[w1],alpha[w2],v_amp,v_amp,bin_s_min,bin_s_max,n_s_bin)        
        num += count[0]
        den += count[1]
        sum_w += count[2]
        
    return num,den,sum_w

In [39]:
@njit(cache=True)
def get_cell_pairs(n_cells):

    #--building every possible cells:
    L_cells=np.empty((n_cells[0]*n_cells[1]*n_cells[2],3))
    i=0
    for a in prange(n_cells[0]): 
        for b in range(n_cells[1]): 
            for c in range(n_cells[2]):
                L_cells[i] = [a,b,c]
                i+=1

    L_cell_pair=[]
    for i in prange(L_cells.shape[0]):
        for j in range(i+1,L_cells.shape[0]):
            dist = L_cells[i]-L_cells[j]
            if -1<=dist[0]<=1 and -1<=dist[1]<=1 and -1<=dist[2]<=1:
                L_cell_pair.append(list(L_cells[i])+list(L_cells[j]))
    return np.array(L_cell_pair),L_cells

In [54]:
n_cells = np.int64(np.floor((np.max(ran_vpos,axis=0)-np.min(ran_vpos,axis=0))/ s_edges[-1]))
indices_cells_dat = np.int64(np.floor(n_cells * (dat_vpos-np.min(ran_dpos,axis=0))/(np.max(ran_vpos,axis=0)-np.min(ran_vpos,axis=0))))
indices_cells_ran = np.int64(np.floor(n_cells * (ran_vpos-np.min(ran_dpos,axis=0))/(np.max(ran_vpos,axis=0)-np.min(ran_vpos,axis=0))))
    
L_cell_pair,L_cells = get_cell_pairs(n_cells)
L_cells = np.int64(L_cells)
L_cell_pair = np.int64(L_cell_pair)


In [55]:
a = psi1_counter(ran_vpos,ran_vn,ran_v,ran_vsig,ran_alpha,ran_vamp,
                             s_edges,n_cells,indices_cells_ran,L_cell_pair,L_cells)

IndexError: index 1 is out of bounds for axis 0 with size 1

In [50]:
a[0]

array([      0.        ,       0.        ,       0.        ,
             0.        ,       0.        ,  -77712.29192056,
             0.        ,       0.        ,       0.        ,
             0.        , -222170.36631575,       0.        ,
             0.        ,       0.        ,       0.        ,
             0.        ,       0.        ,  183765.09192653,
             0.        ])

In [58]:
x = np.zeros((2,))

In [59]:
print(x.shape)

(2,)


In [63]:
print(x[1,1])

IndexError: too many indices for array: array is 1-dimensional, but 2 were indexed