In [1]:
from coreset_construction import obtainSensitivity, generateCoreset

In [2]:
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from sklearn import svm
from sklearn.metrics import accuracy_score

In [3]:
# Define transformation to convert images to tensors and normalize
transform = transforms.Compose([
    transforms.ToTensor()
])

# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
# Filter for airplane (0) and bird (2) classes
classes_to_keep = [0, 2]
train_indices = [i for i, label in enumerate(train_dataset.targets) if label in classes_to_keep]
test_indices = [i for i, label in enumerate(test_dataset.targets) if label in classes_to_keep]

train_subset = Subset(train_dataset, train_indices)
test_subset = Subset(test_dataset, test_indices)

In [5]:
# Create data loaders
train_loader = DataLoader(train_subset, batch_size=len(train_subset), shuffle=False)
test_loader = DataLoader(test_subset, batch_size=len(test_subset), shuffle=False)

# Extract data
x_train, y_train = next(iter(train_loader))
x_test, y_test = next(iter(test_loader))

# Flatten images into vectors
x_train_flat = x_train.view(x_train.size(0), -1).numpy()
x_test_flat = x_test.view(x_test.size(0), -1).numpy()

# Normalize pixel values
X = x_train_flat / 255.0
x_test_flat = x_test_flat / 255.0

In [6]:
%%capture
## RBFNN coreset
# ----- 1. scale to the unit ball (‖x‖₂ ≤ 1) ---------------------------------
R = np.max(np.linalg.norm(X, axis=1))
X_scaled = X / R            #  now every point satisfies ‖x‖₂ ≤ 1

# ----- 2. lifting step  q_p = [‖x‖² , -2xᵀ , 1] ------------------------------
phi = np.hstack([
    np.sum(X_scaled**2, axis=1, keepdims=True),   # ‖x‖₂²
    -2 * X_scaled,                                # -2 xᵀ
    np.ones((X_scaled.shape[0], 1))               # 1
])

# ----- 3. sensitivities & coreset -------------------------------------------
sens = obtainSensitivity(phi, w=None, approxMVEE=True)
m = 2000                                      # coreset size
idx, X_cs_rbfnn, labels, w_cs_rbfnn, _ = generateCoreset(phi, y_train, sens, m)
print(f"Coreset shape: {X_cs_rbfnn.shape}")
print(f"Coreset labels shape: {labels.shape}")
print(f"Coreset weights shape: {w_cs_rbfnn.shape}")

In [13]:
X_cs = X[idx,]
X_cs.shape
Y_cs = y_train[idx]

In [10]:
print(X.shape)
print(idx.shape)
print(f"Coreset shape: {X_cs_rbfnn.shape}")
print(f"Coreset labels shape: {labels.shape}")
print(f"Coreset weights shape: {w_cs_rbfnn.shape}")

(10000, 3072)
(2000,)
Coreset shape: (2000, 3074)
Coreset labels shape: torch.Size([2000])
Coreset weights shape: (2000,)


In [19]:
# Initialize and train SVM
clf = svm.SVC(kernel='linear')
clf.fit(X_cs,Y_cs,sample_weight=w_cs_rbfnn)

# Predict and evaluate
y_pred = clf.predict(x_test_flat)
accuracy = accuracy_score(y_test, y_pred)

print(f"SVM PCA accuracy on airplane vs bird: {accuracy * 100:.2f}%")

SVM PCA accuracy on airplane vs bird: 75.90%
