In [3]:
import numpy as np
from numba import jit

@jit
def count_weighted_pairs_3d(x1, y1, z1, w1, x2, y2, z2, w2, rbins_squared, result):

    n1 = x1.shape[0]
    n2 = x2.shape[0]
    nbins = rbins_squared.shape[0]

    for i in range(n1):
        px = x1[i]
        py = y1[i]
        pz = z1[i]
        pw = w1[i]
        for j in range(n2):
            qx = x2[j]
            qy = y2[j]
            qz = z2[j]
            qw = w2[j]
            dx = px-qx
            dy = py-qy
            dz = pz-qz
            wprod = pw*qw
            dsq = dx*dx + dy*dy + dz*dz

            k = nbins-1
            while dsq <= rbins_squared[k]:
                result[k-1] += wprod
                k=k-1
                if k<=0:
                    break


In [4]:
n1, n2 = 1000, 1000
Lbox = 1000.
x1 = np.random.uniform(0, Lbox, size=n1)
y1 = np.random.uniform(0, Lbox, size=n1)
z1 = np.random.uniform(0, Lbox, size=n1)
w1 = np.random.uniform(0, 1, size=n1)

x2 = np.random.uniform(0, Lbox, size=n2)
y2 = np.random.uniform(0, Lbox, size=n2)
z2 = np.random.uniform(0, Lbox, size=n2)
w2 = np.random.uniform(0, 1, size=n2)

nbins = 20
rmin, rmax = 0.1, 40
rbins = np.logspace(np.log10(rmin), np.log10(rmax), nbins)
rbins_squared = rbins**2

result_f32 = np.zeros(nbins).astype('f4')
result_f64 = np.zeros(nbins).astype('f8')



In [5]:
count_weighted_pairs_3d(x1.astype('f4'), y1.astype('f4'), z1.astype('f4'), w1.astype('f4'), 
                        x2.astype('f4'), y2.astype('f4'), z2.astype('f4'), w2.astype('f4'), 
                        rbins_squared.astype('f4'), result_f32)

count_weighted_pairs_3d(x1.astype('f8'), y1.astype('f8'), z1.astype('f8'), w1.astype('f8'), 
                        x2.astype('f8'), y2.astype('f8'), z2.astype('f8'), w2.astype('f8'), 
                        rbins_squared.astype('f8'), result_f64)

print(np.allclose(result_f32, result_f64))

True


In [6]:
result_f32

array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.28704327,  0.5060246 ,
        2.9829164 , 10.560754  , 27.483397  , 62.219048  ,  0.        ],
      dtype=float32)

In [7]:
result_f64

array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.28704326,  0.50602458,
        2.98291623, 10.56075547, 27.48339284, 62.21903872,  0.        ])