## 简单实现KNN

In [2]:
import numpy as np
import time

# 简单KNN 实现 （线性扫描）
def knn_linear_scan(X_train, y_train, x_query, k=3):
    # 计算 x_query 到所有训练样本的欧氏距离
    distances = np.sqrt(np.sum((X_train - x_query) ** 2, axis=1))
    # 找到距离最小的k个索引
    nearest_indices = np.argsort(distances)[:k]
    # 返回k个标签中出现最多的类别
    nearest_labels = y_train[nearest_indices]
    labels, counts = np.unique(nearest_labels, return_counts=True)
    return labels[np.argmax(counts)]

# 测试数据
def generate_data(n_samples):
    X = np.random.rand(n_samples,10) # 10维特征
    y = np.random.randint(0,2, size=n_samples)
    return X,y

# 运行比较
for size in [100,100000000]: # 小训练集和大训练集
    X_train, y_train = generate_data(size)
    x_query = np.random.rand(10) # 随机查询点

    start_time = time.time()
    pred = knn_linear_scan(X_train, y_train, x_query, k=3)
    time_used = time.time() - start_time
    print(f"训练集大小: {size:>6} | 预测结果: {pred} | 耗时: {time_used *1000:.3f} ms")


训练集大小:    100 | 预测结果: 0 | 耗时: 0.674 ms
训练集大小: 100000000 | 预测结果: 1 | 耗时: 91836.480 ms


当数据维度上升，数据集的数量增加之后，训练耗时明显增长。
由于大模型的参数不计其数以亿为单位，显然线性扫描不是KNN的最优解法。

## 使用KD树实现 KNN

In [5]:
from sklearn.neighbors import KDTree

# KD-Tree 版 KNN
def knn_kdtree(X_train, y_train, x_query, k=3):
    # 构建KD-Tree
    tree = KDTree(X_train)
    # 查询最近k个点
    dist, idx = tree.query(x_query.reshape(1, -1), k=k)
    nearest_labels = y_train[idx[0]]
    # 多数投票
    labels, counts = np.unique(nearest_labels, return_counts=True)
    return labels[np.argmax(counts)]

# 数据生成
def generate_data(n_samples):
    X = np.random.rand(n_samples, 10)  # 10维特征
    y = np.random.randint(0, 2, size=n_samples)  # 0或1分类
    return X, y

# 对比线性扫描和KD-Tree
from math import log10

for size in [100, 100000]:
    X_train, y_train = generate_data(size)
    x_query = np.random.rand(10)

    # 线性扫描
    start_time = time.time()
    pred_linear = knn_linear_scan(X_train, y_train, x_query, k=3)
    time_linear = (time.time() - start_time) * 1000

    # KD-Tree
    start_time = time.time()
    pred_kdtree = knn_kdtree(X_train, y_train, x_query, k=3)
    time_kdtree = (time.time() - start_time) * 1000

    print(f"训练集大小: {size:>6} | 线性扫描: {time_linear:.3f} ms | KD-Tree: {time_kdtree:.3f} ms")

训练集大小:    100 | 线性扫描: 0.686 ms | KD-Tree: 3.102 ms
训练集大小: 100000 | 线性扫描: 8.967 ms | KD-Tree: 26.503 ms


In [None]:
from sklearn.neighbors import KDTree

# KD-Tree 版 KNN
def knn_kdtree(X_train, y_train, x_query, k=3):
    # 构建KD-Tree
    tree = KDTree(X_train)
    # 查询最近k个点
    dist, idx = tree.query(x_query.reshape(1, -1), k=k)
    nearest_labels = y_train[idx[0]]
    # 多数投票
    labels, counts = np.unique(nearest_labels, return_counts=True)
    return labels[np.argmax(counts)]

# 数据生成
def generate_data(n_samples):
    X = np.random.rand(n_samples, 10)  # 10维特征
    y = np.random.randint(0, 2, size=n_samples)  # 0或1分类
    return X, y

# 对比线性扫描和KD-Tree
from math import log10

for size in [100, 1000000]:
    X_train, y_train = generate_data(size)
    x_query = np.random.rand(10)

    # 线性扫描
    start_time = time.time()
    pred_linear = knn_linear_scan(X_train, y_train, x_query, k=3)
    time_linear = (time.time() - start_time) * 1000

    # KD-Tree
    start_time = time.time()
    pred_kdtree = knn_kdtree(X_train, y_train, x_query, k=3)
    time_kdtree = (time.time() - start_time) * 1000

    print(f"训练集大小: {size:>6} | 线性扫描: {time_linear:.3f} ms | KD-Tree: {time_kdtree:.3f} ms")

In [4]:
from sklearn.neighbors import KDTree

# KD-Tree 版 KNN
def knn_kdtree(X_train, y_train, x_query, k=3):
    # 构建KD-Tree
    tree = KDTree(X_train)
    # 查询最近k个点
    dist, idx = tree.query(x_query.reshape(1, -1), k=k)
    nearest_labels = y_train[idx[0]]
    # 多数投票
    labels, counts = np.unique(nearest_labels, return_counts=True)
    return labels[np.argmax(counts)]

# 数据生成
def generate_data(n_samples):
    X = np.random.rand(n_samples, 10)  # 10维特征
    y = np.random.randint(0, 2, size=n_samples)  # 0或1分类
    return X, y

# 对比线性扫描和KD-Tree
from math import log10

for size in [100, 100000000]:
    X_train, y_train = generate_data(size)
    x_query = np.random.rand(10)

    # 线性扫描
    start_time = time.time()
    pred_linear = knn_linear_scan(X_train, y_train, x_query, k=3)
    time_linear = (time.time() - start_time) * 1000

    # KD-Tree
    start_time = time.time()
    pred_kdtree = knn_kdtree(X_train, y_train, x_query, k=3)
    time_kdtree = (time.time() - start_time) * 1000

    print(f"训练集大小: {size:>6} | 线性扫描: {time_linear:.3f} ms | KD-Tree: {time_kdtree:.3f} ms")

训练集大小:    100 | 线性扫描: 1.419 ms | KD-Tree: 2.268 ms
训练集大小: 100000000 | 线性扫描: 76761.729 ms | KD-Tree: 160128.061 ms
