In [13]:
import numpy as np
import cupy as cp
from cupyx import jit
import time

print(f"CuPy version {cp.__version__}")
mempool = cp.get_default_memory_pool()
mempool.free_all_blocks()
print(f"mempool.used_bytes {mempool.used_bytes()}")


@jit.rawkernel()
def foo(particle_xy: cp.ndarray, # Nx2
        beacon_xy: cp.ndarray, # Mx2
        outp: cp.ndarray, # NxM
        size: np.int32) -> None:
    tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x
    ntid = jit.gridDim.x * jit.blockDim.x
    for i in range(tid, size, ntid):
        # i is the row in N
        for j in range(beacon_xy.shape[0]):
            # j is the row in M
            p_x = particle_xy[i,0]
            p_y = particle_xy[i,1]
            b_x = beacon_xy[j,0]
            b_y = beacon_xy[j,1]
            dx = p_x - b_x
            dy = p_y - b_y
            d = cp.hypot(dx, dy)
            outp[i, j] = d
    

#particle_xy = cp.array([[1,2],[3,4],[5,6]])

for PARTICLE_COUNT in range(1,10000000,1000000):


    particle_xy = cp.random.uniform(low=0, high=10, size=(PARTICLE_COUNT,2))
    beacon_xy = cp.array([[1,0],[0,1]])
    b = cp.zeros((particle_xy.shape[0], beacon_xy.shape[0]))

    #print("particle_xy")
    #print(particle_xy)
    #print("beacon_xy")
    #print(beacon_xy)
    #print("b")
    #print(b)

    t0_ns = time.time_ns()
    foo((128,),(1024,),(particle_xy,beacon_xy,b,PARTICLE_COUNT))
    t1_ns = time.time_ns()
    duration_ns = t1_ns - t0_ns
    per_particle_ns = duration_ns/PARTICLE_COUNT
    print(f" count {PARTICLE_COUNT} duration us {duration_ns/1000:.0f} per particle ns {per_particle_ns:.2f}")

    #print("b")
    #print(b)


CuPy version 13.2.0
mempool.used_bytes 288001536
 count 1 duration us 4261 per particle ns 4260565.00
 count 1000001 duration us 95 per particle ns 0.09
 count 2000001 duration us 86 per particle ns 0.04
 count 3000001 duration us 77 per particle ns 0.03
 count 4000001 duration us 77 per particle ns 0.02
 count 5000001 duration us 100 per particle ns 0.02
 count 6000001 duration us 138 per particle ns 0.02
 count 7000001 duration us 83 per particle ns 0.01
 count 8000001 duration us 118 per particle ns 0.01
 count 9000001 duration us 128 per particle ns 0.01
