In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import time
import numpy as np
from torch.nn.functional import one_hot
from torchvision import datasets, transforms
from klap import GaussianKernel, ExponentialKernel

train_set = datasets.MNIST('~/data/mnist', download=False, train=True)
test_set = datasets.MNIST('~/data/mnist/', download=False, train=False)

In [3]:
n_train = len(train_set)
n_test = len(test_set)

x_train = train_set.data.view(n_train, 28*28).numpy().astype(float)
mean = x_train.mean(axis=0, keepdims=True)
std = x_train.std(axis=0, keepdims=True)
d_eff = np.sum(std > 0)
std[std==0] = 1
x_train -= mean
x_train /= std

y_train = one_hot(train_set.targets).numpy().astype(float)

x_test = test_set.data.view(n_test, 28*28).numpy().astype(float)
x_test -= mean
x_test /= std

y_test = test_set.targets.numpy()

In [4]:
sigma = np.sqrt(d_eff)
k = 20
ind = 10000
p = 300

sigmas = [.1 * sigma, sigma, 10*sigma]
accuracy = np.zeros((len(sigmas), 2))
times = np.zeros((len(sigmas), 2))

for i, sigma in enumerate(sigmas):
    # kernel = ExponentialKernel(sigma=sigma)
    kernel = GaussianKernel(sigma=sigma)
    for j in range(2):
        t = time.time()
        if j == 0:
            kernel.fit(x_train[:ind], p=p, k=k, L_reg=1e-14, R_reg=0, inverse_L=True)
        else:
            kernel.fit_with_graph_laplacian(kernel.kernel, x_train[:ind], p=p, k=k, L_reg=0, R_reg=0, inverse_L=True)
        times[i, j] = time.time() - t

        phi_train = kernel.features_map(x_train)
        mean = phi_train.mean(axis=0, keepdims=True)
        std = phi_train.std(axis=0, keepdims=True)
        std[std==0] = 1
        phi_train -= mean
        phi_train /= std
        phi_test = kernel.features_map(x_test)
        phi_test -= mean
        phi_test /= std

        A = phi_train.T @ phi_train
        A += 1e-13 * np.eye(k)
        b = phi_train.T @ y_train
        beta = np.linalg.solve(A, b)

        y_pred = (phi_test @ beta).argmax(axis=1)
        accuracy[i, j] = np.mean(y_pred == y_test)

In [5]:
print(accuracy)
print(times)

[[0.1135 0.1009]
 [0.5941 0.7542]
 [0.803  0.799 ]]
[[2.10351014 2.86011529]
 [1.46070385 2.57164645]
 [1.40298486 2.68242002]]
