Skip to content

Prediction Fails with Data Uncertainties #1039

Open
@oscarbranson

Description

@oscarbranson

I'm trying to run a GPRegressor on some [x, y] data where there are known uncertainties (standard deviations) in the y data. I'm using the method outlined in #196 ... this issue is from 2015, but I can't find any more recent mention of this in the docs?

The model optimizes fine, but whenever I try to make a prediction on new x values I get an array broadcast error.

I'm not familiar enough with GPy to know whether this is me being stupid, or a bug... Any thoughts?

Setup:

I'm running GPy from the devel branch (because waiting for updates to allow GPy to work with numpy >= 1.24)

GPy                           1.12.0                             GitHub:devel
numpy                         1.26.0

MWE:

import numpy as np
import GPy

# Generate some data
n = 55

x = np.sort(np.random.uniform(0, 12, (n, 1)), axis=0)
y = 2 + np.sin(x) + np.random.normal(2.2, 0.2, (n,1)) + 0.2 * x.reshape(-1,1)
y_err = np.random.uniform(0.15, 0.5, (n,1))  # the data uncertainties

# Build covariance matrix
cov = np.zeros((n, n))
cov[np.diag_indices_from(cov)] = (y_err**2).flat

# Define kernel
kernel = (
    GPy.kern.RBF(input_dim=1, variance=0.3, lengthscale=1.)
    + GPy.kern.White(input_dim=1, variance=0.1)
    + GPy.kern.Bias(input_dim=1, variance=0.1)
    + GPy.kern.Fixed(input_dim=1, covariance_matrix=cov)  # representing the data uncertainties here
)

# Optimise model
m = GPy.models.GPRegression(x, y, kernel=kernel)
m.optimize()

# Predict on new `x` data
xnew = np.linspace(0,15, 100)

pred, pred_var = m.predict_noiseless(xnew.reshape(-1,1))

Error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/home/[username]/GitHub/aquarist/aquarist/state/dev_GPR.ipynb Cell 6 line 3
      1 xn = np.linspace(0,15, 100)
----> 3 pred, pred_var = m.predict_noiseless(xn.reshape(-1,1))
      5 plt.scatter(x,y,s=5)
      6 plt.plot(xn, pred, alpha=0.3)

File ~/GitHub/GPy/GPy/core/gp.py:393, in GP.predict_noiseless(self, Xnew, full_cov, Y_metadata, kern)
    367 def predict_noiseless(self,  Xnew, full_cov=False, Y_metadata=None, kern=None):
    368     """
    369     Convenience function to predict the underlying function of the GP (often
    370     referred to as f) without adding the likelihood variance on the
   (...)
    391     Note: If you want the predictive quantiles (e.g. 95% confidence interval) use :py:func:`~GPy.core.gp.GP.predict_quantiles`.
    392     """
--> 393     return self.predict(Xnew, full_cov, Y_metadata, kern, None, False)

File ~/GitHub/GPy/GPy/core/gp.py:346, in GP.predict(self, Xnew, full_cov, Y_metadata, kern, likelihood, include_likelihood)
    310 """
    311 Predict the function(s) at the new point(s) Xnew. This includes the
    312 likelihood variance added to the predicted underlying function
   (...)
    342 interval) use :py:func:`~GPy.core.gp.GP.predict_quantiles`.
    343 """
    345 # Predict the latent function values
--> 346 mean, var = self._raw_predict(Xnew, full_cov=full_cov, kern=kern)
    348 if include_likelihood:
    349     # now push through likelihood
    350     if likelihood is None:

File ~/GitHub/GPy/GPy/core/gp.py:303, in GP._raw_predict(self, Xnew, full_cov, kern)
    290 def _raw_predict(self, Xnew, full_cov=False, kern=None):
    291     """
    292     For making predictions, does not account for normalization or likelihood
    293 
   (...)
    301         \Sigma := \texttt{Likelihood.variance / Approximate likelihood covariance}
    302     """
--> 303     mu, var = self.posterior._raw_predict(kern=self.kern if kern is None else kern, Xnew=Xnew, pred_var=self._predictive_variable, full_cov=full_cov)
    304     if self.mean_function is not None:
    305         mu += self.mean_function.f(Xnew)

File ~/GitHub/GPy/GPy/inference/latent_function_inference/posterior.py:292, in PosteriorExact._raw_predict(self, kern, Xnew, pred_var, full_cov)
    290     var = var
    291 else:
--> 292     Kxx = kern.Kdiag(Xnew)
    293     if self._woodbury_chol.ndim == 2:
    294         tmp = dtrtrs(self._woodbury_chol, Kx)[0]

File ~/GitHub/GPy/GPy/kern/src/kernel_slice_operations.py:126, in _slice_Kdiag.<locals>.wrap(self, X, *a, **kw)
    123 @wraps(f)
    124 def wrap(self, X, *a, **kw):
    125     with _Slice_wrap(self, X, None) as s:
--> 126         ret = f(self, s.X, *a, **kw)
    127     return ret

File /usr/lib/python3.11/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File ~/.python/py3/lib/python3.11/site-packages/paramz/caching.py:283, in Cache_this.__call__.<locals>.g(obj, *args, **kw)
    281 else:
    282     cacher = cache[self.f] = Cacher(self.f, self.limit, self.ignore_args, self.force_kwargs, cacher_enabled=cache.caching_enabled)
--> 283 return cacher(*args, **kw)

File ~/.python/py3/lib/python3.11/site-packages/paramz/caching.py:172, in Cacher.__call__(self, *args, **kw)
    168         # 2: if anything is not cachable, we will just return the operation, without caching
    169         if reduce(lambda a, b: a or (not (isinstance(b, Observable) or b is None or isinstance(b, Number) or isinstance(b, str))), inputs, False):
    170 #             print 'WARNING: '+self.operation.__name__ + ' not cacheable!'
    171 #             print [not (isinstance(b, Observable)) for b in inputs]
--> 172             return self.operation(*args, **kw)
    173         # 3&4: check whether this cache_id has been cached, then has it changed?
    174         not_seen = not(cache_id in self.inputs_changed)

File ~/GitHub/GPy/GPy/kern/src/add.py:80, in Add.Kdiag(self, X, which_parts)
     77 elif not isinstance(which_parts, (list, tuple)):
     78     # if only one part is given
     79     which_parts = [which_parts]
---> 80 return reduce(np.add, (p.Kdiag(X) for p in which_parts))

ValueError: operands could not be broadcast together with shapes (100,) (55,) 

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions