Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: possible bug in handling kernels that are combinations of combinations #428

Open
matthewrhysjones opened this issue Dec 27, 2023 · 2 comments
Labels
bug Something isn't working

Comments

@matthewrhysjones
Copy link

Bug Report

0.8.0

Current behavior:

This may be a problem with how I am interpreting how GPJax handles combination kernels, so sorry if I've missed something.

It seems that kernels which are a combination of a combination kernel are not being handled as expected when more than one type of combination operator is used (e.g the kernel is a sum of product kernels, or the kernel is a product of sum kernels). There doesn't appear to be a problem if both combination operators are identical (a sum of a sum kernel, or a product of product kernel).

Expected behavior:

When using a combination of combination kernel, predictive mean should be identical whether using GPJax or computing manually.

Steps to reproduce:
see below

Related code:

xall = jnp.linspace(-5,5,1000)
toy_fun = lambda x: 1/5*x**2 + jnp.sin(x*5)**3 + jnp.cos(x*3)**2

xtrain = xall[0::25][:, None]
ytrain = toy_fun(xtrain)
xtest = xall[:, None]
ytest = toy_fun(xtest)

D = gpx.gps.Dataset(xtrain, ytrain)

kernel1 = gpx.kernels.RBF()
kernel2 = gpx.kernels.Matern32()
sum_kernel = kernel1 + kernel2

# using GPJax
pos_kernel = sum_kernel * sum_kernel # pos = product of sum

pos_prior = gpx.gps.Prior(mean_function=gpx.mean_functions.Zero(), kernel = pos_kernel)
pos_posterior = pos_prior * gpx.gps.Gaussian(D.n)

latent_dist_pos = pos_posterior.likelihood(pos_posterior(xtest, train_data=D))
mu_pos = latent_dist_pos.mean()
std_pos = latent_dist_pos.stddev()

# manual calculation of predictive dist

kxx = (kernel1.gram(xtrain).to_dense() + kernel2.gram(xtrain).to_dense()) * (kernel1.gram(xtrain).to_dense() + kernel2.gram(xtrain).to_dense())
kxt = (kernel1.cross_covariance(xtrain,xtest) + kernel2.cross_covariance(xtrain,xtest)) * (kernel1.cross_covariance(xtrain,xtest) + kernel2.cross_covariance(xtrain,xtest))
ktt = (kernel1.gram(xtest).to_dense() + kernel2.gram(xtest).to_dense()) * (kernel1.gram(xtest).to_dense() + kernel2.gram(xtest).to_dense())

L = jnp.linalg.cholesky(kxx + 1*jnp.eye(D.n))  #1 here is to match the obs noise as assigned in the GPJax likelihood
alpha = jnp.linalg.solve(L.T,jnp.linalg.solve(L,ytrain))
v = jnp.linalg.solve(L,kxt)

mu_manual_pos = kxt.T @ alpha
cov_manual_pos = ktt - v.T @ v
var_manual_pos = jnp.diag(cov_manual_pos) +1 # adding obs variance to match GPJax stddev output

plt.plot(xtest,mu_manual_pos,':')
plt.plot(xtest,mu_pos,'--')

there is a discrepancy between "mu_manual_pos" and "mu_pos" when I don't believe there should be. Also true if we use a kernel that is a sum of individual product kernels. However, if the combination operators are identical (sum of sum, product of products), then the results become the same, and so it appears there is some problem with the way that GPJax is handling combinations of combinations that contain multiple operators.

Other information:

I found this issue when I've been working with kernels that are combinations of combinations for a personal project, where I am seeing drastic differences between using GPJax and manual computation. I've tried to simplify the problem for this post to make it as clear as possible.

@matthewrhysjones matthewrhysjones added the bug Something isn't working label Dec 27, 2023
@ChrisBoettner
Copy link

Hey Matthew,
I just ran into the same problem. I think the issue is in the post_init of the Combination kernel class.

    def __post_init__(self):
        # Add kernels to a list, flattening out instances of this class therein, as in GPFlow kernels.
        kernels_list: List[AbstractKernel] = []

        for kernel in self.kernels:
            if not isinstance(kernel, AbstractKernel):
                raise TypeError("can only combine Kernel instances")  # pragma: no cover

            if isinstance(kernel, self.__class__):
                kernels_list.extend(kernel.kernels)
            else:
                kernels_list.append(kernel)

        self.kernels = kernels_list

Here it calculates a flattened list of kernels, and saves it to the the kernels attribute. When the kernel is called, it returns the operation of the kernel across all kernels in the kernel list

  return self.operator(jnp.stack([k(x, y) for k in self.kernels]))

So the structure of operations of kernels is lost, it blindly applies the current operation (e.g. sum) for all sub-kernels. This explains why the results are consistent if all kernel operations are the same.

I assume the easy fix would be to have two attributes, self.kernels and self.flattened_kernels

@st--
Copy link
Contributor

st-- commented Jan 25, 2024

This is indeed a bug, thank you for spotting it !

I don't think we need to have a separate flattened_kernels; I would either
a) change SumKernel and ProductKernel to be actual subclasses of CombinationKernel (in which case the test on self.__class__ would only allow combining when the operation matches), or
b) explicitly add an additional check that self.operator is kernel.operator.

Personally I'd prefer a) ... @thomaspinder @daniel-dodd ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants