In [1]:
import os
import sys
import math
import time

import numpy as np
from scipy.cluster import hierarchy

In [2]:
n = 1000
points = np.random.rand(n, 2)

linkage = hierarchy.linkage(list(zip(points[:,0], points[:,1])), method='ward', metric='euclidean').astype(int)

class Cluster():
    def __init__(self, p, left=None, right=None):
        self.point = p # representative point
        self.maxpoint = p
        self.left = left
        self.right = right
        self.dist = 0
        if not (left is None or right is None):
            self.dist = dist(left.point, right.point)

def get_tree(points, linkage):
    clusters = []
    for p in points:
        clusters.append(Cluster(p))
    n = linkage.shape[0]
    for i in range(n):
        left = clusters[linkage[i, 0]]
        right = clusters[linkage[i, 1]]
        p = (left.point + right.point) / 2
        cl = Cluster(p, left, right)
        cl.maxpoint = left.point
        clusters.append(cl)
        
    return clusters

def dist(p1, p2):
    return np.linalg.norm(p2 - p1)

def nn(root, query):
    current = root
    while(True):
        if current.dist == 0: break
        l, r = current.left, current.right
        print(current.dist, l.point, dist(l.point, current.point), r.point, dist(r.point, current.point))
        if dist(l.point, query) < dist(r.point, query):
            current = l
        else: current = r
            
    return current

print(linkage)
root = get_tree(points, linkage)[-1]
print(root)

query = np.array([0.5, 0.5])
nb = nn(root, query)
print(nb.point)
    

[[ 226  231    0    2]
 [ 162  460    0    2]
 [ 594  812    0    2]
 ...
 [1992 1993    6  429]
 [1994 1995    9  571]
 [1996 1997   10 1000]]
<__main__.Cluster object at 0x0000029A18779300>
0.4795017906668457 [0.52635658 0.23312015] 0.23975089533342284 [0.49046182 0.71127654] 0.23975089533342284
0.6034010786654156 [0.79183789 0.72526521] 0.3017005393327078 [0.18908576 0.69728788] 0.3017005393327078
0.34667394090258596 [0.11437411 0.54087856] 0.17333697045129304 [0.2637974 0.8536972] 0.17333697045129295
0.15735773508389364 [0.07534048 0.47256502] 0.07867886754194678 [0.15340774 0.6091921 ] 0.07867886754194688
0.16129652170320047 [0.21430828 0.55632182] 0.08064826085160022 [0.0925072  0.66206237] 0.08064826085160023
0.0876174113698957 [0.17457929 0.53786078] 0.04380870568494785 [0.25403727 0.57478286] 0.04380870568494785
0.08180207380625919 [0.27101466 0.5375718 ] 0.040901036903129594 [0.23705989 0.61199392] 0.040901036903129594
0.045080478906359384 [0.25302564 0.55115331] 0.0225402394

In [98]:
def dist(p1, p2):
    return np.linalg.norm(p2 - p1)

GLOBCNT = 0

class Cluster():
    def __init__(self, p, left=None, right=None):
        self.point = p # representative point
        self.left = left
        self.right = right
        self.dist = 0
        self.ylims = (p[0], p[0])
        self.xlims = (p[1], p[1])
        self.center = p
        self.npoints = 1
        global GLOBCNT
        self.id = GLOBCNT; GLOBCNT += 1
        
        if not (left is None or right is None):
            self.dist = dist(left.point, right.point)
            self.ylims = (min(left.ylims[0], right.ylims[0]), max(left.ylims[1], right.ylims[1]))
            self.xlims = (min(left.xlims[0], right.xlims[0]), max(left.xlims[1], right.xlims[1]))
            self.center = (left.center + right.center) / 2
            self.npoints = left.npoints + right.npoints
            
def hier_cluster(points):
    clusters = [Cluster(p) for p in points]
    while len(clusters) > 1:
        #print([cl.id for cl in clusters])
        a = clusters[0]
        pts = np.array([cl.center for cl in clusters[1:]])
        dists = np.abs(pts - a.center)
        bi = np.argmin(np.sum(dists, axis=1))
        dst = np.min(np.sum(dists, axis=1))
        b = clusters[1:][bi]
        cl = Cluster((a.center + b.center) / 2, a, b)
        #print("merging", a.id, "and", b.id, "into", cl.id, "dist =", round(dst,3))
        clusters.append(cl)
        clusters.pop(0)
        clusters.pop(bi)
    return clusters[0]
        

def get_tree(points):
    linkage = hierarchy.linkage(list(zip(points[:,0], points[:,1])), method='ward', metric='euclidean').astype(int)
    clusters = []
    for p in points:
        clusters.append(Cluster(p))
    n = linkage.shape[0]
    for i in range(n):
        left = clusters[linkage[i, 0]]
        right = clusters[linkage[i, 1]]
        #p = (left.point + right.point) / 2
        p = left.point
        cl = Cluster(p, left, right)
        clusters.append(cl)
        
    return clusters

def intersect(a, b):
    if not (a[1] > b[0] and a[0] < b[1]):
        return False
    else:
        return (max(a[0], b[0]), min(a[1], b[1]))
        

def _crawl1(root, ylims, xlims, mindist, pts):
    #if root.dist <= mindist:
    #    pts.append(root.point)
    #    return
    #print(root.npoints, root.dist)
    ly, lx = intersect(root.left.ylims, ylims), intersect(root.left.xlims, xlims)
    #print(ly, lx)
    if (ly and lx):
        _crawl(root.left, ly, lx, mindist, pts)
    elif root.left.dist <= mindist:
        pts.append(root.left.point)
    ry, rx = intersect(root.right.ylims, ylims), intersect(root.right.xlims, xlims)
    #print(ry, rx)
    if (ry and rx):
        _crawl(root.right, ry, rx, mindist, pts)
    elif root.right.dist <= mindist:
        pts.append(root.right.point)
        
def _crawl(root, ylims, xlims, mindist, pts):
    #if root.dist <= mindist:
    #    pts.append(root.point)
    #    return
    #print(root.npoints, root.dist)
    
    #print(root.id)
    
    #print(root.dist, mindist, ylims, xlims)
    
    if root.dist <= mindist:
        pts.append(root.point)
        return
    
    if ylims and xlims:
        ly, lx = intersect(root.left.ylims, ylims), intersect(root.left.xlims, xlims)
        _crawl(root.left, ly, lx, mindist, pts)
        ry, rx = intersect(root.right.ylims, ylims), intersect(root.right.xlims, xlims)
        _crawl(root.right, ry, rx, mindist, pts)

def get_local(root, ylims, xlims, mindist):
    current = root
    pts = []
    _crawl(root, ylims, xlims, mindist, pts)
    return np.stack(pts) if len(pts) > 1 else np.array([])





In [107]:
points = np.random.rand(10000, 2)
t0 = time.time()
root = hier_cluster(points)
print(time.time() - t0)
#root = get_tree(points)[-1]

22.589049816131592


In [110]:
t0 = time.time()
local = get_local(root, (0.25,0.5), (0.25, 0.5), 0)
print(time.time() - t0)
print(local)

0.003990888595581055
[[0.426441   0.49374394]
 [0.42990724 0.49634399]
 [0.43804307 0.49553265]
 ...
 [0.30436857 0.4387444 ]
 [0.3135178  0.4342421 ]
 [0.30523344 0.43419469]]
