In [1]:
import pandas as pd
import numpy as np

from sklearn.datasets import fetch_openml
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import cross_val_score, cross_val_predict, GridSearchCV, KFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from scipy import ndimage

RANDOM_SEED = 2023

mnist = fetch_openml('mnist_784', as_frame=False, parser='auto')


X, y = mnist.data, mnist.target
print(f'X shape: {X.shape}')
print(f'y shape: {y.shape}')

X shape: (70000, 784)
y shape: (70000,)


In [11]:
def shift_image(img, direction):
    shift_val = []
    img_shaped = img.reshape(28, 28)
    match direction:
        case 'up':
            shift_val = [-1, 0]
        case 'down':
            shift_val = [1, 0]
        case 'left':
            shift_val = [0, -1]
        case 'right':
            shift_val = [0, 1]
    
    new_img = ndimage.shift(img_shaped, shift=shift_val, cval=0)
    return new_img.reshape(784)


def plot_digit(image_data):
    image = image_data.reshape(28, 28)
    plt.imshow(image, cmap='binary')
    plt.axis("off")

# some_digit = X[0] # first image
# plot_digit(some_digit)
# plt.show()

# print(f'The image label: {y[0]}')
        




In [22]:
directions = ['up', 'down', 'left', 'right']
new_arrays = [X]
for d in directions:
    new_images = np.apply_along_axis(shift_image, 1, X, direction=d)
    new_arrays.append(new_images)

print(len(new_arrays))
X_aug = np.concatenate(tuple(new_arrays), axis=0)
print(X_aug.shape)

Y_aug = np.concatenate((y, y, y, y, y), axis=0)
print(Y_aug.shape)

5
(350000, 784)
(3920,)
