# Goal

Datasets:
1. CIFAR-10
2. MNIST

Models:
1. Convolutional Features
2. ReLU Features
3. Fourier Features

Each model transforms the data to a feature matrix $[M_{TM} | M_{TU}]$ where $M_{TM}$ is the data matrix for the training set and $M_{TU}$ are the basis functions that we have not yet modeled. We will compute the best coefficients, $\tilde{c}$ of basis functions to model the labels on the modeled training set and the best coefficients, $c$, of all basis functions to model the labels on the whole training set. We will then compute the error $c_{err} = \tilde{c}-c^*$ where $c^*$ is the truncated version of $c$ to match the size of $\tilde{c}$. We initialize $c$ with the least-squares coefficients learned from the whole training set. Then we compute $\tilde{c}$ by solving the least-squares problem on the sampled training set.

## Bernoulli Sampling with Leverage Scores

We treat the leverage scores as a bernoulli sampling probability for each data point, and sample accordingly. 

We then calculate the error between the least-squares solution found on the sampled matrix versus the full matrix.

We then plot each sample a plot of $(n, \text{error})$.

For each dataset, we will:
- Sample the features uniformly at random vs by leverage scores.
- Plot $||A||_2$, $||M_{TM}^+||_2$, and $||\tilde{c}-c^*||_2$ for the sampled features as a function of the number of sampled points.

We expect to see that leverage score sampling leads to a smaller error $||\tilde{c}-c^*||_2$ for the same number of sampled points.


## Definitions

In [4]:
import numpy as np
import torch

BASIS_FUNCTIONS = 200

# Embed the testing set (not the training set)
def embed_dataset(X, model, device):
    # Embed the data using the convolutional layers of the network
    embeddings = torch.tensor(
        np.zeros(
            (
                X.shape[0],
                BASIS_FUNCTIONS,
            )
        )
    ).to(device)

    with torch.no_grad():
        for batch_start in range(0, X.shape[0], 256):
            batch_end = min(batch_start + 256, X.shape[0])
            batch = X[batch_start:batch_end].to(device)
            batch_embeddings = model.embed(batch)
            embeddings[batch_start:batch_end] = batch_embeddings
    return embeddings


# Random Fourier Features
def rff_features(X, features=200) -> torch.Tensor:
    N, *_ = X.shape
    X = X.reshape(N, -1)

    W = torch.randn(X.shape[1], features, device=X.device)

    return torch.cos(torch.pi * X @ W) / np.sqrt(features)  # Normalize


# Random ReLU Features
def relu_features(X, features=200) -> torch.Tensor:
    N, *_ = X.shape
    X = X.reshape(N, -1)
    W = torch.randn(X.shape[1], features, device=X.device)
    return torch.relu(X @ W) / np.sqrt(features)


def leverage_scores(A: torch.Tensor) -> torch.Tensor:
    q, _ = torch.linalg.qr(A, mode="reduced")
    return torch.sum(q**2, dim=1)


def sample_bernoulli(leverage_scores: torch.Tensor) -> torch.Tensor:
    random_values = torch.rand((leverage_scores.shape[0],), device=leverage_scores.device)
    sampled_indices = random_values < leverage_scores
    return sampled_indices


def least_squares_solution(A: torch.Tensor, b: torch.Tensor, regularizer: float = 1e-12, weights: torch.Tensor | None = None) -> torch.Tensor:
    A = A.float()
    b = b.float()
    if weights is not None:
        A *= torch.sqrt(weights)[:, None]
        b *= torch.sqrt(weights)[:, None]
    return torch.linalg.lstsq(
        A.T @ A + regularizer * torch.eye(A.shape[1], device=A.device),
        A.T @ b,
    ).solution




## MNIST

The MNIST dataset consists of 70,000 images of handwritten digits (0-9) in grayscale with a resolution of 28x28 pixels. This gives us a $70,000 \times 784$ data matrix.
- A Convolutional Neural network will transform the data to a $70,000 \times 200$ matrix (by removing the last layer).
- A Random ReLU fully-connected network ($y({\textbf{t}}) = \sum_{k=1}^{200} w_k \sigma(\left<\textbf{t}, {\textbf{v}}_k\right>)$ with $\sigma(x) = \max(0,x)$ and $\textbf{v}_k$ being randomly initialized weights and $w_k$ being the learned coefficients) will transform the data to a $70,000 \times 200$ matrix.
- A Fourier fully-connected network ($y({\textbf{t}}) = \mathscr{R}(\sum_{k=1}^{200} w_k \exp(i\pi\left<\textbf{t}, {\textbf{v}}_k\right>)) = \sum_{k=1}^{200} w_k \cos(\pi\left<\textbf{t}, {\textbf{v}}_k\right>) = $ with $\textbf{v}_k$ being randomly initialized weights and $w_k$ being the learned coefficients) will transform the data to a $70,000 \times 200$ matrix.

In [5]:
import torch

from torchvision.datasets import MNIST

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("using", DEVICE)

mnist_X = (
    MNIST(root="./data", train=True, download=True)
    .data.float()
    .to(DEVICE)
    .reshape(-1, 1, 28, 28)
    / 255.0
)
mnist_y = MNIST(root="./data", train=True, download=True).targets.to(DEVICE)
test_mnist_X = (
    MNIST(root="./data", train=False, download=True)
    .data.float()
    .to(DEVICE)
    .reshape(-1, 1, 28, 28)
    / 255.0
)
test_mnist_y = MNIST(root="./data", train=False, download=True).targets.to(DEVICE)

combined_mnist_X = torch.cat([mnist_X, test_mnist_X], dim=0)
combined_mnist_y = torch.cat([mnist_y, test_mnist_y], dim=0)

using cpu


In [6]:
# Verify CNN accuracy on MNIST

from models.mnist_cnn import MnistConvNet, BASIS_FUNCTIONS
from torch.utils.data import TensorDataset

mnist_network = MnistConvNet()
mnist_network.load_state_dict(torch.load("models/mnist_cnn.pth", map_location=DEVICE))
mnist_network.eval()


def verify_mnist_cnn(model: MnistConvNet, device):
    model.to(device)
    correct = 0
    total = 0
    test_loader = torch.utils.data.DataLoader(
        TensorDataset(test_mnist_X, test_mnist_y),
    )
    with torch.no_grad():
        for xb, yb in test_loader:
            xb, yb = xb.to(device), yb.to(device)
            preds: torch.Tensor = model(xb)
            correct += (preds.argmax(dim=1) == yb).sum().item()
            total += yb.size(0)
    print(f"Test accuracy: {correct / total:.4f}")


# verify_mnist_cnn(network, DEVICE)

In [7]:
mnist_cnn_embedding = embed_dataset(combined_mnist_X, model=mnist_network, device=DEVICE)
print("mnist cnn on", mnist_cnn_embedding.device)
print(mnist_cnn_embedding.shape)

mnist_rff_features = rff_features(combined_mnist_X, features=200)
print("mnist rff on", mnist_rff_features.device)
print(mnist_rff_features.shape)

mnist_relu_features = relu_features(combined_mnist_X, features=200)
print("mnist relu on", mnist_relu_features.device)
print(mnist_relu_features.shape)

mnist cnn on cpu
torch.Size([70000, 200])
mnist rff on cpu
torch.Size([70000, 200])
mnist relu on cpu
torch.Size([70000, 200])


In [20]:
%matplotlib tk

from matplotlib import pyplot as plt

labels = torch.nn.functional.one_hot(
    combined_mnist_y,
    num_classes=10,
).to(DEVICE, dtype=torch.float32)


TRIALS = 100
# BETAS = 10 ** np.linspace(0, np.log10(50_000), 20)
BETAS = np.linspace(1, 100, 20)
EMBEDDINGS = {
    "MNIST CNN": mnist_cnn_embedding,
    "MNIST RFF": mnist_rff_features,
    "MNIST ReLU": mnist_relu_features,
}


fig, axs = plt.subplots(len(EMBEDDINGS), 2, figsize=(10 * len(EMBEDDINGS), 10))

for i, (embedding_name, embedding) in enumerate(EMBEDDINGS.items()):
    print("Processing embedding:", embedding_name)
    c_full = least_squares_solution(embedding, labels)

    leverage = leverage_scores(embedding)

    # Plot leverage scores
    leverage_cpu = leverage.cpu().numpy()
    axs[i, 1].plot(np.arange(1, len(leverage_cpu) + 1), np.sort(leverage_cpu)[::-1], marker="o", linestyle="")
    axs[i, 1].set_xlabel("Leverage Score")
    axs[i, 1].set_ylabel("Frequency")
    axs[i, 1].set_title(f"Leverage Score Distribution for {embedding_name}")
    axs[i, 1].set_xscale("log")

    for j, beta in enumerate(BETAS):
        print("  Processing beta:", beta)
        errors = []
        sample_sizes = []
        for trial in range(TRIALS):
            sampled_indices = sample_bernoulli(torch.clamp(leverage * beta, max=1.0))
            A_sampled = embedding[sampled_indices, :]
            b_sampled = labels[sampled_indices, :]
            c_sampled = least_squares_solution(A_sampled, b_sampled, weights=1/leverage[sampled_indices])
            sample_size = A_sampled.shape[0]
            error = torch.linalg.matrix_norm(c_sampled - c_full).item()
            errors.append(error)
            sample_sizes.append(sample_size)
        axs[i, 0].plot(sample_sizes, errors, "o", color=plt.cm.viridis(j / len(BETAS)))
    axs[i, 0].set_xlabel("Sample Size")
    axs[i, 0].set_ylabel("Error Norm")
    axs[i, 0].set_title(f"Embedding {embedding_name}")
    axs[i, 0].set_yscale("log")

plt.tight_layout()
plt.savefig("bernoulli_sampling_errors_mnist.png")
plt.show()

Processing embedding: MNIST CNN
  Processing beta: 1.0
  Processing beta: 6.2105263157894735
  Processing beta: 11.421052631578947
  Processing beta: 16.63157894736842
  Processing beta: 21.842105263157894
  Processing beta: 27.052631578947366
  Processing beta: 32.26315789473684
  Processing beta: 37.473684210526315
  Processing beta: 42.68421052631579
  Processing beta: 47.89473684210526
  Processing beta: 53.10526315789473
  Processing beta: 58.315789473684205
  Processing beta: 63.526315789473685
  Processing beta: 68.73684210526315
  Processing beta: 73.94736842105263
  Processing beta: 79.1578947368421
  Processing beta: 84.36842105263158
  Processing beta: 89.57894736842105
  Processing beta: 94.78947368421052
  Processing beta: 100.0
Processing embedding: MNIST RFF
  Processing beta: 1.0
  Processing beta: 6.2105263157894735
  Processing beta: 11.421052631578947
  Processing beta: 16.63157894736842
  Processing beta: 21.842105263157894
  Processing beta: 27.052631578947366
  Pr

## CIFAR-10

The CIFAR-10 dataset consists of 60,000 images in color with a resolution of 32x32 pixels, divided into 10 classes (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck). This gives us a 60,000 x 32 x 32 x 3 = 60,000 x 3072 data matrix.

In [15]:
import torch
from torchvision.datasets import CIFAR10

from models.cifar_cnn import (
    BiggestCifarConvNet,
    CifarConvNet,
    BiggerCifarConvNet,
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

cifar_images = (
    torch.tensor(CIFAR10(root="./data", train=True, download=True).data)
    .to(DEVICE)
    .permute(0, 3, 1, 2)
    .float()
    / 255.0
)
cifar_labels = torch.tensor(
    CIFAR10(root="./data", train=True, download=True).targets, device=DEVICE
)
test_cifar_images = (
    torch.tensor(CIFAR10(root="./data", train=False, download=True).data)
    .to(DEVICE)
    .permute(0, 3, 1, 2)
    .float()
    / 255.0
)
test_cifar_labels = torch.tensor(
    CIFAR10(root="./data", train=False, download=True).targets, device=DEVICE
)
combined_cifar_images = torch.cat([cifar_images, test_cifar_images], dim=0)
combined_cifar_labels = torch.cat([cifar_labels, test_cifar_labels], dim=0)

cifar_network = CifarConvNet()
cifar_network.load_state_dict(torch.load("models/cifar_cnn.pth", map_location=DEVICE))
cifar_network.eval()

bigger_cifar_network = BiggerCifarConvNet()
bigger_cifar_network.load_state_dict(
    torch.load("models/cifar_bigger_cnn.pth", map_location=DEVICE)
)
bigger_cifar_network.eval()

biggest_cifar_network = BiggestCifarConvNet()
biggest_cifar_network.load_state_dict(
    torch.load("models/cifar_biggest_cnn.pth", map_location=DEVICE)
)
biggest_cifar_network.eval()


def verify_cifar_cnn(model, device):
    model.to(device)
    correct = 0
    total = 0
    test_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(test_cifar_images, test_cifar_labels),
        batch_size=256,
        shuffle=False,
        num_workers=2,
    )
    with torch.no_grad():
        for xb, yb in test_loader:
            xb, yb = xb.to(device), yb.to(device)
            preds: torch.Tensor = model(xb)
            correct += (preds.argmax(dim=1) == yb).sum().item()
            total += yb.size(0)
    print(f"  Test accuracy: {correct / total:.4f}")


# print("Verifying CIFAR-10 CNN model accuracy:")
# verify_cifar_cnn(cifar_network, DEVICE)
# print("Verifying Bigger CIFAR-10 CNN model accuracy:")
# verify_cifar_cnn(bigger_network, DEVICE)
# print("Verifying Biggest CIFAR-10 CNN model accuracy:")
# verify_cifar_cnn(biggest_network, DEVICE)

In [16]:
import numpy as np

cifar_cnn_embedding = embed_dataset(combined_cifar_images, model=cifar_network, device=DEVICE)
print("cifar cnn on", cifar_cnn_embedding.device)
print(cifar_cnn_embedding.shape)

cifar_bigger_cnn_embedding = embed_dataset(
    combined_cifar_images, model=bigger_cifar_network, device=DEVICE
)
print("cifar bigger cnn on", cifar_bigger_cnn_embedding.device)
print(cifar_bigger_cnn_embedding.shape)

cifar_biggest_cnn_embedding = embed_dataset(
    combined_cifar_images, model=biggest_cifar_network, device=DEVICE
)
print("cifar biggest cnn on", cifar_biggest_cnn_embedding.device)
print(cifar_biggest_cnn_embedding.shape)

cifar_rff_features = rff_features(combined_cifar_images, features=200)
print("cifar rff on", cifar_rff_features.device)
print(cifar_rff_features.shape)

cifar_relu_features = relu_features(combined_cifar_images, features=200)
print("cifar relu on", cifar_relu_features.device)
print(cifar_relu_features.shape)

cifar cnn on cpu
torch.Size([60000, 200])
cifar bigger cnn on cpu
torch.Size([60000, 200])
cifar biggest cnn on cpu
torch.Size([60000, 200])
cifar rff on cpu
torch.Size([60000, 200])
cifar relu on cpu
torch.Size([60000, 200])


In [21]:
labels = torch.nn.functional.one_hot(
    combined_cifar_labels,
    num_classes=10,
).to(DEVICE, dtype=torch.float32)


TRIALS = 100
BETAS = np.linspace(1, 100, 20)
EMBEDDINGS = {
    "CIFAR CNN": cifar_cnn_embedding,
    "CIFAR Bigger CNN": cifar_bigger_cnn_embedding,
    "CIFAR Biggest CNN": cifar_biggest_cnn_embedding,
    "CIFAR RFF": cifar_rff_features,
    "CIFAR ReLU": cifar_relu_features,
}

fig, axs = plt.subplots(len(EMBEDDINGS), 2, figsize=(10 * len(EMBEDDINGS), 10))

for i, (embedding_name, embedding) in enumerate(EMBEDDINGS.items()):
    print("Processing embedding:", embedding_name)
    c_full = least_squares_solution(embedding, labels)

    leverage = leverage_scores(embedding)

    # Plot leverage scores
    leverage_cpu = leverage.cpu().numpy()
    axs[i, 1].plot(np.arange(1, len(leverage_cpu) + 1), np.sort(leverage_cpu)[::-1], marker="o", linestyle="none")
    axs[i, 1].set_xlabel("Leverage Score")
    axs[i, 1].set_ylabel("Frequency")
    axs[i, 1].set_title(f"Leverage Score Distribution for {embedding_name}")
    axs[i, 1].set_xscale("log")

    for j, beta in enumerate(BETAS):
        print("  Processing beta:", beta)
        errors = []
        sample_sizes = []
        for trial in range(TRIALS):
            sampled_indices = sample_bernoulli(torch.clamp(leverage * beta, max=1.0))
            A_sampled = embedding[sampled_indices, :]
            b_sampled = labels[sampled_indices, :]
            c_sampled = least_squares_solution(A_sampled, b_sampled, weights=1/leverage[sampled_indices])
            sample_size = A_sampled.shape[0]
            error = torch.linalg.matrix_norm(c_sampled - c_full).item()
            errors.append(error)
            sample_sizes.append(sample_size)
        axs[i, 0].plot(sample_sizes, errors, "o", color=plt.cm.viridis(j / len(BETAS)))
    axs[i, 0].set_xlabel("Sample Size")
    axs[i, 0].set_ylabel("Error Norm")
    axs[i, 0].set_title(f"Embedding {embedding_name}")
    axs[i, 0].set_yscale("log")


plt.tight_layout()
plt.savefig("bernoulli_sampling_errors_cifar.png")
plt.show()

Processing embedding: CIFAR CNN
  Processing beta: 1.0
  Processing beta: 6.2105263157894735
  Processing beta: 11.421052631578947
  Processing beta: 16.63157894736842
  Processing beta: 21.842105263157894
  Processing beta: 27.052631578947366
  Processing beta: 32.26315789473684
  Processing beta: 37.473684210526315
  Processing beta: 42.68421052631579
  Processing beta: 47.89473684210526
  Processing beta: 53.10526315789473
  Processing beta: 58.315789473684205
  Processing beta: 63.526315789473685
  Processing beta: 68.73684210526315
  Processing beta: 73.94736842105263
  Processing beta: 79.1578947368421
  Processing beta: 84.36842105263158
  Processing beta: 89.57894736842105
  Processing beta: 94.78947368421052
  Processing beta: 100.0
Processing embedding: CIFAR Bigger CNN
  Processing beta: 1.0
  Processing beta: 6.2105263157894735
  Processing beta: 11.421052631578947
  Processing beta: 16.63157894736842
  Processing beta: 21.842105263157894
  Processing beta: 27.0526315789473