In [1]:
import torch
from datasets import SyntheticData

import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms
import torchvision

import numpy as np

data=SyntheticData()


def k(x, xprime):
    with torch.no_grad():
        v = torch.linalg.norm(x) * torch.linalg.norm(xprime)
        u = .99999 * torch.dot(x, xprime) / v
        return v * (u * (torch.pi - torch.arccos(u) + torch.sqrt(1 - u ** 2) )/ (2 * np.pi)
                    +  u * (torch.pi - torch.arccos(u)) /  (2 * np.pi))

def ntk_kernel(x,z):
    n,_=x.shape
    m,_=z.shape
    H = torch.empty((n, m))
    for i in range(n):
        for j in range(m):
            H[i,j] = k(x[i], z[j])

    return H



In [83]:
def kappa(u,v):
    u=.99999*u
    return v * (u * (torch.pi - torch.arccos(u) + torch.sqrt(1 - u ** 2) )/ (2 * np.pi)
                    +  u * (torch.pi - torch.arccos(u)) /  (2 * np.pi))

def kappa2(u):
    u=.99999*u
    return 2*u/torch.pi * (torch.pi - torch.arccos(u))  + torch.sqrt(1 - u ** 2) /torch.pi

def easier_ntk(x,z):
    inner_prod=x@z.T
    norm_x=x.norm(dim=-1)
    norm_z=z.norm(dim=-1)
    norm_mat=norm_x.unsqueeze(1)@norm_z.unsqueeze(1).T

    return kappa(inner_prod/norm_mat,norm_mat)

def easier_ntk2(x,z):
    inner_prod=x@z.T
    norm_x=x.norm(dim=-1)
    norm_z=z.norm(dim=-1)
    norm_mat=norm_x.unsqueeze(1)@norm_z.unsqueeze(1).T

    return norm_mat*kappa2(inner_prod/norm_mat)

In [86]:
from sklearn.metrics import accuracy_score
training_sizes = [200, 1001, 2000, 5000, 10000]
for ntrain in training_sizes:

    X_train,y_train=data.generate_synthetic_data_separable(ntrain,0.1)
    X_test,y_test=data.generate_synthetic_data_separable(100,0.1)
    #Kernel_train=ntk_kernel(X_train,X_train)
    Kernel_train=easier_ntk2(X_train,X_train)
    # Solve for alpha = K^-1 y
    alpha_interp = torch.linalg.solve(Kernel_train, y_train)
    #alpha_interp=torch.linalg.inv(Kernel_train)@y_train
    # Compute RKHS norm for interpolated solution
    rkhs_norm_interp = torch.sqrt((alpha_interp @ ( Kernel_train@ alpha_interp)))
    rkhs_norm_interp = rkhs_norm_interp.item()


    K_test_interp = easier_ntk2(X_train, X_test)
    y_pred_interp = torch.sign(K_test_interp.T @ alpha_interp).squeeze()
    error_interp = 1 - accuracy_score(y_test.cpu().numpy(), y_pred_interp.cpu().numpy())

    print("Training size : ", ntrain, " Norm : ",rkhs_norm_interp, " Error test : ",error_interp)

Training size :  200  Norm :  2.2862820625305176  Error test :  0.10999999999999999
Training size :  1001  Norm :  5.871571063995361  Error test :  0.09999999999999998
Training size :  2000  Norm :  9.110661506652832  Error test :  0.12
Training size :  5000  Norm :  16.358280181884766  Error test :  0.09999999999999998
Training size :  10000  Norm :  25.352684020996094  Error test :  0.10999999999999999


In [110]:
from eigenpro2.models import KernelModel
kernel_fn = lambda x, y: easier_ntk2(x, y)

X_train,y_train=data.generate_synthetic_data_separable(10000,0.0)
X_test,y_test=data.generate_synthetic_data_separable(100,0.0)

n_subsamples = min(len(X_train), 5000)
top_q = min(160, n_subsamples - 1)

model_overfit = KernelModel(kernel_fn, X_train, 1, device=torch.device("cpu"))
model_overfit.predict = lambda samples: model_overfit.forward(samples)

try:
        result_overfit = model_overfit.fit(
            X_train, y_train.unsqueeze(1), X_test, y_test.unsqueeze(1),
            n_subsamples=n_subsamples, epochs=10, mem_gb=8,
            bs=64, top_q=top_q, print_every=2,run_epoch_eval=True)
except:
        result_overfit = model_overfit.fit(
                X_train, y_train.unsqueeze(1), X_test, y_test.unsqueeze(1),
                n_subsamples=n_subsamples, epochs=10, mem_gb=8,
                bs=64, print_every=2,run_epoch_eval=True)
        
rkhs_norm_overfit = torch.norm(model_overfit.weight).item()

# Predict and calculate classification error for overfitted
y_pred_overfit = model_overfit.predict(X_test).sign().squeeze()
error_overfit = 1 - accuracy_score(y_test.cpu().numpy(), y_pred_overfit.cpu().numpy())
print(rkhs_norm_overfit,error_interp)

SVD time: 65.32s, top_q: 160, top_eigval: 94.75, new top_eigval: 1.03e-01
n_subsamples=5000, bs_gpu=5000, eta=0.15, bs=64, top_eigval=9.48e+01, beta=415.98
--------------------
epoch:   0    time: 02.6s    train accuracy: 100.00%    val accuracy: 100.00%    train mse: nan    val mse: nan
epoch:   2    time: 02.0s    train accuracy: 100.00%    val accuracy: 100.00%    train mse: nan    val mse: nan
epoch:   4    time: 02.3s    train accuracy: 100.00%    val accuracy: 100.00%    train mse: nan    val mse: nan
epoch:   6    time: 02.4s    train accuracy: 100.00%    val accuracy: 100.00%    train mse: nan    val mse: nan
epoch:   8    time: 02.2s    train accuracy: 100.00%    val accuracy: 100.00%    train mse: nan    val mse: nan
nan 0.10999999999999999


In [108]:
model_overfit.weight.max()

tensor(nan)

In [109]:
model_overfit.weight.shape

torch.Size([10000, 1])