# Imports

In [1]:
import numpy as np
from scipy.ndimage import affine_transform, shift
from sklearn import svm
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from typing import Tuple, List
from tensorflow.keras.datasets import mnist

# Functions Definitions

## sklearn SVM

In [2]:

def load_mnist_data() -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Loads the MNIST dataset from Keras.

    Returns:
        Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        A tuple containing:
            - train_images (np.ndarray): Array of shape (60000, 28, 28)
            - train_labels (np.ndarray): Array of shape (60000,)
            - test_images  (np.ndarray): Array of shape (10000, 28, 28)
            - test_labels  (np.ndarray): Array of shape (10000,)
    """
    (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
    return train_images, train_labels, test_images, test_labels

def deskew_image(img: np.ndarray) -> np.ndarray:
    """
    Deskews a single 28x28 grayscale digit image using moment analysis.
    
    Args:
        img (np.ndarray): A 28x28 grayscale image array.
        
    Returns:
        np.ndarray: A deskewed 28x28 image array.
    """
    # Compute the image moments
    y, x = np.indices(img.shape)
    total_mass = img.sum()
    if total_mass == 0:
        return img  # Avoid division by zero

    x_center = (x * img).sum() / total_mass
    y_center = (y * img).sum() / total_mass

    # Compute second-order central moments
    mu_xx = ((x - x_center) ** 2 * img).sum() / total_mass
    mu_yy = ((y - y_center) ** 2 * img).sum() / total_mass
    mu_xy = ((x - x_center) * (y - y_center) * img).sum() / total_mass

    if mu_yy == 0:
        return img  # No skew detected

    skew = mu_xy / mu_yy  # Skew factor

    # Define the transformation matrix
    M = np.array([[1, skew, -skew * x_center], [0, 1, 0]])

    # Apply affine transformation
    deskewed_img = affine_transform(img, M, offset=0, order=1, mode='constant', cval=0)
    
    return deskewed_img

def deskew_dataset(images: np.ndarray) -> np.ndarray:
    """
    Deskews each image in an entire dataset.
    
    Args:
        images (np.ndarray): A batch of images of shape (num_samples, 28, 28).
        
    Returns:
        np.ndarray: A batch of deskewed images of shape (num_samples, 28, 28).
    """
    return np.array([deskew_image(img) for img in images])

def jitter_image(img: np.ndarray, max_shift: int = 2) -> np.ndarray:
    """
    Randomly shifts (jitters) a single 28x28 grayscale digit image.
    
    Args:
        img (np.ndarray): A 28x28 grayscale image array.
        max_shift (int, optional): Maximum pixel shift in both x and y directions. Defaults to 2.
        
    Returns:
        np.ndarray: A 28x28 image array that has been shifted randomly.
    """
    shift_x = np.random.randint(-max_shift, max_shift + 1)
    shift_y = np.random.randint(-max_shift, max_shift + 1)
    
    return shift(img, shift=(shift_y, shift_x), mode='constant', cval=0)

def jitter_dataset(
    images: np.ndarray, 
    labels: np.ndarray, 
    num_copies: int = 1, 
    max_shift: int = 2
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Applies jitter augmentation to an entire dataset. 
    For each image, creates additional copies with random shifts.
    
    Args:
        images (np.ndarray): Image batch of shape (num_samples, 28, 28).
        labels (np.ndarray): Label array of shape (num_samples,).
        num_copies (int, optional): Number of jittered copies to create for each image. Defaults to 1.
        max_shift (int, optional): Maximum pixel shift in x and y directions. Defaults to 2.
        
    Returns:
        Tuple[np.ndarray, np.ndarray]: Augmented dataset and corresponding labels.
    """
    augmented_images = []
    augmented_labels = []
    
    for i in range(len(images)):
        # Add original
        augmented_images.append(images[i])
        augmented_labels.append(labels[i])
        
        # Add jittered copies
        for _ in range(num_copies):
            jittered_img = jitter_image(images[i], max_shift)
            augmented_images.append(jittered_img)
            augmented_labels.append(labels[i])
    
    return np.array(augmented_images, dtype=np.float32), np.array(augmented_labels, dtype=np.int64)

def flatten_normalize(images: np.ndarray) -> np.ndarray:
    """
    Flattens 28x28 images into 1D vectors of length 784 
    and normalizes pixel values to the range [0, 1].
    
    Args:
        images (np.ndarray): A batch of images of shape (num_samples, 28, 28).
        
    Returns:
        np.ndarray: A batch of flattened and normalized images 
                    of shape (num_samples, 784).
    """
    return images.reshape(images.shape[0], -1) / 255.0

def build_svm_classifier(kernel: str = 'poly', degree: int = 9) -> svm.SVC:
    """
    Builds and returns an SVM classifier using scikit-learn's SVC with a specified kernel.
    
    Args:
        kernel (str, optional): The kernel type to be used in the algorithm. Defaults to 'poly'.
        degree (int, optional): Degree of the polynomial kernel function ('poly'). Defaults to 9.
        
    Returns:
        svm.SVC: An SVM classifier with the given kernel parameters.
    """
    return svm.SVC(kernel=kernel, degree=degree, gamma='auto', verbose=1)


In [3]:

def main() -> None:
    """
    Main pipeline for:
        1. Loading and deskewing the MNIST dataset.
        2. Augmenting the training data with random jitter.
        3. Flattening and normalizing the image data.
        4. Building and training a polynomial SVM (degree=9).
        5. Evaluating on the test set.
    """
    # 1. Load MNIST data
    train_images, train_labels, test_images, test_labels = load_mnist_data()

    # 2. Deskew all images
    train_images = deskew_dataset(train_images)
    test_images = deskew_dataset(test_images)

    # 3. Apply jitter augmentation to the training set
    # train_images_aug, train_labels_aug = jitter_dataset(train_images, train_labels, num_copies=1, max_shift=2)
    # print(f"Original training size: {len(train_images)}, Augmented size: {len(train_images_aug)}")

    # 4. Flatten and normalize
    X_train = flatten_normalize(train_images) # noam: removed the _aug
    y_train = train_labels # noam: removed the _aug
    X_test = flatten_normalize(test_images)
    y_test = test_labels

    # 5. Build SVM classifier with polynomial kernel (degree=9)
    classifier = build_svm_classifier(kernel='rbf')

    print("Training SVM... This may take a while.")
    classifier.fit(X_train, y_train)

    # 6. Evaluate
    y_pred = classifier.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    print(f"SVM (poly, degree=9) test accuracy: {accuracy * 100:.2f}%")

if __name__ == "__main__":
    main()

Training SVM... This may take a while.
[LibSVM]SVM (poly, degree=9) test accuracy: 92.74%
