In [1]:
import numpy as np
import cupy as cp
print("cupy: ", cp.__version__)

cupy:  7.3.0


In [116]:
def xy_coords(coords_1d):
    x = coords_1d % 4096
    y = coords_1d // 4096
    return x, y


@cp.fuse()
def dist_mat(x1, x2, y1, y2):
    return (x1 - x2)*(x1 - x2) + (y1-y2)*(y1-y2)



@cp.fuse()
def dist_mat_mink(x1, x2):
    return cp.sum((x1 - x2)*(x1 - x2), axis=-1)


@cp.fuse()
def dist_mat_dot(x1, x1b, y1, y1b):
    dx = x1 - x1b
    dy = y1 - y1b
    
    
    return cp.dot(x1-x1b, x2)



def dist_matrix(coords):
    # Calculate distance matrix for a series of 1D-from-2D-ravelled coordinates. 
    
    # unravel  the 2D coordinate system of the CCD (4096 x 4096). This is faster than unravel() functions
    x, y = xy_coords(coords)
    # Use broadcasting to get Euclidian distances. Seemed faster than using meshgrid functions.
    xb = x[:, cp.newaxis]
    yb = y[:, cp.newaxis]

    dist_matrix = dist_mat(x, xb, y, yb)
    
    return dist_matrix.get()


def dist_matrix_mink(coords):
    # Calculate distance matrix for a series of 1D-from-2D-ravelled coordinates. 
    
    # unravel  the 2D coordinate system of the CCD (4096 x 4096). This is faster than unravel() functions
    xy = cp.array([coords%4096, coords//4096]).T
    xy1 = xy[:, cp.newaxis, :]
    xy2 = xy[cp.newaxis, :, :]

    dist_matrix = dist_mat_mink(xy1, xy2)
    
    return dist_matrix.get()


def dist_matrix_dot(coords):
    # Calculate distance matrix for a series of 1D-from-2D-ravelled coordinates. 
    
    # unravel  the 2D coordinate system of the CCD (4096 x 4096). This is faster than unravel() functions
    xy = cp.array([coords%4096, coords//4096]).T
    xyb = xy[:, :, cp.newaxis]

    dist_matrix = dist_mat_dot(xy1, xy2)
    
    return dist_matrix.get()



def where_near(coords1, coords2, distance):
    # Get a boolean array of same size as coords1 assigning True / False to its elements within 'distance' from coords2
    dmat = dist_matrix(coords1, coords2)
    near_mask = dmat <= distance
    near_mask_tri = cupy.triu(near_mask, k=1)
    where_near = near_mask_tri.any(axis=1)
    return where_near.get()

In [76]:
nspikes = 10_000
spikes = cp.random.randint(1, high=(4096*4096)-1, size=nspikes, dtype=np.int32)

In [31]:
%time dm = dist_matrix(spikes)

CPU times: user 93.5 ms, sys: 41.3 ms, total: 135 ms
Wall time: 134 ms


In [115]:
%time dm = dist_matrix_mink(spikes)

CPU times: user 171 ms, sys: 71.2 ms, total: 242 ms
Wall time: 241 ms


In [105]:
xy = cp.array([spikes%4096, spikes//4096])
d = cp.dot(xy.T, xy)

In [120]:
xy = cp.array([spikes%4096, spikes//4096]).T
xy1 = xy[:, cp.newaxis, :]
xy2 = xy[cp.newaxis, :, :]

In [127]:
dxy = xy-xy2
dxy.shape

(1, 10000, 2)