In [1]:
import numpy as np
import time
from numba import jit, njit

In [2]:
np.random.seed(0)
A = np.random.rand(50000,3)*500

x_bins = np.linspace(0, 500, num=50)

@jit
def count_3D_least_loops(A, x_bins, y_bins, z_bins):
    
    hist = np.zeros((len(x_bins)-1, len(y_bins)-1, len(z_bins)-1))
    
    for i in range(len(x_bins)-1):
        for j in range(len(y_bins)-1):
            for k in range(len(z_bins)-1):
                hist[i,j,k] = np.sum((A[:,0] >= x_bins[i]) * (A[:,0] < x_bins[i+1]) * 
                                     (A[:,1] >= y_bins[j]) * (A[:,1] < y_bins[j+1]) * 
                                     (A[:,2] >= z_bins[k]) * (A[:,2] < z_bins[k+1]))
                
    return hist

@jit
def count_3D_least_loops_optimized(A, x_bins, y_bins, z_bins):

    hist = np.zeros((len(x_bins)-1, len(y_bins)-1, len(z_bins)-1))

    for i in range(len(x_bins)-1):
        a = A[:,0] >= x_bins[i]
        np.less(A[:,0], x_bins[i+1], out=a, where=a)
        
        for j in range(len(y_bins)-1):
            b = a.copy()
            np.greater_equal(A[:,1], y_bins[j], out=b, where=b)
            np.less(A[:,1], y_bins[j+1], out=b, where=b)
            
            for k in range(len(z_bins)-1):
                c = b.copy()
                np.greater_equal(A[:,2], z_bins[k], out=c, where=c)
                np.less(A[:,2], z_bins[k+1], out=c, where=c)
                
                hist[i,j,k] = np.sum(c)

    return hist

@jit
def count_3D_least_loops_hist_outside(A, hist, x_bins, y_bins, z_bins):

    for i in range(len(x_bins)-1):
        a = A[:,0] >= x_bins[i]
        np.less(A[:,0], x_bins[i+1], out=a, where=a)
        
        for j in range(len(y_bins)-1):
            b = a.copy()
            np.greater_equal(A[:,1], y_bins[j], out=b, where=b)
            np.less(A[:,1], y_bins[j+1], out=b, where=b)
            
            for k in range(len(z_bins)-1):
                c = b.copy()
                np.greater_equal(A[:,2], z_bins[k], out=c, where=c)
                np.less(A[:,2], z_bins[k+1], out=c, where=c)
                
                hist[i,j,k] = np.sum(c)

    return hist

@jit
def count_3D_atoms(A, x_bins, y_bins, z_bins):
    
    hist = np.zeros((len(x_bins)-1, len(y_bins)-1, len(z_bins)-1))
    
    for i in range(A.shape[0]):
        for x in range(len(x_bins)-1):
            if A[i,0] >= x_bins[x]:
                if A[i,0] < x_bins[x+1]:
                    a = x
                    break
        for y in range(len(y_bins)-1):
            if A[i,1] >= y_bins[y]:
                if A[i,1] < y_bins[y+1]:
                    b = y
                    break
        for z in range(len(z_bins)-1):
            if A[i,2] >= z_bins[z]:
                if A[i,2] < z_bins[z+1]:
                    c = z
                    break
                
        hist[a,b,c] += 1

    return hist

def count_3D_atoms_no_jit(A, x_bins, y_bins, z_bins):

    hist = np.zeros((len(x_bins)-1, len(y_bins)-1, len(z_bins)-1))
    
    for i in range(A.shape[0]):
        for x in range(len(x_bins)-1):
            if A[i,0] >= x_bins[x]:
                if A[i,0] < x_bins[x+1]:
                    a = x
                    break
        for y in range(len(y_bins)-1):
            if A[i,1] >= y_bins[y]:
                if A[i,1] < y_bins[y+1]:
                    b = y
                    break
        for z in range(len(z_bins)-1):
            if A[i,2] >= z_bins[z]:
                if A[i,2] < z_bins[z+1]:
                    c = z
                    break
                
        hist[a,b,c] += 1

    return hist

@jit
def count_3D_atoms_hist_outside(A, hist, x_bins, y_bins, z_bins):
    
    for i in range(A.shape[0]):
        for x in range(len(x_bins)-1):
            if A[i,0] >= x_bins[x]:
                if A[i,0] < x_bins[x+1]:
                    a = x
                    break
        for y in range(len(y_bins)-1):
            if A[i,1] >= y_bins[y]:
                if A[i,1] < y_bins[y+1]:
                    b = y
                    break
        for z in range(len(z_bins)-1):
            if A[i,2] >= z_bins[z]:
                if A[i,2] < z_bins[z+1]:
                    c = z
                    break
                
        hist[a,b,c] += 1

    return hist

In [3]:
t0 = time.time()

hist1 = count_3D_least_loops(A, x_bins, x_bins, x_bins)

print('Seconds to complete:', time.time()-t0)
print(np.allclose(hist1,hist1))

Seconds to complete: 21.405205249786377
True


In [4]:
t0 = time.time()

hist2 = count_3D_least_loops_optimized(A, x_bins, x_bins, x_bins)

print('Seconds to complete:', time.time()-t0)
print(np.allclose(hist1,hist2))

Seconds to complete: 12.735662460327148
True


In [5]:
t0 = time.time()

hist3 = np.zeros((len(x_bins)-1, len(x_bins)-1, len(x_bins)-1))

hist3 = count_3D_least_loops_hist_outside(A, hist3, x_bins, x_bins, x_bins)

print('Seconds to complete:', time.time()-t0)
print(np.allclose(hist1,hist3))

Seconds to complete: 12.51089859008789
True


In [6]:
t0 = time.time()

hist4 = count_3D_atoms(A, x_bins, x_bins, x_bins)

print('Seconds to complete:', time.time()-t0)
print(np.allclose(hist1,hist4))

Seconds to complete: 0.2802243232727051
True


In [7]:
t0 = time.time()

hist5 = count_3D_atoms_no_jit(A, x_bins, x_bins, x_bins)

print('Seconds to complete:', time.time()-t0)
print(np.allclose(hist1,hist5))

Seconds to complete: 1.8271210193634033
True


In [8]:
t0 = time.time()

hist6 = np.zeros((len(x_bins)-1, len(x_bins)-1, len(x_bins)-1))

hist6 = count_3D_atoms_hist_outside(A, hist6, x_bins, x_bins, x_bins)

print('Seconds to complete:', time.time()-t0)
print(np.allclose(hist1,hist6))

Seconds to complete: 0.22334837913513184
True


In [9]:
hist1

array([[[0., 3., 0., ..., 0., 0., 1.],
        [0., 1., 0., ..., 1., 1., 1.],
        [0., 2., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 2., ..., 2., 0., 1.],
        [1., 0., 0., ..., 1., 1., 2.],
        [1., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 4., ..., 2., 1., 1.],
        [0., 1., 0., ..., 1., 1., 0.],
        [1., 1., 1., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 1., 1.],
        [0., 0., 0., ..., 0., 1., 0.],
        [1., 0., 1., ..., 1., 1., 0.]],

       [[2., 1., 0., ..., 0., 0., 0.],
        [1., 0., 1., ..., 0., 0., 0.],
        [0., 0., 1., ..., 0., 0., 0.],
        ...,
        [1., 1., 0., ..., 0., 1., 1.],
        [0., 0., 0., ..., 1., 0., 1.],
        [0., 1., 0., ..., 1., 0., 0.]],

       ...,

       [[0., 1., 0., ..., 0., 0., 0.],
        [0., 0., 1., ..., 0., 0., 0.],
        [0., 2., 0., ..., 0., 0., 2.],
        ...,
        [0., 0., 1., ..., 0., 1., 1.],
        [0., 0., 0., ..., 1., 1., 0.],
        [0., 0., 1., ..., 0., 1.