In [2]:
%load_ext Cython

In [3]:
%%cython
import numpy as np
cimport numpy as np
from scipy.spatial.distance import cdist

from scipy.spatial.distance import cdist, pdist

cdef double dense_dist_mat_at_ij(double[:] dist, int i, int j, int n):
    cdef int idx
    if i < j:
        idx = i*n - i*(i+1) // 2 - (j-i-1)
    elif i > j:
        idx = j*n - j*(j+1) // 2 - (i-j-1)
    else:
        return 0.0

    return dist[idx]

cpdef tuple well_scattered_points(int n_rep, np.ndarray[np.double_t, ndim=1] mean, np.ndarray[np.double_t, ndim=2] data):
    cdef int n = data.shape[0]
    # if the cluster contains less than no. of rep points, all points are rep points
    if n <= n_rep:
        return list(data), np.arange(data.shape[0])
    
    # calculate distances for fast access
    cdef double[:] distances = pdist(data)

    # farthest point from mean
    cdef int idx = np.argmax(np.linalg.norm(data - mean, axis=1))
    # get well scattered points
    cdef int i, j, max_point
    cdef float max_dist, min_dist
    cdef list scatter_idx = [idx]
    for i in range(1, n_rep):
        max_dist = 0.0
        for j in range(n):
            # minimum distances from points in scatter_idx
            min_dist = min([dense_dist_mat_at_ij(distances, idx, j, n) for idx in scatter_idx])
            if min_dist > max_dist:
                max_dist = min_dist
                max_point = j
        
        scatter_idx.append(max_point)
    
    return [data[i] for i in scatter_idx], scatter_idx

In [None]:
import numpy as np
%timeit well_scattered_points(1000, np.zeros((10,)), np.random.rand(2000, 10).astype(np.float64))

In [65]:
from scipy.spatial.distance import cdist, pdist

def dense_dist_mat_at_ij(dist, i, j, n):
    if i < j:
        idx = int(i*n - i*(i+1) // 2 - (j-i-1))
    elif i > j:
        idx = int(j*n - j*(j+1) // 2 - (i-j-1))
    else:
        return 0.0

    return dist[idx]

def py_well_scattered_points(n_rep: int, mean: np.ndarray, data: np.ndarray):
    n = data.shape[0]
    # if the cluster contains less than no. of rep points, all points are rep points
    if n <= n_rep:
        return list(data), np.arange(data.shape[0])
    
    # calculate distances for fast access
    distances = pdist(data)

    # farthest point from mean
    idx = np.argmax(np.linalg.norm(data - mean, axis=1))
    # get well scattered points
    scatter_idx = [idx]
    for _ in range(1, n_rep):
        max_dist = 0.0
        for j in range(n):
            # minimum distances from points in scatter_idx
            min_dist = min([dense_dist_mat_at_ij(distances, idx, j, n) for idx in scatter_idx])
            if min_dist > max_dist:
                max_dist = min_dist
                max_point = j
        
        scatter_idx.append(max_point)
    
    return [data[i] for i in scatter_idx], scatter_idx

In [None]:
import numpy as np
%timeit -n1 -r2 py_well_scattered_points(1000, np.zeros((10,)), np.random.rand(2000, 10).astype(np.float64))

# keep track of minimum distances

In [69]:
%%cython
import numpy as np
cimport numpy as np
from cpython cimport array
import array
from scipy.spatial.distance import cdist

from scipy.spatial.distance import cdist, pdist

cdef double dense_dist_mat_at_ij(double[:] dist, int i, int j, int n):
    cdef int idx
    if i < j:
        idx = i*n - i*(i+1) // 2 - (j-i-1)
    elif i > j:
        idx = j*n - j*(j+1) // 2 - (i-j-1)
    else:
        return 0.0

    return dist[idx]

cpdef tuple wsp_fast(int n_rep, np.ndarray[np.double_t, ndim=1] mean, np.ndarray[np.double_t, ndim=2] data):
    cdef int n = data.shape[0]
    
    # if the cluster contains less than no. of rep points, all points are rep points
    if n <= n_rep:
        return list(data), np.arange(data.shape[0])
    
    # calculate distances for fast access
    cdef double[:] distances = pdist(data)

    # farthest point from mean
    cdef int idx = np.argmax(np.linalg.norm(data - mean, axis=1))
    
    # keep track of distances to scattered points
    cdef np.ndarray[np.double_t, ndim=2] dist_to_scatter = -1.0*np.ones((n_rep, n)).astype(np.float64)

    # scatter points indices relative to data
    cdef list scatter_idx = [idx]
    cdef int i, j, k, max_point, min_dist_idx
    cdef double min_dist, max_dist, dist
    
    for i in range(n_rep-1):
        max_dist = 0.0
        for j in range(n):
            # calculate distance of point to latest scatter point
            dist_to_scatter[i,j] = dense_dist_mat_at_ij(distances, scatter_idx[-1], j, n)
            # minimum distance of point from scattered points
            min_dist = np.min(dist_to_scatter[:i+1,j])
            if min_dist > max_dist:
                max_dist = min_dist
                max_point = j
        scatter_idx.append(max_point)
        
    return [data[i] for i in scatter_idx], scatter_idx

In [None]:
import numpy as np
data = np.random.rand(2000, 10).astype(np.float64)
mean = np.zeros((10,)).astype(np.float64)
# %timeit -n2 -r2 py_well_scattered_points(100, mean, data)
%timeit -n2 -r2 well_scattered_points(1000, mean, data)
%timeit -n2 -r2 wsp_fast(1000, mean, data)

In [72]:
import numpy as np
data = np.random.rand(2000, 10).astype(np.float64)
mean = np.zeros((10,)).astype(np.float64)
_, idx1 = well_scattered_points(100, mean, data)
_, idx2 = wsp_fast(100, mean, data)

In [73]:
print(np.vstack((idx1, idx2)))

[[1489  961  706  435  305 1126 1267  551  972  400  516 1500 1640  744
   107  887  187  270 1531 1193  534  169  526 1610  335  896  948 1481
    33  985 1079  440  765  680 1320 1507 1611    5 1061 1973 1334   52
   699  203  701 1791  418  515  520 1851 1107  799  884 1440  224  158
   264 1969  899 1200  634 1333  952  106  556 1431  674 1665 1024  986
  1143 1981  471 1649  925  501  386 1986 1508  770  384  764 1591 1301
   162 1170 1769  105 1230 1412  601  686 1711  681  628 1776  333 1821
  1064 1475]
 [1489  961  706  435  305 1126 1267  551  972  400  516 1500 1640  744
   107  887  187  270 1531 1193  534  169  526 1610  335  896  948 1481
    33  985 1079  440  765  680 1320 1507 1611    5 1061 1973 1334   52
   699  203  701 1791  418  515  520 1851 1107  799  884 1440  224  158
   264 1969  899 1200  634 1333  952  106  556 1431  674 1665 1024  986
  1143 1981  471 1649  925  501  386 1986 1508  770  384  764 1591 1301
   162 1170 1769  105 1230 1412  601  686 1711  681