diff --git a/gpjax/gps.py b/gpjax/gps.py index f5fb8047..9ba3aa09 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -532,7 +532,7 @@ def predict( mean_t = self.prior.mean_function(t) Ktt = self.prior.kernel.gram(t) Kxt = self.prior.kernel.cross_covariance(x, t) - Sigma_inv_Kxt = cola.solve(Sigma, Kxt) + Sigma_inv_Kxt = cola.solve(Sigma, Kxt, Cholesky()) # μt + Ktx (Kxx + Io²)⁻¹ (y - μx) mean = mean_t + jnp.matmul(Sigma_inv_Kxt.T, y - mx)