In [6]:
import math

import torch
from gpytorch.kernels import RBFKernel
from linear_operator import to_linear_operator
from linear_operator.operators import DiagLinearOperator

from famgpytorch.kernels import RBFKernelApprox
from famgpytorch.functions import ChebyshevHermitePolynomials

## Different covariance matrices for the different definitions of the approximated RBF kernel

In [7]:
def rbf_kernel_joukov_kulic(x1, x2, number_of_eigenvalues, lengthscale, alpha, chebyshev=True):
    alpha = torch.tensor([[alpha]], dtype=x1.dtype)
    lengthscale = torch.tensor([[lengthscale]], dtype=x1.dtype)

    alpha_sq = alpha.pow(2)
    eta_sq = lengthscale.pow(-2).div(2)
    beta = eta_sq.mul(4).div(alpha_sq).add(1).pow(0.25)
    delta_sq = alpha_sq.div(2).mul(beta.pow(2).sub(1))

    # compute eigenvalues
    denominator = alpha_sq.add(delta_sq).add(eta_sq)
    eigenvalue_a = torch.sqrt(alpha_sq.div(denominator))
    eigenvalue_b = eta_sq.div(denominator)
    eigenvalues = torch.arange(1, number_of_eigenvalues + 1, dtype=x1.dtype, device=x1.device)
    eigenvalues = DiagLinearOperator(eigenvalue_a.mul(eigenvalue_b.pow(eigenvalues))[0, :])

    # define eigenfunctions
    def _eigenfunctions(x, n):
        # compute sqrt factor
        range_ = torch.arange(1, n + 1, dtype=x.dtype, device=x.device)
        sqrt = torch.sqrt(beta.mul(torch.exp(-torch.lgamma(range_ + 1))))

        # compute exp factor
        exp = torch.exp(-alpha_sq.mul(x.pow(2)))

        # compute hermite polynomials
        hermites = ChebyshevHermitePolynomials.apply(alpha.mul(beta).mul(math.sqrt(2) * x), n+1)[:, 1:]
        if not chebyshev:
            hermites = hermites.mul(torch.sqrt(2 ** range_))

        eigenfunctions = sqrt.mul(exp).mul(hermites)

        if torch.isnan(eigenfunctions).any() or torch.isinf(eigenfunctions).any():
            raise ValueError("Interim results too high. Try to reduce the number of eigenvalues.")

        return eigenfunctions

    eigenfunctions1 = to_linear_operator(_eigenfunctions(x1, number_of_eigenvalues))

    if torch.equal(x1, x2):
        eigenfunctions2 = eigenfunctions1
    else:
        eigenfunctions2 = to_linear_operator(_eigenfunctions(x2, number_of_eigenvalues))

    return eigenfunctions1.matmul(eigenvalues).matmul(eigenfunctions2.mT)

In [8]:
data = torch.linspace(0, 1, 3).reshape(-1,1)
l = 1
a = 1

print("---Conventional RBF Kernel---")
rbf_kernel_conv = RBFKernel()
rbf_kernel_conv.lengthscale = l
print(rbf_kernel_conv.forward(data, data).to_dense())

print("\n---Approx RBF Kernel Joukov, Kulic---")
print(rbf_kernel_joukov_kulic(data, data, number_of_eigenvalues=15, lengthscale=l, alpha=a, chebyshev=True).to_dense())

print("\n---Approx RBF Kernel Fasshauer---")
rbf_kernel_fasshauer = RBFKernelApprox(number_of_eigenvalues=15)
rbf_kernel_fasshauer.lengthscale = l
rbf_kernel_fasshauer.alpha = a
print(rbf_kernel_fasshauer.forward(data, data).to_dense())

---Conventional RBF Kernel---
tensor([[1.0000, 0.8825, 0.6065],
        [0.8825, 1.0000, 0.8825],
        [0.6065, 0.8825, 1.0000]], grad_fn=<RBFCovarianceBackward>)

---Approx RBF Kernel Joukov, Kulic---
tensor([[ 0.0366,  0.0028, -0.0327],
        [ 0.0028,  0.1440,  0.1235],
        [-0.0327,  0.1235,  0.1510]])

---Approx RBF Kernel Fasshauer---
tensor([[1.0000, 0.8825, 0.6065],
        [0.8825, 1.0000, 0.8825],
        [0.6065, 0.8825, 1.0000]], grad_fn=<MmBackward0>)


## Maximum eigenvalues / Numerical stability

In [9]:
def approx_rbf_unoptimized(x1, x2, number_of_eigenvalues, lengthscale, alpha):
    alpha = torch.tensor([[alpha]], dtype=x1.dtype)
    lengthscale = torch.tensor([[lengthscale]], dtype=x1.dtype)

    alpha_sq = alpha.pow(2)
    eta_sq = lengthscale.pow(-2).div(2)
    beta = eta_sq.mul(4).div(alpha_sq).add(1).pow(0.25)
    delta_sq = alpha_sq.div(2).mul(beta.pow(2).sub(1))

    # compute eigenvalues
    denominator = alpha_sq.add(delta_sq).add(eta_sq)
    eigenvalue_a = torch.sqrt(alpha_sq.div(denominator))
    eigenvalue_b = eta_sq.div(denominator)
    eigenvalues = torch.arange(number_of_eigenvalues, dtype=x1.dtype, device=x1.device)
    eigenvalues = DiagLinearOperator(eigenvalue_a.mul(eigenvalue_b.pow(eigenvalues))[0, :])

    # define eigenfunctions
    def _eigenfunctions(x, n):
        # compute sqrt factor
        # computing the factorial of i would result in extremely large interim values, however, since we need to
        # calculate the reciprocal of the factorial, we make use of the natural log of the gamma function where
        # lgamma(i+1) = ln(i!) and e^(-ln(i!)) = 1 / i!
        range_ = torch.arange(n, dtype=x.dtype, device=x.device)
        sqrt = torch.sqrt(beta.div(2**range_ * torch.lgamma(range_ + 1).exp()))

        # compute exp factor
        exp = torch.exp(-delta_sq.mul(x.pow(2)))

        # compute hermite polynomials
        hermites = ChebyshevHermitePolynomials.apply(alpha.mul(beta).mul(math.sqrt(2) * x), n)
        hermites = hermites.mul(torch.sqrt(2 ** range_))

        eigenfunctions = sqrt.mul(exp).mul(hermites)

        if torch.isnan(eigenfunctions).any() or torch.isinf(eigenfunctions).any():
            raise ValueError("Interim results too high. Try to reduce the number of eigenvalues.")

        return eigenfunctions

    eigenfunctions1 = to_linear_operator(_eigenfunctions(x1, number_of_eigenvalues))

    if torch.equal(x1, x2):
        eigenfunctions2 = eigenfunctions1
    else:
        eigenfunctions2 = to_linear_operator(_eigenfunctions(x2, number_of_eigenvalues))

    return eigenfunctions1.matmul(eigenvalues).matmul(eigenfunctions2.mT)

In [10]:
data = torch.linspace(0, 1, 3).reshape(-1,1)
l = 1
a = 1

print("\n---Approx RBF Kernel Joukov, Kulic---")
n = 1
while True:
    try:
        covar = approx_rbf_unoptimized(data, data, number_of_eigenvalues=n, lengthscale=l, alpha=a).to_dense()
    except ValueError as e:
        print(f"Maximum number of eigenvalues reached for n = {n}")
        break
    n += 1

print("\n---Approx RBF Kernel Fasshauer---")
while True:
    rbf_kernel_fasshauer = RBFKernelApprox(number_of_eigenvalues=n)
    rbf_kernel_fasshauer.lengthscale = l
    rbf_kernel_fasshauer.alpha = a
    try:
        covar = rbf_kernel_fasshauer.forward(data, data).to_dense()
    except ValueError as e:
        print(f"Maximum number of eigenvalues reached for n = {n}")
        break
    n += 1


---Approx RBF Kernel Joukov, Kulic---
Maximum number of eigenvalues reached for n = 51

---Approx RBF Kernel Fasshauer---
Maximum number of eigenvalues reached for n = 59
