## ThunderSVM's SVM

In [None]:
import numpy as np
from typing import Tuple
from sklearn.metrics import accuracy_score
from thundersvm import SVC  # <-- ThunderSVM's SVC instead of sklearn's
from skimage.measure import moments, moments_central
from skimage.transform import AffineTransform, warp
from scipy.ndimage import shift as nd_shift
from tensorflow.keras.datasets import mnist

In [None]:

def load_mnist_data() -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Loads the MNIST dataset from Keras.

    Returns:
        (train_images, train_labels, test_images, test_labels)
        train_images, test_images: Arrays of shape (n, 28, 28).
        train_labels, test_labels: Arrays of shape (n,).
    """
    (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
    return train_images, train_labels, test_images, test_labels

def deskew_image(img: np.ndarray) -> np.ndarray:
    """
    Deskews a single 28x28 grayscale digit image using central moments.

    Args:
        img (np.ndarray): A 2D array (28x28) representing a single digit.

    Returns:
        np.ndarray: A deskewed 2D array (28x28).
    """
    float_img = img.astype(float)
    
    # Calculate raw moments up to order=3
    raw_mom = moments(float_img, order=3)
    if raw_mom[0, 0] == 0:  # no pixels or blank image
        return img
    
    # Centroid
    x_c = raw_mom[1, 0] / raw_mom[0, 0]
    y_c = raw_mom[0, 1] / raw_mom[0, 0]
    
    # Central moments
    mu = moments_central(float_img, center=(y_c, x_c), order=3)
    mu11 = mu[1, 1]
    mu02 = mu[0, 2]
    if abs(mu02) < 1e-5:
        return img

    skew = mu11 / mu02

    # Construct transform matrix
    transform = AffineTransform(
        matrix=np.array([
            [1.0,    skew,   -0.5 * img.shape[1] * skew],
            [0.0,    1.0,    0.0],
            [0.0,    0.0,    1.0]
        ])
    )

    # Apply the inverse transform with warp
    deskewed = warp(
        float_img,
        inverse_map=transform.inverse,
        output_shape=(28, 28),
        cval=0.0,
        mode='constant',
        preserve_range=True
    )
    return deskewed.astype(img.dtype)

def deskew_dataset(images: np.ndarray) -> np.ndarray:
    """
    Deskews each image in an entire dataset.
    
    Args:
        images (np.ndarray): Shape (num_samples, 28, 28).
        
    Returns:
        np.ndarray: Deskewed images of shape (num_samples, 28, 28).
    """
    deskewed = np.zeros_like(images, dtype=float)
    for i in range(len(images)):
        deskewed[i] = deskew_image(images[i])
    return deskewed.astype(images.dtype)

def jitter_image(img: np.ndarray, max_shift: int = 2) -> np.ndarray:
    """
    Randomly shifts a single 28x28 image by up to ±max_shift pixels.

    Args:
        img (np.ndarray): A 2D array (28x28).
        max_shift (int): Maximum absolute shift in both x and y directions.

    Returns:
        np.ndarray: A 2D array (28x28) that has been shifted.
    """
    shift_x = np.random.randint(-max_shift, max_shift + 1)
    shift_y = np.random.randint(-max_shift, max_shift + 1)
    jittered = nd_shift(
        input=img.astype(float),
        shift=(shift_y, shift_x),  # (dy, dx)
        order=0,
        cval=0.0,
        mode='constant'
    )
    return jittered.astype(img.dtype)

def jitter_dataset(
    images: np.ndarray,
    labels: np.ndarray,
    num_copies: int = 1,
    max_shift: int = 2
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Creates additional jittered copies of each image to augment the dataset.
    
    Args:
        images (np.ndarray): Shape (num_samples, 28, 28).
        labels (np.ndarray): Shape (num_samples,).
        num_copies (int): Number of jittered copies per original image.
        max_shift (int): Max absolute pixel shift in x and y.
        
    Returns:
        (aug_images, aug_labels):
            aug_images: Shape (num_samples*(num_copies+1), 28, 28).
            aug_labels: Shape (num_samples*(num_copies+1),).
    """
    aug_images_list = []
    aug_labels_list = []
    
    for i in range(len(images)):
        # Original
        aug_images_list.append(images[i])
        aug_labels_list.append(labels[i])
        
        # Jittered copies
        for _ in range(num_copies):
            j_img = jitter_image(images[i], max_shift=max_shift)
            aug_images_list.append(j_img)
            aug_labels_list.append(labels[i])
    
    aug_images = np.array(aug_images_list, dtype=images.dtype)
    aug_labels = np.array(aug_labels_list, dtype=labels.dtype)
    return aug_images, aug_labels

def flatten_normalize(images: np.ndarray) -> np.ndarray:
    """
    Flattens 28x28 images into 1D vectors (784) and normalizes to [0,1].
    
    Args:
        images (np.ndarray): Shape (num_samples, 28, 28).
        
    Returns:
        np.ndarray: Flattened and normalized, shape (num_samples, 784).
    """
    flattened = images.reshape((images.shape[0], -1))
    flattened = flattened / 255.0
    return flattened

def build_thundersvm_classifier(
    kernel: str = 'poly', 
    degree: int = 9,
    gpu_id: int = 0,
    C: float = 1.0
) -> SVC:
    """
    Builds and returns a ThunderSVM classifier configured for GPU usage.
    
    Args:
        kernel (str): Kernel type for SVM (e.g. 'poly', 'rbf', 'linear'). Default: 'poly'.
        degree (int): Degree of polynomial (when kernel='poly'). Default: 9.
        gpu_id (int): Which GPU to use (if multiple). Default: 0.
        C (float): Regularization parameter. Default: 1.0.
    
    Returns:
        thundersvm.SVC: Configured SVC that runs on GPU.
    
    Notes:
      - 'verbose' is set to True to show some training progress in the console.
      - 'gamma=auto' is consistent with scikit-learn usage, you can tune it if needed.
    """
    classifier = SVC(
        kernel=kernel,
        degree=degree,
        gamma='auto',
        C=C,
        gpu_id=gpu_id,   # Use the specified GPU
        verbose=True     # This will print progress info during training
    )
    return classifier

def main() -> None:
    """
    Main pipeline:
      1. Load and deskew MNIST data.
      2. Augment with jitter.
      3. Flatten & normalize.
      4. Train a polynomial SVM (ThunderSVM) on GPU.
      5. Evaluate on test set.
    """
    # 1. Load data
    train_images, train_labels, test_images, test_labels = load_mnist_data()

    # 2. Deskew
    train_images = deskew_dataset(train_images)
    test_images = deskew_dataset(test_images)

    # 3. Augment with jitter
    train_images_aug, train_labels_aug = jitter_dataset(
        images=train_images,
        labels=train_labels,
        num_copies=1,  # 1 extra copy per image
        max_shift=2
    )
    print(f"Original training set size: {train_images.shape[0]}, "
          f"augmented size: {train_images_aug.shape[0]}")

    # 4. Flatten & normalize
    X_train = flatten_normalize(train_images_aug)
    y_train = train_labels_aug
    X_test = flatten_normalize(test_images)
    y_test = test_labels

    # 5. Build ThunderSVM classifier (GPU-based) with polynomial kernel, degree=9
    classifier = build_thundersvm_classifier(kernel='poly', degree=9, gpu_id=0, C=1.0)

    # 6. Train
    print("Training ThunderSVM (poly, degree=9) on GPU... This may take a while.")
    classifier.fit(X_train, y_train)

    # 7. Evaluate
    y_pred = classifier.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    print(f"ThunderSVM (poly, degree=9) test accuracy: {accuracy * 100:.2f}%")

if __name__ == "__main__":
    main()