Skip to content

Commit

Permalink
Merge pull request #139 from thomaspinder/params_to_first_slot
Browse files Browse the repository at this point in the history
Move params to the first slot of each function, class, etc.
  • Loading branch information
thomaspinder committed Nov 7, 2022
2 parents 8a2cb02 + afa13cf commit 97683d1
Show file tree
Hide file tree
Showing 23 changed files with 685 additions and 365 deletions.
8 changes: 4 additions & 4 deletions README.md
Expand Up @@ -89,7 +89,7 @@ posterior = prior * likelihood
Equipped with the posterior, we seek to learn the model's hyperparameters through gradient-optimisation of the marginal log-likelihood. We this below, adding Jax's [just-in-time (JIT)](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) compilation to accelerate training.

```python
mll = jit(posterior.marginal_log_likelihood(training, negative=True))
mll = jit(posterior.marginal_log_likelihood(D, negative=True))
```

For purposes of optimisation, we'll use optax's Adam.
Expand Down Expand Up @@ -117,11 +117,11 @@ Using our learned hyperparameters, we can obtain the posterior distribution of t
learned_params, _ = inference_state.unpack()
xtest = jnp.linspace(-3., 3., 100).reshape(-1, 1)

latent_distribution = posterior(training, learned_params)(xtest)
predictive_distribution = likelihood(latent_distribution, params)
latent_distribution = posterior(learned_params, D)(xtest)
predictive_distribution = likelihood(params, latent_distribution)

predictive_mean = predictive_distribution.mean
predictive_cov = predictive_distribution.covariance_matrix
predictive_cov = predictive_distribution.covariance
```

# Installation
Expand Down
2 changes: 1 addition & 1 deletion examples/barycentres.pct.py
Expand Up @@ -117,7 +117,7 @@ def fit_gp(x: jnp.DeviceArray, y: jnp.DeviceArray) -> dx.MultivariateNormalTri:
)

learned_params, training_history = inference_state.unpack()
return likelihood(posterior(D, learned_params)(xtest), learned_params)
return likelihood(learned_params, posterior(learned_params, D)(xtest))


posterior_preds = [fit_gp(x, i) for i in ys]
Expand Down
22 changes: 11 additions & 11 deletions examples/classification.pct.py
Expand Up @@ -9,7 +9,7 @@
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3.9.7 ('gpjax')
# display_name: base
# language: python
# name: python3
# ---
Expand Down Expand Up @@ -98,9 +98,9 @@
# From which we can make predictions at novel inputs, as illustrated below.

# %%
map_latent_dist = posterior(D, map_estimate)(xtest)
map_latent_dist = posterior(map_estimate, D)(xtest)

predictive_dist = likelihood(map_latent_dist, map_estimate)
predictive_dist = likelihood(map_estimate, map_latent_dist)

predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()
Expand Down Expand Up @@ -159,11 +159,11 @@
# that we identify as a Gaussian distribution, $p(\boldsymbol{f}| \mathcal{D}) \approx q(\boldsymbol{f}) := \mathcal{N}(\hat{\boldsymbol{f}}, [-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} ]^{-1} )$. Since the negative Hessian is positive definite, we can use the Cholesky decomposition to obtain the covariance matrix of the Laplace approximation at the datapoints below.

# %%
gram, cross_covariance = (prior.kernel.gram, prior.kernel.cross_covariance)
gram, cross_covariance = (kernel.gram, kernel.cross_covariance)
jitter = 1e-6

# Compute (latent) function value map estimates at training points:
Kxx = gram(prior.kernel, x, map_estimate["kernel"])
Kxx = gram(kernel, map_estimate["kernel"], x)
Kxx += I(D.n) * jitter
Lx = Kxx.triangular_lower()
f_hat = jnp.matmul(Lx, map_estimate["latent"])
Expand Down Expand Up @@ -193,10 +193,10 @@
# %%
def construct_laplace(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormalTri:

map_latent_dist = posterior(D, map_estimate)(test_inputs)
map_latent_dist = posterior(map_estimate, D)(test_inputs)

Kxt = cross_covariance(prior.kernel, x, test_inputs, map_estimate["kernel"])
Kxx = gram(prior.kernel, x, map_estimate["kernel"])
Kxt = cross_covariance(kernel, map_estimate["kernel"], x, test_inputs)
Kxx = gram(kernel, map_estimate["kernel"], x)
Kxx += I(D.n) * jitter
Lx = Kxx.triangular_lower()

Expand All @@ -219,7 +219,7 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormal
# From this we can construct the predictive distribution at the test points.
# %%
laplace_latent_dist = construct_laplace(xtest)
predictive_dist = likelihood(laplace_latent_dist, map_estimate)
predictive_dist = likelihood(map_estimate, laplace_latent_dist)

predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()
Expand Down Expand Up @@ -338,8 +338,8 @@ def one_step(state, rng_key):
ps["latent"] = states.position["latent"][i, :, :]
ps = gpx.constrain(ps, bijectors)

latent_dist = posterior(D, ps)(xtest)
predictive_dist = likelihood(latent_dist, ps)
latent_dist = posterior(ps, D)(xtest)
predictive_dist = likelihood(ps, latent_dist)
samples.append(predictive_dist.sample(seed=key, sample_shape=(10,)))

samples = jnp.vstack(samples)
Expand Down
8 changes: 4 additions & 4 deletions examples/collapsed_vi.pct.py
Expand Up @@ -8,7 +8,7 @@
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3.9.7 ('gpjax')
# display_name: base
# language: python
# name: python3
# ---
Expand Down Expand Up @@ -101,7 +101,7 @@

negative_elbo = jit(sgpr.elbo(D, negative=True))

optimiser = ox.adam(learning_rate=0.005)
optimiser = ox.adam(learning_rate=5e-3)

inference_state = gpx.fit(
objective=negative_elbo,
Expand All @@ -116,8 +116,8 @@
# We show predictions of our model with the learned inducing points overlayed in grey.

# %%
latent_dist = q.predict(D, learned_params)(xtest)
predictive_dist = likelihood(latent_dist, learned_params)
latent_dist = q(learned_params, D)(xtest)
predictive_dist = likelihood(learned_params, latent_dist)

samples = latent_dist.sample(seed=key, sample_shape=(20,))

Expand Down
6 changes: 3 additions & 3 deletions examples/graph_kernels.pct.py
Expand Up @@ -9,7 +9,7 @@
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3.9.7 ('gpjax')
# display_name: base
# language: python
# name: python3
# ---
Expand Down Expand Up @@ -143,8 +143,8 @@

# %%
initial_params = parameter_state.params
initial_dist = likelihood(posterior(D, initial_params)(x), initial_params)
predictive_dist = likelihood(posterior(D, learned_params)(x), learned_params)
initial_dist = likelihood(initial_params, posterior(initial_params, D)(x))
predictive_dist = likelihood(learned_params, posterior(learned_params, D)(x))

initial_mean = initial_dist.mean()
learned_mean = predictive_dist.mean()
Expand Down
20 changes: 12 additions & 8 deletions examples/haiku.pct.py
Expand Up @@ -9,7 +9,7 @@
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3.9.7 ('gpjax')
# display_name: base
# language: python
# name: python3
# ---
Expand All @@ -31,9 +31,13 @@
from chex import dataclass
from jax.config import config
from scipy.signal import sawtooth
from jaxtyping import Float, Array
from typing import Dict


import gpjax as gpx
from gpjax.kernels import DenseKernelComputation, AbstractKernel
from gpjax.types import PRNGKeyType

# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
Expand Down Expand Up @@ -84,18 +88,18 @@ class _DeepKernelFunction:
@dataclass
class DeepKernelFunction(AbstractKernel, DenseKernelComputation, _DeepKernelFunction):
def __call__(
self, x: jnp.DeviceArray, y: jnp.DeviceArray, params: dict
) -> jnp.ndarray:
self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"],
) -> Float[Array, "1"]:
xt = self.network.apply(params=params, x=x)
yt = self.network.apply(params=params, x=y)
return self.base_kernel(xt, yt, params=params)
return self.base_kernel(params, xt, yt)

def initialise(self, dummy_x, key):
def initialise(self, dummy_x: Float[Array, "1 D"], key: PRNGKeyType) -> None:
nn_params = self.network.init(rng=key, x=dummy_x)
base_kernel_params = self.base_kernel._initialise_params(key)
self._params = {**nn_params, **base_kernel_params}

def _initialise_params(self, key):
def _initialise_params(self, key: PRNGKeyType) -> Dict:
return self._params


Expand Down Expand Up @@ -181,8 +185,8 @@ def forward(x):
# With a set of learned parameters, the only remaining task is to predict the output of the model. We can do this by simply applying the model to a test data set.

# %%
latent_dist = posterior(D, learned_params)(xtest)
predictive_dist = likelihood(latent_dist, learned_params)
latent_dist = posterior(learned_params, D)(xtest)
predictive_dist = likelihood(learned_params, latent_dist)

predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()
Expand Down
35 changes: 22 additions & 13 deletions examples/kernels.pct.py
Expand Up @@ -9,7 +9,7 @@
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3.9.7 ('gpjax')
# display_name: base
# language: python
# name: python3
# ---
Expand All @@ -30,6 +30,7 @@
from jax.config import config
from jaxtyping import Array, Float
from optax import adam
from typing import Dict

import gpjax as gpx

Expand Down Expand Up @@ -100,9 +101,17 @@
# We'll now simulate some data and evaluate the kernel on the previously selected input dimensions.

# %%
# Inputs
x_matrix = jr.normal(key, shape=(50, 5))

# Gram kernel computation
gram_fn = slice_kernel.gram
K = gram_fn(slice_kernel, x_matrix, slice_kernel._initialise_params(key))

# Default parameter dictionary
params = slice_kernel._initialise_params(key)

# Compute the Gram matrix
K = gram_fn(slice_kernel, params, x_matrix)
print(K.shape)

# %% [markdown]
Expand All @@ -119,9 +128,9 @@
sum_k = k1 + k2

fig, ax = plt.subplots(ncols=3, figsize=(20, 5))
im0 = ax[0].matshow(k1.gram(k1, x, k1._initialise_params(key)).to_dense())
im1 = ax[1].matshow(k2.gram(k2, x, k2._initialise_params(key)).to_dense())
im2 = ax[2].matshow(sum_k.gram(sum_k, x, sum_k._initialise_params(key)).to_dense())
im0 = ax[0].matshow(k1.gram(k1, k1._initialise_params(key), x).to_dense())
im1 = ax[1].matshow(k2.gram(k2, k2._initialise_params(key), x).to_dense())
im2 = ax[2].matshow(sum_k.gram(sum_k, sum_k._initialise_params(key), x).to_dense())

fig.colorbar(im0, ax=ax[0])
fig.colorbar(im1, ax=ax[1])
Expand All @@ -136,10 +145,10 @@
prod_k = k1 * k2 * k3

fig, ax = plt.subplots(ncols=4, figsize=(20, 5))
im0 = ax[0].matshow(k1.gram(k1, x, k1._initialise_params(key)).to_dense())
im1 = ax[1].matshow(k2.gram(k2, x, k2._initialise_params(key)).to_dense())
im2 = ax[2].matshow(k3.gram(k3, x, k3._initialise_params(key)).to_dense())
im3 = ax[3].matshow(prod_k.gram(prod_k, x, prod_k._initialise_params(key)).to_dense())
im0 = ax[0].matshow(k1.gram(k1, k1._initialise_params(key), x).to_dense())
im1 = ax[1].matshow(k2.gram(k2, k2._initialise_params(key), x).to_dense())
im2 = ax[2].matshow(k3.gram(k3, k3._initialise_params(key), x).to_dense())
im3 = ax[3].matshow(prod_k.gram(prod_k, prod_k._initialise_params(key), x).to_dense())

fig.colorbar(im0, ax=ax[0])
fig.colorbar(im1, ax=ax[1])
Expand Down Expand Up @@ -206,7 +215,7 @@ def __post_init__(self):
self.c = self.period / 2.0 # in [0, \pi]

def __call__(
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict
self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"]
) -> Float[Array, "1"]:
tau = params["tau"]
t = angular_distance(x, y, self.c)
Expand Down Expand Up @@ -248,15 +257,15 @@ def _initialise_params(self, key: jr.PRNGKey) -> dict:
#
#
# ### Custom Parameter Bijection

#
# The constraint on $\tau$ makes optimisation challenging with gradient descent.
# It would be much easier if we could instead parameterise $\tau$ to be on the
# real line. Fortunately, this can be taken care of with GPJax's `add parameter`
# function, only requiring us to define the parameter's name and matching
# bijection (either a Distrax of TensorFlow probability bijector). Under the
# hood, calling this function updates a configuration object to register this
# parameter and its corresponding transform.

#
# To define a bijector here we'll make use of the `Lambda` operator given in
# Distrax. This lets us convert any regular Jax function into a bijection. Given
# that we require $\tau$ to be strictly greater than $4.$, we'll apply a
Expand Down Expand Up @@ -321,7 +330,7 @@ def _initialise_params(self, key: jr.PRNGKey) -> dict:

# %%
posterior_rv = likelihood(
circlular_posterior(D, learned_params)(angles), learned_params
learned_params, circlular_posterior(learned_params, D)(angles)
)
mu = posterior_rv.mean()
one_sigma = posterior_rv.stddev()
Expand Down
6 changes: 2 additions & 4 deletions examples/natgrads.pct.py
Expand Up @@ -9,7 +9,7 @@
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3.9.7 ('gpjax')
# display_name: base
# language: python
# name: python3
# ---
Expand Down Expand Up @@ -108,7 +108,7 @@

# %%
latent_dist = natural_q(learned_params)(xtest)
predictive_dist = likelihood(latent_dist, learned_params)
predictive_dist = likelihood(learned_params, latent_dist)

meanf = predictive_dist.mean()
sigma = predictive_dist.stddev()
Expand Down Expand Up @@ -212,5 +212,3 @@
# %%
# %reload_ext watermark
# %watermark -n -u -v -iv -w -a 'Daniel Dodd'

# %%
6 changes: 3 additions & 3 deletions examples/regression.pct.py
Expand Up @@ -8,7 +8,7 @@
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3.9.7 ('gpjax')
# display_name: base
# language: python
# name: python3
# ---
Expand Down Expand Up @@ -200,8 +200,8 @@
# Equipped with the posterior and a set of optimised hyperparameter values, we are now in a position to query our GP's predictive distribution at novel test inputs. To do this, we use our defined `posterior` and `likelihood` at our test inputs to obtain the predictive distribution as a `Distrax` multivariate Gaussian upon which `mean` and `stddev` can be used to extract the predictive mean and standard deviatation.

# %%
latent_dist = posterior(D, learned_params)(xtest)
predictive_dist = likelihood(latent_dist, learned_params)
latent_dist = posterior(learned_params, D)(xtest)
predictive_dist = likelihood(learned_params, latent_dist)

predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()
Expand Down
6 changes: 3 additions & 3 deletions examples/tfp_integration.pct.py
Expand Up @@ -9,7 +9,7 @@
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3.9.7 ('gpjax')
# display_name: base
# language: python
# name: python3
# ---
Expand Down Expand Up @@ -95,7 +95,7 @@
# %% [markdown]
# ### Specifying priors
#
# We can define Gamma priors on our hyperparameters through TensorFlow Probability's `Distributions` module.
# We can define Gamma priors on our hyperparameters through TensorFlow Probability's `Distributions` module. We transform these to the unconstained space via `tfd.TransformedDistribution`.

# %%
import tensorflow_probability.substrates.jax as tfp
Expand Down Expand Up @@ -209,7 +209,7 @@ def run_chain(key, state):
xtest = jnp.linspace(-5.2, 5.2, 500).reshape(-1, 1)
learned_params = array_to_dict([jnp.mean(i) for i in constrained_sample_list])

predictive_dist = likelihood(posterior(D, learned_params)(xtest), learned_params)
predictive_dist = likelihood(learned_params, posterior(learned_params, D)(xtest))

mu = predictive_dist.mean()
sigma = predictive_dist.stddev()
Expand Down
4 changes: 2 additions & 2 deletions examples/uncollapsed_vi.pct.py
Expand Up @@ -9,7 +9,7 @@
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3.9.7 ('gpjax')
# display_name: base
# language: python
# name: python3
# ---
Expand Down Expand Up @@ -176,7 +176,7 @@

# %%
latent_dist = q(learned_params)(xtest)
predictive_dist = likelihood(latent_dist, learned_params)
predictive_dist = likelihood(learned_params, latent_dist)

meanf = predictive_dist.mean()
sigma = predictive_dist.stddev()
Expand Down

0 comments on commit 97683d1

Please sign in to comment.