#### Imports

In [1]:
import numpy as np
from scipy.stats import mode
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML, Image
plt.style.use('ggplot')
params = {'legend.fontsize': '18',
          'axes.labelsize': '20',
          'axes.labelweight': 'bold',
          'axes.titlesize':'20',
          'xtick.labelsize':'18',
          'ytick.labelsize':'18'}
plt.rcParams.update(params)

#### Functions

In [2]:
def closest_node(node, nodes, n=1):
    """
    Find the closest point in a list to a query point.

    Parameters
    ----------
    node : array
        query point
    nodes: array
        array of points to compare to
    n: int
        return n closest pairs

    Returns
    -------
    int
        The index of the closest point in the list
    """
    nodes = np.asarray(nodes)
    dist = np.sum((nodes - node)**2, axis=1)
    
    return np.argsort(dist)[:n]

def mk_fig():
    """
    Convenience function to plot figure canvas

    Returns
    -------
    fig, axes
        Figure and axes objects
    """
    fig, axes = plt.subplots(1, 1, figsize=(7, 7))
    axes.set_xlim(-1, 11)
    axes.set_ylim(-1, 11)
    axes.set_xlabel('X1')
    axes.set_ylabel('X2')

    return fig, axes

#### Data

In [3]:
np.random.seed(11)
X = np.array([np.random.randint(0, 10, 9),
              np.random.randint(0, 10, 9)]).T
y = np.random.randint(0, 2, 9)
xq = np.array([6, 3])  # query point

#### Create and save animations

In [4]:
fig, axes = mk_fig()
mask = y == 0
colors = ['#E24A33', '#348ABD']

def init():
    axes.scatter(X[mask, 0], X[mask, 1], s=100, c='#E24A33', label='Class 0', zorder=2)
    axes.scatter(X[~mask, 0], X[~mask, 1], s=100, c='#348ABD', label='Class 1', zorder=2)
    axes.scatter(xq[0], xq[1], s=200, facecolor='None', edgecolor='k', lw=2, label='Unknown', zorder=2)
    axes.legend(facecolor='#F0F0F0', framealpha=1)

def animate(i):
    if i == 0:
        axes.scatter(xq[0], xq[1], s=200, facecolor='None', edgecolor='k', lw=2, label='Unknown', zorder=2)
    if i == 1:
        k = closest_node(xq, X, i)
        axes.plot([xq[0], X[k,0]],
                  [xq[1], X[k,1]],
                  'k-', zorder=1)
        axes.scatter(xq[0], xq[1]+0.01, s=200, facecolor=colors[y[k[0]]], edgecolor='k', lw=2, zorder=2)
    elif i % 2 == 1:
        k = closest_node(xq, X, i)
        axes.plot([np.repeat(xq[0], 2), X[k[-2:],0]],
                  [np.repeat(xq[1], 2), X[k[-2:],1]],
                  'k-', zorder=1)
        if sum(y[k] == 0) > sum(y[k] == 1):
            axes.scatter(xq[0], xq[1]+0.01, s=200, facecolor=colors[0], edgecolor='k', lw=2, zorder=2)
        elif sum(y[k] == 1) > sum(y[k] == 0):
            axes.scatter(xq[0], xq[1]+0.01, s=200, facecolor=colors[1], edgecolor='k', lw=2, zorder=2)
        else:  # if equal counts, set to closest point's color
            axes.scatter(xq[0], xq[1]+0.01, s=200, facecolor=colors[y[k[0]]], edgecolor='k', lw=2, zorder=2)
            
plt.close(fig)
ani = animation.FuncAnimation(fig,
                              animate,
                              init_func=init,
                              frames=10,
                              interval=600)
ani.save('../gif/knn/knn.gif', writer='imagemagick', fps=1, dpi=75)
# HTML(ani.to_jshtml())

#### View animations

In [None]:
Image(url='../gif/knn/knn.gif')