生成训练测试数据集：

In [1]:
import numpy as np
import random
import time

m = 10000
d = 200
k = 50
random.seed(0)
np.random.seed(0)
trains = np.zeros((m, d))
for i in range(m):
    trains[i] = np.array([random.gauss(0, 1) for z in range(d)])
test = np.array([random.gauss(0, 1) for z in range(d)])

print(trains.shape)

(10000, 200)


Python 实现 Annoy：

In [2]:
import numpy as np
from queue import PriorityQueue

def means(X):
    """
    启发式的选取两个点

    参数
    ----------
    X : 特征矩阵
    
    返回
    ----------
    两个向量点
    """
    iteration_steps = 20
    count = X.shape[0]
    i = np.random.randint(0, count)
    j = np.random.randint(0, count - 1)
    # 保证 i\j 不相同
    j += (j >= i)
    ic = 1
    jc = 1
    p = X[i]
    q = X[j]
    for l in range(iteration_steps):
        k = np.random.randint(0, count)
        di = ic * distance(p, X[k])
        dj = jc * distance(q, X[k])
        if di == dj:
            continue
        if di < dj:
            p = (p * ic + X[k]) / (ic + 1)
            ic = ic + 1
        else:
            q = (q * jc + X[k]) / (jc + 1)
            jc = jc + 1
    return p, q
        
def distance(a, b):
    """
    计算距离

    参数
    ----------
    a : 向量 a

    b : 向量 b
    
    返回
    ----------
    向量 a 与 向量 b 之间的距离
    """
    return np.linalg.norm(a - b)

class annoynode:
    """
    Annoy 树结点
    """
    
    def __init__(self, index, size, w, b, left = None, right = None):
        # 结点包含的样本点下标
        self.index = index
        # 结点及其子结点包含的样本数
        self.size = size
        # 分割超平面的系数
        self.w = w
        # 分割超平面的偏移量
        self.b = b
        # 左子树
        self.left = left
        # 右子树
        self.right = right
    
    def __lt__(self, other):
        # 结点大小比较
        return self.size < other.size

class annoytree:
    """
    Annoy 树算法实现
    
    参数
    ----------
    X : 特征矩阵

    leaf_size : 叶子节点包含的最大特征向量数量，默认为 10
    """
    
    def __init__(self, X, leaf_size = 10):
        def build_node(X_indexes):
            """
            构建结点
            
            参数
            ----------
            X_indexes : 特征矩阵下标
            """
            # 当特征矩阵小于等于指定的叶子结点的大小时，创建叶子结点并返回
            if len(X_indexes) <= leaf_size:
                return annoynode(X_indexes, len(X_indexes), None, None)
            # 当前特征矩阵
            _X = X[X_indexes, :]
            # 启发式的选取两点
            p, q = means(_X)
            # 超平面的系数
            w = p - q
            # 超平面的偏移量
            b = -np.dot((p + q) / 2, w)
            # 构建结点
            node = annoynode(None, len(X_indexes), w, b)
            # 在超平面“左”侧的特征矩阵下标
            left_index = (_X.dot(w) + b) > 0
            if left_index.any():
                # 递归的构建左子树
                node.left = build_node(X_indexes[left_index])
            # 在超平面“右”侧的特征矩阵下标
            right_index = ~left_index
            if right_index.any():
                # 递归的构建右子树
                node.right = build_node(X_indexes[right_index])
            return node
        # 根结点
        self.root = build_node(np.array(range(X.shape[0])))
        
class annoytrees:
    """
    Annoy 算法实现
    
    参数
    ----------
    X : 特征矩阵
    
    n_trees : Annoy 树的数量，默认为 10

    leaf_size : 叶子节点包含的最大特征向量数量，默认为 10
    """
    
    def __init__(self, X, n_trees = 10, leaf_size = 10):
        self._X = X
        self._trees = []
        # 循环的创建 Annoy 树
        for i in range(n_trees):
            self._trees.append(annoytree(X, leaf_size = leaf_size))
            
    def query(self, x, k = 1, search_k = -1):
        """
        查询距离最近 k 个特征向量

        参数
        ----------
        x : 目标向量

        k : 查询邻居数量

        search_k : 最少遍历出的邻居数量，默认为 Annoy 树的数量 * 查询数量
        """
        
        # 创建结点优先级队列
        nodes = PriorityQueue()
        # 先将所有根结点加入到队列中
        for tree in self._trees:
            nodes.put([float("inf"), tree.root])
        if search_k == -1:
            search_k = len(self._trees) * k
        # 待查询的邻居下标数组
        nns = []
        # 循环优先级队列
        while len(nns) < search_k and not nodes.empty():
            # 获取优先级最高的结点
            (dist, node) = nodes.get()
            # 如果是叶子结点，将下标数组加入待查询的邻居中
            if node.left is None and node.right is None:
                nns.extend(node.index)
            else:
                # 计算目标向量到结点超平面的距离
                dist = min(dist, np.abs(x.dot(node.w) + node.b))
                # 将距离做为优先级的结点加入到优先级队列中
                if node.left is not None:
                    nodes.put([dist, node.left])
                if node.right is not None:
                    nodes.put([dist, node.right])
        # 对下标数组进行排序
        nns.sort()
        prev = -1
        # 优先级队列
        nns_distance = PriorityQueue()
        for idx in nns:
            # 过滤重复的特征矩阵下标
            if idx == prev:
                continue
            prev = idx
            # 计算特征向量与目标向量的距离做为优先级
            nns_distance.put([distance(x, self._X[idx]), idx])
        nearests = []
        distances = []
        # 取前 k 个
        for i in range(k):
            if nns_distance.empty():
                break
            (dist, idx) = nns_distance.get()
            nearests.append(idx)
            distances.append(dist)
        return nearests, distances

Annoy 构建与查询：

In [3]:
from annoy import AnnoyIndex

start = time.time()
# 初始化 AnnoyIndex，使用欧式距离
t = AnnoyIndex(d, 'euclidean')
for i in range(m):
    # 添加样本点
    t.add_item(i, trains[i])
# 构建 20 棵二叉树
t.build(20)
cost = time.time() - start
print("Annoy Build: ", cost)

start = time.time()
# 查询 test 点最近 k 个样本点
nearests, distances = t.get_nns_by_vector(test, k, include_distances=True)
cost = time.time() - start
print("Annoy Search: ", cost)
print("Indexes: ", np.array(nearests))
print("Distances: ", np.array(distances))

Annoy Build:  0.16872692108154297
Annoy Search:  0.0003139972686767578
Indexes:  [4263 2711 8938 6058 6995 9254 8756 1542 6713 6640 8446 4618 5747 3473
 6107 7014 2599 8880 9649 3344 7127 6305  885 7457 7365 2740 4445 3752
 6464 6484 8103 1754 8691 7636 2617 1397 4580 8546 3016 8992 3664 3845
 7561 9854 6193 2640 8624 3976 7465 4824]
Distances:  [17.18696976 17.19728088 17.23875999 17.3003273  17.36844254 17.37878036
 17.53240776 17.54758453 17.5496769  17.60453987 17.65777779 17.69774818
 17.75228119 17.79432106 17.80599403 17.81834412 17.8206501  17.87829208
 17.90117836 17.91422272 17.92863464 17.94997787 17.96464729 17.98784828
 17.9901638  17.99071312 17.99897003 18.02326584 18.04662704 18.05654716
 18.07147408 18.076334   18.08225822 18.09370613 18.09519958 18.10449219
 18.11090088 18.12662125 18.14160156 18.16368294 18.19980431 18.21333122
 18.21651459 18.21673584 18.23332596 18.27451706 18.28308868 18.29053307
 18.30069351 18.30153465]


Ball Tree 构建与查询：

In [4]:
from sklearn.neighbors import BallTree

start = time.time()
tree = BallTree(trains)
cost = time.time() - start
print("Ball Tree Build: ", cost)

start = time.time()
distances, nearests = tree.query(np.array([test]), k)
cost = time.time() - start
print("Ball Tree Search: ", cost)
print("Indexes: ", nearests[0])
print("Distances: ", distances[0])

Ball Tree Build:  0.08020997047424316
Ball Tree Search:  0.0034809112548828125
Indexes:  [3023 4660 4848 8681 4263 2711 8938 3296 6058 6995 9254 2756 4017 5118
 9802 2292 2349 4214 7375 8756 1542 6713 5888 3159  757 7346 5538 6476
 6640 5592 8580 6515 8446 5029 4618 7520 8727 8565 1694 7146 5562 8740
 1827 5747 8337 7047 5336 1153 8531 3473]
Distances:  [16.77680894 16.81368954 16.89781552 17.12930738 17.18696959 17.1972807
 17.23875941 17.29506739 17.30032781 17.36844175 17.37878153 17.39754721
 17.40540655 17.41499656 17.45595374 17.46670573 17.47245261 17.52817008
 17.52950492 17.53240905 17.54758408 17.54967691 17.57053752 17.58606948
 17.58642659 17.59461516 17.59943184 17.6024969  17.60453848 17.62707216
 17.65294171 17.65729137 17.65777842 17.68250413 17.6977477  17.70868314
 17.71570826 17.72714918 17.72860328 17.73046251 17.73337673 17.74372632
 17.75180531 17.75228135 17.76013932 17.76091155 17.76948143 17.78045058
 17.78751779 17.79432141]


python 实现构建与查询：

In [5]:
at = annoytrees(trains, 20)
nearests, distances = at.query(test, k)
print("Indexes: ", np.array(nearests))
print("Distances: ", np.array(distances))

Indexes:  [4848 2756 2292 5888 5029 7146 5562 1827 3820 5916 4765 7279 8910 8774
 8749 9185 9074 6037 6707 1825 6922 8047 9739 4921 3962   84  324 8187
 6433   85 4008 4768 3018  360 6443 1938 3042 3771 4069 6604 7688 1503
 1314 6893 8094 9481 7436 6828 4426 8434]
Distances:  [16.89781552 17.39754721 17.46670573 17.57053752 17.68250413 17.73046251
 17.73337673 17.75180531 17.85768485 17.86940295 17.95864107 18.05522963
 18.10345792 18.10943947 18.11842027 18.1262473  18.12951272 18.15185938
 18.18700009 18.24833165 18.25292595 18.2672885  18.27565028 18.27972265
 18.30357089 18.30882151 18.3303367  18.33563371 18.34697864 18.38029483
 18.38316955 18.38962212 18.39529567 18.42886611 18.43021626 18.43573353
 18.45821091 18.46099592 18.47144542 18.47567675 18.47773606 18.48109985
 18.48129014 18.50716549 18.51928416 18.52190635 18.52265084 18.53003909
 18.5496016  18.55423501]


平衡二叉树演示：

In [6]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

%matplotlib notebook

def means_step(X, iteration_steps = 200):
    count = X.shape[0]
    i = np.random.randint(0, count)
    j = np.random.randint(0, count - 1)
    j += (j >= i)
    ic = 1
    jc = 1
    p = X[i]
    q = X[j]
    ps = np.zeros((iteration_steps, len(p)))
    qs = np.zeros((iteration_steps, len(q)))
    ts = np.zeros((iteration_steps, len(q)))
    for l in range(iteration_steps):
        p, q, ic, jc, _ = step(X, p, q, ic, jc)
        ps[l], qs[l], ts[l] = p, q, _
    return ps, qs, ts

def step(X, p, q, ic, jc):
    k = np.random.randint(0, X.shape[0])
    di = ic * distance(p, X[k])
    dj = jc * distance(q, X[k])
    if di == dj:
        return
    if di < dj:
        p = (p * ic + X[k]) / (ic + 1)
        ic = ic + 1
    else:
        q = (q * jc + X[k]) / (jc + 1)
        jc = jc + 1
    return p, q, ic, jc, X[k]

plt.rcParams['font.sans-serif'] = ['PingFang HK']  # 选择一个本地的支持中文的字体
fig, ax = plt.subplots()
ax.set_facecolor('#f8f9fa')

X = np.array([[1,1], [2,3], [3,-1], [0,0], [-1,-2], [-2, 2], [-4, 4], [3,-1], [1, -4],[0,2], [-3, 0]])
itr = 50
ps, qs, ts = means_step(X, iteration_steps = itr)
x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
x1 = X[:, 0]
y1 = X[:, 1]
plt.scatter(x1, y1, c='#e63946', marker='.')

p = plt.scatter(ps[0][0], ps[0][1], c='#457b9d', marker='$P$', s=60)
q = plt.scatter(qs[0][0], qs[0][1], c='#457b9d', marker='$Q$', s=60)
t = plt.scatter(ts[0][0], ts[0][1], c='#457b9d', marker='$T$', s=60)

def update(i):
    p.set_offsets(ps[i])
    q.set_offsets(qs[i])
    if i + 1 < itr:
        t.set_offsets(ts[i + 1])
    else:
        t.set_offsets([10, 10])
    return p, q, t

ani = animation.FuncAnimation(fig, update, range(1, itr), interval=500, blit=True, repeat=False)
ani.save('means.gif')
plt.show()

<IPython.core.display.Javascript object>