In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

## Read in and create training and testing sets

Let's read in the data, build the X and y matrices, and then create the training and testing sets.

In [2]:
df = pd.read_csv("datasets/classification.csv")
X = np.c_[df["age"], df["interest"]]
y = np.array(df["success"])
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42, test_size=0.2)

## K-Fold Cross Validation

Now, rather than creating a single validation set, let's use k-fold cross validation with k=5.

In [3]:
from sklearn.model_selection import KFold
splits = KFold(n_splits=5)

## Hyperparameter testing with k-fold cross validation

Let's figure out the number of neighbors using k-fold cross validation.

In [6]:
for n in range(1,16,2):
    scores = []
    for train_index, val_index in splits.split(X_train):
        neigh = KNeighborsClassifier(n_neighbors=n)
        X_subtrain = X_train[train_index]
        X_val = X_train[val_index]
        y_subtrain = y_train[train_index]
        y_val = y_train[val_index]
        neigh.fit(X_subtrain, y_subtrain)
        y_predict = neigh.predict(X_val)
        scores.append(accuracy_score(y_val, y_predict))
    print(n, np.mean(scores))

1 0.864804964539007
3 0.9323581560283689
5 0.9156028368794326
7 0.911436170212766
9 0.9156028368794326
11 0.9156028368794326
13 0.9113475177304965
15 0.9071808510638298


## Building our classifier

Now we can build our classifier from the number of neighbors we found through k-fold cross validation. 

In [7]:
neigh = KNeighborsClassifier(n_neighbors=3)
neigh.fit(X_train, y_train)


## Testing our model

We're ready to test/evaluate our model on the testing set!

In [8]:
y_predict = neigh.predict(X_test)
print(accuracy_score(y_test, y_predict))

0.9333333333333333


In [9]:
y_predict = neigh.predict(X_train)
print(accuracy_score(y_train, y_predict))

0.9493670886075949
