In [1]:
import numpy as np
import os
from sklearn.datasets import fetch_openml
from sklearn.linear_model import SGDClassifier
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import accuracy_score
from sklearn.neighbors import KNeighborsClassifier
from scipy.ndimage.interpolation import shift, rotate

In [41]:
np.random.seed(42)
split_num = 60000

In [2]:
def sort_by_target(mnist):
    reorder_train = np.array(sorted([(target, i) for i, target in enumerate(mnist.target[:split_num])]))[:, 1]
    reorder_test = np.array(sorted([(target, i) for i, target in enumerate(mnist.target[split_num:])]))[:, 1]
    mnist.data[:split_num] = mnist.data[reorder_train]
    mnist.target[:split_num] = mnist.target[reorder_train]
    mnist.data[split_num:] = mnist.data[reorder_test + split_num]
    mnist.target[split_num:] = mnist.target[reorder_test + split_num]

In [3]:
# Fetch and sort data
mnist = fetch_openml('mnist_784', version=1, cache=True)
mnist.target = mnist.target.astype(np.int8)
sort_by_target(mnist)

In [4]:
X = mnist["data"]
y = mnist["target"]

In [5]:
# Split data into train and test sets
X_train, X_test, y_train, y_test = X[:split_num], X[split_num:], y[:split_num], y[split_num:]
shuffle_index = np.random.permutation(split_num)
X_train = X_train[shuffle_index]
y_train = y_train[shuffle_index]

In [16]:
# Params from previous grid search
knn_params = {
    'n_neighbors': 4,
    'weights': 'distance'
}

In [11]:
# Fit KNN without augmented data for comparison
knn = KNeighborsClassifier(**knn_params)
knn.fit(X_train, y_train)

KNeighborsClassifier(n_neighbors=4, weights='distance')

In [12]:
y_pred = knn.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy without augmentation: {accuracy}')

Accuracy without augmentation: 0.9714


In [32]:
# Helper functions for shifting and rotating images
def shift_image(image, dx, dy):
    image = image.reshape((28, 28))
    shifted_image = shift(image, [dy, dx], cval=0, mode="constant")
    return shifted_image.reshape([-1])

def rotate_image(image, deg):
    image = image.reshape((28, 28))
    rotated_image = rotate(image, deg, cval=0, mode="constant")
    return rotated_image.reshape([-1])

In [33]:
# Convert to python list for easy manipulation
X_train_aug_list = [image for image in X_train]
y_train_aug_list = [label for label in y_train]

pxl_amt = 1
shift_vals = ((pxl_amt, 0), (-pxl_amt, 0), (0, pxl_amt), (0, -pxl_amt))
rotate_vals = [90, 180, 270]

# Loop through all training data, adding shifted and rotate images to the dataset
for image, label in zip(X_train, y_train):
    for dx, dy in shift_vals:
        X_train_aug_list.append(shift_image(image, dx, dy))
        y_train_aug_list.append(label)
    for deg in rotate_vals:
        X_train_aug_list.append(rotate_image(image, deg))
        y_train_aug_list.append(label)

# Convert back to numpy array
X_train_aug = np.array(X_train_aug_list)
y_train_aug = np.array(y_train_aug_list)

In [37]:
# Fit KNN with augmented data
knn_aug = KNeighborsClassifier(**knn_params)
knn_aug.fit(X_train_aug, y_train_aug)

KNeighborsClassifier(n_neighbors=4, weights='distance')

In [40]:
y_pred_aug = knn_aug.predict(X_test)
accuracy_aug = accuracy_score(y_test, y_pred_aug)
print(f'Accuracy with augmentation: {accuracy_aug}')

Accuracy with augmentation: 0.9721
