In [1]:
import numpy as np
from sklearn import datasets
from collections import Counter

In [2]:
def vote(neighbors):
    class_counter = Counter()
    for neighbor in neighbors:
        class_counter[neighbor[2]] += 1
    return class_counter.most_common(1)[0][0]

def distance(instance1, instance2):
    # just in case, if the instances are lists or tuples:
    instance1 = np.array(instance1) 
    instance2 = np.array(instance2)
    
    return np.linalg.norm(instance1 - instance2)
def get_neighbors(training_set, 
                  labels, 
                  test_instance, 
                  k):
    """
    get_neighors calculates a list of the k nearest neighbors
    of an instance 'test_instance'.
    The list neighbors contains 3-tuples with  
    (index, dist, label)
    where 
    index    is the index from the training_set, 
    dist     is the distance between the test_instance and the 
             instance training_set[index]
    distance is a reference to a function used to calculate the 
             distances
    """
    distances = []
    for index in range(len(training_set)):
        dist = distance(test_instance, training_set[index])
        distances.append((training_set[index], dist, labels[index]))
    distances.sort(key=lambda x: x[1])
    neighbors = distances[:k]
    return neighbors

In [11]:
digits = datasets.load_digits()
digits_X = digits.data
digits_Y = digits.target

np.random.seed(0)
indices = np.random.permutation(len(digits_X))

digits_X_train = digits_X[indices[:-10]]
digits_y_train = digits_Y[indices[:-10]]
digits_X_test  = digits_X[indices[-10:]]
digits_y_test  = digits_Y[indices[-10:]]

for i in range(10):
    neighbors = get_neighbors(digits_X_train, 
                              digits_y_train, 
                              digits_X_test[i], 
                              10)
    print("index: ", i, 
          ", result of vote: ", vote(neighbors), 
          ", label: ", digits_y_test[i])

index:  0 , result of vote:  1 , label:  1
index:  1 , result of vote:  4 , label:  4
index:  2 , result of vote:  8 , label:  8
index:  3 , result of vote:  4 , label:  4
index:  4 , result of vote:  5 , label:  5
index:  5 , result of vote:  3 , label:  3
index:  6 , result of vote:  3 , label:  3
index:  7 , result of vote:  7 , label:  7
index:  8 , result of vote:  7 , label:  7
index:  9 , result of vote:  8 , label:  8
