# Nearest neighbour classifier

In [None]:
# Import necessary libraries
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from IPython import display

In [None]:
# load data the butterfly data from csv file
butterflies = pd.read_csv("data/butterflies.csv")

Here is our 2D dataset, with 3 different classes

In [None]:
# plot the species type in color
plt.figure(figsize=(10, 6))
sns.scatterplot(data=butterflies, x="Width", y="Height", hue="Species")

Each datapoint has some <span style="color: deepskyblue;">features</span> and a <span style="color: coral;">class label</span>

In [None]:
plt.figure(figsize=(10, 6))

# plot the species type in color
ax = sns.scatterplot(
    data=butterflies,
    x="Width",
    y="Height",
    hue="Species",
    style="Species",
    s=50,
)

# select a single data point
sample = butterflies.iloc[40]

ax.annotate(
    text="(Width, Height), (Species)",
    xy=(sample.Width + 0.1, sample.Height - 0.05),
    xytext=(sample.Width + 0.5, sample.Height - 0.3),
    fontsize=12,
    arrowprops={
        "width": 1,
        "headwidth": 6,
        "headlength": 6,
        "edgecolor": "black",
        "facecolor": "black",
    },
);

Given a <span style="color: deepskyblue;">new</span> datapoint, how can we determine its <span style="color: coral">class</span>?

In [None]:
# create new test point
test_point = [3.8, 1.6]

In [None]:
# print test point values
print("Test Width: ", test_point[0])
print("Test Height: ", test_point[1])

In [None]:
plt.figure(figsize=(10, 6))

# plot the species type in color
ax = sns.scatterplot(
    data=butterflies,
    x="Width",
    y="Height",
    hue="Species",
    style="Species",
    s=50,
)

# plot the new test point
ax.scatter(x=[test_point[0]], y=[test_point[1]], color="deepskyblue")

# add "new" text and arrow pointing at new test point
ax.annotate(
    text="New",
    xy=(test_point[0] - 0.05, test_point[1] + 0.02),
    xytext=(test_point[0] - 1, test_point[1] + 0.3),
    fontsize=12,
    color="deepskyblue",
    arrowprops={
        "width": 1,
        "headwidth": 6,
        "headlength": 6,
        "edgecolor": "deepskyblue",
        "facecolor": "deepskyblue",
    },
);

Here we will find the nearest point in our dataset to a given test point.

In [None]:
plt.figure(figsize=(10, 6))


# plot the species type in color
ax = sns.scatterplot(
    data=butterflies,
    x="Width",
    y="Height",
    hue="Species",
    style="Species",
    s=50,
)

# plot a line from the test point to each point in the dataset
for butterfly in butterflies.itertuples():
    ax.plot(
        [test_point[0], butterfly.Width],
        [test_point[1], butterfly.Height],
        color="gray",
        linewidth=1,
        alpha=0.5,
        zorder=3,
    )

# plot the test point
ax.scatter(x=[test_point[0]], y=[test_point[1]], color="deepskyblue", s=100, zorder=1)

Computing "similarity" between two points

In [None]:
display.Image("images/euclidean_distance.png")

In [None]:
# Implementation of distance between two points
def compute_distance(point1, point2):
    return np.sqrt((point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2)

In [None]:
# compute the distance from the test point to every example in the dataset
distance = butterflies.apply(
    lambda row: compute_distance(test_point, [row["Width"], row["Height"]]),
    axis=1,
)

distance

In [None]:
# find the point in the dataset that is closes to the test point and record its distance
closest_point = butterflies.iloc[distance.argmin()]
distance_to_closest = distance.min()
print("Distance to closest point: ", distance_to_closest)

In [None]:
# plot training dataset
sns.scatterplot(
    data=butterflies,
    x="Width",
    y="Height",
    hue="Species",
)

# plot the test point as an 'x'
sns.scatterplot(
    x=[test_point[0]],
    y=[test_point[1]],
    marker="x",
    label="test point",
    color="black",
)

# plot a ring around the nearest datapoint
sns.scatterplot(
    x=[closest_point["Width"]],
    y=[closest_point["Height"]],
    marker="o",
    label="nearest training point",
    edgecolor="black",
    facecolor="none",
)

In [None]:
# assume the test point is the same class as the datapoint it is closest to
predicted_species = closest_point["Species"]
print("Predicted species: ", predicted_species)

# Nearest Neighbour Algorithm

1. Given a test point x
2. Compute the distance between x and every other datapoint
3. The class of x is set as the same as the closest datapoint

Let's implement the algorithm as a Python function

In [None]:
def nearest_neighbour(test_point):
    # compute the distance from the test point to every example in the dataset
    distance = butterflies.apply(
        lambda row: compute_distance(test_point, [row["Width"], row["Height"]]),
        axis=1,
    )

    # find the point in the dataset that is closest to the test point
    closest_point = butterflies.iloc[distance.argmin()]

    # assume the test point is the same class as the datapoint it is closest to
    predicted_species = closest_point["Species"]

    return predicted_species, closest_point

Let's try a different point

In [None]:
test_point = [2.1, 0.7]
predicted_species, closest_point = nearest_neighbour(test_point)

# plot training dataset
ax = sns.scatterplot(
    data=butterflies,
    x="Width",
    y="Height",
    hue="Species",
)

# plot the test point as an 'x'
sns.scatterplot(
    x=[test_point[0]],
    y=[test_point[1]],
    marker="x",
    label="test point",
    color="black",
)

# plot a ring around the nearest datapoint
sns.scatterplot(
    x=[closest_point["Width"]],
    y=[closest_point["Height"]],
    marker="o",
    label="nearest training point",
    edgecolor="black",
    facecolor="none",
)

For every point in the space we colour it with the class of the datapoint it is closest to.

In [None]:
display.Image("images/nn_classification.png")