In [1]:
import pandas as pd
import numpy as np
from scipy.stats import mode

In [18]:
X_train = pd.read_csv('TinyMNIST/trainData.csv', header = None)
y_train = pd.read_csv('TinyMNIST/trainLabels.csv', header = None).values
X_test = pd.read_csv('TinyMNIST/testData.csv', header = None).values
y_test = pd.read_csv('TinyMNIST/testLabels.csv', header = None).values

According to slides, a reasonable estimate for posterior probability is

$P_n(w_i|x) = \frac{k_i}{k}$

So there's no need to estimate probabilties and we can predict by majority voting in K-neighbors.

In [15]:

def KNN_nomral_predict(X_test, X_train, k=5, d=0):
    predicts = []
    probs = []
    for i, x in enumerate(X_test):
        points = get_points_idx(kd_tree, X_train, x, d)
        kn_idx = np.argsort(np.sum((X_train[points] - x)**2, axis=1))[:k]
        probs.append(y_train[kn_idx] == y_test[i].mean())
        predicts.append(mode(y_train[kn_idx]).mode[0])
    return predicts, probs

In [16]:
dim = 196
def make_kd_tree(points, i=0):
    if len(points) >= 10:
        while points[i].var() == 0:
            i = (i + 1) % dim
        points = points.sort_values(by = i)
        half = len(points) // 2
        new_i = (i+1)%dim
        return [
            make_kd_tree(points[: half].reset_index(drop=True), new_i),
            make_kd_tree(points[half + 1:].reset_index(drop=True), new_i),
            points[i][half],
            i,
            list(points['index'])
        ]
    else:
        return [None, None, points[i][0], i, list(points['index'])]
def get_points_idx(node, X_train, x, d):
    for d in range(d):
        if not node[0]:
            break
        i = node[3]
        pivot = node[2]
        if x[i] > pivot:
            node = node[1]
        else:
            node = node[0]
    
    return node[4]

In [21]:
from time import time
kd_tree = make_kd_tree(X_train.reset_index().copy())
for d in range(10):
    t = time()
    predicts, probs = KNN_nomral_predict(X_test, X_train.values, 1, d)
    print('D:', d, end=' ')
    print('CCR:', (predicts == y_test).mean(), end=' ')
    print('Error rate:', 1 - np.mean(probs))
    print('time:',time() - t)

D: 0 CCR: 0.1064 Error rate: 0.8936
time: 13.9174222946167
D: 1 CCR: 0.1024 Error rate: 0.8976
time: 6.694664001464844
D: 2 CCR: 0.0884 Error rate: 0.9116
time: 2.818631887435913
D: 3 CCR: 0.1568 Error rate: 0.8432
time: 1.3450729846954346
D: 4 CCR: 0.098 Error rate: 0.902
time: 0.9241905212402344
D: 5 CCR: 0.1044 Error rate: 0.8956
time: 0.7720091342926025
D: 6 CCR: 0.086 Error rate: 0.914
time: 0.6769630908966064
D: 7 CCR: 0.074 Error rate: 0.926
time: 0.6161766052246094
D: 8 CCR: 0.0736 Error rate: 0.9264
time: 0.6139168739318848
D: 9 CCR: 0.1196 Error rate: 0.8804
time: 0.557265043258667
