Skip to content

Commit

Permalink
Move params to the first slot of each function, class, etc.
Browse files Browse the repository at this point in the history
Move params to the first slot of each function, class, etc.
  • Loading branch information
daniel-dodd committed Nov 7, 2022
1 parent 8a2cb02 commit 0ff8847
Show file tree
Hide file tree
Showing 10 changed files with 577 additions and 290 deletions.
79 changes: 56 additions & 23 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormalTri:
t = test_inputs
n_test = test_inputs.shape[0]

μt = mean_function(t, params["mean_function"])
Ktt = gram(kernel, t, params["kernel"])
μt = mean_function(params["mean_function"], t)
Ktt = gram(kernel, params["kernel"], t)
Ktt += I(n_test) * jitter
Lt = Ktt.triangular_lower()

Expand Down Expand Up @@ -339,8 +339,8 @@ class ConjugatePosterior(AbstractPosterior):
name: Optional[str] = "Conjugate posterior"

def predict(
self, train_data: Dataset, params: Dict
) -> Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]:
self, params: Dict, train_data: Dataset,
) -> Callable[[Float[Array, "N D"]], dx.MultivariateNormalFullCovariance]:
"""Conditional on a training data set, compute the GP's posterior
predictive distribution for a given set of parameters. The returned
function can be evaluated at a set of test inputs to compute the
Expand Down Expand Up @@ -372,15 +372,15 @@ def predict(
>>> xtest = jnp.linspace(0, 1).reshape(-1, 1)
>>>
>>> params = gpx.initialise(posterior)
>>> predictive_dist = posterior.predict(gpx.Dataset(X=xtrain, y=ytrain), params)
>>> predictive_dist = posterior.predict(params, gpx.Dataset(X=xtrain, y=ytrain))
>>> predictive_dist(xtest)
Args:
train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset.
params (Dict): A dictionary of parameters that should be used to compute the posterior.
train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset.
Returns:
Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: A function that accepts an input array and returns the predictive distribution as a `dx.MultivariateNormalTri`.
Callable[[Float[Array, "N D"]], dx.MultivariateNormalFullCovariance]: A function that accepts an input array and returns the predictive distribution as a `dx.MultivariateNormalTri`.
"""
jitter = get_defaults()["jitter"]

Expand All @@ -397,24 +397,32 @@ def predict(

# Observation noise σ²
obs_noise = params["likelihood"]["obs_noise"]
μx = mean_function(x, params["mean_function"])
μx = mean_function(params["mean_function"], x)

# Precompute Gram matrix, Kxx, at training inputs, x
Kxx = gram(kernel, x, params["kernel"])
Kxx = gram(kernel, params["kernel"], x)
Kxx += I(n) * jitter

# Σ = Kxx + Iσ²
Sigma = Kxx + I(n) * obs_noise

def predict(test_inputs: Float[Array, "N D"]) -> dx.Distribution:
"""Compute the predictive distribution at a set of test inputs.
Args:
test_inputs (Float[Array, "N D"]): A Jax array of test inputs.
Returns:
dx.Distribution: A `dx.MultivariateNormalFullCovariance` object that represents the predictive distribution.
"""

# Unpack test inputs
t = test_inputs
n_test = test_inputs.shape[0]

μt = mean_function(t, params["mean_function"])
Ktt = gram(kernel, t, params["kernel"])
Kxt = cross_covariance(kernel, x, t, params["kernel"])
μt = mean_function(params["mean_function"], t)
Ktt = gram(kernel, params["kernel"], t)
Kxt = cross_covariance(kernel, params["kernel"], x, t)

# TODO: Investigate lower triangular solves for general covariance operators
# this is more efficient than the full solve for dense matrices in the current implimentation.
Expand Down Expand Up @@ -514,15 +522,24 @@ def marginal_log_likelihood(
def mll(
params: Dict,
):
"""Compute the marginal log-likelihood of the Gaussian process.
Args:
params (Dict): The model's parameters.
Returns:
Float[Array, "1"]: The marginal log-likelihood.
"""

# Observation noise σ²
obs_noise = params["likelihood"]["obs_noise"]
μx = mean_function(x, params["mean_function"])
μx = mean_function(params["mean_function"], x)

# TODO: This implementation does not take advantage of the covariance operator structure.
# Future work concerns implementation of a custom Gaussian distribution / measure object that accepts a covariance operator.

# Σ = (Kxx + Iσ²) = LLᵀ
Kxx = gram(kernel, x, params["kernel"])
Kxx = gram(kernel, params["kernel"], x)
Kxx += I(n) * jitter
Sigma = Kxx + I(n) * obs_noise
L = Sigma.triangular_lower()
Expand Down Expand Up @@ -583,7 +600,7 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict:
return parameters

def predict(
self, train_data: Dataset, params: Dict
self, params: Dict, train_data: Dataset,
) -> Callable[[Float[Array, "N D"]], dx.Distribution]:
"""
Conditional on a set of training data, compute the GP's posterior
Expand All @@ -594,8 +611,8 @@ def predict(
transformed through the likelihood function's inverse link function.
Args:
train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset.
params (Dict): A dictionary of parameters that should be used to compute the posterior.
train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset.
Returns:
tp.Callable[[Array], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `dx.Distribution`.
Expand All @@ -614,19 +631,27 @@ def predict(
cross_covariance = kernel.cross_covariance

# Precompute lower triangular of Gram matrix, Lx, at training inputs, x
Kxx = gram(kernel, x, params["kernel"])
Kxx = gram(kernel, params["kernel"], x)
Kxx += I(n) * jitter
Lx = Kxx.triangular_lower()

def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution:
"""Predictive distribution of the latent function for a given set of test inputs.
Args:
test_inputs (Float[Array, "N D"]): A set of test inputs.
Returns:
dx.Distribution: The predictive distribution of the latent function.
"""

# Unpack test inputs
t, n_test = test_inputs, test_inputs.shape[0]

# Compute terms of the posterior predictive distribution
Ktx = cross_covariance(kernel, t, x, params["kernel"])
Ktt = gram(kernel, t, params["kernel"]) + I(n_test) * jitter
μt = mean_function(t, params["mean_function"])
Ktx = cross_covariance(kernel, params["kernel"], t, x)
Ktt = gram(kernel, params["kernel"], t) + I(n_test) * jitter
μt = mean_function(params["mean_function"], t)

# Lx⁻¹ Kxt
Lx_inv_Kxt = jsp.linalg.solve_triangular(Lx, Ktx.T, lower=True)
Expand Down Expand Up @@ -689,14 +714,22 @@ def marginal_log_likelihood(
constant = jnp.array(-1.0) if negative else jnp.array(1.0)

def mll(params: Dict):
"""Compute the marginal log-likelihood of the model.
Args:
params (Dict): A dictionary of parameters that should be used to compute the marginal log-likelihood.
Returns:
Float[Array, "1"]: The marginal log-likelihood of the model.
"""

# Compute lower triangular of the kernel Gram matrix
Kxx = gram(kernel, x, params["kernel"])
Kxx = gram(kernel, params["kernel"], x)
Kxx += I(n) * jitter
Lx = Kxx.triangular_lower()

# Compute the prior mean function
μx = mean_function(x, params["mean_function"])
μx = mean_function(params["mean_function"], x)

# Whitened function values, wx, correponding to the inputs, x
wx = params["latent"]
Expand All @@ -705,7 +738,7 @@ def mll(params: Dict):
fx = μx + jnp.matmul(Lx, wx)

# p(y | f(x), θ), where θ are the model hyperparameters
likelihood = link_function(fx, params)
likelihood = link_function(params, fx)

# log p(θ)
log_prior_density = evaluate_priors(params, priors)
Expand Down

0 comments on commit 0ff8847

Please sign in to comment.