In [13]:
import numpy as np
from scipy import linalg as la
from scipy.spatial import KDTree
from scipy.stats import mode

In [2]:
def exhaustive_search(X, z):
    """Solve the nearest neighbor search problem with an exhaustive search.

    Parameters:
        X ((m,k) ndarray): a training set of m k-dimensional points.
        z ((k, ) ndarray): a k-dimensional target point.

    Returns:
        ((k,) ndarray) the element (row) of X that is nearest to z.
        (float) The Euclidean distance from the nearest neighbor to z.
    """
    dist = la.norm(X - z, axis = 1)
    return X[np.argmin(dist)], min(dist)

In [18]:
def test_exhaustive():
	A = np.array([[1, 2, 3]])
	z = np.array([6, 5, 4])
	assert np.allclose(exhaustive_search(A,z)[0], np.array([1,2,3]))
	assert exhaustive_search(A,z)[1] == 5.9160797830996161

In [22]:
test_exhaustive()

In [3]:
class KDTnode:
    """
    Parameters:
        x ((1,k) ndarray): vector of data

    Attributes:
        value(np.ndarray): vector of data
        left(KDTnode): left child node
        right(KDTnode): right child node
        pivot(int): integer mod k 
    """
    def __init__(self, x):
        """Initialize value, left, right and pivot. 
        Do error checking on x
        """
        if type(x) == np.ndarray:
            self.value = x
            self.left = None
            self.right = None
            self.pivot = None
        else:
            raise ValueError("x must be ndarray")

In [4]:
class KDT:
    """A k-dimensional binary tree for solving the nearest neighbor problem.

    Attributes:
        root (KDTNode): the root node of the tree. Like all other nodes in
            the tree, the root has a NumPy array of shape (k,) as its value.
        k (int): the dimension of the data in the tree.
    """
    def __init__(self):
        """Initialize the root and k attributes."""
        self.root = None
        self.k = None
        self.z = None

    def find(self, data):
        """Return the node containing the data. If there is no such node in
        the tree, or if the tree is empty, raise a ValueError.
        """
        def _step(current):
            """Recursively step through the tree until finding the node
            containing the data. If there is no such node, raise a ValueError.
            """
            if current is None:                     # Base case 1: dead end.
                raise ValueError(str(data) + " is not in the tree")
            elif np.allclose(data, current.value):
                return current                      # Base case 2: data found!
            elif data[current.pivot] < current.value[current.pivot]:
                return _step(current.left)          # Recursively search left.
            else:
                return _step(current.right)         # Recursively search right.

        # Start the recursive search at the root of the tree.
        return _step(self.root)

    # Problem 3
    def insert(self, data):
        """Insert a new node containing the specified data.

        Parameters:
            data ((k,) ndarray): a k-dimensional point to insert into the tree.

        Raises:
            ValueError: if data does not have the same dimensions as other
                values in the tree.
        """
        found = False #used to determin if a duplicate is found

        if type(data) != np.ndarray:
            raise ValueError("Not an array")
        if self.k != None and self.k != len(data): #if not empty check data length
            raise ValueError("Data is not in R^k") 

        #Tree is empty
        if self.root == None:
            self.root = KDTnode(data)
            self.k = len(data)
            self.root.pivot = 0

        #Tree is nonempty
        elif self.root != None:
            #Recursive parent finder/linker
            def _parent(current):
                """Recursively step through graph until a parent with
                null children is found.
                """

                #recurse while children are not null
                if data[current.pivot] < current.value[current.pivot] and current.left != None:
                    return _parent(current.left)
                elif data[current.pivot] >= current.value[current.pivot] and current.right != None:
                    return _parent(current.right)

                #found parent with null children
                else:
                    #Determin which child to insert into
                    if data[current.pivot] < current.value[current.pivot]:
                        current.left = KDTnode(data)
                        current.left.pivot = (current.pivot + 1) % self.k
                    elif data[current.pivot] >= current.value[current.pivot]:
                        current.right = KDTnode(data)
                        current.right.pivot = (current.pivot + 1) % self.k

            #duplicate checking
            try:
                temp = self.find(data)
                if temp != None:
                    found = True

            #if no duplicate is found
            except ValueError as error:
                _parent(self.root)
            if found:
                raise ValueError("No Duplicates")

    def _KDsearch(self, current, nearest, d):
        """Recursive function used in query
        """
        if current is None:
            return nearest, d
        x = current.value
        i = current.pivot

        if la.norm(x - self.z) < d: 
            nearest = current
            d = la.norm(x - self.z)

        if self.z[i] < x[i]:
            nearest, d = self._KDsearch(current.left, nearest, d)
            if self.z[i] + d >= x[i]:
                nearest, d = self._KDsearch(current.right, nearest,d)

        else:
            nearest,d = self._KDsearch(current.right, nearest, d)
            if self.z[i] - d <= x[i]:
                nearest,d = self._KDsearch(current.left,nearest,d)
        return nearest, d

    # Problem 4
    def query(self, z):
        """Find the value in the tree that is nearest to z.

        Parameters:
            z ((k,) ndarray): a k-dimensional target point.

        Returns:
            ((k,) ndarray) the value in the tree that is nearest to z.
            (float) The Euclidean distance from the nearest neighbor to z.
        """
        self.z = z
        a, b = self._KDsearch(self.root, self.root, la.norm(self.root.value - z))
        return a.value, b

    def __str__(self):
        """String representation: a hierarchical list of nodes and their axes.

        Example:                           'KDT(k=2)
                    [5,5]                   [5 5]   pivot = 0
                    /                      [3 2]   pivot = 1
                [3,2]   [8,4]               [8 4]   pivot = 1
                                          [2 6]   pivot = 0
                    [2,6]   [7,5]           [7 5]   pivot = 0'
        """
        if self.root is None:
            return "Empty KDT"
        nodes, strs = [self.root], []
        while nodes:
            current = nodes.pop(0)
            strs.append("{}\tpivot = {}".format(current.value, current.pivot))
            for child in [current.left, current.right]:
                if child:
                    nodes.append(child)
        return "KDT(k={})\n".format(self.k) + "\n".join(strs)

In [5]:
class KNeighborsClassifier:
    """
    Attributes:
        n(int): the number of voting neighbors
        tree(KDTree): the representation of the matrix

    """
    def __init__(self,n_neighbors):
        """Initialize the class with the integer representing the 
        number of neighbors
        """
        self.n = n_neighbors
        self.tree = None
        self.label = None

    def fit(self, X, y):
        """A setter function that load
        """
        self.tree = KDTree(X)
        self.label = y

    def predict(self, z):
        """Accepts an array z, runs KDTree.query on it
        returns most common label
        """
        dists, indices = self.tree.query(z, k = self.n)
        return mode(self.label[indices])[0][0]

In [6]:
def TestCase(n_neighbors, filename="mnist_subset.npz"):
    """Extract the data from the given file. Load a KNeighborsClassifier with
    the training data and the corresponding labels. Use the classifier to
    predict labels for the test data. Return the classification accuracy, the
    percentage of predictions that match the test labels.

    Parameters:
        n_neighbors (int): the number of neighbors to use for classification.
        filename (str): the name of the data file. Should be an npz file with
            keys 'X_train', 'y_train', 'X_test', and 'y_test'.

    Returns:
        (float): the classification accuracy.
    """
    data = np.load("mnist_subset.npz")
    X_train = data["X_train"].astype(np.float)
    y_train = data["y_train"]
    X_test = data["X_test"].astype(np.float)
    y_test = data["y_test"]
    classifier = KNeighborsClassifier(n_neighbors)
    classifier.fit(X_train, y_train)

    num_correct = 0
    for i in range(len(X_test)):
        if y_test[i] == classifier.predict(X_test[i]):
            num_correct += 1
    return num_correct/len(X_test)

In [10]:
TestCase(10)

0.906