Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tidy up RFF and extra tests/formatting #243

Merged
merged 1 commit into from
Apr 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ repos:
rev: 23.3.0
hooks:
- id: black
language_version: python3.8
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
Expand Down
7 changes: 6 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,14 @@ def find_version(*file_paths):
copyright = "2021, Thomas Pinder"
author = "Thomas Pinder"

from os.path import (
dirname,
join,
pardir,
)

# The full version, including alpha/beta/rc tags
import sys
from os.path import dirname, join, pardir

sys.path.insert(0, join(dirname(__file__), pardir))

Expand Down
8 changes: 7 additions & 1 deletion docs/conf_sphinx_patch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# This file is credited to the Flax authors.

from typing import Any, Dict, List, Set, Tuple
from typing import (
Any,
Dict,
List,
Set,
Tuple,
)

import sphinx.ext.autodoc
import sphinx.ext.autosummary.generate as ag
Expand Down
2 changes: 1 addition & 1 deletion gpjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from gpjax.likelihoods import (
Bernoulli,
Gaussian,
Poisson
Poisson,
)
from gpjax.mean_functions import (
Constant,
Expand Down
67 changes: 45 additions & 22 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from beartype.typing import (
Any,
Callable,
Dict,
Optional,
)
import jax.numpy as jnp
Expand Down Expand Up @@ -269,17 +268,18 @@ def sample_approx(
FunctionalSample: A function representing an approximate sample from the Gaussian
process prior.
"""
if (not isinstance(num_features, int)) or num_features <= 0:
raise ValueError("num_features must be a positive integer")

if (not isinstance(num_samples, int)) or num_samples <= 0:
raise ValueError("num_samples must be a positive integer")

approximate_kernel = RFF(base_kernel=self.kernel, num_basis_fns=num_features)
# sample fourier features
fourier_feature_fn = _build_fourier_features_fn(self, num_features, key)

# sample fourier weights
feature_weights = normal(key, [num_samples, 2 * num_features]) # [B, L]

def sample_fn(test_inputs: Float[Array, "N D"]) -> Float[Array, "N B"]:
feature_evals = approximate_kernel.compute_features(x=test_inputs)
feature_evals *= jnp.sqrt(self.kernel.variance / num_features)
feature_evals = fourier_feature_fn(test_inputs) # [N, L]
evaluated_sample = jnp.inner(feature_evals, feature_weights) # [N, B]
return self.mean_function(test_inputs) + evaluated_sample

Expand Down Expand Up @@ -501,24 +501,13 @@ def sample_approx(
FunctionalSample: A function representing an approximate sample from the Gaussian
process prior.
"""
if (not isinstance(num_features, int)) or num_features <= 0:
raise ValueError("num_features must be a positive integer")
if (not isinstance(num_samples, int)) or num_samples <= 0:
raise ValueError("num_samples must be a positive integer")

# Approximate kernel with feature decomposition
approximate_kernel = RFF(
base_kernel=self.prior.kernel, num_basis_fns=num_features
)

def eval_fourier_features(
test_inputs: Float[Array, "N D"]
) -> Float[Array, "N L"]:
Phi = approximate_kernel.compute_features(x=test_inputs)
Phi *= jnp.sqrt(self.prior.kernel.variance / num_features)
return Phi
# sample fourier features
fourier_feature_fn = _build_fourier_features_fn(self.prior, num_features, key)

# sample weights for Fourier features
# sample fourier weights
fourier_weights = normal(key, [num_samples, 2 * num_features]) # [B, L]

# sample weights v for canonical features
Expand All @@ -531,13 +520,13 @@ def eval_fourier_features(
key, [train_data.n, num_samples]
) # [N, B]
y = train_data.y - self.prior.mean_function(train_data.X) # account for mean
Phi = eval_fourier_features(train_data.X)
Phi = fourier_feature_fn(train_data.X)
canonical_weights = Sigma.solve(
y + eps - jnp.inner(Phi, fourier_weights)
) # [N, B]

def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:
fourier_features = eval_fourier_features(test_inputs)
fourier_features = fourier_feature_fn(test_inputs) # [n, L]
weight_space_contribution = jnp.inner(
fourier_features, fourier_weights
) # [n, B]
Expand Down Expand Up @@ -634,6 +623,9 @@ def predict(
return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)


#######################
# Utils
#######################
def construct_posterior(
prior: Prior, likelihood: AbstractLikelihood
) -> AbstractPosterior:
Expand All @@ -658,6 +650,37 @@ def construct_posterior(
return NonConjugatePosterior(prior=prior, likelihood=likelihood)


def _build_fourier_features_fn(
prior: Prior, num_features: int, key: KeyArray
) -> Callable[[Float[Array, "N D"]], Float[Array, "N L"]]:
"""Return a function that evaluates features sampled from the Fourier feature
decomposition of the prior's kernel.

Args:
prior (Prior): The Prior distribution.
num_features (int): The number of feature functions to be sampled.
key (KeyArray): The random seed used.

Returns
-------
Callable: A callable function evaluation the sampled feature functions.
"""
if (not isinstance(num_features, int)) or num_features <= 0:
raise ValueError("num_features must be a positive integer")

# Approximate kernel with feature decomposition
approximate_kernel = RFF(
base_kernel=prior.kernel, num_basis_fns=num_features, key=key
)

def eval_fourier_features(test_inputs: Float[Array, "N D"]) -> Float[Array, "N L"]:
Phi = approximate_kernel.compute_features(x=test_inputs)
Phi *= jnp.sqrt(prior.kernel.variance / num_features)
return Phi

return eval_fourier_features


__all__ = [
"AbstractPrior",
"Prior",
Expand Down
27 changes: 10 additions & 17 deletions tests/test_gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
ValidationErrors = ValueError

from dataclasses import is_dataclass
import shutil
from typing import Callable

from jax.config import config
Expand All @@ -31,11 +30,6 @@
import pytest
import tensorflow_probability.substrates.jax.distributions as tfd

from gpjax.base import (
load_tree,
save_tree,
)

# from gpjax.dataset import Dataset
from gpjax.dataset import Dataset
from gpjax.gaussian_distribution import GaussianDistribution
Expand All @@ -50,15 +44,13 @@
from gpjax.kernels import (
RBF,
AbstractKernel,
Matern12,
Matern32,
Matern52,
)
from gpjax.likelihoods import (
AbstractLikelihood,
Bernoulli,
Gaussian,
Poisson
Poisson,
)
from gpjax.mean_functions import (
AbstractMeanFunction,
Expand Down Expand Up @@ -272,7 +264,7 @@ def test_posterior_construct(
@pytest.mark.parametrize("kernel", [RBF, Matern52])
@pytest.mark.parametrize("mean_function", [Zero(), Constant()])
def test_prior_sample_approx(num_datapoints, kernel, mean_function):
kern = kernel(lengthscale=5.0, variance=0.1)
kern = kernel(lengthscale=jnp.array([5.0, 1.0]), variance=0.1)
p = Prior(kernel=kern, mean_function=mean_function)
key = jr.PRNGKey(123)

Expand All @@ -292,7 +284,7 @@ def test_prior_sample_approx(num_datapoints, kernel, mean_function):
sampled_fn = p.sample_approx(1, key, 100)
assert isinstance(sampled_fn, Callable) # check type

x = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1)
x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 2))
evals = sampled_fn(x)
assert evals.shape == (num_datapoints, 1.0) # check shape

Expand Down Expand Up @@ -325,16 +317,17 @@ def test_prior_sample_approx(num_datapoints, kernel, mean_function):
@pytest.mark.parametrize("kernel", [RBF, Matern52])
@pytest.mark.parametrize("mean_function", [Zero(), Constant()])
def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function):
kern = kernel(lengthscale=5.0, variance=0.1)
kern = kernel(lengthscale=jnp.array([5.0, 1.0]), variance=0.1)
p = Prior(kernel=kern, mean_function=mean_function) * Gaussian(
num_datapoints=num_datapoints
)
key = jr.PRNGKey(123)
x = jnp.sort(
jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 1)),
axis=0,

x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 2))
y = (
jnp.mean(jnp.sin(x), 1, keepdims=True)
+ jr.normal(key=key, shape=(num_datapoints, 1)) * 0.1
)
y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1
D = Dataset(X=x, y=y)

with pytest.raises(ValueError):
Expand All @@ -353,7 +346,7 @@ def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function
sampled_fn = p.sample_approx(1, D, key, 100)
assert isinstance(sampled_fn, Callable) # check type

x = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1)
x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 2))
evals = sampled_fn(x)
assert evals.shape == (num_datapoints, 1.0) # check shape

Expand Down
1 change: 0 additions & 1 deletion tests/test_likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ class TestPoisson(BaseTestLikelihood):
def _test_call_check(
likelihood: AbstractLikelihood, latent_mean, latent_cov, latent_dist
):

# Test call method.
pred_dist = likelihood(latent_dist)

Expand Down