Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move params to the first slot of each function, class, etc. #139

Merged
merged 2 commits into from Nov 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Expand Up @@ -89,7 +89,7 @@ posterior = prior * likelihood
Equipped with the posterior, we seek to learn the model's hyperparameters through gradient-optimisation of the marginal log-likelihood. We this below, adding Jax's [just-in-time (JIT)](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) compilation to accelerate training.

```python
mll = jit(posterior.marginal_log_likelihood(training, negative=True))
mll = jit(posterior.marginal_log_likelihood(D, negative=True))
```

For purposes of optimisation, we'll use optax's Adam.
Expand Down Expand Up @@ -117,11 +117,11 @@ Using our learned hyperparameters, we can obtain the posterior distribution of t
learned_params, _ = inference_state.unpack()
xtest = jnp.linspace(-3., 3., 100).reshape(-1, 1)

latent_distribution = posterior(training, learned_params)(xtest)
predictive_distribution = likelihood(latent_distribution, params)
latent_distribution = posterior(learned_params, D)(xtest)
predictive_distribution = likelihood(params, latent_distribution)

predictive_mean = predictive_distribution.mean
predictive_cov = predictive_distribution.covariance_matrix
predictive_cov = predictive_distribution.covariance
```

# Installation
Expand Down
2 changes: 1 addition & 1 deletion examples/barycentres.pct.py
Expand Up @@ -117,7 +117,7 @@ def fit_gp(x: jnp.DeviceArray, y: jnp.DeviceArray) -> dx.MultivariateNormalTri:
)

learned_params, training_history = inference_state.unpack()
return likelihood(posterior(D, learned_params)(xtest), learned_params)
return likelihood(learned_params, posterior(learned_params, D)(xtest))


posterior_preds = [fit_gp(x, i) for i in ys]
Expand Down
22 changes: 11 additions & 11 deletions examples/classification.pct.py
Expand Up @@ -9,7 +9,7 @@
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3.9.7 ('gpjax')
# display_name: base
# language: python
# name: python3
# ---
Expand Down Expand Up @@ -98,9 +98,9 @@
# From which we can make predictions at novel inputs, as illustrated below.

# %%
map_latent_dist = posterior(D, map_estimate)(xtest)
map_latent_dist = posterior(map_estimate, D)(xtest)

predictive_dist = likelihood(map_latent_dist, map_estimate)
predictive_dist = likelihood(map_estimate, map_latent_dist)

predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()
Expand Down Expand Up @@ -159,11 +159,11 @@
# that we identify as a Gaussian distribution, $p(\boldsymbol{f}| \mathcal{D}) \approx q(\boldsymbol{f}) := \mathcal{N}(\hat{\boldsymbol{f}}, [-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} ]^{-1} )$. Since the negative Hessian is positive definite, we can use the Cholesky decomposition to obtain the covariance matrix of the Laplace approximation at the datapoints below.

# %%
gram, cross_covariance = (prior.kernel.gram, prior.kernel.cross_covariance)
gram, cross_covariance = (kernel.gram, kernel.cross_covariance)
jitter = 1e-6

# Compute (latent) function value map estimates at training points:
Kxx = gram(prior.kernel, x, map_estimate["kernel"])
Kxx = gram(kernel, map_estimate["kernel"], x)
Kxx += I(D.n) * jitter
Lx = Kxx.triangular_lower()
f_hat = jnp.matmul(Lx, map_estimate["latent"])
Expand Down Expand Up @@ -193,10 +193,10 @@
# %%
def construct_laplace(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormalTri:

map_latent_dist = posterior(D, map_estimate)(test_inputs)
map_latent_dist = posterior(map_estimate, D)(test_inputs)

Kxt = cross_covariance(prior.kernel, x, test_inputs, map_estimate["kernel"])
Kxx = gram(prior.kernel, x, map_estimate["kernel"])
Kxt = cross_covariance(kernel, map_estimate["kernel"], x, test_inputs)
Kxx = gram(kernel, map_estimate["kernel"], x)
Kxx += I(D.n) * jitter
Lx = Kxx.triangular_lower()

Expand All @@ -219,7 +219,7 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormal
# From this we can construct the predictive distribution at the test points.
# %%
laplace_latent_dist = construct_laplace(xtest)
predictive_dist = likelihood(laplace_latent_dist, map_estimate)
predictive_dist = likelihood(map_estimate, laplace_latent_dist)

predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()
Expand Down Expand Up @@ -338,8 +338,8 @@ def one_step(state, rng_key):
ps["latent"] = states.position["latent"][i, :, :]
ps = gpx.constrain(ps, bijectors)

latent_dist = posterior(D, ps)(xtest)
predictive_dist = likelihood(latent_dist, ps)
latent_dist = posterior(ps, D)(xtest)
predictive_dist = likelihood(ps, latent_dist)
samples.append(predictive_dist.sample(seed=key, sample_shape=(10,)))

samples = jnp.vstack(samples)
Expand Down
8 changes: 4 additions & 4 deletions examples/collapsed_vi.pct.py
Expand Up @@ -8,7 +8,7 @@
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3.9.7 ('gpjax')
# display_name: base
# language: python
# name: python3
# ---
Expand Down Expand Up @@ -101,7 +101,7 @@

negative_elbo = jit(sgpr.elbo(D, negative=True))

optimiser = ox.adam(learning_rate=0.005)
optimiser = ox.adam(learning_rate=5e-3)

inference_state = gpx.fit(
objective=negative_elbo,
Expand All @@ -116,8 +116,8 @@
# We show predictions of our model with the learned inducing points overlayed in grey.

# %%
latent_dist = q.predict(D, learned_params)(xtest)
predictive_dist = likelihood(latent_dist, learned_params)
latent_dist = q(learned_params, D)(xtest)
predictive_dist = likelihood(learned_params, latent_dist)

samples = latent_dist.sample(seed=key, sample_shape=(20,))

Expand Down
6 changes: 3 additions & 3 deletions examples/graph_kernels.pct.py
Expand Up @@ -9,7 +9,7 @@
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3.9.7 ('gpjax')
# display_name: base
# language: python
# name: python3
# ---
Expand Down Expand Up @@ -143,8 +143,8 @@

# %%
initial_params = parameter_state.params
initial_dist = likelihood(posterior(D, initial_params)(x), initial_params)
predictive_dist = likelihood(posterior(D, learned_params)(x), learned_params)
initial_dist = likelihood(initial_params, posterior(initial_params, D)(x))
predictive_dist = likelihood(learned_params, posterior(learned_params, D)(x))

initial_mean = initial_dist.mean()
learned_mean = predictive_dist.mean()
Expand Down
20 changes: 12 additions & 8 deletions examples/haiku.pct.py
Expand Up @@ -9,7 +9,7 @@
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3.9.7 ('gpjax')
# display_name: base
# language: python
# name: python3
# ---
Expand All @@ -31,9 +31,13 @@
from chex import dataclass
from jax.config import config
from scipy.signal import sawtooth
from jaxtyping import Float, Array
from typing import Dict


import gpjax as gpx
from gpjax.kernels import DenseKernelComputation, AbstractKernel
from gpjax.types import PRNGKeyType

# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
Expand Down Expand Up @@ -84,18 +88,18 @@ class _DeepKernelFunction:
@dataclass
class DeepKernelFunction(AbstractKernel, DenseKernelComputation, _DeepKernelFunction):
def __call__(
self, x: jnp.DeviceArray, y: jnp.DeviceArray, params: dict
) -> jnp.ndarray:
self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"],
) -> Float[Array, "1"]:
xt = self.network.apply(params=params, x=x)
yt = self.network.apply(params=params, x=y)
return self.base_kernel(xt, yt, params=params)
return self.base_kernel(params, xt, yt)

def initialise(self, dummy_x, key):
def initialise(self, dummy_x: Float[Array, "1 D"], key: PRNGKeyType) -> None:
nn_params = self.network.init(rng=key, x=dummy_x)
base_kernel_params = self.base_kernel._initialise_params(key)
self._params = {**nn_params, **base_kernel_params}

def _initialise_params(self, key):
def _initialise_params(self, key: PRNGKeyType) -> Dict:
return self._params


Expand Down Expand Up @@ -181,8 +185,8 @@ def forward(x):
# With a set of learned parameters, the only remaining task is to predict the output of the model. We can do this by simply applying the model to a test data set.

# %%
latent_dist = posterior(D, learned_params)(xtest)
predictive_dist = likelihood(latent_dist, learned_params)
latent_dist = posterior(learned_params, D)(xtest)
predictive_dist = likelihood(learned_params, latent_dist)

predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()
Expand Down
35 changes: 22 additions & 13 deletions examples/kernels.pct.py
Expand Up @@ -9,7 +9,7 @@
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3.9.7 ('gpjax')
# display_name: base
# language: python
# name: python3
# ---
Expand All @@ -30,6 +30,7 @@
from jax.config import config
from jaxtyping import Array, Float
from optax import adam
from typing import Dict

import gpjax as gpx

Expand Down Expand Up @@ -100,9 +101,17 @@
# We'll now simulate some data and evaluate the kernel on the previously selected input dimensions.

# %%
# Inputs
x_matrix = jr.normal(key, shape=(50, 5))

# Gram kernel computation
gram_fn = slice_kernel.gram
K = gram_fn(slice_kernel, x_matrix, slice_kernel._initialise_params(key))

# Default parameter dictionary
params = slice_kernel._initialise_params(key)

# Compute the Gram matrix
K = gram_fn(slice_kernel, params, x_matrix)
print(K.shape)

# %% [markdown]
Expand All @@ -119,9 +128,9 @@
sum_k = k1 + k2

fig, ax = plt.subplots(ncols=3, figsize=(20, 5))
im0 = ax[0].matshow(k1.gram(k1, x, k1._initialise_params(key)).to_dense())
im1 = ax[1].matshow(k2.gram(k2, x, k2._initialise_params(key)).to_dense())
im2 = ax[2].matshow(sum_k.gram(sum_k, x, sum_k._initialise_params(key)).to_dense())
im0 = ax[0].matshow(k1.gram(k1, k1._initialise_params(key), x).to_dense())
im1 = ax[1].matshow(k2.gram(k2, k2._initialise_params(key), x).to_dense())
im2 = ax[2].matshow(sum_k.gram(sum_k, sum_k._initialise_params(key), x).to_dense())

fig.colorbar(im0, ax=ax[0])
fig.colorbar(im1, ax=ax[1])
Expand All @@ -136,10 +145,10 @@
prod_k = k1 * k2 * k3

fig, ax = plt.subplots(ncols=4, figsize=(20, 5))
im0 = ax[0].matshow(k1.gram(k1, x, k1._initialise_params(key)).to_dense())
im1 = ax[1].matshow(k2.gram(k2, x, k2._initialise_params(key)).to_dense())
im2 = ax[2].matshow(k3.gram(k3, x, k3._initialise_params(key)).to_dense())
im3 = ax[3].matshow(prod_k.gram(prod_k, x, prod_k._initialise_params(key)).to_dense())
im0 = ax[0].matshow(k1.gram(k1, k1._initialise_params(key), x).to_dense())
im1 = ax[1].matshow(k2.gram(k2, k2._initialise_params(key), x).to_dense())
im2 = ax[2].matshow(k3.gram(k3, k3._initialise_params(key), x).to_dense())
im3 = ax[3].matshow(prod_k.gram(prod_k, prod_k._initialise_params(key), x).to_dense())

fig.colorbar(im0, ax=ax[0])
fig.colorbar(im1, ax=ax[1])
Expand Down Expand Up @@ -206,7 +215,7 @@ def __post_init__(self):
self.c = self.period / 2.0 # in [0, \pi]

def __call__(
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict
self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"]
) -> Float[Array, "1"]:
tau = params["tau"]
t = angular_distance(x, y, self.c)
Expand Down Expand Up @@ -248,15 +257,15 @@ def _initialise_params(self, key: jr.PRNGKey) -> dict:
#
#
# ### Custom Parameter Bijection

#
# The constraint on $\tau$ makes optimisation challenging with gradient descent.
# It would be much easier if we could instead parameterise $\tau$ to be on the
# real line. Fortunately, this can be taken care of with GPJax's `add parameter`
# function, only requiring us to define the parameter's name and matching
# bijection (either a Distrax of TensorFlow probability bijector). Under the
# hood, calling this function updates a configuration object to register this
# parameter and its corresponding transform.

#
# To define a bijector here we'll make use of the `Lambda` operator given in
# Distrax. This lets us convert any regular Jax function into a bijection. Given
# that we require $\tau$ to be strictly greater than $4.$, we'll apply a
Expand Down Expand Up @@ -321,7 +330,7 @@ def _initialise_params(self, key: jr.PRNGKey) -> dict:

# %%
posterior_rv = likelihood(
circlular_posterior(D, learned_params)(angles), learned_params
learned_params, circlular_posterior(learned_params, D)(angles)
)
mu = posterior_rv.mean()
one_sigma = posterior_rv.stddev()
Expand Down
6 changes: 2 additions & 4 deletions examples/natgrads.pct.py
Expand Up @@ -9,7 +9,7 @@
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3.9.7 ('gpjax')
# display_name: base
# language: python
# name: python3
# ---
Expand Down Expand Up @@ -108,7 +108,7 @@

# %%
latent_dist = natural_q(learned_params)(xtest)
predictive_dist = likelihood(latent_dist, learned_params)
predictive_dist = likelihood(learned_params, latent_dist)

meanf = predictive_dist.mean()
sigma = predictive_dist.stddev()
Expand Down Expand Up @@ -212,5 +212,3 @@
# %%
# %reload_ext watermark
# %watermark -n -u -v -iv -w -a 'Daniel Dodd'

# %%
6 changes: 3 additions & 3 deletions examples/regression.pct.py
Expand Up @@ -8,7 +8,7 @@
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3.9.7 ('gpjax')
# display_name: base
# language: python
# name: python3
# ---
Expand Down Expand Up @@ -200,8 +200,8 @@
# Equipped with the posterior and a set of optimised hyperparameter values, we are now in a position to query our GP's predictive distribution at novel test inputs. To do this, we use our defined `posterior` and `likelihood` at our test inputs to obtain the predictive distribution as a `Distrax` multivariate Gaussian upon which `mean` and `stddev` can be used to extract the predictive mean and standard deviatation.

# %%
latent_dist = posterior(D, learned_params)(xtest)
predictive_dist = likelihood(latent_dist, learned_params)
latent_dist = posterior(learned_params, D)(xtest)
predictive_dist = likelihood(learned_params, latent_dist)

predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()
Expand Down
6 changes: 3 additions & 3 deletions examples/tfp_integration.pct.py
Expand Up @@ -9,7 +9,7 @@
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3.9.7 ('gpjax')
# display_name: base
# language: python
# name: python3
# ---
Expand Down Expand Up @@ -95,7 +95,7 @@
# %% [markdown]
# ### Specifying priors
#
# We can define Gamma priors on our hyperparameters through TensorFlow Probability's `Distributions` module.
# We can define Gamma priors on our hyperparameters through TensorFlow Probability's `Distributions` module. We transform these to the unconstained space via `tfd.TransformedDistribution`.

# %%
import tensorflow_probability.substrates.jax as tfp
Expand Down Expand Up @@ -209,7 +209,7 @@ def run_chain(key, state):
xtest = jnp.linspace(-5.2, 5.2, 500).reshape(-1, 1)
learned_params = array_to_dict([jnp.mean(i) for i in constrained_sample_list])

predictive_dist = likelihood(posterior(D, learned_params)(xtest), learned_params)
predictive_dist = likelihood(learned_params, posterior(learned_params, D)(xtest))

mu = predictive_dist.mean()
sigma = predictive_dist.stddev()
Expand Down
4 changes: 2 additions & 2 deletions examples/uncollapsed_vi.pct.py
Expand Up @@ -9,7 +9,7 @@
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3.9.7 ('gpjax')
# display_name: base
# language: python
# name: python3
# ---
Expand Down Expand Up @@ -176,7 +176,7 @@

# %%
latent_dist = q(learned_params)(xtest)
predictive_dist = likelihood(latent_dist, learned_params)
predictive_dist = likelihood(learned_params, latent_dist)

meanf = predictive_dist.mean()
sigma = predictive_dist.stddev()
Expand Down