# Reqd funcitons


In [4]:
import torch
import numpy as np
import torch
from numpy.linalg import solve, svd, norm
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
# import classic_kernel

import time
from tqdm import tqdm
# import hickle

In [6]:
'''Implementation of kernel functions.'''

def euclidean_distances(samples, centers, squared=True):
    '''Compute the Euclidean distances between samples and centers.
    
    Args: 
        samples: of shape (n_sample, n_feature).
        centers: of shape (n_center, n_feature).
        squared: whether to return squared distances.
        
    Returns:
        distances: of shape (n_sample, n_center).
    
    '''
    samples_norm = torch.sum(samples**2, dim=1, keepdim=True)
    if samples is centers:
        centers_norm = samples_norm
    else:
        centers_norm = torch.sum(centers**2, dim=1, keepdim=True)
    centers_norm = torch.reshape(centers_norm, (1, -1))

    distances = samples.mm(torch.t(centers))
    distances.mul_(-2)
    distances.add_(samples_norm)
    distances.add_(centers_norm)
    #print(centers_norm.size(), samples_norm.size(), distances.size())
    if not squared:
        distances.clamp_(min=0)
        distances.sqrt_()

    return distances


def euclidean_distances_M(samples, centers, M, squared=True):
    '''Compute the Euclidean distances between samples and centers.

    Args:
        samples: of shape (n_sample, n_feature).
        centers: of shape (n_center, n_feature).
        squared: whether to return squared distances.
    
    Returns:
        distances: of shape (n_sample, n_center).
    '''

    samples_norm = (samples @ M)  * samples
    samples_norm = torch.sum(samples_norm, dim=1, keepdim=True)

    if samples is centers:
        centers_norm = samples_norm
    else:
        centers_norm = (centers @ M) * centers
        centers_norm = torch.sum(centers_norm, dim=1, keepdim=True)

    centers_norm = torch.reshape(centers_norm, (1, -1))

    distances = samples.mm(M @ torch.t(centers))
    distances.mul_(-2)
    distances.add_(samples_norm)
    distances.add_(centers_norm)

    if not squared:
        distances.clamp_(min=0)
        distances.sqrt_()

    return distances


def gaussian(samples, centers, bandwidth):
    '''Gaussian kernel.

    Args:
        samples: of shape (n_sample, n_feature).
        centers: of shape (n_center, n_feature).
        bandwidth: kernel bandwidth.

    Returns:
        kernel matrix of shape (n_sample, n_center).
    '''
    assert bandwidth > 0
    kernel_mat = euclidean_distances(samples, centers)
    kernel_mat.clamp_(min=0)
    gamma = 1. / (2 * bandwidth ** 2)
    kernel_mat.mul_(-gamma)
    kernel_mat.exp_()

    #print(samples.size(), centers.size(),
    #      kernel_mat.size())
    return kernel_mat


def laplacian(samples, centers, bandwidth):
    '''Laplacian kernel.

    Args:
        samples: of shape (n_sample, n_feature).
        centers: of shape (n_center, n_feature).
        bandwidth: kernel bandwidth.

    Returns:
        kernel matrix of shape (n_sample, n_center).
    '''
    assert bandwidth > 0
    kernel_mat = euclidean_distances(samples, centers, squared=False)
    kernel_mat.clamp_(min=0)
    gamma = 1. / bandwidth
    kernel_mat.mul_(-gamma)
    kernel_mat.exp_()
    return kernel_mat



def laplacian_M(samples, centers, bandwidth, M):
    assert bandwidth > 0
    kernel_mat = euclidean_distances_M(samples, centers, M, squared=False)
    kernel_mat.clamp_(min=0)
    gamma = 1. / bandwidth
    kernel_mat.mul_(-gamma)
    kernel_mat.exp_()
    return kernel_mat


def dispersal(samples, centers, bandwidth, gamma):
    '''Dispersal kernel.

    Args:
        samples: of shape (n_sample, n_feature).
        centers: of shape (n_center, n_feature).
        bandwidth: kernel bandwidth.
        gamma: dispersal factor.

    Returns:
        kernel matrix of shape (n_sample, n_center).
    '''
    assert bandwidth > 0
    kernel_mat = euclidean_distances(samples, centers)
    kernel_mat.pow_(gamma / 2.)
    kernel_mat.mul_(-1. / bandwidth)
    kernel_mat.exp_()
    return kernel_mat


In [15]:
def get_mse(y_pred, y_true):
    return np.mean(np.square(y_pred - y_true))


def kernel(pair1, pair2, nngp=False):
    
    out = pair1 @ pair2.transpose(1, 0)
    N1 = torch.sum(torch.pow(pair1, 2), dim=-1).view(-1, 1)
    N2 = torch.sum(torch.pow(pair2, 2), dim=-1).view(-1, 1)

    XX = torch.sqrt(N1 @ N2.transpose(1, 0))
    out = out / XX

    out = torch.clamp(out, -1, 1)

    first = 1/np.pi * (out * (np.pi - torch.acos(out)) \
                       + torch.sqrt(1. - torch.pow(out, 2))) * XX
    if nngp:
        out = first
    else:
        sec = 1/np.pi * out * (np.pi - torch.acos(out)) * XX
        out = first + sec

    return out

def laplace_kernel(pair1, pair2, bandwidth):
    return laplacian(pair1, pair2, bandwidth)

def laplace_kernel_M(pair1, pair2, bandwidth, M):
    return laplacian_M(pair1, pair2, bandwidth, M)


def original_ntk(X_train, y_train, X_test, y_test, use_nngp=False):
    K_train = kernel(X_train, X_train, nngp=use_nngp).numpy()
    sol = solve(K_train, y_train).T
    K_test = kernel(X_train, X_test, nngp=use_nngp).numpy()
    y_pred = (sol @ K_test).T

    mse = get_mse(y_pred, y_test.numpy())
    if use_nngp:
        print("Original NNGP MSE: ", mse)
        return mse
    else:
        print("Original NTK MSE: ", mse)
        return mse


def get_grads(X, sol, L, P):
    M = 0.

    start = time.time()
    num_samples = 20000
    indices = np.random.randint(len(X), size=num_samples)

    #"""
    if len(X) > len(indices):
        x = X[indices, :]
    else:
        x = X

    #n, d = X.shape
    #x = np.random.normal(size=(1000, d))
    #x = torch.from_numpy(x)

    K = laplace_kernel_M(X, x, L, P)

    dist = euclidean_distances_M(X, x, P, squared=False)
    dist = torch.where(dist < 1e-10, torch.zeros(1).float(), dist)

    K = K/dist
    K[K == float("Inf")] = 0.

    a1 = torch.from_numpy(sol.T).float()
    n, d = X.shape
    n, c = a1.shape
    m, d = x.shape

    a1 = a1.reshape(n, c, 1)
    X1 = (X @ P).reshape(n, 1, d)
    step1 = a1 @ X1
    del a1, X1
    step1 = step1.reshape(-1, c*d)

    step2 = K.T @ step1
    del step1

    step2 = step2.reshape(-1, c, d)

    a2 = torch.from_numpy(sol).float()
    step3 = (a2 @ K).T

    del K, a2

    step3 = step3.reshape(m, c, 1)
    x1 = (x @ P).reshape(m, 1, d)
    step3 = step3 @ x1

    G = (step2 - step3) * -1/L

    M = 0.

    bs = 10
    batches = torch.split(G, bs)
    #for i in tqdm(range(len(batches))):
    for i in range(len(batches)):
        grad = batches[i].cuda()
        gradT = torch.transpose(grad, 1, 2)
        #gradT = torch.swapaxes(grad, 1, 2)#.cuda()
        M += torch.sum(gradT @ grad, dim=0).cpu()
        del grad, gradT
    torch.cuda.empty_cache()
    M /= len(G)

    M = M.numpy()

    end = time.time()

    #print("Time: ", end - start)
    return M


def convert_one_hot(y, c):
    o = np.zeros((y.size, c))
    o[np.arange(y.size), y] = 1
    return o


def hyperparam_train(X_train, y_train, X_test, y_test, c,
                     iters=5, reg=0, L=10, normalize=False):

    y_t_orig = y_train
    y_v_orig = y_test
    y_train = convert_one_hot(y_train, c)
    y_test = convert_one_hot( y_test, c)

    if normalize:
        X_train /= norm(X_train, axis=-1).reshape(-1, 1)
        X_test /= norm(X_test, axis=-1).reshape(-1, 1)

    X_train = torch.from_numpy(X_train).float()
    y_train = torch.from_numpy(y_train).float()
    X_test = torch.from_numpy(X_test).float()
    y_test = torch.from_numpy(y_test).float()

    best_acc = 0.
    best_iter = 0.
    best_M = 0.

    n, d = X_train.shape
    M = np.eye(d, dtype='float32')

    for i in range(iters):
        print('a')
        K_train = laplace_kernel_M(X_train, X_train, L, torch.from_numpy(M)).numpy()
        sol = solve(K_train + reg * np.eye(len(K_train)), y_train).T

        K_test = laplace_kernel_M(X_train, X_test, L, torch.from_numpy(M)).numpy()
        preds = (sol @ K_test).T

        y_pred = torch.from_numpy(preds)
        preds = torch.argmax(y_pred, dim=-1)
        labels = torch.argmax(y_test, dim=-1)
        count = torch.sum(labels == preds).numpy()

        old_test_acc = count / len(labels)
        print(old_test_acc)
        if old_test_acc > best_acc:
            best_iter = i
            best_acc = old_test_acc
            best_M = M
        M  = get_grads(X_train, sol, L, torch.from_numpy(M))
        print(M)

    return best_acc, best_iter, best_M


def train(X_train, y_train, X_test, y_test, c, M,
          iters=5, reg=0, L=10, normalize=False):

    y_t_orig = y_train
    y_v_orig = y_test
    y_train = convert_one_hot(y_train, c)
    y_test = convert_one_hot(y_test, c)

    if normalize:
        X_train /= norm(X_train, axis=-1).reshape(-1, 1)
        X_test /= norm(X_test, axis=-1).reshape(-1, 1)

    X_train = torch.from_numpy(X_train).float()
    y_train = torch.from_numpy(y_train).float()
    X_test = torch.from_numpy(X_test).float()
    y_test = torch.from_numpy(y_test).float()

    K_train = laplace_kernel_M(X_train, X_train, L, torch.from_numpy(M)).numpy()
    sol = solve(K_train + reg * np.eye(len(K_train)), y_train).T

    K_test = laplace_kernel_M(X_train, X_test, L, torch.from_numpy(M)).numpy()
    preds = (sol @ K_test).T

    y_pred = torch.from_numpy(preds)
    preds = torch.argmax(y_pred, dim=-1)
    labels = torch.argmax(y_test, dim=-1)
    count = torch.sum(labels == preds).numpy()

    acc = count / len(labels)
    return acc


# Run RFM


In [8]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import numpy as np

# Load the Iris dataset
iris = load_iris()
X = iris.data  # Features
y = iris.target  # Labels

# Split the dataset into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)



In [9]:
noise_train = np.random.rand(X_train.shape[0], 1)
X_train = np.hstack((X_train, noise_train))

noise_test = np.random.rand(X_test.shape[0], 1)
X_test = np.hstack((X_test, noise_test))


In [12]:
X_train.shape, y_train.shape, X_test.shape, y_test.shape

((120, 5), (120,), (30, 5), (30,))

In [16]:

hyperparam_train(X_train, y_train, X_test, y_test, 3,
          iters=10, reg=0.1, L=10, normalize=False)

a
1.0
[[ 0.03467679  0.02162893 -0.04380478 -0.05097466 -0.01649249]
 [ 0.02162893  0.04755911 -0.08503123 -0.08106354 -0.02178327]
 [-0.04380478 -0.08503123  0.23471989  0.22143796  0.05368748]
 [-0.05097466 -0.08106354  0.22143796  0.23431176  0.05978322]
 [-0.01649249 -0.02178327  0.05368748  0.05978322  0.04022331]]
a
1.0
[[ 0.01567313  0.0224993  -0.05514165 -0.05687731 -0.01597974]
 [ 0.0224993   0.03502753 -0.08730035 -0.08893613 -0.02431678]
 [-0.05514165 -0.08730035  0.22129692  0.22506167  0.06086985]
 [-0.05687731 -0.08893613  0.22506167  0.2295715   0.06259564]
 [-0.01597974 -0.02431678  0.06086985  0.06259564  0.01811201]]
a
1.0
[[ 0.01433107  0.02203385 -0.05529991 -0.05650548 -0.01548618]
 [ 0.02203385  0.0339829  -0.08538363 -0.08720069 -0.02385475]
 [-0.05529991 -0.08538363  0.21462613  0.21915673  0.05991496]
 [-0.05650548 -0.08720069  0.21915673  0.22380333  0.06120748]
 [-0.01548618 -0.02385475  0.05991496  0.06120748  0.01676692]]
a
1.0
[[ 0.01438857  0.0221706  -0

(1.0,
 0,
 array([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.]], dtype=float32))