# MNIST Classifier using KNN

This is an assignment from chapter 3 of the book "Hands-On Machine Learning with Scikit-Learn and TensorFlow" by Aurelien Geron. The assignment is to use GridSearchCV to find the best hyperparameters for a KNN classifier that achieve above 97% accuracy on the MNIST dataset. Then to augment the data to improve that accuracy further.

Here we load in the MNIST dataset

In [None]:
import pandas as pd 
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1)

Here we split the data into training and test sets

In [3]:
from sklearn.model_selection import train_test_split

X = mnist['data']
y = mnist['target']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=42)

print("X_train shape: ", X_train.shape)
print("y_train shape: ", y_train.shape)
print("X_test shape: ", X_test.shape)
print("y_test shape: ", y_test.shape)

X_train shape:  (56000, 784)
y_train shape:  (56000,)
X_test shape:  (14000, 784)
y_test shape:  (14000,)


We then train a KNN classifier on the training data using grid searccv to find the best hyperparameters.

K Nearest Neighbors Classifier clf (KNN) is an instance based machine learning algorithm for classifcation. It works to classify by memorizing the entire training set and comparing the example that it is trying to classify to the instances closest to it. The amount of most similar instances it compares is determined by the hyperparameter k. Hence the name. This is diffferent than model based learning algorithms which extract a model that is able to predict without the use of data.

In [4]:
from sklearn.model_selection import GridSearchCV
import numpy as np
from sklearn.neighbors import KNeighborsClassifier

knn = KNeighborsClassifier()



param_grid = {
    'weights': ['uniform', 'distance'],
    'n_neighbors': [3, 4, 5]
}

grid_search = GridSearchCV(knn, param_grid, cv=5, verbose=3)

grid_search.fit(X_train, y_train)

Fitting 5 folds for each of 6 candidates, totalling 30 fits
[CV 1/5] END ....n_neighbors=3, weights=uniform;, score=0.969 total time=   7.3s
[CV 2/5] END ....n_neighbors=3, weights=uniform;, score=0.969 total time=   8.9s
[CV 3/5] END ....n_neighbors=3, weights=uniform;, score=0.972 total time=  10.0s
[CV 4/5] END ....n_neighbors=3, weights=uniform;, score=0.971 total time=   9.6s
[CV 5/5] END ....n_neighbors=3, weights=uniform;, score=0.970 total time=  10.7s
[CV 1/5] END ...n_neighbors=3, weights=distance;, score=0.971 total time=  10.8s
[CV 2/5] END ...n_neighbors=3, weights=distance;, score=0.970 total time=  10.9s
[CV 3/5] END ...n_neighbors=3, weights=distance;, score=0.973 total time=  11.3s
[CV 4/5] END ...n_neighbors=3, weights=distance;, score=0.972 total time=  11.3s
[CV 5/5] END ...n_neighbors=3, weights=distance;, score=0.971 total time=  11.4s
[CV 1/5] END ....n_neighbors=4, weights=uniform;, score=0.967 total time=  10.9s
[CV 2/5] END ....n_neighbors=4, weights=uniform;,

In [5]:
from sklearn.metrics import accuracy_score

print("Best Parameters: ", grid_search.best_params_)
print("Best Score: ", grid_search.best_score_)

y_pred = grid_search.predict(X_test)
accuracy_score(y_test, y_pred)

Best Parameters:  {'n_neighbors': 4, 'weights': 'distance'}
Best Score:  0.9721964285714286


0.9731428571428572

97.3% accuracy

# Data Augentation to Further Improve Accuracy

Here we augment the data by shifting the images in each direction by one pixel and adding the new images to the training set. We then train a new KNN classifier on the augmented data.

In [None]:
import numpy as np
from scipy.ndimage.interpolation import shift


def shift_mnist(mnist):
    
    shifted_down = []
    shifted_up = []
    shifted_right = []
    shifted_left = []
    
    shifted_to_add = [shifted_down, shifted_up, shifted_right, shifted_left]
    
    shifts = {
        'down':(1, 0),
        'up':(-1, 0),
        'right':(0, 1),
        'left':(0, -1)
    }
    
    for i, direction in enumerate(shifts.values()):

        for j in range(len(mnist['data'])):

            image_to_shift = mnist['data'].iloc[j]
            image_to_shift = image_to_shift.values.reshape(28, 28) #pandas dataframe to numpy array

            shifted = shift(image_to_shift, shift=direction, mode='constant', cval=0) 
            
            shifted_to_add[i].append(shifted.ravel()) #Ravel() is used to flattne the image back into input data for model 
            
    return shifted_to_add #Returns a list of 4 list with the augmented data for each direction
            
        
shift_images = shift_mnist(mnist)

Here we concatenate out augmented data with the original training data.

In [7]:
import numpy as np

#Concatenate original data and the shifted data
original_data = mnist['data']
all_shifted_data = np.concatenate(shift_images, axis=0)

augmented_data = np.concatenate((original_data, all_shifted_data), axis=0)

#Concatenate the original target values and the shifted target values
original_targets = mnist['target']
all_shifted_targets = np.tile(original_targets, 4) # Repeats the targwets 4 times (matching each direction)

augmented_targets = np.concatenate((original_targets, all_shifted_targets), axis=0)

#Create a new dictionary with the augmented data and tartgets

augmented_data = {
    'data': augmented_data,
    'target': augmented_targets
}

Here we fit the new model with the augmented data

In [8]:
Aug_X_Train = augmented_data['data']
Aug_y_Train = augmented_data['target']

#Best Parameters:  {'n_neighbors': 4, 'weights': 'distance'}

knn_with_aug = KNeighborsClassifier(n_neighbors=4, weights='distance')

knn_with_aug.fit(Aug_X_Train, Aug_y_Train)

Here we evaluate the new model on the test set

In [9]:
from sklearn.metrics import accuracy_score
from sklearn.model_selection import cross_val_predict


y_aug_predict = knn_with_aug.predict(X_test)

acc = accuracy_score(y_test, y_aug_predict)
print(acc)




1.0


93.2% accuracy -> 100% accuracy

The improvement can be attributed to the large increase in data. This aids the KNN classifier because it is an instance based learning algorithm. It memorizes the training data and compares the example it is trying to classify to the instances closest to it. The more data it has to compare to the more accurate it will be.