**Supervised 4**

# K-Nearest Neighbours Classification

## Part 1. A simple introduction to machine learning classification

<br>

**What is classification?** Classification is about putting things into categories. In policy analysis, we may want to classify:
- Which programs will succeed or fail
- Which regions have similar characteristics  
- Which interventions are most needed

<br>
<br>

This notebook introduces K-Nearest Neighbours (KNN) - the simplest classification algorithm.

**The KNN Idea:** To classify something new, look at its nearest neighbours and vote.
- If most of your neighbours are Category A, you're probably Category A too
- It's like saying "birds of a feather flock together"

<br>

---

<br>

**Setup:** Import required libraries

In [1]:
import pandas as pd         # For data manipulation
import numpy as np          # For numerical operations
import altair as alt        # For plotting our results
alt.data_transformers.disable_max_rows()  # To avoid Altair row limits
import matplotlib.pyplot as plt     # For plotting with Matplotlib 

from sklearn.neighbors import KNeighborsClassifier, NearestNeighbors
from sklearn.model_selection import train_test_split


---

## Our Data: US States Policy Outcomes

We'll use a simple question: Can we tell if a state is from the **South** or **Northeast** based on just two factors?
- Median income
- Firearm death rate

(This is deliberately oversimplified to focus on learning the method!)

In [2]:
data = pd.read_csv('https://raw.githubusercontent.com/RDeconomist/RDeconomist.github.io/main/charts/usa/data_USsocioEconomic.csv')
data.head()

Unnamed: 0,State,StateInitials,Gini,DeathRate,Firearms_vs_avg,medIncome,Income_vs_med,ImprisonmentRate,PrisonRate,ImprisonmentRate.1,FirearmDeaths,GeographicDivision
0,Alabama,AL,0.472,21.5,1.647005,47221,0.799827,736,1.264605,0.736,0.215,South
1,Alaska,AK,0.422,23.3,1.784894,75723,1.282593,376,0.646048,0.376,0.233,West
2,Arizona,AZ,0.455,15.2,1.164394,57100,0.967157,764,1.312715,0.764,0.152,West
3,Arkansas,AR,0.458,17.8,1.363567,45907,0.777571,763,1.310997,0.763,0.178,South
4,California,CA,0.471,7.9,0.605178,66637,1.128695,430,0.738832,0.43,0.079,West


In [3]:
print(data['GeographicDivision'].unique())

['South' 'West' 'Northeast' 'Midwest']


<br>
<br>

Keep only Southern and North-Eastern states

In [4]:
# Keep only South and Northeast states
states = data[data['GeographicDivision'].isin(['South', 'Northeast'])].copy()

# Show what we're working with
print(f"We have {len(states)} states total")
print(f"  • {sum(states['GeographicDivision'] == 'South')} Southern states")
print(f"  • {sum(states['GeographicDivision'] == 'Northeast')} Northeastern states")

We have 25 states total
  • 16 Southern states
  • 9 Northeastern states


<br>
<br>

---

<br>
<br>

## Step 1: Visualise the data

Let's see if Southern and Northeastern states naturally cluster together:

In [5]:
# Create a scatter plot
base = alt.Chart(data).encode(
    x=alt.X('medIncome:Q').scale(zero=False, padding=20).title('Median Income ($)'),
    y=alt.Y('DeathRate:Q').scale(zero=False, padding=20).title('Firearm Deaths per 100,000 people'),
    color=alt.Color('GeographicDivision:N').scale(domain=['South', 'Northeast', 'West', 'Midwest'], range=['#e74c3c', '#3498db', "#f1b23d", "#4cd540"]),
    tooltip=['State:N', 'GeographicDivision:N']
).properties(
    width=300, height=300
)

# Left plot will be all the states
scatter_all = base.mark_circle(size=100, opacity=0.8)

# right plot will filter to South and Northeastern states
scatter_filter = base.mark_circle(size=100, opacity=0.8).transform_filter(
    alt.FieldOneOfPredicate(field='GeographicDivision', oneOf=['South', 'Northeast'])
)

alt.hconcat(scatter_all, scatter_filter).properties(
    title='Can we tell regions apart by income and firearm deaths?'
)

<br>

**What do we notice?** the Southern and Northeastern stats seem to cluster.

<br>
<br>

## Step 2: Prepare data

ML algorithms need data in a specific format:
- **X**: The features (things we measure)
- **y**: The labels (what we're trying to predict)

In [6]:
# Select our features (X) and labels (y)
features = ['medIncome', 'DeathRate']
X = states[features]
y = states['GeographicDivision']

In [7]:
print("Our features (X) - what we know:")
print(X.head())
print(f"\nShape: {X.shape} means {X.shape[0]} states and {X.shape[1]} features")

print("\n" + "="*50)
print("Our labels (y) - what we want to predict:")
print(y.head())

Our features (X) - what we know:
   medIncome  DeathRate
0      47221       21.5
3      45907       17.8
6      75923        4.6
7      58046       11.0
8      51176       12.6

Shape: (25, 2) means 25 states and 2 features

Our labels (y) - what we want to predict:
0        South
3        South
6    Northeast
7        South
8        South
Name: GeographicDivision, dtype: object


<br>
<br>
<br>

## Step 3: Train KNN classifier

K-Nearest Neighbours is very simple:
1. To classify a new point, find its K nearest neighbors
2. Let them vote on what category it belongs to
3. The majority wins!

Let's start with `K=3` (look at the 3 nearest neighbors):

In [8]:
# Create and train the KNN classifier
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X, y)

print(f"Model fitted on {len(X)} states")

Model fitted on 25 states


<br>
<br>

## Step 4: Make predictions

Now let's see if our model can correctly classify the states:

In [9]:
# Make predictions on the input data
predictions = knn.predict(X)

# Add the predictions back to our main dataframe, and a True/False column
states['Prediction'] = predictions
states['Correct'] = states['GeographicDivision'] == states['Prediction']

print(predictions)
states.sample(3)

['South' 'South' 'Northeast' 'South' 'South' 'South' 'South' 'South'
 'South' 'Northeast' 'Northeast' 'South' 'Northeast' 'Northeast'
 'Northeast' 'South' 'South' 'Northeast' 'Northeast' 'South' 'South'
 'South' 'Northeast' 'Northeast' 'South']


Unnamed: 0,State,StateInitials,Gini,DeathRate,Firearms_vs_avg,medIncome,Income_vs_med,ImprisonmentRate,PrisonRate,ImprisonmentRate.1,FirearmDeaths,GeographicDivision,Prediction,Correct
37,Pennsylvania,PA,0.461,12.0,0.919258,60979,1.03286,484,0.831615,0.484,0.12,Northeast,Northeast,True
38,Rhode Island,RI,0.467,4.1,0.31408,61528,1.042159,239,0.410653,0.239,0.041,Northeast,Northeast,True
28,New Hampshire,NH,0.425,9.3,0.712425,76260,1.291689,262,0.450172,0.262,0.093,Northeast,Northeast,True


<br>

Calculate an accuracy score

In [10]:
# Calculate accuracy
accuracy = states['Correct'].mean()
print(f"\nModel accuracy: {accuracy:.1%}")

print(f"(correctly classified {sum(states['Correct'])} out of {len(states)} states)")


Model accuracy: 88.0%
(correctly classified 22 out of 25 states)


<br>
<br>
<br>

### Step 5: How KNN makes decisions

In [30]:
# Pick a test state to examine
example_idx = X.index[0]       # Just picking the first state in the test set
example_state = states.loc[example_idx, 'State']
example_features = X.loc[example_idx].values.reshape(1, -1) # This gives an array
# Add back feature name and convert to dataframe (not strictly necessary, but cleaner to read)
example_features = pd.DataFrame(example_features, columns=features)

# Find its 3 nearest neighbors in the training set
distances, indices = knn.kneighbors(example_features, n_neighbors=3+1)  # +1 to include itself

print(f"How has KNN classified {example_state}:")
print(f"   Income: ${X.loc[example_idx, 'medIncome']:,.0f}")
print(f"   Firearm Deaths: {X.loc[example_idx, 'DeathRate']:.1f} per 100k")

How has KNN classified Alabama:
   Income: $47,221
   Firearm Deaths: 21.5 per 100k


In [32]:
print(f"\nIts 3 nearest neighbors are (+ itself):")
for i, (idx, dist) in enumerate(zip(indices[0], distances[0])):
    neighbor_state = states.loc[X.index[idx], 'State']
    neighbor_region = y.iloc[idx]
    print(f"  {i+1}. {neighbor_state} ({neighbor_region}) - distance: {dist:.0f}")


Its 3 nearest neighbors are (+ itself):
  1. Alabama (South) - distance: 0
  2. Arkansas (South) - distance: 1314
  3. Kentucky (South) - distance: 1852
  4. West Virginia (South) - distance: 2867


In [29]:
print(f"\nMajority vote → Predicted as: {predictions[0]}")
print(f"Actually is: {y.iloc[0]}")


Majority vote → Predicted as: South
Actually is: South


<br>
<br>
<br>


### Step 6. Visualise predictions

Let's check how each state is being classified by visualising its neighbours

(Don't worry about knowing all this code, we'll just use it here to create a nice interactive visualisation)

In [15]:
# Create connections dataframe for visualisation
# For each state, find its k nearest neighbours
k=3
nbrs = NearestNeighbors(n_neighbors=k+1)  # +1 because it includes itself. NOTE: this just searches for our nearest neighbours
nbrs.fit(X)
distances, indices = nbrs.kneighbors(X)

# Build connections dataframe
connections = []
for i, state_name in enumerate(states['State'].values):
    # Skip the first neighbour (itself) and take the next k
    for j in range(1, k+1):
        neighbour_idx = indices[i][j]
        neighbour_distance = distances[i][j]
        
        connections.append({
            'origin': state_name,
            'origin_income': states.iloc[i]['medIncome'],
            'origin_deathrate': states.iloc[i]['DeathRate'],
            'origin_region': states.iloc[i]['GeographicDivision'],
            'destination': states.iloc[neighbour_idx]['State'],
            'destination_income': states.iloc[neighbour_idx]['medIncome'],
            'destination_deathrate': states.iloc[neighbour_idx]['DeathRate'],
            'destination_region': states.iloc[neighbour_idx]['GeographicDivision'],
            'distance': neighbour_distance,
            'neighbour_rank': j  # 1st, 2nd, or 3rd nearest neighbour
        })

connections_df = pd.DataFrame(connections)

Create the chart, with an interactive hover showing the closest `k` neighbours using the distances calculated in our KNN classification model. 

In [16]:
# Create a selection that captures the hovered state NAME (not the State field)
hover_selection = alt.selection_point(
    on='pointerover',
    nearest=True,
    fields=['origin'],  # Use 'origin' to match connections_df
    empty=False
)

# Transform the states data to have an 'origin' field for consistency
states_for_hover = states.copy()
states_for_hover['origin'] = states_for_hover['State']

# Base points with selection
hover_highlight = alt.Chart(states_for_hover).mark_circle(
    size=80,
    # fill=alt.expr("datum.GeographicDivision == datum.Prediction ? 'red' : 'green'")
    fill=alt.expr("datum.GeographicDivision == 'South' ? '#e41a1c' : '#377eb8'"),
    stroke=alt.expr("datum.GeographicDivision == datum.Prediction ? datum.GeographicDivision == 'South' ? '#e41a1c' : '#377eb8' : datum.GeographicDivision == 'South' ? '#377eb8' : '#e41a1c'")
).encode(
    x=alt.X('medIncome:Q').scale(zero=False).title('Median Household Income ($)'),
    y=alt.Y('DeathRate:Q').title('Firearm Death Rate (per 100,000)'),
    color=alt.Color('GeographicDivision:N').scale(
            domain=['South', 'Northeast'], range=['#e41a1c', '#377eb8']).title('Region').legend(symbolStrokeWidth=0),
    # stroke=alt.Stroke('GeographicDivision:N').scale(
    #         domain=['South', 'Northeast'], range=['#e41a1c', '#377eb8']).legend(None),
    size=alt.condition(hover_selection, alt.value(250), alt.value(80)),
    tooltip=[
        alt.Tooltip('State:N', title='State'),
        alt.Tooltip('GeographicDivision:N', title='Actual Region'),
        alt.Tooltip('Prediction:N', title='Predicted Region'),
        alt.Tooltip('medIncome:Q', title='Median Income', format='$,.0f'),
        alt.Tooltip('DeathRate:Q', title='Death Rate', format='.1f')
    ]
).add_params(hover_selection)

# Lines filtered by selection
lines = alt.Chart(connections_df).mark_rule(strokeWidth=2, opacity=0.6).encode(
    x='origin_income:Q',
    y='origin_deathrate:Q',
    x2='destination_income:Q',
    y2='destination_deathrate:Q',
    size=alt.Size('neighbour_rank:O').scale(
        domain=[1, 2, 3],
        range=[2.5, 1.5, 0.5]
        ).legend(None),
    stroke=alt.Stroke('neighbour_rank:O').scale(domain=[1, 2, 3], range=['#2b2b2b', '#666666', '#999999']).legend(title='Neighbour Rank')
).transform_filter(
    hover_selection  # This should now work because both use 'origin'
)

# Combine
chart = (lines + hover_highlight).resolve_scale(size='independent').properties(
    title={
        "text": "KNN Nearest Neighbour results",
        "subtitle": ["Hover over a state to see its 3 nearest neighbours", "Outlines indicate predicted region, fill indicates actual region."], 
    }
)
chart.display()

<br>

Interact with the points above - **why are the 'closest' states not always given as the 'nearest' neighbour?**

<br>

We've discovered a fundamental issue in machine learning - **the importance of feature scaling.**

The reason the "3 nearest neighbours" on the chart don't look visually closest is because KNN isn't using visual distance - it's using the raw feature values. 

Since median income ranges from ~$40,000-76,000 while death rates only range from ~3-22, the income differences completely dominate the distance calculation.

<br>

Think of it this way:
- A $5,000 difference in income squared = 25,000,000
- A 5-point difference in death rate squared = 25

> The income term is a million times larger. So KNN essentially ignores death rates and just finds states with similar incomes. (Hence why states are essentially being classified in lines up from median income.)

<br>
<br>
<br>

##### What does this mean for our analysis?
- Most of the time, we need to introduce **scaling** into our analysis by preprocessing our data before fitting the model.
- We also likely have issues with **overfitting**, whereby our model performance (though looking good), wouldn't generalise well to unseen data.

<br>

**Solutions**:
- Introduce scaling to pre-process data (see docs [here](https://scikit-learn.org/stable/modules/preprocessing.html)).
- Introduce train-test split to exclude some data when we train our model: we then use the unseen data when testing the model performance (see docs [here](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html)).


<br>
<br>
<br>
<br>

---

<br>
<br>

