diff --git a/gpjax/gps.py b/gpjax/gps.py index 0fb135ac0..3de86dd8d 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -218,8 +218,8 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormalTri: t = test_inputs n_test = test_inputs.shape[0] - μt = mean_function(t, params["mean_function"]) - Ktt = gram(kernel, t, params["kernel"]) + μt = mean_function(params["mean_function"], t) + Ktt = gram(kernel, params["kernel"], t) Ktt += I(n_test) * jitter Lt = Ktt.triangular_lower() @@ -339,8 +339,8 @@ class ConjugatePosterior(AbstractPosterior): name: Optional[str] = "Conjugate posterior" def predict( - self, train_data: Dataset, params: Dict - ) -> Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: + self, params: Dict, train_data: Dataset, + ) -> Callable[[Float[Array, "N D"]], dx.MultivariateNormalFullCovariance]: """Conditional on a training data set, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the @@ -372,15 +372,15 @@ def predict( >>> xtest = jnp.linspace(0, 1).reshape(-1, 1) >>> >>> params = gpx.initialise(posterior) - >>> predictive_dist = posterior.predict(gpx.Dataset(X=xtrain, y=ytrain), params) + >>> predictive_dist = posterior.predict(params, gpx.Dataset(X=xtrain, y=ytrain)) >>> predictive_dist(xtest) Args: - train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset. params (Dict): A dictionary of parameters that should be used to compute the posterior. + train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset. Returns: - Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: A function that accepts an input array and returns the predictive distribution as a `dx.MultivariateNormalTri`. + Callable[[Float[Array, "N D"]], dx.MultivariateNormalFullCovariance]: A function that accepts an input array and returns the predictive distribution as a `dx.MultivariateNormalTri`. """ jitter = get_defaults()["jitter"] @@ -397,24 +397,32 @@ def predict( # Observation noise σ² obs_noise = params["likelihood"]["obs_noise"] - μx = mean_function(x, params["mean_function"]) + μx = mean_function(params["mean_function"], x) # Precompute Gram matrix, Kxx, at training inputs, x - Kxx = gram(kernel, x, params["kernel"]) + Kxx = gram(kernel, params["kernel"], x) Kxx += I(n) * jitter # Σ = Kxx + Iσ² Sigma = Kxx + I(n) * obs_noise def predict(test_inputs: Float[Array, "N D"]) -> dx.Distribution: + """Compute the predictive distribution at a set of test inputs. + + Args: + test_inputs (Float[Array, "N D"]): A Jax array of test inputs. + + Returns: + dx.Distribution: A `dx.MultivariateNormalFullCovariance` object that represents the predictive distribution. + """ # Unpack test inputs t = test_inputs n_test = test_inputs.shape[0] - μt = mean_function(t, params["mean_function"]) - Ktt = gram(kernel, t, params["kernel"]) - Kxt = cross_covariance(kernel, x, t, params["kernel"]) + μt = mean_function(params["mean_function"], t) + Ktt = gram(kernel, params["kernel"], t) + Kxt = cross_covariance(kernel, params["kernel"], x, t) # TODO: Investigate lower triangular solves for general covariance operators # this is more efficient than the full solve for dense matrices in the current implimentation. @@ -514,15 +522,24 @@ def marginal_log_likelihood( def mll( params: Dict, ): + """Compute the marginal log-likelihood of the Gaussian process. + + Args: + params (Dict): The model's parameters. + + Returns: + Float[Array, "1"]: The marginal log-likelihood. + """ + # Observation noise σ² obs_noise = params["likelihood"]["obs_noise"] - μx = mean_function(x, params["mean_function"]) + μx = mean_function(params["mean_function"], x) # TODO: This implementation does not take advantage of the covariance operator structure. # Future work concerns implementation of a custom Gaussian distribution / measure object that accepts a covariance operator. # Σ = (Kxx + Iσ²) = LLᵀ - Kxx = gram(kernel, x, params["kernel"]) + Kxx = gram(kernel, params["kernel"], x) Kxx += I(n) * jitter Sigma = Kxx + I(n) * obs_noise L = Sigma.triangular_lower() @@ -583,7 +600,7 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: return parameters def predict( - self, train_data: Dataset, params: Dict + self, params: Dict, train_data: Dataset, ) -> Callable[[Float[Array, "N D"]], dx.Distribution]: """ Conditional on a set of training data, compute the GP's posterior @@ -594,8 +611,8 @@ def predict( transformed through the likelihood function's inverse link function. Args: - train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset. params (Dict): A dictionary of parameters that should be used to compute the posterior. + train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset. Returns: tp.Callable[[Array], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `dx.Distribution`. @@ -614,19 +631,27 @@ def predict( cross_covariance = kernel.cross_covariance # Precompute lower triangular of Gram matrix, Lx, at training inputs, x - Kxx = gram(kernel, x, params["kernel"]) + Kxx = gram(kernel, params["kernel"], x) Kxx += I(n) * jitter Lx = Kxx.triangular_lower() def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: + """Predictive distribution of the latent function for a given set of test inputs. + + Args: + test_inputs (Float[Array, "N D"]): A set of test inputs. + + Returns: + dx.Distribution: The predictive distribution of the latent function. + """ # Unpack test inputs t, n_test = test_inputs, test_inputs.shape[0] # Compute terms of the posterior predictive distribution - Ktx = cross_covariance(kernel, t, x, params["kernel"]) - Ktt = gram(kernel, t, params["kernel"]) + I(n_test) * jitter - μt = mean_function(t, params["mean_function"]) + Ktx = cross_covariance(kernel, params["kernel"], t, x) + Ktt = gram(kernel, params["kernel"], t) + I(n_test) * jitter + μt = mean_function(params["mean_function"], t) # Lx⁻¹ Kxt Lx_inv_Kxt = jsp.linalg.solve_triangular(Lx, Ktx.T, lower=True) @@ -689,14 +714,22 @@ def marginal_log_likelihood( constant = jnp.array(-1.0) if negative else jnp.array(1.0) def mll(params: Dict): + """Compute the marginal log-likelihood of the model. + + Args: + params (Dict): A dictionary of parameters that should be used to compute the marginal log-likelihood. + + Returns: + Float[Array, "1"]: The marginal log-likelihood of the model. + """ # Compute lower triangular of the kernel Gram matrix - Kxx = gram(kernel, x, params["kernel"]) + Kxx = gram(kernel, params["kernel"], x) Kxx += I(n) * jitter Lx = Kxx.triangular_lower() # Compute the prior mean function - μx = mean_function(x, params["mean_function"]) + μx = mean_function(params["mean_function"], x) # Whitened function values, wx, correponding to the inputs, x wx = params["latent"] @@ -705,7 +738,7 @@ def mll(params: Dict): fx = μx + jnp.matmul(Lx, wx) # p(y | f(x), θ), where θ are the model hyperparameters - likelihood = link_function(fx, params) + likelihood = link_function(params, fx) # log p(θ) log_prior_density = evaluate_priors(params, priors) diff --git a/gpjax/kernels.py b/gpjax/kernels.py index 70e4a3845..49f417abd 100644 --- a/gpjax/kernels.py +++ b/gpjax/kernels.py @@ -44,19 +44,19 @@ class AbstractKernel: spectral: Optional[bool] = False name: Optional[str] = "AbstractKernel" - def __post_init__(self): + def __post_init__(self) -> None: self.ndims = 1 if not self.active_dims else len(self.active_dims) @abc.abstractmethod def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict + self,params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"], ) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs. Args: + params (Dict): Parameter set for which the kernel should be evaluated on. x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. y (Float[Array, "1 D"]): The right hand argument of the kernel function's call - params (Dict): Parameter set for which the kernel should be evaluated on. Returns: Float[Array, "1"]: The value of :math:`k(x, y)`. @@ -123,15 +123,17 @@ class AbstractKernelComputation: @staticmethod @abc.abstractmethod def gram( - kernel: AbstractKernel, inputs: Float[Array, "N D"], params: Dict + kernel: AbstractKernel, + params: Dict, + inputs: Float[Array, "N D"], ) -> CovarianceOperator: """Compute Gram covariance operator of the kernel function. Args: kernel (AbstractKernel): The kernel function to be evaluated. - inputs (Float[Array, "N N"]): The inputs to the kernel function. params (Dict): The parameters of the kernel function. + inputs (Float[Array, "N N"]): The inputs to the kernel function. Returns: CovarianceOperator: Gram covariance operator of the kernel function. @@ -142,40 +144,44 @@ def gram( @staticmethod def cross_covariance( kernel: AbstractKernel, + params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"], - params: Dict, ) -> Float[Array, "N M"]: """For a given kernel, compute the NxM gram matrix on an a pair of input matrices with shape NxD and MxD. Args: kernel (AbstractKernel): The kernel for which the cross-covariance matrix should be computed for. + params (Dict): The kernel's parameter set. x (Float[Array,"N D"]): The first input matrix. y (Float[Array,"M D"]): The second input matrix. - params (Dict): The kernel's parameter set. Returns: Float[Array, "N M"]: The computed square Gram matrix. """ - cross_cov = vmap(lambda x: vmap(lambda y: kernel(x, y, params))(y))(x) + cross_cov = vmap(lambda x: vmap(lambda y: kernel(params, x, y))(y))(x) return cross_cov @staticmethod def diagonal( - kernel: AbstractKernel, inputs: Float[Array, "N D"], params: Dict + kernel: AbstractKernel, + params: Dict, + inputs: Float[Array, "N D"], ) -> CovarianceOperator: """For a given kernel, compute the elementwise diagonal of the NxN gram matrix on an input matrix of shape NxD. + Args: kernel (AbstractKernel): The kernel for which the variance vector should be computed for. - inputs (Float[Array, "N D"]): The input matrix. params (Dict): The kernel's parameter set. + inputs (Float[Array, "N D"]): The input matrix. + Returns: CovarianceOperator: The computed diagonal variance entries. """ - diag = vmap(lambda x: kernel(x, x, params))(inputs) + diag = vmap(lambda x: kernel(params, x, x))(inputs) return DiagonalCovarianceOperator(diag=diag) @@ -185,20 +191,22 @@ class DenseKernelComputation(AbstractKernelComputation): @staticmethod def gram( - kernel: AbstractKernel, inputs: Float[Array, "N D"], params: Dict + kernel: AbstractKernel, + params: Dict, + inputs: Float[Array, "N D"], ) -> CovarianceOperator: """For a given kernel, compute the NxN gram matrix on an input matrix of shape NxD. Args: kernel (AbstractKernel): The kernel for which the Gram matrix should be computed for. - inputs (Float[Array,"N D"]): The input matrix. params (Dict): The kernel's parameter set. + inputs (Float[Array,"N D"]): The input matrix. Returns: CovarianceOperator: The computed square Gram matrix. """ - matrix = vmap(lambda x: vmap(lambda y: kernel(x, y, params))(inputs))(inputs) + matrix = vmap(lambda x: vmap(lambda y: kernel(params, x, y))(inputs))(inputs) return DenseCovarianceOperator(matrix=matrix) @@ -206,26 +214,29 @@ def gram( class DiagonalKernelComputation(AbstractKernelComputation): @staticmethod def gram( - kernel: AbstractKernel, inputs: Float[Array, "N D"], params: Dict + kernel: AbstractKernel, + params: Dict, + inputs: Float[Array, "N D"], ) -> CovarianceOperator: """For a kernel with diagonal structure, compute the NxN gram matrix on an input matrix of shape NxD. Args: kernel (AbstractKernel): The kernel for which the Gram matrix should be computed for. - inputs (Float[Array, "N D"]): The input matrix. params (Dict): The kernel's parameter set. + inputs (Float[Array, "N D"]): The input matrix. Returns: CovarianceOperator: The computed square Gram matrix. """ - diag = vmap(lambda x: kernel(x, x, params))(inputs) + diag = vmap(lambda x: kernel(params, x, x))(inputs) return DiagonalCovarianceOperator(diag=diag) @dataclass class _KernelSet: + """A mixin class for storing a list of kernels. Useful for combination kernels.""" kernel_set: List[AbstractKernel] @@ -236,7 +247,7 @@ class CombinationKernel(AbstractKernel, _KernelSet, DenseKernelComputation): name: Optional[str] = "Combination kernel" combination_fn: Optional[Callable] = None - def __post_init__(self): + def __post_init__(self) -> None: """Set the kernel set to the list of kernels passed to the constructor.""" kernels = self.kernel_set @@ -255,6 +266,7 @@ def _set_kernels(self, kernels: Sequence[AbstractKernel]) -> None: kernels_list.extend(k.kernel_set) else: kernels_list.append(k) + self.kernel_set = kernels_list def _initialise_params(self, key: PRNGKeyType) -> Dict: @@ -262,10 +274,20 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: return [kernel._initialise_params(key) for kernel in self.kernel_set] def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict + self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"], ) -> Float[Array, "1"]: + """Evaluate combination kernel on a pair of inputs. + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call + + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)`. + """ return self.combination_fn( - jnp.stack([k(x, y, p) for k, p in zip(self.kernel_set, params)]) + jnp.stack([k(p, x, y) for k, p in zip(self.kernel_set, params)]) ) @@ -294,11 +316,11 @@ class RBF(AbstractKernel, DenseKernelComputation): name: Optional[str] = "Radial basis function kernel" - def __post_init__(self): + def __post_init__(self) -> None: self.ndims = 1 if not self.active_dims else len(self.active_dims) def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict + self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] ) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\ell` and variance :math:`\sigma^2` @@ -306,9 +328,9 @@ def __call__( k(x, y) = \\sigma^2 \\exp \\Bigg( \\frac{\\lVert x - y \\rVert^2_2}{2 \\ell^2} \\Bigg) Args: + params (Dict): Parameter set for which the kernel should be evaluated on. x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. - params (Dict): Parameter set for which the kernel should be evaluated on. Returns: Float[Array, "1"]: The value of :math:`k(x, y)`. @@ -325,22 +347,17 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: } -# @dataclass -# class RBF(_RBF, DenseKernelComputation): -# pass - - @dataclass(repr=False) class Matern12(AbstractKernel, DenseKernelComputation): """The Matérn kernel with smoothness parameter fixed at 0.5.""" name: Optional[str] = "Matern 1/2" - def __post_init__(self): + def __post_init__(self) -> None: self.ndims = 1 if not self.active_dims else len(self.active_dims) def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict + self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"], ) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\ell` and variance :math:`\sigma^2` @@ -348,9 +365,9 @@ def __call__( k(x, y) = \\sigma^2 \\exp \\Bigg( -\\frac{\\lvert x-y \\rvert}{\\ell} \\Bigg) Args: + params (Dict): Parameter set for which the kernel should be evaluated on. x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. y (Float[Array, "1 D"]): The right hand argument of the kernel function's call - params (Dict): Parameter set for which the kernel should be evaluated on. Returns: Float[Array, "1"]: The value of :math:`k(x, y)` """ @@ -372,11 +389,11 @@ class Matern32(AbstractKernel, DenseKernelComputation): name: Optional[str] = "Matern 3/2" - def __post_init__(self): + def __post_init__(self) -> None: self.ndims = 1 if not self.active_dims else len(self.active_dims) def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict + self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"], ) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with lengthscale parameter :math:`\ell` and variance :math:`\sigma^2` @@ -384,9 +401,9 @@ def __call__( k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{3}\\lvert x-y \\rvert}{\\ell} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{3}\\lvert x-y\\rvert}{\\ell} \\Bigg) Args: + params (Dict): Parameter set for which the kernel should be evaluated on. x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. - params (Dict): Parameter set for which the kernel should be evaluated on. Returns: Float[Array, "1"]: The value of :math:`k(x, y)`. @@ -414,11 +431,11 @@ class Matern52(AbstractKernel, DenseKernelComputation): name: Optional[str] = "Matern 5/2" - def __post_init__(self): + def __post_init__(self) -> None: self.ndims = 1 if not self.active_dims else len(self.active_dims) def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict + self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] ) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with lengthscale parameter :math:`\ell` and variance :math:`\sigma^2` @@ -426,9 +443,9 @@ def __call__( k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{5}\\lvert x-y \\rvert}{\\ell} + \\frac{5\\lvert x - y \\rvert^2}{3\\ell^2} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{5}\\lvert x-y\\rvert}{\\ell} \\Bigg) Args: + params (Dict): Parameter set for which the kernel should be evaluated on. x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. - params (Dict): Parameter set for which the kernel should be evaluated on. Returns: Float[Array, "1"]: The value of :math:`k(x, y)`. @@ -457,12 +474,12 @@ class Polynomial(AbstractKernel, DenseKernelComputation): name: Optional[str] = "Polynomial" degree: int = 1 - def __post_init__(self): + def __post_init__(self) -> None: self.ndims = 1 if not self.active_dims else len(self.active_dims) self.name = f"Polynomial Degree: {self.degree}" def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict + self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] ) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with shift parameter :math:`\\alpha` and variance :math:`\sigma^2` through @@ -470,9 +487,9 @@ def __call__( k(x, y) = \\Big( \\alpha + \\sigma^2 xy \\Big)^{d} Args: + params (Dict): Parameter set for which the kernel should be evaluated on. x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. y (Float[Array, "1 D"]): The right hand argument of the kernel function's call - params (Dict): Parameter set for which the kernel should be evaluated on. Returns: Float[Array, "1"]: The value of :math:`k(x, y)`. @@ -491,11 +508,11 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: @dataclass(repr=False) class White(AbstractKernel, DiagonalKernelComputation): - def __post_init__(self): + def __post_init__(self) -> None: self.ndims = 1 if not self.active_dims else len(self.active_dims) def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict + self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] ) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with variance :math:`\sigma` @@ -503,9 +520,9 @@ def __call__( k(x, y) = \\sigma^2 \delta(x-y) Args: + params (Dict): Parameter set for which the kernel should be evaluated on. x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. - params (Dict): Parameter set for which the kernel should be evaluated on. Returns: Float[Array, "1"]: The value of :math:`k(x, y)`. @@ -537,21 +554,21 @@ class _EigenKernel: class GraphKernel(AbstractKernel, _EigenKernel, DenseKernelComputation): name: Optional[str] = "Graph kernel" - def __post_init__(self): + def __post_init__(self) -> None: self.ndims = 1 evals, self.evecs = jnp.linalg.eigh(self.laplacian) self.evals = evals.reshape(-1, 1) self.num_vertex = self.laplacian.shape[0] def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict + self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] ) -> Float[Array, "1"]: """Evaluate the graph kernel on a pair of vertices :math:`v_i, v_j`. Args: + params (Dict): Parameter set for which the kernel should be evaluated on. x (Float[Array, "1 D"]): Index of the ith vertex. y (Float[Array, "1 D"]): Index of the jth vertex. - params (Dict): Parameter set for which the kernel should be evaluated on. Returns: Float[Array, "1"]: The value of :math:`k(v_i, v_j)`. diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index 60f337cb4..e13df48b0 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -25,8 +25,6 @@ from .config import get_defaults from .types import PRNGKeyType -DEFAULT_JITTER = get_defaults()["jitter"] - @dataclass class AbstractLikelihood: @@ -112,24 +110,33 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: return {"obs_noise": jnp.array([1.0])} @property - def link_function(self) -> Callable: + def link_function(self) -> Callable[[Dict, Float[Array, "N 1"]], dx.Distribution]: """Return the link function of the Gaussian likelihood. Here, this is simply the identity function, but we include it for completeness. Returns: - Callable: A link function that maps the predictive distribution to the likelihood function. + Callable[[Dict, Float[Array, "N 1"]], dx.Distribution]: A link function that maps the predictive distribution to the likelihood function. """ - def link_fn(x, params: Dict) -> dx.Normal: - return dx.Normal(loc=x, scale=params["obs_noise"]) + def link_fn(params: Dict, f: Float[Array, "N 1"]) -> dx.Normal: + """The link function of the Gaussian likelihood. + + Args: + params (Dict): The parameters of the likelihood function. + f (Float[Array, "N 1"]): Function values. + + Returns: + dx.Normal: The likelihood function. + """ + return dx.Normal(loc=f, scale=params["obs_noise"]) return link_fn - def predict(self, dist: dx.MultivariateNormalTri, params: Dict) -> dx.Distribution: + def predict(self, params: Dict, dist: dx.MultivariateNormalTri) -> dx.Distribution: """Evaluate the Gaussian likelihood function at a given predictive distribution. Computationally, this is equivalent to summing the observation noise term to the diagonal elements of the predictive distribution's covariance matrix. Args: - dist (dx.Distribution): The Gaussian process posterior, evaluated at a finite set of test points. params (Dict): The parameters of the likelihood function. + dist (dx.Distribution): The Gaussian process posterior, evaluated at a finite set of test points. Returns: dx.Distribution: The predictive distribution. @@ -159,20 +166,29 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: return {} @property - def link_function(self) -> Callable: + def link_function(self) -> Callable[[Dict, Float[Array, "N 1"]], dx.Distribution]: """Return the probit link function of the Bernoulli likelihood. Returns: - Callable: A probit link function that maps the predictive distribution to the likelihood function. + Callable[[Dict, Float[Array, "N 1"]], dx.Distribution]: A probit link function that maps the predictive distribution to the likelihood function. """ - def link_fn(x, params: Dict) -> dx.Distribution: - return dx.Bernoulli(probs=inv_probit(x)) + def link_fn(params: Dict, f: Float[Array, "N 1"]) -> dx.Distribution: + """The probit link function of the Bernoulli likelihood. + + Args: + params (Dict): The parameters of the likelihood function. + f (Float[Array, "N 1"]): Function values. + + Returns: + dx.Distribution: The likelihood function. + """ + return dx.Bernoulli(probs=inv_probit(f)) return link_fn @property - def predictive_moment_fn(self) -> Callable: + def predictive_moment_fn(self) -> Callable[[Dict, Float[Array, "N 1"]], Float[Array, "N 1"]]: """Instantiate the predictive moment function of the Bernoulli likelihood that is parameterised by a probit link function. Returns: @@ -180,26 +196,36 @@ def predictive_moment_fn(self) -> Callable: """ def moment_fn( - mean: Float[Array, "N D"], variance: Float[Array, "N D"], params: Dict + params: Dict, mean: Float[Array, "N 1"], variance: Float[Array, "N 1"], ): - rv = self.link_function(mean / jnp.sqrt(1 + variance), params) + """The predictive moment function of the Bernoulli likelihood. + + Args: + params (Dict): The parameters of the likelihood function. + mean (Float[Array, "N 1"]): The mean of the latent function values. + variance (Float[Array, "N 1"]): The diagonal variance of the latent function values. + + Returns: + Float[Array, "N 1"]: The pointwise predictive distribution. + """ + rv = self.link_function(params, mean / jnp.sqrt(1.0 + variance)) return rv return moment_fn - def predict(self, dist: dx.Distribution, params: Dict) -> dx.Distribution: + def predict(self, params: Dict, dist: dx.Distribution) -> dx.Distribution: """Evaluate the pointwise predictive distribution, given a Gaussian process posterior and likelihood parameters. Args: - dist (dx.Distribution): The Gaussian process posterior, evaluated at a finite set of test points. params (Dict): The parameters of the likelihood function. + dist (dx.Distribution): The Gaussian process posterior, evaluated at a finite set of test points. Returns: dx.Distribution: The pointwise predictive distribution. """ variance = jnp.diag(dist.covariance()) - mean = dist.mean() - return self.predictive_moment_fn(mean.ravel(), variance, params) + mean = dist.mean().ravel() + return self.predictive_moment_fn(params, mean, variance) def inv_probit(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: @@ -211,7 +237,7 @@ def inv_probit(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: Returns: Float[Array, "N 1"]: The inverse probit of the input vector. """ - jitter = 1e-3 # ensures output is strictly between 0 and 1 + jitter = 1e-3 # To ensure output is in interval (0, 1). return 0.5 * (1.0 + jsp.special.erf(x / jnp.sqrt(2.0))) * (1 - 2 * jitter) + jitter diff --git a/gpjax/mean_functions.py b/gpjax/mean_functions.py index a6a580722..2f3d6d4c1 100644 --- a/gpjax/mean_functions.py +++ b/gpjax/mean_functions.py @@ -31,12 +31,12 @@ class AbstractMeanFunction: name: Optional[str] = "Mean function" @abc.abstractmethod - def __call__(self, x: Float[Array, "N D"], params: Dict) -> Float[Array, "N Q"]: + def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: """Evaluate the mean function at the given points. This method is required for all subclasses. Args: - x (Float[Array, "N D"]): The input points at which to evaluate the mean function. params (Dict): The parameters of the mean function. + x (Float[Array, "N D"]): The input points at which to evaluate the mean function. Returns: Float[Array, "N Q"]: The mean function evaluated point-wise on the inputs. @@ -65,12 +65,12 @@ class Zero(AbstractMeanFunction): output_dim: Optional[int] = 1 name: Optional[str] = "Zero mean function" - def __call__(self, x: Float[Array, "N D"], params: Dict) -> Float[Array, "N Q"]: + def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: """Evaluate the mean function at the given points. Args: - x (Float[Array, "N D"]): The input points at which to evaluate the mean function. params (Dict): The parameters of the mean function. + x (Float[Array, "N D"]): The input points at which to evaluate the mean function. Returns: Float[Array, "N Q"]: A vector of zeros. @@ -100,12 +100,12 @@ class Constant(AbstractMeanFunction): output_dim: Optional[int] = 1 name: Optional[str] = "Constant mean function" - def __call__(self, x: Float[Array, "N D"], params: Dict) -> Float[Array, "N Q"]: + def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: """Evaluate the mean function at the given points. Args: - x (Float[Array, "N D"]): The input points at which to evaluate the mean function. params (Dict): The parameters of the mean function. + x (Float[Array, "N D"]): The input points at which to evaluate the mean function. Returns: Float[Array, "N Q"]: A vector of repeated constant values. diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index 29081f025..00a5d4387 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -29,8 +29,6 @@ from .types import Dataset, PRNGKeyType from .utils import concat_dictionaries -DEFAULT_JITTER = get_defaults()["jitter"] - @dataclass class AbstractVariationalFamily: @@ -81,7 +79,6 @@ class AbstractVariationalGaussian(AbstractVariationalFamily): prior: Prior inducing_inputs: Float[Array, "N D"] name: str = "Gaussian" - jitter: Optional[float] = DEFAULT_JITTER def __post_init__(self): """Initialise the variational Gaussian distribution.""" @@ -134,6 +131,8 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: Float[Array, "1"]: The KL-divergence between our variational approximation and the GP prior. """ + jitter = get_defaults()["jitter"] + # Unpack variational parameters mu = params["variational_family"]["moments"]["variational_mean"] sqrt = params["variational_family"]["moments"]["variational_root_covariance"] @@ -147,9 +146,9 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: # Unpack kernel computation gram = kernel.gram - μz = mean_function(z, params["mean_function"]) - Kzz = gram(kernel, z, params["kernel"]) - Kzz += I(m) * self.jitter + μz = mean_function(params["mean_function"], z) + Kzz = gram(kernel, params["kernel"], z) + Kzz += I(m) * jitter Lz = Kzz.triangular_lower() qu = dx.MultivariateNormalTri(jnp.atleast_1d(mu.squeeze()), sqrt) @@ -172,6 +171,7 @@ def predict( Returns: Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: A function that accepts a set of test points and will return the predictive distribution at those points. """ + jitter = get_defaults()["jitter"] # Unpack variational parameters mu = params["variational_family"]["moments"]["variational_mean"] @@ -187,10 +187,10 @@ def predict( gram = kernel.gram cross_covariance = kernel.cross_covariance - Kzz = gram(kernel, z, params["kernel"]) - Kzz += I(m) * self.jitter + Kzz = gram(kernel, params["kernel"], z) + Kzz += I(m) * jitter Lz = Kzz.triangular_lower() - μz = mean_function(z, params["mean_function"]) + μz = mean_function(params["mean_function"], z) def predict_fn( test_inputs: Float[Array, "N D"] @@ -199,9 +199,9 @@ def predict_fn( # Unpack test inputs t, n_test = test_inputs, test_inputs.shape[0] - Ktt = gram(kernel, t, params["kernel"]) - Kzt = cross_covariance(kernel, z, t, params["kernel"]) - μt = mean_function(t, params["mean_function"]) + Ktt = gram(kernel, params["kernel"], t) + Kzt = cross_covariance(kernel, params["kernel"], z, t) + μt = mean_function(params["mean_function"], t) # Lz⁻¹ Kzt Lz_inv_Kzt = jsp.linalg.solve_triangular(Lz, Kzt, lower=True) @@ -221,7 +221,7 @@ def predict_fn( - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T) ) - covariance += I(n_test) * self.jitter + covariance += I(n_test) * jitter return dx.MultivariateNormalFullCovariance( jnp.atleast_1d(mean.squeeze()), covariance.to_dense() @@ -276,6 +276,7 @@ def predict( Returns: Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: A function that accepts a set of test points and will return the predictive distribution at those points. """ + jitter = get_defaults()["jitter"] # Unpack variational parameters mu = params["variational_family"]["moments"]["variational_mean"] @@ -291,8 +292,8 @@ def predict( gram = kernel.gram cross_covariance = kernel.cross_covariance - Kzz = gram(kernel, z, params["kernel"]) - Kzz += I(m) * self.jitter + Kzz = gram(kernel, params["kernel"], z) + Kzz += I(m) * jitter Lz = Kzz.triangular_lower() def predict_fn( @@ -302,9 +303,9 @@ def predict_fn( # Unpack test inputs t, n_test = test_inputs, test_inputs.shape[0] - Ktt = gram(kernel, t, params["kernel"]) - Kzt = cross_covariance(kernel, z, t, params["kernel"]) - μt = mean_function(t, params["mean_function"]) + Ktt = gram(kernel, params["kernel"], t) + Kzt = cross_covariance(kernel, params["kernel"], z, t) + μt = mean_function(params["mean_function"], t) # Lz⁻¹ Kzt Lz_inv_Kzt = jsp.linalg.solve_triangular(Lz, Kzt, lower=True) @@ -321,7 +322,7 @@ def predict_fn( - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + jnp.matmul(Ktz_Lz_invT_sqrt, Ktz_Lz_invT_sqrt.T) ) - covariance += I(n_test) * self.jitter + covariance += I(n_test) * jitter return dx.MultivariateNormalFullCovariance( jnp.atleast_1d(mean.squeeze()), covariance.to_dense() @@ -374,6 +375,7 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: Returns: Float[Array, "1"]: The KL-divergence between our variational approximation and the GP prior. """ + jitter = get_defaults()["jitter"] # Unpack variational parameters natural_vector = params["variational_family"]["moments"]["natural_vector"] @@ -390,7 +392,7 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: # S⁻¹ = -2θ₂ S_inv = -2 * natural_matrix - S_inv += jnp.eye(m) * self.jitter + S_inv += jnp.eye(m) * jitter # Compute L⁻¹, where LLᵀ = S, via a trick found in the NumPyro source code and https://nbviewer.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril: sqrt_inv = jnp.swapaxes( @@ -406,9 +408,9 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: # μ = Sθ₁ mu = jnp.matmul(S, natural_vector) - μz = mean_function(z, params["mean_function"]) - Kzz = gram(kernel, z, params["kernel"]) - Kzz += I(m) * self.jitter + μz = mean_function(params["mean_function"], z) + Kzz = gram(kernel, params["kernel"], z) + Kzz += I(m) * jitter Lz = Kzz.triangular_lower() qu = dx.MultivariateNormalTri(jnp.atleast_1d(mu.squeeze()), sqrt) @@ -433,6 +435,7 @@ def predict( Returns: Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: A function that accepts a set of test points and will return the predictive distribution at those points. """ + jitter = get_defaults()["jitter"] # Unpack variational parameters natural_vector = params["variational_family"]["moments"]["natural_vector"] @@ -450,7 +453,7 @@ def predict( # S⁻¹ = -2θ₂ S_inv = -2 * natural_matrix - S_inv += jnp.eye(m) * self.jitter + S_inv += jnp.eye(m) * jitter # Compute L⁻¹, where LLᵀ = S, via a trick found in the NumPyro source code and https://nbviewer.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril: sqrt_inv = jnp.swapaxes( @@ -466,19 +469,19 @@ def predict( # μ = Sθ₁ mu = jnp.matmul(S, natural_vector) - Kzz = gram(kernel, z, params["kernel"]) - Kzz += I(m) * self.jitter + Kzz = gram(kernel, params["kernel"], z) + Kzz += I(m) * jitter Lz = Kzz.triangular_lower() - μz = mean_function(z, params["mean_function"]) + μz = mean_function(params["mean_function"], z) def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormalTri: # Unpack test inputs t, n_test = test_inputs, test_inputs.shape[0] - Ktt = gram(kernel, t, params["kernel"]) - Kzt = cross_covariance(kernel, z, t, params["kernel"]) - μt = mean_function(t, params["mean_function"]) + Ktt = gram(kernel, params["kernel"], t) + Kzt = cross_covariance(kernel, params["kernel"], z, t) + μt = mean_function(params["mean_function"], t) # Lz⁻¹ Kzt Lz_inv_Kzt = jsp.linalg.solve_triangular(Lz, Kzt, lower=True) @@ -498,7 +501,7 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormalTri: - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + jnp.matmul(Ktz_Kzz_inv_L, Ktz_Kzz_inv_L.T) ) - covariance += I(n_test) * self.jitter + covariance += I(n_test) * jitter return dx.MultivariateNormalFullCovariance( jnp.atleast_1d(mean.squeeze()), covariance.to_dense() @@ -553,6 +556,7 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: Returns: Float[Array, "1"]: The KL-divergence between our variational approximation and the GP prior. """ + jitter = get_defaults()["jitter"] # Unpack variational parameters expectation_vector = params["variational_family"]["moments"][ @@ -576,14 +580,14 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: # S = η₂ - η₁ η₁ᵀ S = expectation_matrix - jnp.outer(mu, mu) - S += jnp.eye(m) * self.jitter + S += jnp.eye(m) * jitter # S = sqrt sqrtᵀ sqrt = jnp.linalg.cholesky(S) - μz = mean_function(z, params["mean_function"]) - Kzz = gram(kernel, z, params["kernel"]) - Kzz += I(m) * self.jitter + μz = mean_function(params["mean_function"], z) + Kzz = gram(kernel, params["kernel"], z) + Kzz += I(m) * jitter Lz = Kzz.triangular_lower() qu = dx.MultivariateNormalTri(jnp.atleast_1d(mu.squeeze()), sqrt) @@ -608,6 +612,7 @@ def predict( Returns: Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: A function that accepts a set of test points and will return the predictive distribution at those points. """ + jitter = get_defaults()["jitter"] # Unpack variational parameters expectation_vector = params["variational_family"]["moments"][ @@ -632,15 +637,15 @@ def predict( # S = η₂ - η₁ η₁ᵀ S = expectation_matrix - jnp.matmul(mu, mu.T) - S += jnp.eye(m) * self.jitter + S += jnp.eye(m) * jitter # S = sqrt sqrtᵀ sqrt = jnp.linalg.cholesky(S) - Kzz = gram(kernel, z, params["kernel"]) - Kzz += I(m) * self.jitter + Kzz = gram(kernel, params["kernel"], z) + Kzz += I(m) * jitter Lz = Kzz.triangular_lower() - μz = mean_function(z, params["mean_function"]) + μz = mean_function(params["mean_function"], z) def predict_fn( test_inputs: Float[Array, "N D"] @@ -649,9 +654,9 @@ def predict_fn( # Unpack test inputs t, n_test = test_inputs, test_inputs.shape[0] - Ktt = gram(kernel, t, params["kernel"]) - Kzt = cross_covariance(kernel, z, t, params["kernel"]) - μt = mean_function(t, params["mean_function"]) + Ktt = gram(kernel, params["kernel"], t) + Kzt = cross_covariance(kernel, params["kernel"], z, t) + μt = mean_function(params["mean_function"], t) # Lz⁻¹ Kzt Lz_inv_Kzt = jsp.linalg.solve_triangular(Lz, Kzt, lower=True) @@ -671,7 +676,7 @@ def predict_fn( - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T) ) - covariance += I(n_test) * self.jitter + covariance += I(n_test) * jitter return dx.MultivariateNormalFullCovariance( jnp.atleast_1d(mean.squeeze()), covariance.to_dense() @@ -690,7 +695,6 @@ class CollapsedVariationalGaussian(AbstractVariationalFamily): inducing_inputs: Float[Array, "M D"] name: str = "Collapsed variational Gaussian" diag: Optional[bool] = False - jitter: Optional[float] = DEFAULT_JITTER def __post_init__(self): """Initialise the variational Gaussian distribution.""" @@ -720,6 +724,7 @@ def predict( Returns: Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: A function that accepts a set of test points and will return the predictive distribution at those points. """ + jitter = get_defaults()["jitter"] def predict_fn( test_inputs: Float[Array, "N D"] @@ -745,9 +750,9 @@ def predict_fn( gram = kernel.gram cross_covariance = kernel.cross_covariance - Kzx = cross_covariance(kernel, z, x, params["kernel"]) - Kzz = gram(kernel, z, params["kernel"]) - Kzz += I(m) * self.jitter + Kzx = cross_covariance(kernel, params["kernel"], z, x) + Kzz = gram(kernel, params["kernel"], z) + Kzz += I(m) * jitter # Lz Lzᵀ = Kzz Lz = Kzz.triangular_lower() @@ -764,7 +769,7 @@ def predict_fn( # LLᵀ = I + AAᵀ L = jnp.linalg.cholesky(jnp.eye(m) + AAT) - μx = mean_function(x, params["mean_function"]) + μx = mean_function(params["mean_function"], x) diff = y - μx # Lz⁻¹ Kzx (y - μx) @@ -777,9 +782,9 @@ def predict_fn( Lz.T, Lz_inv_Kzx_diff, lower=False ) - Ktt = gram(kernel, t, params["kernel"]) - Kzt = cross_covariance(kernel, z, t, params["kernel"]) - μt = mean_function(t, params["mean_function"]) + Ktt = gram(kernel, params["kernel"], t) + Kzt = cross_covariance(kernel, params["kernel"], z, t) + μt = mean_function(params["mean_function"], t) # Lz⁻¹ Kzt Lz_inv_Kzt = jsp.linalg.solve_triangular(Lz, Kzt, lower=True) @@ -796,7 +801,7 @@ def predict_fn( - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + jnp.matmul(L_inv_Lz_inv_Kzt.T, L_inv_Lz_inv_Kzt) ) - covariance += I(n_test) * self.jitter + covariance += I(n_test) * jitter return dx.MultivariateNormalFullCovariance( jnp.atleast_1d(mean.squeeze()), covariance.to_dense() diff --git a/gpjax/variational_inference.py b/gpjax/variational_inference.py index 07ca4a17f..16e7c1409 100644 --- a/gpjax/variational_inference.py +++ b/gpjax/variational_inference.py @@ -22,6 +22,7 @@ from jax import vmap from jaxtyping import Array, Float +from .config import get_defaults from .covariance_operator import I from .gps import AbstractPosterior from .likelihoods import Gaussian @@ -132,7 +133,7 @@ def variational_expectation( # log(p(y|f(x))) link_function = self.likelihood.link_function - log_prob = vmap(lambda f, y: link_function(f, params["likelihood"]).log_prob(y)) + log_prob = vmap(lambda f, y: link_function(params["likelihood"], f).log_prob(y)) # ≈ ∫[log(p(y|f(x))) q(f(x))] df(x) expectation = gauss_hermite_quadrature(log_prob, mean, variance, y=y) @@ -180,7 +181,7 @@ def elbo( gram, cross_covariance = kernel.gram, kernel.cross_covariance m = self.num_inducing - jitter = self.variational_family.jitter + jitter = get_defaults()["jitter"] # Constant for whether or not to negate the elbo for optimisation purposes constant = jnp.array(-1.0) if negative else jnp.array(1.0) @@ -188,11 +189,11 @@ def elbo( def elbo_fn(params: Dict) -> Float[Array, "1"]: noise = params["likelihood"]["obs_noise"] z = params["variational_family"]["inducing_inputs"] - Kzz = gram(kernel, z, params["kernel"]) + Kzz = gram(kernel, params["kernel"], z) Kzz += I(m) * jitter - Kzx = cross_covariance(kernel, z, x, params["kernel"]) - Kxx_diag = vmap(kernel, in_axes=(0, 0, None))(x, x, params["kernel"]) - μx = mean_function(x, params["mean_function"]) + Kzx = cross_covariance(kernel, params["kernel"], z, x) + Kxx_diag = vmap(kernel, in_axes=(0, 0, None))(params["kernel"], x, x) + μx = mean_function(params["mean_function"], x) Lz = Kzz.triangular_lower() diff --git a/tests/test_gps.py b/tests/test_gps.py index 9f23cd74f..28bac4f00 100644 --- a/tests/test_gps.py +++ b/tests/test_gps.py @@ -79,7 +79,7 @@ def test_conjugate_posterior(num_datapoints): assert isinstance(post2, AbstractPrior) parameter_state = initialise(post, key) - params, _, bijectors = parameter_state.unpack() + params, *_ = parameter_state.unpack() # Marginal likelihood mll = post.marginal_log_likelihood(train_data=D) @@ -88,7 +88,7 @@ def test_conjugate_posterior(num_datapoints): assert objective_val.shape == () # Prediction - predictive_dist_fn = post(D, params) + predictive_dist_fn = post(params, D) assert isinstance(predictive_dist_fn, tp.Callable) x = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) @@ -130,7 +130,7 @@ def test_nonconjugate_posterior(num_datapoints, likel): assert objective_val.shape == () # Prediction - predictive_dist_fn = post(D, params) + predictive_dist_fn = post(params, D) assert isinstance(predictive_dist_fn, tp.Callable) x = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) diff --git a/tests/test_kernels.py b/tests/test_kernels.py index f1e890abb..1921feb1f 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -27,10 +27,9 @@ from gpjax.covariance_operator import ( CovarianceOperator, - DenseCovarianceOperator, - DiagonalCovarianceOperator, I, ) + from gpjax.kernels import ( RBF, CombinationKernel, @@ -50,9 +49,8 @@ # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) -"""Default values for tests""" _initialise_key = jr.PRNGKey(123) -_jitter = 100 +_jitter = 1e-6 def test_abstract_kernel(): @@ -111,7 +109,7 @@ def test_gram(kernel: AbstractKernel, dim: int, n: int) -> None: params = kernel._initialise_params(_initialise_key) # Test gram matrix: - Kxx = gram(kernel, x, params) + Kxx = gram(kernel, params, x) assert isinstance(Kxx, CovarianceOperator) assert Kxx.shape == (n, n) @@ -135,7 +133,7 @@ def test_cross_covariance( params = kernel._initialise_params(_initialise_key) # Test cross covariance, Kab: - Kab = cross_cov(kernel, a, b, params) + Kab = cross_cov(kernel, params, a, b) assert isinstance(Kab, jnp.ndarray) assert Kab.shape == (num_a, num_b) @@ -152,7 +150,7 @@ def test_call(kernel: AbstractKernel, dim: int) -> None: params = kernel._initialise_params(_initialise_key) # Test calling gives an autocovariance value of no dimension between the inputs: - kxy = kernel(x, y, params) + kxy = kernel(params, x, y) assert isinstance(kxy, jnp.DeviceArray) assert kxy.shape == () @@ -174,7 +172,7 @@ def test_pos_def( params = {"lengthscale": jnp.array([ell]), "variance": jnp.array([sigma])} # Test gram matrix eigenvalues are positive: - Kxx = gram(kern, x, params) + Kxx = gram(kern, params, x) Kxx += I(n) * _jitter eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) assert (eigen_values > 0.0).all() @@ -220,25 +218,37 @@ def test_polynomial( degree: int, dim: int, variance: float, shift: float, n: int ) -> None: + # Define inputs x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) + # Define kernel kern = Polynomial(degree=degree, active_dims=[i for i in range(dim)]) + + # Unpack kernel computation + gram = kern.gram + + # Check name assert kern.name == f"Polynomial Degree: {degree}" + # Initialise parameters params = kern._initialise_params(_initialise_key) params["shift"] * shift params["variance"] * variance - gram = kern.gram + # Check parameter keys + assert list(params.keys()) == ["shift", "variance"] - # Test positive definiteness: - Kxx = gram(kern, x, params) - Kxx += I(n) * _jitter - eigen_values, _ = jnp.linalg.eigh(Kxx.to_dense()) - assert (eigen_values > 0).all() + # Compute gram matrix + Kxx = gram(kern, params, x) + + # Check shapes assert Kxx.shape[0] == x.shape[0] assert Kxx.shape[0] == Kxx.shape[1] - assert list(params.keys()) == ["shift", "variance"] + + # Test positive definiteness + Kxx += I(n) * _jitter + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + assert (eigen_values > 0).all() @pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) @@ -247,20 +257,32 @@ def test_active_dim(kernel: AbstractKernel) -> None: perm_length = 2 dim_pairs = list(permutations(dim_list, r=perm_length)) n_dims = len(dim_list) - key = jr.PRNGKey(123) - X = jr.normal(key, shape=(20, n_dims)) + + # Generate random inputs + x = jr.normal(_initialise_key, shape=(20, n_dims)) for dp in dim_pairs: - Xslice = X[..., dp] + # Take slice of x + slice = x[..., dp] + + # Define kernels ad_kern = kernel(active_dims=dp) manual_kern = kernel(active_dims=[i for i in range(perm_length)]) - ad_default_params = ad_kern._initialise_params(key) - manual_default_params = manual_kern._initialise_params(key) + # Unpack kernel computation + ad_gram = ad_kern.gram + manual_gram = manual_kern.gram + + # Get initial parameters + ad_params = ad_kern._initialise_params(_initialise_key) + manual_params = manual_kern._initialise_params(_initialise_key) - k1 = ad_kern.gram(ad_kern, X, ad_default_params) - k2 = manual_kern.gram(manual_kern, Xslice, manual_default_params) - assert jnp.all(k1.to_dense() == k2.to_dense()) + # Compute gram matrices + ad_Kxx = ad_gram(ad_kern, ad_params, x) + manual_Kxx = manual_gram(manual_kern, manual_params, slice) + + # Test gram matrices are equal + assert jnp.all(ad_Kxx.to_dense() == manual_Kxx.to_dense()) @pytest.mark.parametrize("combination_type", [SumKernel, ProductKernel]) @@ -270,18 +292,44 @@ def test_combination_kernel( combination_type: CombinationKernel, kernel: AbstractKernel, n_kerns: int ) -> None: + # Create inputs n = 20 - kern_list = [kernel() for _ in range(n_kerns)] - c_kernel = combination_type(kernel_set=kern_list) - assert len(c_kernel.kernel_set) == n_kerns - assert len(c_kernel._initialise_params(_initialise_key)) == n_kerns - assert isinstance(c_kernel.kernel_set, list) - assert isinstance(c_kernel.kernel_set[0], AbstractKernel) - assert isinstance(c_kernel._initialise_params(_initialise_key)[0], dict) x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) - Kff = c_kernel.gram(c_kernel, x, c_kernel._initialise_params(_initialise_key)) - assert Kff.shape[0] == Kff.shape[1] - assert Kff.shape[1] == n + + # Create list of kernels + kernel_set = [kernel() for _ in range(n_kerns)] + + # Create combination kernel + combination_kernel = combination_type(kernel_set=kernel_set) + + # Unpack kernel computation + gram = combination_kernel.gram + + # Initialise default parameters + params = combination_kernel._initialise_params(_initialise_key) + + # Check params are a list of dictionaries + assert len(params) == n_kerns + + for p in params: + assert isinstance(p, dict) + + # Check combination kernel set + assert len(combination_kernel.kernel_set) == n_kerns + assert isinstance(combination_kernel.kernel_set, list) + assert isinstance(combination_kernel.kernel_set[0], AbstractKernel) + + # Compute gram matrix + Kxx = gram(combination_kernel, params, x) + + # Check shapes + assert Kxx.shape[0] == Kxx.shape[1] + assert Kxx.shape[1] == n + + # Check positive definiteness + Kxx += I(n) * _jitter + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + assert (eigen_values > 0).all() @pytest.mark.parametrize( @@ -291,17 +339,38 @@ def test_combination_kernel( "k2", [RBF(), Matern12(), Matern32(), Matern52(), Polynomial()] ) def test_sum_kern_value(k1: AbstractKernel, k2: AbstractKernel) -> None: + # Create inputs n = 10 - sum_kernel = SumKernel(kernel_set=[k1, k2]) x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) - Kff = sum_kernel.gram( - sum_kernel, x, sum_kernel._initialise_params(_initialise_key) - ).to_dense() - Kff_manual = ( - k1.gram(k1, x, k1._initialise_params(_initialise_key)).to_dense() - + k2.gram(k2, x, k2._initialise_params(_initialise_key)).to_dense() - ) - assert jnp.all(Kff == Kff_manual) + + # Create sum kernel + sum_kernel = SumKernel(kernel_set=[k1, k2]) + + # Unpack kernel computation + gram = sum_kernel.gram + + # Initialise default parameters + params = sum_kernel._initialise_params(_initialise_key) + + # Compute gram matrix + Kxx = gram(sum_kernel, params, x) + + # NOW we do the same thing manually and check they are equal: + + # Unpack kernel computation + k1_gram = k1.gram + k2_gram = k2.gram + + # Initialise default parameters + k1_params = k1._initialise_params(_initialise_key) + k2_params = k2._initialise_params(_initialise_key) + + # Compute gram matrix + Kxx_k1 = k1_gram(k1, k1_params, x) + Kxx_k2 = k2_gram(k2, k2_params, x) + + # Check manual and automatic gram matrices are equal + assert jnp.all(Kxx.to_dense() == Kxx_k1.to_dense() + Kxx_k2.to_dense()) @pytest.mark.parametrize( @@ -311,44 +380,81 @@ def test_sum_kern_value(k1: AbstractKernel, k2: AbstractKernel) -> None: "k2", [RBF(), Matern12(), Matern32(), Matern52(), Polynomial()] ) def test_prod_kern_value(k1: AbstractKernel, k2: AbstractKernel) -> None: + + # Create inputs n = 10 - prod_kernel = ProductKernel(kernel_set=[k1, k2]) x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) - Kff = prod_kernel.gram( - prod_kernel, x, prod_kernel._initialise_params(_initialise_key) - ).to_dense() - Kff_manual = ( - k1.gram(k1, x, k1._initialise_params(_initialise_key)).to_dense() - * k2.gram(k2, x, k2._initialise_params(_initialise_key)).to_dense() - ) - assert jnp.all(Kff == Kff_manual) + + # Create product kernel + prod_kernel = ProductKernel(kernel_set=[k1, k2]) + + # Unpack kernel computation + gram = prod_kernel.gram + + # Initialise default parameters + params = prod_kernel._initialise_params(_initialise_key) + + # Compute gram matrix + Kxx = gram(prod_kernel, params, x) + + # NOW we do the same thing manually and check they are equal: + + # Unpack kernel computation + k1_gram = k1.gram + k2_gram = k2.gram + + # Initialise default parameters + k1_params = k1._initialise_params(_initialise_key) + k2_params = k2._initialise_params(_initialise_key) + + # Compute gram matrix + Kxx_k1 = k1_gram(k1, k1_params, x) + Kxx_k2 = k2_gram(k2, k2_params, x) + + # Check manual and automatic gram matrices are equal + assert jnp.all(Kxx.to_dense() == Kxx_k1.to_dense() * Kxx_k2.to_dense()) def test_graph_kernel(): + + # Create a random graph, G, and verice labels, x, n_verticies = 20 n_edges = 40 G = nx.gnm_random_graph(n_verticies, n_edges, seed=123) + x = jnp.arange(n_verticies).reshape(-1, 1) + + # Compute graph laplacian L = nx.laplacian_matrix(G).toarray() + jnp.eye(n_verticies) * 1e-12 + + # Create graph kernel kern = GraphKernel(laplacian=L) assert isinstance(kern, GraphKernel) assert isinstance(kern, _EigenKernel) + assert kern.num_vertex == n_verticies + assert kern.evals.shape == (n_verticies, 1) + assert kern.evecs.shape == (n_verticies, n_verticies) + + # Unpack kernel computation + gram = kern.gram - kern_params = kern._initialise_params(_initialise_key) - assert isinstance(kern_params, dict) - assert list(sorted(list(kern_params.keys()))) == [ + # Initialise default parameters + params = kern._initialise_params(_initialise_key) + assert isinstance(params, dict) + assert list(sorted(list(params.keys()))) == [ "lengthscale", "smoothness", "variance", ] - x = jnp.arange(n_verticies).reshape(-1, 1) - Kxx = kern.gram(kern, x, kern._initialise_params(_initialise_key)) + + # Compute gram matrix + Kxx = gram(kern, params, x) assert Kxx.shape == (n_verticies, n_verticies) - eigen_values, _ = jnp.linalg.eigh(Kxx.to_dense() + jnp.eye(n_verticies) * 1e-8) - assert all(eigen_values > 0) - assert kern.num_vertex == n_verticies - assert kern.evals.shape == (n_verticies, 1) - assert kern.evecs.shape == (n_verticies, n_verticies) + # Check positive definiteness + Kxx += I(n_verticies) * _jitter + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + assert all(eigen_values > 0) + @pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52, Polynomial]) diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py index 3c6bf1a65..5e95860e8 100644 --- a/tests/test_likelihoods.py +++ b/tests/test_likelihoods.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -import typing as tp +from typing import Callable, Dict import distrax as dx import jax.numpy as jnp @@ -21,6 +21,7 @@ import numpy as np import pytest from jax.config import config +from jaxtyping import Array, Float from gpjax.likelihoods import ( AbstractLikelihood, @@ -28,94 +29,171 @@ Conjugate, Gaussian, NonConjugate, + inv_probit ) -from gpjax.parameters import initialise + +from gpjax.types import PRNGKeyType # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) +_initialise_key = jr.PRNGKey(123) + +# Likelihood parameter names to test in initialisation. true_initialisation = { "Gaussian": ["obs_noise"], "Bernoulli": [], } +def test_abstract_likelihood(): + # Test that abstract likelihoods cannot be instantiated. + with pytest.raises(TypeError): + AbstractLikelihood(num_datapoints=123) + + # Create a dummy likelihood class with abstract methods implemented. + class DummyLikelihood(AbstractLikelihood): + + def _initialise_params(self, key: PRNGKeyType) -> Dict: + return {} + + def predict(self, params: Dict, dist: dx.Distribution) -> dx.Distribution: + return dx.Normal(0.0, 1.0) + + def link_function(self) -> Callable: + + def link(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: + return dx.MultivariateNormalDiag(loc=x) + return link + + # Test that the dummy likelihood can be instantiated. + dummy_likelihood = DummyLikelihood(num_datapoints=123) + assert isinstance(dummy_likelihood, AbstractLikelihood) + -@pytest.mark.parametrize("num_datapoints", [1, 10]) +@pytest.mark.parametrize("n", [1, 10]) @pytest.mark.parametrize("lik", [Gaussian, Bernoulli]) -def test_initialisers(num_datapoints, lik): - key = jr.PRNGKey(123) - lhood = lik(num_datapoints=num_datapoints) - params, _, _ = initialise(lhood, key).unpack() - assert list(params.keys()) == true_initialisation[lhood.name] - assert len(list(params.values())) == len(true_initialisation[lhood.name]) +def test_initialisers(n: int, lik: AbstractLikelihood) -> None: + key = _initialise_key + + # Initialise the likelihood. + likelihood = lik(num_datapoints=n) + + # Get default parameter dictionary. + params = likelihood._initialise_params(key) + + # Check parameter dictionary + assert list(params.keys()) == true_initialisation[likelihood.name] + assert len(list(params.values())) == len(true_initialisation[likelihood.name]) @pytest.mark.parametrize("n", [1, 10]) -def test_predictive_moment(n): - lhood = Bernoulli(num_datapoints=n) - key = jr.PRNGKey(123) - fmean = jr.uniform(key=key, shape=(n,)) * -1 - fvar = jr.uniform(key=key, shape=(n,)) - pred_mom_fn = lhood.predictive_moment_fn - params, _, _ = initialise(lhood, key).unpack() - rv = pred_mom_fn(fmean, fvar, params) - mu = rv.mean() - sigma = rv.variance() - assert isinstance(lhood.predictive_moment_fn, tp.Callable) - assert mu.shape == (n,) - assert sigma.shape == (n,) +def test_bernoulli_predictive_moment(n: int) -> None: + key = _initialise_key + + # Initialise bernoulli likelihood. + likelihood = Bernoulli(num_datapoints=n) + + # Initialise parameters. + params = likelihood._initialise_params(key) + + # Construct latent function mean and variance values + mean_key, var_key = jr.split(key) + fmean = jr.uniform(mean_key, shape=(n, 1)) + fvar = jnp.exp(jr.normal(var_key, shape=(n, 1))) + + # Test predictive moments. + assert isinstance(likelihood.predictive_moment_fn, Callable) + + y = likelihood.predictive_moment_fn(params, fmean, fvar) + y_mean = y.mean() + y_var = y.variance() + + assert y_mean.shape == (n, 1) + assert y_var.shape == (n, 1) @pytest.mark.parametrize("lik", [Gaussian, Bernoulli]) @pytest.mark.parametrize("n", [1, 10]) -def test_link_fns(lik: AbstractLikelihood, n: int): - key = jr.PRNGKey(123) - lhood = lik(num_datapoints=n) - params, _, _ = initialise(lhood, key).unpack() - link_fn = lhood.link_function - assert isinstance(link_fn, tp.Callable) +def test_link_fns(lik: AbstractLikelihood, n: int) -> None: + key = _initialise_key + + # Create test inputs. x = jnp.linspace(-3.0, 3.0).reshape(-1, 1) - l_eval = link_fn(x, params) - assert isinstance(l_eval, dx.Distribution) + # Initialise likelihood. + likelihood = lik(num_datapoints=n) + + # Initialise parameters. + params = likelihood._initialise_params(key) + + # Test likelihood link function. + assert isinstance(likelihood.link_function, Callable) + assert isinstance(likelihood.link_function(params, x), dx.Distribution) @pytest.mark.parametrize("noise", [0.1, 0.5, 1.0]) @pytest.mark.parametrize("n", [1, 2, 10]) -def test_call_gaussian(noise, n): - key = jr.PRNGKey(123) - lhood = Gaussian(num_datapoints=n) - dist = dx.MultivariateNormalFullCovariance(jnp.zeros(n), jnp.eye(n)) +def test_call_gaussian(noise: float, n: int) -> None: + key = _initialise_key + + # Initialise likelihood and parameters. + likelihood = Gaussian(num_datapoints=n) params = {"likelihood": {"obs_noise": noise}} + + # Construct latent function distribution. + latent_mean = jr.uniform(key, shape=(n,)) + latent_sqrt = jr.uniform(key, shape=(n, n)) + latent_cov = jnp.matmul(latent_sqrt, latent_sqrt.T) + latent_dist = dx.MultivariateNormalFullCovariance(latent_mean, latent_cov) + + # Test call method. + pred_dist = likelihood(params, latent_dist) + + # Check that the distribution is a MultivariateNormalFullCovariance. + assert isinstance(pred_dist, dx.MultivariateNormalFullCovariance) - l_dist = lhood(dist, params) - assert (l_dist.mean() == jnp.zeros(n)).all() - noise_mat = jnp.diag(jnp.repeat(noise, n)) - assert np.allclose(l_dist.scale_tri, jnp.linalg.cholesky(jnp.eye(n) + noise_mat)) - l_dist = lhood.predict(dist, params) - assert (l_dist.mean() == jnp.zeros(n)).all() - noise_mat = jnp.diag(jnp.repeat(noise, n)) - assert np.allclose(l_dist.scale_tri, jnp.linalg.cholesky(jnp.eye(n) + noise_mat)) + # Check predictive mean and variance. + assert (pred_dist.mean() == latent_mean).all() + noise_matrix = jnp.eye(n) * noise + assert np.allclose(pred_dist.scale_tri, jnp.linalg.cholesky(latent_cov + noise_matrix)) -def test_call_bernoulli(): - n = 10 - lhood = Bernoulli(num_datapoints=n) - dist = dx.MultivariateNormalFullCovariance(jnp.zeros(n), jnp.eye(n)) + +@pytest.mark.parametrize("n", [1, 2, 10]) +def test_call_bernoulli(n: int) -> None: + key = _initialise_key + + # Initialise likelihood and parameters. + likelihood = Bernoulli(num_datapoints=n) params = {"likelihood": {}} - l_dist = lhood(dist, params) - assert (l_dist.mean() == 0.5 * jnp.ones(n)).all() - assert (l_dist.variance() == 0.25 * jnp.ones(n)).all() + # Construct latent function distribution. + latent_mean = jr.uniform(key, shape=(n,)) + latent_sqrt = jr.uniform(key, shape=(n, n)) + latent_cov = jnp.matmul(latent_sqrt, latent_sqrt.T) + latent_dist = dx.MultivariateNormalFullCovariance(latent_mean, latent_cov) + + # Test call method. + pred_dist = likelihood(params, latent_dist) - l_dist = lhood.predict(dist, params) - assert (l_dist.mean() == 0.5 * jnp.ones(n)).all() - assert (l_dist.variance() == 0.25 * jnp.ones(n)).all() + # Check that the distribution is a Bernoulli. + assert isinstance(pred_dist, dx.Bernoulli) + + # Check predictive mean and variance. + + p = inv_probit(latent_mean / jnp.sqrt(1.0 + jnp.diagonal(latent_cov))) + assert (pred_dist.mean() == p ).all() + assert (pred_dist.variance() == p * (1.0 - p)).all() @pytest.mark.parametrize("lik", [Gaussian, Bernoulli]) -def test_conjugacy(lik): - likelihood = lik(num_datapoints=10) +@pytest.mark.parametrize("n", [1, 2, 10]) +def test_conjugacy(lik: AbstractLikelihood, n:int) -> None: + likelihood = lik(num_datapoints=n) + + # Gaussian likelihood is conjugate. if isinstance(likelihood, Gaussian): assert isinstance(likelihood, Conjugate) + + # Bernoulli likelihood is non-conjugate. elif isinstance(likelihood, Bernoulli): assert isinstance(likelihood, NonConjugate) diff --git a/tests/test_mean_functions.py b/tests/test_mean_functions.py index 7d9e1bbc8..051e3ec00 100644 --- a/tests/test_mean_functions.py +++ b/tests/test_mean_functions.py @@ -13,36 +13,57 @@ # limitations under the License. # ============================================================================== -import typing as tp +from typing import Dict import jax.numpy as jnp import jax.random as jr import pytest from jax.config import config +from jaxtyping import Array, Float -from gpjax.mean_functions import Constant, Zero -from gpjax.parameters import initialise +from gpjax.mean_functions import AbstractMeanFunction, Constant, Zero +from gpjax.types import PRNGKeyType # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) +_initialise_key = jr.PRNGKey(123) -@pytest.mark.parametrize("meanf", [Zero, Constant]) +def test_abstract_mean_function() -> None: + # Test that the abstract mean function cannot be instantiated. + with pytest.raises(TypeError): + AbstractMeanFunction() + + # Create a dummy mean funcion class with abstract methods implemented. + class DummyMeanFunction(AbstractMeanFunction): + def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N 1"]: + return jnp.ones((x.shape[0], 1)) + + def _initialise_params(self, key: PRNGKeyType) -> Dict: + return {} + + # Test that the dummy mean function can be instantiated. + dummy_mean_function = DummyMeanFunction() + assert isinstance(dummy_mean_function, AbstractMeanFunction) + + +@pytest.mark.parametrize("mean_function", [Zero, Constant]) @pytest.mark.parametrize("dim", [1, 2, 5]) -def test_shape(meanf, dim): - key = jr.PRNGKey(123) - meanf = meanf(output_dim=dim) - x = jnp.linspace(-1.0, 1.0, num=10).reshape(-1, 1) - if dim > 1: - x = jnp.hstack([x] * dim) - params, _, _ = initialise(meanf, key).unpack() - mu = meanf(x, params) - assert mu.shape[0] == x.shape[0] - assert mu.shape[1] == dim +@pytest.mark.parametrize("n", [1, 2]) +def test_shape(mean_function: AbstractMeanFunction, n:int, dim: int) -> None: + key = _initialise_key + + # Create test inputs. + x = jnp.linspace(-1.0, 1.0, num=n * dim).reshape(n, dim) + # Initialise mean function. + mf = mean_function(output_dim=dim) -@pytest.mark.parametrize("meanf", [Zero, Constant]) -def test_initialisers(meanf): - key = jr.PRNGKey(123) - params, _, _ = initialise(meanf(), key).unpack() - assert isinstance(params, tp.Dict) + # Initialise parameters. + params = mf._initialise_params(key) + assert isinstance(params, dict) + + # Test shape of mean function. + mu = mf(params, x) + assert mu.shape[0] == x.shape[0] + assert mu.shape[1] == dim \ No newline at end of file