In [29]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from collections import Counter
import tensorflow_datasets as tfds

In [3]:
data_train, data_test = tfds.load('cifar10', split=['train', 'test'], shuffle_files=True)

In [4]:
data_train = tfds.as_dataframe(data_train)
data_test=  tfds.as_dataframe(data_test)

2022-10-17 16:17:25.601598: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


In [5]:
X_train = data_train.image
X_test = data_test.image
y_train = data_train.label
y_test = data_test.label
X_train = np.array(X_train)
X_test = np.array(X_test)
y_train = np.array(y_train)
y_test = np.array(y_test)

In [6]:
def convert_to_array(ndarr):
    n = ndarr.shape[0]
    ndarr = ndarr.tolist()
    ndarr = [img.tolist() for img in ndarr]
    ndarr = np.array(ndarr)
    ndarr = ndarr.reshape((n, -1))
    return ndarr

In [7]:
X_train = convert_to_array(X_train)
X_test = convert_to_array(X_test)

In [8]:
class KNN:
    def __init__(self, X_train, y_train, n_neighbors=3, p=2):
        self.X_train = X_train
        self.y_train = y_train
        self.p = p
        self.n = n_neighbors
    
    def predict(self, X):
        knn_list = []
        # Initialization
        for i in range(self.n):
            dist = np.linalg.norm(X - self.X_train[i], ord=self.p)
            knn_list.append((dist, self.y_train[i]))
        
        for i in range(self.n, len(self.X_train)):
            # Retrieve the index of the point with the maximum distance
            max_index = knn_list.index(max(knn_list, key=lambda x: x[0]))
            dist = np.linalg.norm(X - self.X_train[i], ord=self.p)
            if knn_list[max_index][0] > dist:
                knn_list[max_index] = (dist, self.y_train[i])
        
        # After calculating all distances, we have to get the label with the maximum frequency
        knn = [k[-1] for k in knn_list]
        count = Counter(knn)
        max_count = sorted(count.items(), key=lambda x: x[1])[-1][0]
        # Note here we use count.items() rather than count in order to make the stuff iterable
        return max_count
    
    def score(self, X_test, y_test):
        right_count = 0
        for i in range(len(X_test)):
            label = self.predict(X_test[i])
            if label == y_test[i]:
                right_count += 1
        return str(np.round(right_count / len(X_test) * 100, 2)) + "%"

In [9]:
clf = KNN(X_train, y_train)

In [24]:
clf.score(X_test, y_test)

KeyboardInterrupt: 

In [11]:
# Using Class from Scikit-learn
from sklearn.neighbors import KNeighborsClassifier
clf_sk = KNeighborsClassifier(n_neighbors=1)
clf_sk.fit(X_train, y_train)
clf_sk.score(X_test, y_test)

0.3539

In [75]:
# Define KdTree
class KdNode(object):
    
    def __init__(self, split, dom_elt, left, right):
        self.split = split
        self.dom_elt = dom_elt
        self.left = left
        self.right = right
    
class KdTree(object):
    
    def __init__(self, data):
        
        # Get the dimension
        k = len(data[0])
        
        # Create the KdTree using resursion
        def CreateNode(split, data_set):
            if not data_set:
                return None
            # Sort the dataset
            data_set.sort(key=lambda x: x[split])
            split_pos = len(data_set) // 2
            median = data_set[split_pos]
            split_next = (split + 1) % k 
            
            return KdNode(split, median,
                          CreateNode(split_next, data_set[:split_pos]),
                          CreateNode(split_next, data_set[split_pos + 1:]))
            
        self.root = CreateNode(0, data)
        
# Define Preorder Function
def preorder(root):
    print(root.dom_elt)
    if root.left:
        preorder(root.left)
    if root.right:
        preorder(root.right)

In [13]:
# Define KdTree Search Algorithm
from math import sqrt
from collections import namedtuple

result = namedtuple("Result",
                    "nearest_point  nearest_dist  nodes_visited")

def find_nearest(tree, point):
    k = len(point)
    
    def travel(kd_node, target, max_dist):
        
        if not kd_node:
            return result([0] * k, float('inf'), 0)
        # Note here we cannot simply return None, since this will rule the 
        # special situation out of the function (kd_node is null)
        # Understand the essence of recursion: There is a hidden general stack!!!
        nodes_visited = 1
        s = kd_node.split
        pivot = kd_node.dom_elt
        
        # Determine the next search iteration
        if target[s] < pivot[s]:
            nearer_node = kd_node.left
            further_node = kd_node.right
        else:
            nearer_node = kd_node.right
            further_node = kd_node.left
        
        # Keep searching
        temp1 = travel(nearer_node, target, max_dist)
        
        # Note this is midorder!!! IF THERE IS A NEED, WE SEARCH THE FURTHER ONE!
        nearest = temp1.nearest_point
        dist = temp1.nearest_dist
        
        nodes_visited += temp1.nodes_visited
        
        if dist < max_dist:
            max_dist = dist # This is the radius
            
        # Now we have to calculate whether the radius will reach to another region
        temp_dist = abs(pivot[s] - target[s])
        # if no intersaction, this means that IN THIS DIMENSION the nearest point is what we got
        if temp_dist > max_dist:
            return result(nearest, dist, nodes_visited)
        # Note: temp_dist > max_dist indicates that the hyperball is totally in that region
        # if not, there are still two possibilities
        # 1.pivot is the nearest point
        # 2. the nearest point is in the other region (the offspring of further node)
        
        temp_dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(pivot, target)))
        
        if temp_dist < dist: # pivot is nearer
            nearest = pivot
            dist = temp_dist
            max_dist = dist
        # Search in another region
        temp2 = travel(further_node,target, max_dist)
        
        if temp2.nearest_dist < dist:
            nearest = temp2.nearest_point
            dist = temp2.nearest_dist
            
        nodes_visited += temp2.nodes_visited
            
        return result(nearest, dist, nodes_visited)

    return travel(tree.root, point, float('inf'))  

In [14]:
data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
kd = KdTree(data)
preorder(kd.root)

[7, 2]
[5, 4]
[2, 3]
[4, 7]
[9, 6]
[8, 1]


In [15]:
from time import perf_counter
from random import random

def random_point(k):
    return [random() for _ in range(k)]

def random_points(k, n):
    return [random_point(k) for _ in range(n)]

In [16]:
find_nearest(kd, [3, 4.5])

Result(nearest_point=[2, 3], nearest_dist=1.8027756377319946, nodes_visited=4)

In [20]:
N = 1000000
t0 = perf_counter()
kd2 = KdTree(random_points(5, N))
ret2 = find_nearest(kd2, [0.2, 0.2, 0.8, 0.5, 0.3])
t1 = perf_counter()
print(ret2)
print(t1 - t0)

Result(nearest_point=[0.17874394397764803, 0.21678990438838575, 0.7998362553852564, 0.5196189220081592, 0.30228129223618416], nearest_dist=0.033523931963963666, nodes_visited=343)
11.919583375000002


### Realization of KNN accelerated algorithm

In [77]:
class KdNode(object):
    
    def __init__(self, split, dom_elt, left, right):
        self.split = split
        self.dom_elt = dom_elt
        self.left = left
        self.right = right
        
class KdTree(object):
    
    def __init__(self, data):
        
        k = len(data[0])
        
        def CreateTree(split, data_set):
            if not data_set:
                return None

            data_set.sort(key=lambda x: x[split])
            split_pos = len(data_set) // 2
            median = data_set[split_pos]
            split_next = (split + 1) % k
            
            return KdNode(split, median,
                          CreateTree(split_next, data_set[:split_pos]),
                          CreateTree(split_next, data_set[split_pos + 1:]))
            
        self.root = CreateTree(0, data)
        
def preorder(root):
    print(root.dom_elt)
    if root.left:
        preorder(root.left)
    if root.right:
        preorder(root.right)

In [96]:
from math import sqrt

class result(object):
    
    def __init__(self, nearest_point, nearest_dist):
        self.nearest_point = nearest_point
        self.nearest_dist = nearest_dist

def find_k_nearest(tree, point, k):
    
    p = len(point)
    
    def travel(kd_node, target, max_dist):
        
        if not kd_node:
            return result([[0 for _ in range(p)] for _ in range(k)],
                          [float('inf') for _ in range(k)])
        
        s = kd_node.split
        pivot = kd_node.dom_elt
        
        if target[s] <= pivot[s]:
            nearer = kd_node.left
            further = kd_node.right
        else:
            nearer = kd_node.right
            further = kd_node.left
        
        temp1 = travel(nearer, target, max_dist)
        
        nearest_point = temp1.nearest_point
        nearest_dist = temp1.nearest_dist
        
        # iteration needed
        max_index = nearest_dist.index(sorted(nearest_dist)[-1])
        max_point = nearest_point[max_index]
        cur_max_dist = nearest_dist[max_index]
        
        if cur_max_dist < max_dist:
            max_dist = cur_max_dist
        
        temp_dist = abs(target[s] - pivot[s])
        
        if temp_dist > max_dist:
            return result(nearest_point, nearest_dist)

        temp_dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(target, pivot)))
        
        if temp_dist < cur_max_dist:
            nearest_point[max_index] = pivot
            nearest_dist[max_index] = temp_dist
            temp_dist = max_dist
            
        temp2 = travel(further, target, max_dist)
        
        # select the k nearest points from these two sets
        
        temp_nearest_point = nearest_point + temp2.nearest_point
        temp_nearest_dist = nearest_dist + temp2.nearest_dist
        temp_set = list(zip(temp_nearest_point, temp_nearest_dist))
        
        k_min_set = sorted(temp_set, key=lambda x: x[1])[:k]
        nearest_point = [point for point, dist in k_min_set]
        nearest_dist = [dist for point, dist in k_min_set]
        
        return result(nearest_point, nearest_dist)
    
    return travel(tree.root, point, float('inf'))
        

In [147]:
data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]

In [148]:
kd = KdTree(data)

In [156]:
ret = find_k_nearest(kd, [6.1, 7.2], 3)

In [157]:
ret.nearest_point

[[4, 7], [9, 6], [5, 4]]

In [158]:
ret.nearest_dist

[2.109502310972898, 3.1384709652950433, 3.383784863137726]

In [93]:
from random import random
def random_point(p):
    return [random() for _ in range(p)]
def random_points(p, n):
    return [random_point(p) for _ in range(n)]

In [100]:
ret1 = find_k_nearest(kd, [3, 4.5], 2)

In [98]:
from random import random

# 产生一个k维随机向量，每维分量值在0~1之间
def random_point(k):
    return [random() for _ in range(k)]
 
# 产生n个k维随机向量 
def random_points(k, n):
    return [random_point(k) for _ in range(n)] 

In [99]:
N = 400000
kd2 = KdTree(random_points(3, N)) 
ret2 = find_k_nearest(kd2, [0.1, 0.5, 0.8], 10)    


In [132]:
from collections import Counter
def knn_train(k, X_train, y_train, X_test, y_test):
    right_count = 0
    X_train_list = X_train.tolist()
    X_test_list = X_test.tolist()
    kd = KdTree(X_train_list)
    for i in range(len(X_test)):
        ret = find_k_nearest(kd, X_test_list[i], k)
        train_label = []
        for j in range(k):
            idx = X_train_list.index(ret.nearest_point[j])
            label = y_train[idx]
            train_label.append(label)
        count = Counter(train_label)
        label = sorted(count.items(), key=lambda x: x[1])[-1][0]
        if label == y_test[i]:
            right_count += 1
    return np.round(right_count / len(X_test), 4)          

In [145]:
knn_train(1, X_train[:2000], y_train[:2000], X_test[:100], y_test[:100])

0.12

In [172]:
from sklearn.neighbors import KDTree
train_data = np.array([(2, 3), (5, 4), (9, 6), (4, 7), (8, 1), (7, 2)])
kd = KdTree(train_data.tolist())
ret = find_k_nearest(kd, [3, 4.5], 1)
ret.nearest_point

[[2, 3]]