In [102]:
# Imports
import random
import math
import time

In [103]:
# Parameters
N = 5000  # training set size
M = 10   # test set size
D = 4  # dimension of data
K = 3  # number of neighbours to use
SEED = 42  # random seed

In [104]:
# Helpers
def distance(X1, X2):
    """Euclidean distance between two vectors."""
    return sum((x1 - x2) ** 2 for x1, x2 in zip(X1, X2))

In [105]:
# Generate dataset
random.seed(SEED)
X_train = [[random.gauss(0, 1) 
            for d in range(D)]
           for n in range(N)]
y_train = [random.gauss(0, 1) for n in range(N)]
X_test = [[random.gauss(0, 1) 
           for d in range(D)]
          for n in range(M)]
y_pred = []

In [106]:
# Precompute distance matrix
train_dist = [
    [
        math.sqrt(distance(X_train[i], X_train[j]))
        for j in range(i+1, N)
    ]
    for i in range(N-1)
]

In [107]:
# Record start time
start = time.time()

# kNN implementation
dist_calls = 0
for X0 in X_test:
    possible = [True for i in range(N)]
    curr_neighbours = [None for i in range(K)]

    cnt = 0
    for i, (X, y) in enumerate(zip(X_train[1:], y_train[1:])):
        if not possible[i]:
            continue
        d = math.sqrt(distance(X0, X))
        dist_calls += 1
        if i > K:
            for j, X in enumerate(X_train[i+1:]):
                if not possible[j]:
                    continue
                if abs(d - train_dist[i][j]) > curr_neighbours[K-1][3]:
                    cnt += 1
                    possible[j] = False
        for k in range(K):
            if curr_neighbours[k] is None or \
                    curr_neighbours[k][3] > d:
                curr_neighbours.insert(k, (i, X, y, d))
                del curr_neighbours[K]

# Report elapsed time
elapsed = time.time() - start
print(dist_calls, "distance calls")
print(f"Completed in {elapsed:.02f} seconds")

498 distance calls
Completed in 0.12 seconds


In [108]:
# Record start time
start = time.time()

# kNN implementation
dist_calls = 0
for X0 in X_test:
    possible = [True for i in range(N)]
    curr_neighbours = [None for i in range(K)]

    for i, (X, y) in enumerate(zip(X_train[1:], y_train[1:])):
        if not possible[i]:
            continue
        d = distance(X0, X)
        dist_calls += 1
        #if i > K:
            #for j, X in enumerate(X_train[i+1:]):
            #    if abs(d - train_dist[i][j]) > curr_neighbours[K-1][3]:
            #        possible[j] = False
        for k in range(K):
            if curr_neighbours[k] is None or \
                    curr_neighbours[k][3] > d:
                curr_neighbours.insert(k, (i, X, y, d))
                del curr_neighbours[K]
                
# Report elapsed time
elapsed = time.time() - start
print(dist_calls, "distance calls")
print(f"Completed in {elapsed:.02f} seconds")

49990 distance calls
Completed in 0.08 seconds
