From 75b196c3bdbda483f259d4cd56df8f8528cbc9a4 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Mon, 27 Mar 2023 15:48:24 +0100 Subject: [PATCH 01/44] Add module. --- gpjax/bijectors.py | 34 ++ gpjax/module.py | 258 ++++++++++++ gpjax/param.py | 49 +++ tests/test_bijectors.py | 22 ++ tests/test_module.py | 850 ++++++++++++++++++++++++++++++++++++++++ tests/test_param.py | 41 ++ 6 files changed, 1254 insertions(+) create mode 100644 gpjax/bijectors.py create mode 100644 gpjax/module.py create mode 100644 gpjax/param.py create mode 100644 tests/test_bijectors.py create mode 100644 tests/test_module.py create mode 100644 tests/test_param.py diff --git a/gpjax/bijectors.py b/gpjax/bijectors.py new file mode 100644 index 000000000..5a257629e --- /dev/null +++ b/gpjax/bijectors.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +__all__ = ["Bijector", "Identity", "Softplus"] + +from dataclasses import dataclass +from typing import Callable + +import jax.numpy as jnp +from simple_pytree import Pytree, static_field + + +@dataclass +class Bijector(Pytree): + """ + Create a bijector. + + Args: + forward(Callable): The forward transformation. + inverse(Callable): The inverse transformation. + + Returns: + Bijector: A bijector. + """ + + forward: Callable = static_field() + inverse: Callable = static_field() + + +Identity = Bijector(forward=lambda x: x, inverse=lambda x: x) + +Softplus = Bijector( + forward=lambda x: jnp.log(1.0 + jnp.exp(x)), + inverse=lambda x: jnp.log(jnp.exp(x) - 1.0), +) diff --git a/gpjax/module.py b/gpjax/module.py new file mode 100644 index 000000000..cd0fdd9a3 --- /dev/null +++ b/gpjax/module.py @@ -0,0 +1,258 @@ +from __future__ import annotations + +__all__ = ["Module", "meta_leaves", "meta_flatten", "meta_map", "meta"] + +import dataclasses +from copy import copy, deepcopy +from typing import Any, Callable, Dict, Iterable, Tuple + +import jax +import jax.tree_util as jtu +from jax._src.tree_util import _registry +from simple_pytree import Pytree, static_field + +from .bijectors import Bijector, Identity + + +class Module(Pytree): + _pytree__meta: Dict[str, Any] = static_field() + + def __init_subclass__(cls, mutable: bool = False): + cls._pytree__meta = dict() + super().__init_subclass__(mutable=mutable) + class_vars = vars(cls) + for field, value in class_vars.items(): + if ( + field not in cls._pytree__static_fields + and isinstance(value, dataclasses.Field) + and value.metadata is not None + ): + cls._pytree__meta[field] = {**value.metadata} + + def replace(self, **kwargs: Any) -> Module: + """ + Replace the values of the fields of the object. + + Args: + **kwargs: keyword arguments to replace the fields of the object. + + Returns: + Module: with the fields replaced. + """ + fields = vars(self) + for key in kwargs: + if key not in fields: + raise ValueError(f"'{key}' is not a field of {type(self).__name__}") + + pytree = copy(self) + pytree.__dict__.update(kwargs) + return pytree + + def replace_meta(self, **kwargs: Any) -> Module: + """ + Replace the metadata of the fields. + + Args: + **kwargs: keyword arguments to replace the metadata of the fields of the object. + + Returns: + Module: with the metadata of the fields replaced. + """ + fields = vars(self) + for key in kwargs: + if key not in fields: + raise ValueError(f"'{key}' is not a field of {type(self).__name__}") + + pytree = copy(self) + pytree.__dict__.update(_pytree__meta={**pytree._pytree__meta, **kwargs}) + return pytree + + def update_meta(self, **kwargs: Any) -> Module: + """ + Update the metadata of the fields. The metadata must already exist. + + Args: + **kwargs: keyword arguments to replace the fields of the object. + + Returns: + Module: with the fields replaced. + """ + fields = vars(self) + for key in kwargs: + if key not in fields: + raise ValueError(f"'{key}' is not a field of {type(self).__name__}") + + pytree = copy(self) + new = deepcopy(pytree._pytree__meta) + for key, value in kwargs.items(): + if key in new: + new[key].update(value) + else: + new[key] = value + pytree.__dict__.update(_pytree__meta=new) + return pytree + + def replace_trainable(self: Module, **kwargs: Dict[str, bool]) -> Module: + """Replace the trainability status of local nodes of the Module.""" + return self.update_meta(**{k: {"trainable": v} for k, v in kwargs.items()}) + + def replace_bijector(self: Module, **kwargs: Dict[str, Bijector]) -> Module: + """Replace the bijectors of local nodes of the Module.""" + return self.update_meta(**{k: {"bijector": v} for k, v in kwargs.items()}) + + def constrain(self) -> Module: + """Transform model parameters to the constrained space according to their defined bijectors. + + Returns: + Module: tranformed to the constrained space. + """ + + def _apply_constrain(meta_leaf): + meta, leaf = meta_leaf + return meta.get("bijector", Identity).forward(leaf) + + return meta_map(_apply_constrain, self) + + def unconstrain(self) -> Module: + """Transform model parameters to the unconstrained space according to their defined bijectors. + + Returns: + Module: tranformed to the unconstrained space. + """ + + def _apply_unconstrain(meta_leaf): + meta, leaf = meta_leaf + return meta.get("bijector", Identity).inverse(leaf) + + return meta_map(_apply_unconstrain, self) + + def stop_gradient(self) -> Module: + """Stop gradients flowing through the Module. + + Returns: + Module: with gradients stopped. + """ + + # 🛑 Stop gradients flowing through a given leaf if it is not trainable. + def _stop_grad(leaf: jax.Array, trainable: bool) -> jax.Array: + return jax.lax.cond(trainable, lambda x: x, jax.lax.stop_gradient, leaf) + + def _apply_stop_grad(meta_leaf): + meta, leaf = meta_leaf + return _stop_grad(leaf, meta.get("trainable", True)) + + return meta_map(_apply_stop_grad, self) + + +def _toplevel_meta(pytree: Any) -> List[Dict[str, Any]]: + """Unpacks a list of meta corresponding to the top-level nodes of the pytree. + + Args: + pytree (Any): pytree to unpack the meta from. + + Returns: + List[Dict[str, Any]]: meta of the top-level nodes of the pytree. + """ + if isinstance(pytree, Iterable): + return [None] * len(pytree) + return [ + pytree._pytree__meta.get(field, {}) + for field, _ in sorted(vars(pytree).items()) + if field not in pytree._pytree__static_fields + ] + + +def meta_leaves( + pytree: Module, + *, + is_leaf: Callable[[Any], bool] | None = None, +) -> List[Tuple[Dict[str, Any], Any]]: + """ + Returns the meta of the leaves of the pytree. + + Args: + pytree (Module): pytree to get the meta of. + is_leaf (Callable[[Any], bool]): predicate to determine if a node is a leaf. Defaults to None. + + Returns: + List[Tuple[Dict[str, Any], Any]]: meta of the leaves of the pytree. + """ + + def _unpack_metadata( + meta_leaf: Any, + pytree: Module, + is_leaf: Callable[[Any], bool] | None, + ): + """Recursively unpack leaf metadata.""" + if is_leaf and is_leaf(pytree): + yield meta_leaf + return + + if type(pytree) in _registry: # Registry tree trick, thanks to PyTreeClass! + leaves_values, _ = _registry[type(pytree)].to_iter(pytree) + leaves_meta = _toplevel_meta(pytree) + + elif pytree is not None: + yield meta_leaf + return + + for metadata, leaf in zip(leaves_meta, leaves_values): + yield from _unpack_metadata((metadata, leaf), leaf, is_leaf) + + return list(_unpack_metadata(pytree, pytree, is_leaf)) + + +def meta_flatten( + pytree: Module, *, is_leaf: Callable[[Any], bool] | None = None +) -> Module: + """ + Returns the meta of the Module. + + Args: + pytree (Module): Module to get the meta of. + is_leaf (Callable[[Any], bool]): predicate to determine if a node is a leaf. Defaults to None. + + Returns: + Module: meta of the Module. + """ + return meta_leaves(pytree, is_leaf=is_leaf), jtu.tree_structure( + pytree, is_leaf=is_leaf + ) + + +def meta_map( + f: Callable[[Any, Dict[str, Any]], Any], + pytree: Module, + *rest: Any, + is_leaf: Callable[[Any], bool] | None = None, +) -> Module: + """Apply a function to a Module where the first argument are the pytree leaves, and the second argument are the Module metadata leaves. + Args: + f (Callable[[Any, Dict[str, Any]], Any]): The function to apply to the pytree. + pytree (Module): The pytree to apply the function to. + rest (Any, optional): Additional pytrees to apply the function to. Defaults to None. + is_leaf (Callable[[Any], bool], optional): predicate to determine if a node is a leaf. Defaults to None. + + Returns: + Module: The transformed pytree. + """ + leaves, treedef = meta_flatten(pytree, is_leaf=is_leaf) + all_leaves = [leaves] + [treedef.treedef.flatten_up_to(r) for r in rest] + return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) + + +def meta(pytree: Module, *, is_leaf: Callable[[Any], bool] | None = None) -> Module: + """Returns the metadata of the Module as a pytree. + + Args: + pytree (Module): pytree to get the metadata of. + + Returns: + Module: metadata of the pytree. + """ + + def _filter_meta(meta_leaf): + meta, _ = meta_leaf + return meta + + return meta_map(_filter_meta, pytree, is_leaf=is_leaf) diff --git a/gpjax/param.py b/gpjax/param.py new file mode 100644 index 000000000..745f95488 --- /dev/null +++ b/gpjax/param.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +__all__ = ["param_field"] + +import dataclasses +from typing import Any, Mapping, Optional + +from .bijectors import Bijector, Identity + + +def param_field( + default: Any = dataclasses.MISSING, + *, + bijector: Bijector = Identity, + trainable: bool = True, + default_factory: Any = dataclasses.MISSING, + init: bool = True, + repr: bool = True, + hash: Optional[bool] = None, + compare: bool = True, + metadata: Optional[Mapping[str, Any]] = None, +): + if metadata is None: + metadata = {} + else: + metadata = dict(metadata) + + if "bijector" in metadata: + raise ValueError("Cannot use metadata with `bijector` already set.") + + if "trainable" in metadata: + raise ValueError("Cannot use metadata with `trainable` already set.") + + if "pytree_node" in metadata: + raise ValueError("Cannot use metadata with `pytree_node` already set.") + + metadata["bijector"] = bijector + metadata["trainable"] = trainable + metadata["pytree_node"] = True + + return dataclasses.field( + default=default, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + ) diff --git a/tests/test_bijectors.py b/tests/test_bijectors.py new file mode 100644 index 000000000..9e6eb5996 --- /dev/null +++ b/tests/test_bijectors.py @@ -0,0 +1,22 @@ +import jax.numpy as jnp +import pytest + +from mytree.bijectors import Bijector, Identity, Softplus + + +def test_bijector(): + bij = Bijector(forward=lambda x: jnp.exp(x), inverse=lambda x: jnp.log(x)) + assert bij.forward(1.0) == pytest.approx(jnp.exp(1.0)) + assert bij.inverse(jnp.exp(1.0)) == pytest.approx(1.0) + + +def test_identity(): + bij = Identity + assert bij.forward(1.0) == pytest.approx(1.0) + assert bij.inverse(1.0) == pytest.approx(1.0) + + +def test_softplus(): + bij = Softplus + assert bij.forward(1.0) == pytest.approx(jnp.log(1.0 + jnp.exp(1.0))) + assert bij.inverse(jnp.log(1.0 + jnp.exp(1.0))) == pytest.approx(1.0) diff --git a/tests/test_module.py b/tests/test_module.py new file mode 100644 index 000000000..41042ecc8 --- /dev/null +++ b/tests/test_module.py @@ -0,0 +1,850 @@ +import dataclasses +from dataclasses import dataclass, field +from typing import Any, Generic, Iterable, TypeVar + +import jax +import jax.numpy as jnp +import jax.tree_util as jtu +import pytest +from flax import serialization +from simple_pytree import Pytree, static_field + +from gpjax.module import Module, meta +from gpjax.bijectors import Identity, Softplus +from gpjax.param import param_field + +@pytest.mark.parametrize("is_dataclass", [True, False]) +def test_init_and_meta_scrambled(is_dataclass): + class Tree(Module): + c: float = field(metadata={"c": 4.0}) + b: float = field(metadata={"b": 5.0}) + a: float = field(metadata={"a": 6.0}) + + def __init__(self, a, b, c): + self.b = b + self.a = a + self.c = c + + if is_dataclass: + Tree = dataclass(Tree) + + # Test init + tree = Tree(1, 2, 3) + + assert isinstance(tree, Module) + assert isinstance(tree, Pytree) + + assert tree.a == 1 + assert tree.b == 2 + assert tree.c == 3 + + # Test meta + meta_tree = meta(tree) + assert meta_tree.a == {"a": 6.0} + assert meta_tree.b == {"b": 5.0} + assert meta_tree.c == {"c": 4.0} + + # Test replacing changes only the specified field + new = tree.replace(a=123) + meta_new = meta(new) + + assert new.a == 123 + assert new.b == 2 + assert new.c == 3 + + assert meta_new.a == {"a": 6.0} + assert meta_new.b == {"b": 5.0} + assert meta_new.c == {"c": 4.0} + + +@pytest.mark.parametrize("is_dataclass", [True, False]) +def test_scrambled_annotations(is_dataclass): + class Tree(Module): + c: float = field(metadata={"c": 4.0}) + b: float = field(metadata={"b": 5.0}) + a: float = field(metadata={"a": 6.0}) + + def __init__(self, a, b, c): + self.a = a + self.b = b + self.c = c + + if is_dataclass: + Tree = dataclass(Tree) + + tree = Tree(1, 2, 3) + + assert isinstance(tree, Module) + assert isinstance(tree, Pytree) + + assert tree.a == 1 + assert tree.b == 2 + assert tree.c == 3 + + meta_tree = meta(tree) + assert meta_tree.a == {"a": 6.0} + assert meta_tree.b == {"b": 5.0} + assert meta_tree.c == {"c": 4.0} + + +@pytest.mark.parametrize("is_dataclass", [True, False]) +def test_scrambled_init(is_dataclass): + class Tree(Module): + a: float = field(metadata={"a": 6.0}) + b: float = field(metadata={"b": 5.0}) + c: float = field(metadata={"c": 4.0}) + + def __init__(self, a, b, c): + self.b = b + self.a = a + self.c = c + + if is_dataclass: + Tree = dataclass(Tree) + + tree = Tree(1, 2, 3) + + assert isinstance(tree, Module) + assert isinstance(tree, Pytree) + + assert tree.a == 1 + assert tree.b == 2 + assert tree.c == 3 + + meta_tree = meta(tree) + assert meta_tree.a == {"a": 6.0} + assert meta_tree.b == {"b": 5.0} + assert meta_tree.c == {"c": 4.0} + + +@pytest.mark.parametrize("is_dataclass", [True, False]) +def test_simple_linear_model(is_dataclass): + class SimpleModel(Module): + weight: float = param_field(bijector=Softplus, trainable=False) + bias: float + + def __init__(self, weight, bias): + self.weight = weight + self.bias = bias + + def __call__(self, test_point): + return test_point * self.weight + self.bias + + if is_dataclass: + SimpleModel = dataclass(SimpleModel) + + model = SimpleModel(1.0, 2.0) + + assert isinstance(model, Module) + assert isinstance(model, Pytree) + + assert model.weight == 1.0 + assert model.bias == 2.0 + + meta_model = meta(model) + + assert meta_model.weight["bijector"] == Softplus + assert meta_model.weight["trainable"] == False + assert meta_model.bias == {} + + constrained_model = model.constrain() + assert constrained_model.weight == Softplus.forward(1.0) + assert constrained_model.bias == Identity.forward(2.0) + + meta_constrained_model = meta(constrained_model) + assert meta_constrained_model.weight["bijector"] == Softplus + assert meta_constrained_model.weight["trainable"] == False + assert meta_constrained_model.bias == {} + + unconstrained_model = constrained_model.unconstrain() + assert unconstrained_model.weight == 1.0 + assert unconstrained_model.bias == 2.0 + + meta_unconstrained_model = meta(unconstrained_model) + assert meta_unconstrained_model.weight["bijector"] == Softplus + assert meta_unconstrained_model.weight["trainable"] == False + assert meta_unconstrained_model.bias == {} + + def loss_fn(model): + model = model.stop_gradient() + return (model(1.0) - 2.0) ** 2 + + grad = jax.grad(loss_fn)(model) + assert grad.weight == 0.0 + assert grad.bias == 2.0 + + new = model.replace_meta(bias={"amazing": True}) + assert new.weight == 1.0 + assert new.bias == 2.0 + assert model.weight == 1.0 + assert model.bias == 2.0 + assert meta(new).bias == {"amazing": True} + assert meta(model).bias == {} + + with pytest.raises(ValueError, match=f"'cool' is not a field of SimpleModel"): + model.replace_meta(cool={"don't": "think so"}) + + with pytest.raises(ValueError, match=f"'cool' is not a field of SimpleModel"): + model.update_meta(cool={"don't": "think so"}) + + new = model.update_meta(bias={"amazing": True}) + assert new.weight == 1.0 + assert new.bias == 2.0 + assert model.weight == 1.0 + assert model.bias == 2.0 + assert meta(new).bias == {"amazing": True} + assert meta(model).bias == {} + + +@pytest.mark.parametrize("is_dataclass", [True, False]) +def test_nested_Module_structure(is_dataclass): + class SubTree(Module): + c: float = param_field(bijector=Identity) + d: float = param_field(bijector=Softplus) + e: float = param_field(bijector=Softplus) + + def __init__(self, c, d, e): + self.c = c + self.d = d + self.e = e + + class Tree(Module): + a: float = param_field(bijector=Identity) + sub_tree: SubTree + b: float = param_field(bijector=Softplus) + + def __init__(self, a, sub_tree, b): + self.a = a + self.sub_tree = sub_tree + self.b = b + + if is_dataclass: + SubTree = dataclass(SubTree) + Tree = dataclass(Tree) + + tree = Tree( + a=1.0, + sub_tree=SubTree(c=2.0, d=3.0, e=4.0), + b=5.0, + ) + + assert isinstance(tree, Module) + assert isinstance(tree, Pytree) + assert isinstance(tree.sub_tree, Module) + assert isinstance(tree.sub_tree, Pytree) + + assert tree.a == 1.0 + assert tree.b == 5.0 + assert tree.sub_tree.c == 2.0 + assert tree.sub_tree.d == 3.0 + assert tree.sub_tree.e == 4.0 + + meta_tree = meta(tree) + + assert isinstance(meta_tree, Module) + assert isinstance(meta_tree, Pytree) + + assert meta_tree.a["bijector"] == Identity + assert meta_tree.a["trainable"] == True + assert meta_tree.b["bijector"] == Softplus + assert meta_tree.b["trainable"] == True + assert meta_tree.sub_tree.c["bijector"] == Identity + assert meta_tree.sub_tree.c["trainable"] == True + assert meta_tree.sub_tree.d["bijector"] == Softplus + assert meta_tree.sub_tree.d["trainable"] == True + assert meta_tree.sub_tree.e["bijector"] == Softplus + assert meta_tree.sub_tree.e["trainable"] == True + + # Test constrain and unconstrain + constrained = tree.constrain() + + assert isinstance(constrained, Module) + assert isinstance(constrained, Pytree) + + assert constrained.a == Identity.forward(1.0) + assert constrained.b == Softplus.forward(5.0) + assert constrained.sub_tree.c == Identity.forward(2.0) + assert constrained.sub_tree.d == Softplus.forward(3.0) + assert constrained.sub_tree.e == Softplus.forward(4.0) + + meta_constrained = meta(constrained) + + assert isinstance(meta_constrained, Module) + assert isinstance(meta_constrained, Pytree) + + assert meta_constrained.a["bijector"] == Identity + assert meta_constrained.a["trainable"] == True + assert meta_constrained.b["bijector"] == Softplus + assert meta_constrained.b["trainable"] == True + assert meta_constrained.sub_tree.c["bijector"] == Identity + assert meta_constrained.sub_tree.c["trainable"] == True + assert meta_constrained.sub_tree.d["bijector"] == Softplus + assert meta_constrained.sub_tree.d["trainable"] == True + assert meta_constrained.sub_tree.e["bijector"] == Softplus + assert meta_constrained.sub_tree.e["trainable"] == True + + # Test constrain and unconstrain + unconstrained = tree.unconstrain() + + assert isinstance(unconstrained, Module) + assert isinstance(unconstrained, Pytree) + + assert unconstrained.a == Identity.inverse(1.0) + assert unconstrained.b == Softplus.inverse(5.0) + assert unconstrained.sub_tree.c == Identity.inverse(2.0) + assert unconstrained.sub_tree.d == Softplus.inverse(3.0) + assert unconstrained.sub_tree.e == Softplus.inverse(4.0) + + meta_unconstrained = meta(unconstrained) + + assert isinstance(meta_unconstrained, Module) + assert isinstance(meta_unconstrained, Pytree) + + assert meta_unconstrained.a["bijector"] == Identity + assert meta_unconstrained.a["trainable"] == True + assert meta_unconstrained.b["bijector"] == Softplus + assert meta_unconstrained.b["trainable"] == True + assert meta_unconstrained.sub_tree.c["bijector"] == Identity + assert meta_unconstrained.sub_tree.c["trainable"] == True + assert meta_unconstrained.sub_tree.d["bijector"] == Softplus + assert meta_unconstrained.sub_tree.d["trainable"] == True + assert meta_unconstrained.sub_tree.e["bijector"] == Softplus + assert meta_unconstrained.sub_tree.e["trainable"] == True + + # Test updating metadata + + new_subtree = tree.sub_tree.replace_bijector(c=Softplus, e=Identity) + new_subtree = new_subtree.replace_trainable(c=False, e=False) + + new_tree = tree.replace_bijector(b=Identity) + new_tree = new_tree.replace_trainable(b=False) + new_tree = new_tree.replace(sub_tree=new_subtree) + + assert isinstance(new_tree, Module) + assert isinstance(new_tree, Pytree) + + assert new_tree.a == 1.0 + assert new_tree.b == 5.0 + assert new_tree.sub_tree.c == 2.0 + assert new_tree.sub_tree.d == 3.0 + assert new_tree.sub_tree.e == 4.0 + + meta_new_tree = meta(new_tree) + + assert isinstance(meta_new_tree, Module) + assert isinstance(meta_new_tree, Pytree) + + assert meta_new_tree.a["bijector"] == Identity + assert meta_new_tree.a["trainable"] == True + assert meta_new_tree.b["bijector"] == Identity + assert meta_new_tree.b["trainable"] == False + assert meta_new_tree.sub_tree.c["bijector"] == Softplus + assert meta_new_tree.sub_tree.c["trainable"] == False + assert meta_new_tree.sub_tree.d["bijector"] == Softplus + assert meta_new_tree.sub_tree.d["trainable"] == True + assert meta_new_tree.sub_tree.e["bijector"] == Identity + assert meta_new_tree.sub_tree.e["trainable"] == False + + # Test stop gradients + def loss(tree): + t = tree.stop_gradient() + return jnp.sum( + t.a**2 + + t.sub_tree.c**2 + + t.sub_tree.d**2 + + t.sub_tree.e**2 + + t.b**2 + ) + + g = jax.grad(loss)(new_tree) + + assert g.a == 2.0 + assert g.sub_tree.c == 0.0 + assert g.sub_tree.d == 6.0 + assert g.sub_tree.e == 0.0 + assert g.b == 0.0 + + +@pytest.mark.parametrize("is_dataclass", [True, False]) +@pytest.mark.parametrize("iterable", [list, tuple]) +def test_iterable_attribute(is_dataclass, iterable): + class SubTree(Module): + a: int = param_field(bijector=Identity, default=1) + b: int = param_field(bijector=Softplus, default=2) + c: int = param_field(bijector=Identity, default=3, trainable=False) + + def __init__(self, a=1.0, b=2.0, c=3.0): + self.a = a + self.b = b + self.c = c + + class Tree(Module): + trees: Iterable + + def __init__(self, trees): + self.trees = trees + + if is_dataclass: + SubTree = dataclass(SubTree) + Tree = dataclass(Tree) + + tree = Tree(iterable([SubTree(), SubTree(), SubTree()])) + + assert isinstance(tree, Module) + assert isinstance(tree, Pytree) + + assert tree.trees[0].a == 1.0 + assert tree.trees[0].b == 2.0 + assert tree.trees[0].c == 3.0 + + assert tree.trees[1].a == 1.0 + assert tree.trees[1].b == 2.0 + assert tree.trees[1].c == 3.0 + + assert tree.trees[2].a == 1.0 + assert tree.trees[2].b == 2.0 + assert tree.trees[2].c == 3.0 + + meta_tree = meta(tree) + + assert isinstance(meta_tree, Module) + assert isinstance(meta_tree, Pytree) + + assert meta_tree.trees[0].a["bijector"] == Identity + assert meta_tree.trees[0].a["trainable"] == True + assert meta_tree.trees[0].b["bijector"] == Softplus + assert meta_tree.trees[0].b["trainable"] == True + assert meta_tree.trees[0].c["bijector"] == Identity + assert meta_tree.trees[0].c["trainable"] == False + + assert meta_tree.trees[1].a["bijector"] == Identity + assert meta_tree.trees[1].a["trainable"] == True + assert meta_tree.trees[1].b["bijector"] == Softplus + assert meta_tree.trees[1].b["trainable"] == True + assert meta_tree.trees[1].c["bijector"] == Identity + assert meta_tree.trees[1].c["trainable"] == False + + assert meta_tree.trees[2].a["bijector"] == Identity + assert meta_tree.trees[2].a["trainable"] == True + assert meta_tree.trees[2].b["bijector"] == Softplus + assert meta_tree.trees[2].b["trainable"] == True + assert meta_tree.trees[2].c["bijector"] == Identity + assert meta_tree.trees[2].c["trainable"] == False + + # Test constrain and unconstrain + + constrained_tree = tree.constrain() + unconstrained_tree = tree.unconstrain() + + assert jtu.tree_structure(unconstrained_tree) == jtu.tree_structure(tree) + assert jtu.tree_structure(constrained_tree) == jtu.tree_structure(tree) + + assert isinstance(constrained_tree, Module) + assert isinstance(constrained_tree, Pytree) + + assert isinstance(unconstrained_tree, Module) + assert isinstance(unconstrained_tree, Pytree) + + assert constrained_tree.trees[0].a == Identity.forward(1.0) + assert constrained_tree.trees[0].b == Softplus.forward(2.0) + assert constrained_tree.trees[0].c == Identity.forward(3.0) + + assert constrained_tree.trees[1].a == Identity.forward(1.0) + assert constrained_tree.trees[1].b == Softplus.forward(2.0) + assert constrained_tree.trees[1].c == Identity.forward(3.0) + + assert constrained_tree.trees[2].a == Identity.forward(1.0) + assert constrained_tree.trees[2].b == Softplus.forward(2.0) + assert constrained_tree.trees[2].c == Identity.forward(3.0) + + assert unconstrained_tree.trees[0].a == Identity.inverse(1.0) + assert unconstrained_tree.trees[0].b == Softplus.inverse(2.0) + assert unconstrained_tree.trees[0].c == Identity.inverse(3.0) + + assert unconstrained_tree.trees[1].a == Identity.inverse(1.0) + assert unconstrained_tree.trees[1].b == Softplus.inverse(2.0) + assert unconstrained_tree.trees[1].c == Identity.inverse(3.0) + + assert unconstrained_tree.trees[2].a == Identity.inverse(1.0) + assert unconstrained_tree.trees[2].b == Softplus.inverse(2.0) + assert unconstrained_tree.trees[2].c == Identity.inverse(3.0) + + +# The following tests are adapted from equinox 🏴‍☠️ + + +def test_Module_not_enough_attributes(): + @dataclass + class Tree1(Module): + weight: Any = param_field(bijector=Identity) + + with pytest.raises(TypeError): + Tree1() + + @dataclass + class Tree2(Module): + weight: Any = param_field(bijector=Identity) + + def __init__(self): + return None + + with pytest.raises(TypeError): + Tree2(1) + + +def test_Module_too_many_attributes(): + @dataclass + class Tree1(Module): + weight: Any = param_field(bijector=Identity) + + with pytest.raises(TypeError): + Tree1(1, 2) + + +def test_Module_setattr_after_init(): + @dataclass + class Tree(Module): + weight: Any = param_field(bijector=Identity) + + m = Tree(1) + with pytest.raises(AttributeError): + m.asdf = True + + +# The main part of this test is to check that __init__ works correctly. +def test_inheritance(): + # no custom init / no custom init + + @dataclass + class Tree(Module): + weight: Any = param_field(bijector=Identity) + + @dataclass + class Tree2(Tree): + weight2: Any = param_field(bijector=Identity) + + m = Tree2(1, 2) + assert m.weight == 1 + assert m.weight2 == 2 + m = Tree2(1, weight2=2) + assert m.weight == 1 + assert m.weight2 == 2 + m = Tree2(weight=1, weight2=2) + assert m.weight == 1 + assert m.weight2 == 2 + with pytest.raises(TypeError): + m = Tree2(2, weight=2) + + # not custom init / custom init + + @dataclass + class Tree3(Tree): + weight3: Any = param_field(bijector=Identity) + + def __init__(self, *, weight3, **kwargs): + self.weight3 = weight3 + super().__init__(**kwargs) + + m = Tree3(weight=1, weight3=3) + assert m.weight == 1 + assert m.weight3 == 3 + + # custom init / no custom init + + @dataclass + class Tree4(Module): + weight4: Any = param_field(bijector=Identity) + + @dataclass + class Tree5(Tree4): + weight5: Any = param_field(bijector=Identity) + + with pytest.raises(TypeError): + m = Tree5(value4=1, weight5=2) + + @dataclass + class Tree6(Tree4): + pass + + m = Tree6(weight4=1) + assert m.weight4 == 1 + + # custom init / custom init + + @dataclass + class Tree7(Tree4): + weight7: Any = param_field(bijector=Identity) + + def __init__(self, value7, **kwargs): + self.weight7 = value7 + super().__init__(**kwargs) + + m = Tree7(weight4=1, value7=2) + assert m.weight4 == 1 + assert m.weight7 == 2 + + +def test_static_field(): + @dataclass + class Tree(Module): + field1: int = param_field(bijector=Identity) + field2: int = static_field() + field3: int = static_field(default=3) + + m = Tree(1, 2) + flat, treedef = jtu.tree_flatten(m) + assert len(flat) == 1 + assert flat[0] == 1 + rm = jtu.tree_unflatten(treedef, flat) + assert rm.field1 == 1 + assert rm.field2 == 2 + assert rm.field3 == 3 + + +def test_init_subclass(): + ran = [] + + @dataclass + class Tree(Module): + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + ran.append(True) + + @dataclass + class AnotherModule(Tree): + pass + + assert ran == [True] + + +# Taken from simple-pytree version = 0.1.6 🏴‍☠️ + + +class TestPytree: + def test_immutable_pytree(self): + class Foo(Module): + x: int = static_field() + y: int + + def __init__(self, y) -> None: + self.x = 2 + self.y = y + + pytree = Foo(y=3) + + leaves = jax.tree_util.tree_leaves(pytree) + assert leaves == [3] + + pytree = jax.tree_map(lambda x: x * 2, pytree) + assert pytree.x == 2 + assert pytree.y == 6 + + pytree = pytree.replace(x=3) + assert pytree.x == 3 + assert pytree.y == 6 + + with pytest.raises( + AttributeError, match="is immutable, trying to update field" + ): + pytree.x = 4 + + def test_immutable_pytree_dataclass(self): + @dataclasses.dataclass(frozen=True) + class Foo(Module): + y: int = field() + x: int = static_field(2) + + pytree = Foo(y=3) + + leaves = jax.tree_util.tree_leaves(pytree) + assert leaves == [3] + + pytree = jax.tree_map(lambda x: x * 2, pytree) + assert pytree.x == 2 + assert pytree.y == 6 + + pytree = pytree.replace(x=3) + assert pytree.x == 3 + assert pytree.y == 6 + + with pytest.raises(AttributeError, match="cannot assign to field"): + pytree.x = 4 + + def test_jit(self): + @dataclasses.dataclass + class Foo(Module): + a: int + b: int = static_field() + + module = Foo(a=1, b=2) + + @jax.jit + def f(m: Foo): + return m.a + m.b + + assert f(module) == 3 + + def test_flax_serialization(self): + class Bar(Module): + a: int = static_field() + b: int + + def __init__(self, a, b): + self.a = a + self.b = b + + @dataclasses.dataclass + class Foo(Module): + bar: Bar + c: int + d: int = static_field() + + foo: Foo = Foo(bar=Bar(a=1, b=2), c=3, d=4) + + state_dict = serialization.to_state_dict(foo) + + assert state_dict == { + "bar": { + "b": 2, + }, + "c": 3, + } + + state_dict["bar"]["b"] = 5 + + foo = serialization.from_state_dict(foo, state_dict) + + assert foo.bar.b == 5 + + del state_dict["bar"]["b"] + + with pytest.raises(ValueError, match="Missing field"): + serialization.from_state_dict(foo, state_dict) + + state_dict["bar"]["b"] = 5 + + # add unknown field + state_dict["x"] = 6 + + with pytest.raises(ValueError, match="Unknown field"): + serialization.from_state_dict(foo, state_dict) + + def test_generics(self): + T = TypeVar("T") + + class MyClass(Module, Generic[T]): + def __init__(self, x: T): + self.x = x + + MyClass[int] + + def test_key_paths(self): + @dataclasses.dataclass + class Bar(Module): + a: int = 1 + b: int = static_field(2) + + @dataclasses.dataclass + class Foo(Module): + x: int = 3 + y: int = static_field(4) + z: Bar = field(default_factory=Bar) + + foo = Foo() + + path_values, treedef = jax.tree_util.tree_flatten_with_path(foo) + path_values = [(list(map(str, path)), value) for path, value in path_values] + + assert path_values[0] == ([".x"], 3) + assert path_values[1] == ([".z", ".a"], 1) + + def test_setter_attribute_allowed(self): + n = None + + class SetterDescriptor: + def __set__(self, _, value): + nonlocal n + n = value + + class Foo(Module): + x: int = SetterDescriptor() + + foo = Foo() + foo.x = 1 + + assert n == 1 + + with pytest.raises(AttributeError, match=r"<.*> is immutable"): + foo.y = 2 + + def test_replace_unknown_fields_error(self): + class Foo(Module): + pass + + with pytest.raises(ValueError, match="'y' is not a field of Foo"): + Foo().replace(y=1) + + def test_dataclass_inheritance(self): + @dataclasses.dataclass + class A(Module): + a: int = 1 + b: int = static_field(2) + + @dataclasses.dataclass + class B(A): + c: int = 3 + + pytree = B() + leaves = jax.tree_util.tree_leaves(pytree) + assert leaves == [1, 3] + + +class TestMutablePytree: + def test_pytree(self): + class Foo(Module, mutable=True): + x: int = static_field() + y: int + + def __init__(self, y) -> None: + self.x = 2 + self.y = y + + pytree = Foo(y=3) + + leaves = jax.tree_util.tree_leaves(pytree) + assert leaves == [3] + + pytree = jax.tree_map(lambda x: x * 2, pytree) + assert pytree.x == 2 + assert pytree.y == 6 + + pytree = pytree.replace(x=3) + assert pytree.x == 3 + assert pytree.y == 6 + + # test mutation + pytree.x = 4 + assert pytree.x == 4 + + def test_pytree_dataclass(self): + @dataclasses.dataclass + class Foo(Module, mutable=True): + y: int = field() + x: int = static_field(2) + + pytree: Foo = Foo(y=3) + + leaves = jax.tree_util.tree_leaves(pytree) + assert leaves == [3] + + pytree = jax.tree_map(lambda x: x * 2, pytree) + assert pytree.x == 2 + assert pytree.y == 6 + + pytree = pytree.replace(x=3) + assert pytree.x == 3 + assert pytree.y == 6 + + # test mutation + pytree.x = 4 + assert pytree.x == 4 \ No newline at end of file diff --git a/tests/test_param.py b/tests/test_param.py new file mode 100644 index 000000000..0dffb163f --- /dev/null +++ b/tests/test_param.py @@ -0,0 +1,41 @@ +import dataclasses + +import pytest + +from mytree import Identity, Softplus, param_field + + +@pytest.mark.parametrize("bijector", [Identity, Softplus]) +@pytest.mark.parametrize("trainable", [True, False]) +def test_param(bijector, trainable): + param_field_ = param_field(bijector=bijector, trainable=trainable) + assert isinstance(param_field_, dataclasses.Field) + assert param_field_.metadata["bijector"] == bijector + assert param_field_.metadata["trainable"] == trainable + + with pytest.raises(ValueError): + param_field( + bijector=bijector, trainable=trainable, metadata={"trainable": trainable} + ) + + with pytest.raises(ValueError): + param_field( + bijector=bijector, trainable=trainable, metadata={"bijector": bijector} + ) + + with pytest.raises(ValueError): + param_field( + bijector=bijector, + trainable=trainable, + metadata={"bijector": Softplus, "trainable": trainable}, + ) + + with pytest.raises(ValueError): + param_field( + bijector=bijector, trainable=trainable, metadata={"pytree_node": True} + ) + + with pytest.raises(ValueError): + param_field( + bijector=bijector, trainable=trainable, metadata={"pytree_node": False} + ) From fb2d6d1571a6c37d8918df74cfa761da9124e6cd Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Mon, 27 Mar 2023 17:45:56 +0100 Subject: [PATCH 02/44] Add stuff. --- gpjax/kernels/base.py | 257 +++++-------- gpjax/kernels/computations/base.py | 76 +--- gpjax/kernels/computations/basis_functions.py | 68 +--- .../kernels/computations/constant_diagonal.py | 57 +-- gpjax/kernels/computations/dense.py | 20 +- gpjax/kernels/computations/diagonal.py | 39 +- gpjax/kernels/computations/eigen.py | 59 +-- gpjax/kernels/nonstationary/linear.py | 50 +-- gpjax/kernels/nonstationary/polynomial.py | 50 +-- gpjax/kernels/stationary/rbf.py | 49 +-- tests/test_kernels/test_approximations.py | 8 +- tests/test_kernels/test_base.py | 2 +- tests/test_kernels/test_computation.py | 15 +- tests/test_kernels/test_non_euclidean.py | 2 +- tests/test_kernels/test_nonstationary.py | 65 ++-- tests/test_kernels/test_stationary.py | 353 +++++++++--------- 16 files changed, 421 insertions(+), 749 deletions(-) diff --git a/gpjax/kernels/base.py b/gpjax/kernels/base.py index 6014aa7db..725d3e022 100644 --- a/gpjax/kernels/base.py +++ b/gpjax/kernels/base.py @@ -13,100 +13,78 @@ # limitations under the License. # ============================================================================== -import abc -from typing import Callable, Dict, List, Optional, Sequence +from __future__ import annotations -import deprecation +import abc import jax.numpy as jnp -import jax.random -import jax -from jax.random import KeyArray +from typing import List, Callable, Union from jaxtyping import Array, Float -from jaxutils import PyTree +from functools import partial +from mytree import Mytree, param_field +from simple_pytree import static_field +from dataclasses import dataclass +from functools import partial from .computations import AbstractKernelComputation, DenseKernelComputation -import distrax as dx - -########################################## -# Abtract classes -########################################## -class AbstractKernel(PyTree): - """ - Base kernel class""" - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - spectral_density: Optional[dx.Distribution] = None, - name: Optional[str] = "AbstractKernel", - ) -> None: - self._compute_engine = compute_engine - self.active_dims = active_dims - self.spectral_density = spectral_density - self.name = name - self._stationary = False - self.ndims = 1 if not self.active_dims else len(self.active_dims) - compute_engine = self.compute_engine(kernel_fn=self.__call__) - self.gram = compute_engine.gram - self.cross_covariance = compute_engine.cross_covariance +@dataclass +class AbstractKernel(Mytree): + """Base kernel class.""" + compute_engine: AbstractKernelComputation = static_field(DenseKernelComputation) + active_dims: List[int] = static_field(None) @property - def stationary(self) -> bool: - """Boolean property as to whether the kernel is stationary or not. + def ndims(self): + return 1 if not self.active_dims else len(self.active_dims) - Returns: - bool: True if the kernel is stationary. - """ - return self._stationary + def cross_covariance(self, x: Float[Array, "N D"], y: Float[Array, "M D"]): + return self.compute_engine(self).cross_covariance(x, y) - @property - def compute_engine(self) -> AbstractKernelComputation: - """The compute engine that is used to perform the kernel computations. + def gram(self, x: Float[Array, "N D"]): + return self.compute_engine(self).gram(x) + + def slice_input(self, x: Float[Array, "N D"]) -> Float[Array, "N S"]: + """Select the relevant columns of the supplied matrix to be used within the kernel's evaluation. + Args: + x (Float[Array, "N D"]): The matrix or vector that is to be sliced. Returns: - AbstractKernelComputation: The compute engine that is used to perform the kernel computations. + Float[Array, "N S"]: A sliced form of the input matrix. """ - return self._compute_engine - - @compute_engine.setter - def compute_engine(self, compute_engine: AbstractKernelComputation) -> None: - self._compute_engine = compute_engine - compute_engine = self.compute_engine(kernel_fn=self.__call__) - self.gram = compute_engine.gram - self.cross_covariance = compute_engine.cross_covariance + return x[..., self.active_dims] @abc.abstractmethod def __call__( self, - params: Dict, - x: Float[Array, "1 D"], - y: Float[Array, "1 D"], + x: Float[Array, "D"], + y: Float[Array, "D"], ) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs. Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call + x (Float[Array, "D"]): The left hand input of the kernel function. + y (Float[Array, "D"]): The right hand input of the kernel function. Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. + Float[Array, "1"]: The evaluated kernel function at the supplied inputs. """ raise NotImplementedError - def slice_input(self, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: - """Select the relevant columns of the supplied matrix to be used within the kernel's evaluation. - + def __add__(self, other: Union[AbstractKernel, Float[Array, "1"]]) -> AbstractKernel: + """Add two kernels together. Args: - x (Float[Array, "N D"]): The matrix or vector that is to be sliced. + other (AbstractKernel): The kernel to be added to the current kernel. + Returns: - Float[Array, "N Q"]: A sliced form of the input matrix. + AbstractKernel: A new kernel that is the sum of the two kernels. """ - return x[..., self.active_dims] - def __add__(self, other: "AbstractKernel") -> "AbstractKernel": + if isinstance(other, AbstractKernel): + return SumKernel([self, other]) + + return SumKernel([self, Constant(other)]) + + def __radd__(self, other: Union[AbstractKernel, Float[Array, "1"]]) -> AbstractKernel: """Add two kernels together. Args: other (AbstractKernel): The kernel to be added to the current kernel. @@ -114,9 +92,9 @@ def __add__(self, other: "AbstractKernel") -> "AbstractKernel": Returns: AbstractKernel: A new kernel that is the sum of the two kernels. """ - return SumKernel(kernel_set=[self, other]) + return self.__add__(other) - def __mul__(self, other: "AbstractKernel") -> "AbstractKernel": + def __mul__(self, other: Union[AbstractKernel, Float[Array, "1"]]) -> AbstractKernel: """Multiply two kernels together. Args: @@ -125,131 +103,70 @@ def __mul__(self, other: "AbstractKernel") -> "AbstractKernel": Returns: AbstractKernel: A new kernel that is the product of the two kernels. """ - return ProductKernel(kernel_set=[self, other]) + if isinstance(other, AbstractKernel): + return ProductKernel([self, other]) + + return ProductKernel([self, Constant(other)]) - @property - def ard(self): - """Boolean property as to whether the kernel is isotropic or of - automatic relevance determination form. - Returns: - bool: True if the kernel is an ARD kernel. - """ - return True if self.ndims > 1 else False - - @abc.abstractmethod - def init_params(self, key: KeyArray) -> Dict: - """A template dictionary of the kernel's parameter set. - - Args: - key (KeyArray): A PRNG key to be used for initialising - the kernel's parameters. - - Returns: - Dict: A dictionary of the kernel's parameters. - """ - raise NotImplementedError +@dataclass +class Constant(AbstractKernel): + """ + A constant mean function. This function returns a repeated scalar value for all inputs. + The scalar value itself can be treated as a model hyperparameter and learned during training. + """ + constant: Float[Array, "1"] = param_field(jnp.array(0.0)) - @deprecation.deprecated( - deprecated_in="0.0.3", - removed_in="0.1.0", - ) - def _initialise_params(self, key: KeyArray) -> Dict: - """A template dictionary of the kernel's parameter set. + def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: + """Evaluate the kernel on a pair of inputs. Args: - key (KeyArray): A PRNG key to be used for initialising - the kernel's parameters. + x (Float[Array, "D"]): The left hand input of the kernel function. + y (Float[Array, "D"]): The right hand input of the kernel function. Returns: - Dict: A dictionary of the kernel's parameters. + Float[Array, "1"]: The evaluated kernel function at the supplied inputs. """ - raise NotImplementedError + return self.constant.squeeze() +@dataclass class CombinationKernel(AbstractKernel): - """A base class for products or sums of kernels.""" - - def __init__( - self, - kernel_set: List[AbstractKernel], - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - name: Optional[str] = "AbstractKernel", - ) -> None: - super().__init__(compute_engine, active_dims, name) - self.kernel_set = kernel_set - name: Optional[str] = "Combination kernel" - self.combination_fn: Optional[Callable] = None - - if not all(isinstance(k, AbstractKernel) for k in self.kernel_set): - raise TypeError("can only combine Kernel instances") # pragma: no cover - if all(k.stationary for k in self.kernel_set): - self._stationary = True - self._set_kernels(self.kernel_set) - - def _set_kernels(self, kernels: Sequence[AbstractKernel]) -> None: - """Combine multiple kernels. Based on GPFlow's Combination kernel.""" - # add kernels to a list, flattening out instances of this class therein + """A base class for products or sums of MeanFunctions.""" + kernels: List[AbstractKernel] = None + operator: Callable = static_field(None) + + def __post_init__(self): + #Add kernels to a list, flattening out instances of this class therein, as in GPFlow kernels. kernels_list: List[AbstractKernel] = [] - for k in kernels: - if isinstance(k, self.__class__): - kernels_list.extend(k.kernel_set) - else: - kernels_list.append(k) + + for kernel in self.kernels: + if not isinstance(kernel, AbstractKernel): + raise TypeError("can only combine Kernel instances") # pragma: no cover - self.kernel_set = kernels_list + if isinstance(kernel, self.__class__): + kernels_list.extend(kernel.kernels) + else: + kernels_list.append(kernel) - def init_params(self, key: KeyArray) -> Dict: - """A template dictionary of the kernel's parameter set.""" - num_kernels = len(self.kernel_set) - key_per_kernel = jax.random.split(key=key, num=num_kernels) - return [kernel.init_params(key_) for key_, kernel in zip(key_per_kernel, self.kernel_set)] + self.kernels = kernels_list def __call__( self, - params: Dict, - x: Float[Array, "1 D"], - y: Float[Array, "1 D"], + x: Float[Array, "D"], + y: Float[Array, "D"], ) -> Float[Array, "1"]: - """Evaluate combination kernel on a pair of inputs. + """Evaluate the kernel on a pair of inputs. Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call + x (Float[Array, "D"]): The left hand input of the kernel function. + y (Float[Array, "D"]): The right hand input of the kernel function. Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. + Float[Array, "1"]: The evaluated kernel function at the supplied inputs. """ - return self.combination_fn( - jnp.stack([k(p, x, y) for k, p in zip(self.kernel_set, params)]) - ) - + return self.operator(jnp.stack([k(x, y) for k in self.kernels])) + -class SumKernel(CombinationKernel): - """A kernel that is the sum of a set of kernels.""" - - def __init__( - self, - kernel_set: List[AbstractKernel], - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - name: Optional[str] = "Sum kernel", - ) -> None: - super().__init__(kernel_set, compute_engine, active_dims, name) - self.combination_fn: Optional[Callable] = jnp.sum - - -class ProductKernel(CombinationKernel): - """A kernel that is the product of a set of kernels.""" - - def __init__( - self, - kernel_set: List[AbstractKernel], - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - name: Optional[str] = "Product kernel", - ) -> None: - super().__init__(kernel_set, compute_engine, active_dims, name) - self.combination_fn: Optional[Callable] = jnp.prod +SumKernel = partial(CombinationKernel, operator=jnp.sum) +ProductKernel = partial(CombinationKernel, operator=jnp.sum) \ No newline at end of file diff --git a/gpjax/kernels/computations/base.py b/gpjax/kernels/computations/base.py index d8584a70c..718a4c65c 100644 --- a/gpjax/kernels/computations/base.py +++ b/gpjax/kernels/computations/base.py @@ -14,103 +14,61 @@ # ============================================================================== import abc -from typing import Callable, Dict - +from typing import Any from jax import vmap -from jaxtyping import Array, Float -from jaxutils import PyTree - -from ...linops import ( +from gpjax.linops import ( DenseLinearOperator, DiagonalLinearOperator, LinearOperator, ) +from jaxtyping import Array, Float +from dataclasses import dataclass +Kernel = Any -class AbstractKernelComputation(PyTree): +@dataclass +class AbstractKernelComputation: """Abstract class for kernel computations.""" - - def __init__( - self, - kernel_fn: Callable[ - [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array - ] = None, - ) -> None: - self._kernel_fn = kernel_fn - - @property - def kernel_fn( - self, - ) -> Callable[[Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array]: - return self._kernel_fn - - @kernel_fn.setter - def kernel_fn( - self, - kernel_fn: Callable[[Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array], - ) -> None: - self._kernel_fn = kernel_fn + kernel: Kernel def gram( self, - params: Dict, - inputs: Float[Array, "N D"], + x: Float[Array, "N D"], ) -> LinearOperator: """Compute Gram covariance operator of the kernel function. Args: - kernel (AbstractKernel): The kernel function to be evaluated. - params (Dict): The parameters of the kernel function. - inputs (Float[Array, "N N"]): The inputs to the kernel function. + x (Float[Array, "N N"]): The inputs to the kernel function. Returns: LinearOperator: Gram covariance operator of the kernel function. """ - - matrix = self.cross_covariance(params, inputs, inputs) - - return DenseLinearOperator(matrix=matrix) + Kxx = self.cross_covariance(x, x) + return DenseLinearOperator(Kxx) @abc.abstractmethod - def cross_covariance( - self, - params: Dict, - x: Float[Array, "N D"], - y: Float[Array, "M D"], - ) -> Float[Array, "N M"]: + def cross_covariance(self, x: Float[Array, "N D"], y: Float[Array, "M D"]) -> Float[Array, "N M"]: """For a given kernel, compute the NxM gram matrix on an a pair of input matrices with shape NxD and MxD. Args: - kernel (AbstractKernel): The kernel for which the cross-covariance - matrix should be computed for. - params (Dict): The kernel's parameter set. x (Float[Array,"N D"]): The first input matrix. y (Float[Array,"M D"]): The second input matrix. Returns: - Float[Array, "N M"]: The computed square Gram matrix. + Float[Array, "N M"]: The computed cross-covariance. """ raise NotImplementedError - def diagonal( - self, - params: Dict, - inputs: Float[Array, "N D"], - ) -> DiagonalLinearOperator: + def diagonal(self, inputs: Float[Array, "N D"]) -> DiagonalLinearOperator: """For a given kernel, compute the elementwise diagonal of the NxN gram matrix on an input matrix of shape NxD. Args: - kernel (AbstractKernel): The kernel for which the variance - vector should be computed for. - params (Dict): The kernel's parameter set. inputs (Float[Array, "N D"]): The input matrix. Returns: - LinearOperator: The computed diagonal variance entries. + DiagonalLinearOperator: The computed diagonal variance entries. """ - diag = vmap(lambda x: self._kernel_fn(params, x, x))(inputs) - - return DiagonalLinearOperator(diag=diag) + return DiagonalLinearOperator(diag=vmap(lambda x: self.kernel(x, x))(inputs)) diff --git a/gpjax/kernels/computations/basis_functions.py b/gpjax/kernels/computations/basis_functions.py index a128ea5a0..e0a057fed 100644 --- a/gpjax/kernels/computations/basis_functions.py +++ b/gpjax/kernels/computations/basis_functions.py @@ -1,40 +1,17 @@ -from typing import Callable, Dict - import jax.numpy as jnp from jaxtyping import Array, Float from .base import AbstractKernelComputation -from ...linops import DenseLinearOperator +from gpjax.linops import DenseLinearOperator + +from dataclasses import dataclass +@dataclass class BasisFunctionComputation(AbstractKernelComputation): """Compute engine class for finite basis function approximations to a kernel.""" + num_basis_fns = None - def __init__( - self, - kernel_fn: Callable[ - [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array - ] = None, - ) -> None: - """Initialise the computation engine for a basis function approximation to a kernel. - - Args: - kernel_fn: A. The kernel function for which the compute engine is assigned to. - """ - super().__init__(kernel_fn) - self._num_basis_fns = None - - @property - def num_basis_fns(self) -> float: - """The number of basis functions used to approximate the kernel.""" - return self._num_basis_fns - - @num_basis_fns.setter - def num_basis_fns(self, num_basis_fns: int) -> None: - self._num_basis_fns = float(num_basis_fns) - - def cross_covariance( - self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] - ) -> Float[Array, "N M"]: + def cross_covariance(self, x: Float[Array, "N D"], y: Float[Array, "M D"]) -> Float[Array, "N M"]: """For a pair of inputs, compute the cross covariance matrix between the inputs. Args: params (Dict): A dictionary of parameters for which the cross-covariance matrix should be constructed with. @@ -44,16 +21,12 @@ def cross_covariance( Returns: _type_: A N x M array of cross-covariances. """ - z1 = self.compute_features( - x, params["frequencies"], scaling_factor=params["lengthscale"] - ) - z2 = self.compute_features( - y, params["frequencies"], scaling_factor=params["lengthscale"] - ) + z1 = self.compute_features(x) + z2 = self.compute_features(y) z1 /= self.num_basis_fns - return params["variance"] * jnp.matmul(z1, z2.T) + return self.kernel.variance * jnp.matmul(z1, z2.T) - def gram(self, params: Dict, inputs: Float[Array, "N D"]) -> DenseLinearOperator: + def gram(self, inputs: Float[Array, "N D"]) -> DenseLinearOperator: """For the Gram matrix, we can save computations by computing only one matrix multiplication between the inputs and the scaled frequencies. Args: @@ -63,19 +36,12 @@ def gram(self, params: Dict, inputs: Float[Array, "N D"]) -> DenseLinearOperator Returns: DenseLinearOperator: A dense linear operator representing the N x N Gram matrix. """ - z1 = self.compute_features( - inputs, params["frequencies"], scaling_factor=params["lengthscale"] - ) + z1 = self.compute_features(inputs) matrix = jnp.matmul(z1, z1.T) # shape: (n_samples, n_samples) matrix /= self.num_basis_fns - return DenseLinearOperator(params["variance"] * matrix) + return DenseLinearOperator(self.kernel.variance * matrix) - @staticmethod - def compute_features( - x: Float[Array, "N D"], - frequencies: Float[Array, "M D"], - scaling_factor: Float[Array, "D"] = None, - ) -> Float[Array, "N L"]: + def compute_features(self, x: Float[Array, "N D"]) -> Float[Array, "N L"]: """Compute the features for the inputs. Args: @@ -85,8 +51,8 @@ def compute_features( Returns: Float[Array, "N L"]: A N x L array of features where L = 2M. """ - if scaling_factor is not None: - frequencies = frequencies / scaling_factor - z = jnp.matmul(x, frequencies.T) + frequencies = self.kernel.frequencies + scaling_factor = self.kernel.lengthscale + z = jnp.matmul(x, (frequencies / scaling_factor).T) z = jnp.concatenate([jnp.cos(z), jnp.sin(z)], axis=-1) - return z + return z \ No newline at end of file diff --git a/gpjax/kernels/computations/constant_diagonal.py b/gpjax/kernels/computations/constant_diagonal.py index 5cfef9813..7d9fb2949 100644 --- a/gpjax/kernels/computations/constant_diagonal.py +++ b/gpjax/kernels/computations/constant_diagonal.py @@ -12,57 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - -from typing import Callable, Dict import jax.numpy as jnp from jax import vmap -from jaxtyping import Array, Float -from .base import AbstractKernelComputation - -from ...linops import ( +from gpjax.linops import ( ConstantDiagonalLinearOperator, DiagonalLinearOperator, ) +from jaxtyping import Array, Float +from .base import AbstractKernelComputation class ConstantDiagonalKernelComputation(AbstractKernelComputation): - def __init__( - self, - kernel_fn: Callable[ - [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array - ] = None, - ) -> None: - super().__init__(kernel_fn) - - def gram( - self, - params: Dict, - inputs: Float[Array, "N D"], - ) -> ConstantDiagonalLinearOperator: - """For a kernel with diagonal structure, compute the NxN gram matrix on - an input matrix of shape NxD. + def gram(self, x: Float[Array, "N D"]) -> ConstantDiagonalLinearOperator: + """Compute Gram covariance operator of the kernel function. Args: - kernel (AbstractKernel): The kernel for which the Gram matrix - should be computed for. - params (Dict): The kernel's parameter set. - inputs (Float[Array, "N D"]): The input matrix. - - Returns: - CovarianceOperator: The computed square Gram matrix. + x (Float[Array, "N N"]): The inputs to the kernel function. """ - value = self.kernel_fn(params, inputs[0], inputs[0]) + value = self.kernel(x[0], x[0]) - return ConstantDiagonalLinearOperator( - value=jnp.atleast_1d(value), size=inputs.shape[0] - ) + return ConstantDiagonalLinearOperator(value=jnp.atleast_1d(value), size=x.shape[0]) - def diagonal( - self, - params: Dict, - inputs: Float[Array, "N D"], - ) -> DiagonalLinearOperator: + def diagonal(self, inputs: Float[Array, "N D"]) -> DiagonalLinearOperator: """For a given kernel, compute the elementwise diagonal of the NxN gram matrix on an input matrix of shape NxD. @@ -76,20 +48,15 @@ def diagonal( LinearOperator: The computed diagonal variance entries. """ - diag = vmap(lambda x: self.kernel_fn(params, x, x))(inputs) + diag = vmap(lambda x: self.kernel(x, x))(inputs) return DiagonalLinearOperator(diag=diag) - def cross_covariance( - self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] - ) -> Float[Array, "N M"]: + def cross_covariance(self, x: Float[Array, "N D"], y: Float[Array, "M D"]) -> Float[Array, "N M"]: """For a given kernel, compute the NxM covariance matrix on a pair of input matrices of shape NxD and MxD. Args: - kernel (AbstractKernel): The kernel for which the Gram - matrix should be computed for. - params (Dict): The kernel's parameter set. x (Float[Array,"N D"]): The input matrix. y (Float[Array,"M D"]): The input matrix. @@ -97,5 +64,5 @@ def cross_covariance( CovarianceOperator: The computed square Gram matrix. """ # TODO: This is currently a dense implementation. We should implement a sparse LinearOperator for non-square cross-covariance matrices. - cross_cov = vmap(lambda x: vmap(lambda y: self.kernel_fn(params, x, y))(y))(x) + cross_cov = vmap(lambda x: vmap(lambda y: self.kernel(x, y))(y))(x) return cross_cov diff --git a/gpjax/kernels/computations/dense.py b/gpjax/kernels/computations/dense.py index 7fb74b5ee..c64981feb 100644 --- a/gpjax/kernels/computations/dense.py +++ b/gpjax/kernels/computations/dense.py @@ -13,41 +13,27 @@ # limitations under the License. # ============================================================================== -from typing import Callable, Dict - from jax import vmap from jaxtyping import Array, Float from .base import AbstractKernelComputation - class DenseKernelComputation(AbstractKernelComputation): """Dense kernel computation class. Operations with the kernel assume a dense gram matrix structure. """ - def __init__( - self, - kernel_fn: Callable[ - [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array - ] = None, - ) -> None: - super().__init__(kernel_fn) - def cross_covariance( - self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] + self, x: Float[Array, "N D"], y: Float[Array, "M D"] ) -> Float[Array, "N M"]: """For a given kernel, compute the NxM covariance matrix on a pair of input matrices of shape NxD and MxD. Args: - kernel (AbstractKernel): The kernel for which the Gram - matrix should be computed for. - params (Dict): The kernel's parameter set. x (Float[Array,"N D"]): The input matrix. y (Float[Array,"M D"]): The input matrix. Returns: - CovarianceOperator: The computed square Gram matrix. + Float[Array, "N M"]: The computed cross-covariance. """ - cross_cov = vmap(lambda x: vmap(lambda y: self.kernel_fn(params, x, y))(y))(x) + cross_cov = vmap(lambda x: vmap(lambda y: self.kernel(x, y))(y))(x) return cross_cov diff --git a/gpjax/kernels/computations/diagonal.py b/gpjax/kernels/computations/diagonal.py index 5c2910005..fc14a257e 100644 --- a/gpjax/kernels/computations/diagonal.py +++ b/gpjax/kernels/computations/diagonal.py @@ -14,59 +14,38 @@ # ============================================================================== from jax import vmap -from typing import Callable, Dict +from gpjax.linops import ( + DiagonalLinearOperator, +) from jaxtyping import Array, Float - from .base import AbstractKernelComputation -from ...linops import DiagonalLinearOperator + class DiagonalKernelComputation(AbstractKernelComputation): - def __init__( - self, - kernel_fn: Callable[ - [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array - ] = None, - ) -> None: - super().__init__(kernel_fn) - def gram( - self, - params: Dict, - inputs: Float[Array, "N D"], - ) -> DiagonalLinearOperator: + def gram(self, x: Float[Array, "N D"]) -> DiagonalLinearOperator: """For a kernel with diagonal structure, compute the NxN gram matrix on an input matrix of shape NxD. Args: - kernel (AbstractKernel): The kernel for which the Gram matrix - should be computed for. - params (Dict): The kernel's parameter set. inputs (Float[Array, "N D"]): The input matrix. Returns: CovarianceOperator: The computed square Gram matrix. """ + return DiagonalLinearOperator(diag=vmap(lambda x: self.kernel(x, x))(x)) - diag = vmap(lambda x: self.kernel_fn(params, x, x))(inputs) - - return DiagonalLinearOperator(diag=diag) - - def cross_covariance( - self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] - ) -> Float[Array, "N M"]: + def cross_covariance(self, x: Float[Array, "N D"], y: Float[Array, "M D"]) -> Float[Array, "N M"]: """For a given kernel, compute the NxM covariance matrix on a pair of input matrices of shape NxD and MxD. Args: - kernel (AbstractKernel): The kernel for which the Gram - matrix should be computed for. - params (Dict): The kernel's parameter set. x (Float[Array,"N D"]): The input matrix. y (Float[Array,"M D"]): The input matrix. Returns: - CovarianceOperator: The computed square Gram matrix. + Float[Array, "N M"]: The computed cross-covariance. """ # TODO: This is currently a dense implementation. We should implement a sparse LinearOperator for non-square cross-covariance matrices. - cross_cov = vmap(lambda x: vmap(lambda y: self.kernel_fn(params, x, y))(y))(x) + cross_cov = vmap(lambda x: vmap(lambda y: self.kernel(x, y))(y))(x) return cross_cov diff --git a/gpjax/kernels/computations/eigen.py b/gpjax/kernels/computations/eigen.py index 094fc8d16..e80d2204c 100644 --- a/gpjax/kernels/computations/eigen.py +++ b/gpjax/kernels/computations/eigen.py @@ -13,59 +13,24 @@ # limitations under the License. # ============================================================================== -from typing import Callable, Dict +from typing import Dict import jax.numpy as jnp from jaxtyping import Array, Float from .base import AbstractKernelComputation +from dataclasses import dataclass - +@dataclass class EigenKernelComputation(AbstractKernelComputation): - def __init__( - self, - kernel_fn: Callable[ - [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array - ] = None, - ) -> None: - super().__init__(kernel_fn) - self._eigenvalues = None - self._eigenvectors = None - self._num_verticies = None - - # Define an eigenvalue setter and getter property - @property - def eigensystem(self) -> Float[Array, "N"]: - return self._eigenvalues, self._eigenvectors, self._num_verticies - - @eigensystem.setter - def eigensystem( - self, eigenvalues: Float[Array, "N"], eigenvectors: Float[Array, "N N"] - ) -> None: - self._eigenvalues = eigenvalues - self._eigenvectors = eigenvectors - - @property - def num_vertex(self) -> int: - return self._num_verticies - - @num_vertex.setter - def num_vertex(self, num_vertex: int) -> None: - self._num_verticies = num_vertex - - def _compute_S(self, params): - evals, evecs = self.eigensystem - S = jnp.power( - evals - + 2 * params["smoothness"] / params["lengthscale"] / params["lengthscale"], - -params["smoothness"], + eigenvalues: Float[Array, "N"] = None + eigenvectors: Float[Array, "N N"] = None + num_verticies: int = None + + def cross_covariance(self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"]) -> Float[Array, "N M"]: + evals = self.eigenvalues + S = jnp.power(evals + 2 * self.kernel.smoothness / self.kernel.lengthscale / self.kernel.lengthscale, + - self.kernel.smoothness, ) S = jnp.multiply(S, self.num_vertex / jnp.sum(S)) S = jnp.multiply(S, params["variance"]) - return S - - def cross_covariance( - self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] - ) -> Float[Array, "N M"]: - S = self._compute_S(params=params) - matrix = self.kernel_fn(params, x, y, S=S) - return matrix + return self.kernel(x, y, S=S) \ No newline at end of file diff --git a/gpjax/kernels/nonstationary/linear.py b/gpjax/kernels/nonstationary/linear.py index 9e5a2509a..226e56840 100644 --- a/gpjax/kernels/nonstationary/linear.py +++ b/gpjax/kernels/nonstationary/linear.py @@ -13,56 +13,38 @@ # limitations under the License. # ============================================================================== -from typing import Dict, List, Optional - -import jax import jax.numpy as jnp -from jax.random import KeyArray from jaxtyping import Array from ..base import AbstractKernel -from ..computations import ( - DenseKernelComputation, -) +from dataclasses import dataclass +from jaxtyping import Array, Float +from mytree import param_field, Softplus -########################################## -# Euclidean kernels -########################################## +@dataclass class Linear(AbstractKernel): """The linear kernel.""" + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) - def __init__( + def __call__( self, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - name: Optional[str] = "Linear", - ) -> None: - super().__init__( - DenseKernelComputation, - active_dims, - spectral_density=None, - name=name, - ) - self._stationary = False - - def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with variance parameter :math:`\\sigma` + x: Float[Array, "D"], + y: Float[Array, "D"], + ) -> Float[Array, "1"]: + """Evaluate the linear kernel on a pair of inputs :math:`(x, y)` with variance parameter :math:`\\sigma` .. math:: k(x, y) = \\sigma^2 x^{T}y Args: - x (jax.Array): The left hand argument of the kernel function's call. - y (jax.Array): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "D"]): The left hand input of the kernel function. + y (Float[Array, "D"]): The right hand input of the kernel function. + Returns: - Array: The value of :math:`k(x, y)` + Float[Array, "1"]: The evaluated kernel function :math:`k(x, y)` at the supplied inputs. """ x = self.slice_input(x) y = self.slice_input(y) - K = params["variance"] * jnp.matmul(x.T, y) - return K.squeeze() - - def init_params(self, key: KeyArray) -> Dict: - return {"variance": jnp.array([1.0])} + K = self.variance * jnp.matmul(x.T, y) + return K.squeeze() \ No newline at end of file diff --git a/gpjax/kernels/nonstationary/polynomial.py b/gpjax/kernels/nonstationary/polynomial.py index 18bd581e2..582460cca 100644 --- a/gpjax/kernels/nonstationary/polynomial.py +++ b/gpjax/kernels/nonstationary/polynomial.py @@ -13,61 +13,35 @@ # limitations under the License. # ============================================================================== -from typing import Dict, List, Optional - import jax.numpy as jnp -from jax.random import KeyArray from jaxtyping import Array, Float - from ..base import AbstractKernel -from ..computations import ( - DenseKernelComputation, -) - +from dataclasses import dataclass +from simple_pytree import static_field +from mytree import param_field, Softplus +@dataclass class Polynomial(AbstractKernel): """The Polynomial kernel with variable degree.""" - def __init__( - self, - degree: int = 1, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - name: Optional[str] = "Polynomial", - ) -> None: - super().__init__( - DenseKernelComputation, - active_dims, - spectral_density=None, - name=name, - ) - self.degree = degree - self.name = f"Polynomial Degree: {self.degree}" - self._stationary = False + degree: int = static_field(2) + shift: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) - def __call__( - self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] - ) -> Float[Array, "1"]: + def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with shift parameter :math:`\\alpha` and variance :math:`\\sigma^2` through .. math:: k(x, y) = \\Big( \\alpha + \\sigma^2 xy \\Big)^{d} Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call + x (Float[Array, "D"]): The left hand argument of the kernel function's call. + y (Float[Array, "D"]): The right hand argument of the kernel function's call Returns: Float[Array, "1"]: The value of :math:`k(x, y)`. """ x = self.slice_input(x).squeeze() y = self.slice_input(y).squeeze() - K = jnp.power(params["shift"] + jnp.dot(x * params["variance"], y), self.degree) - return K.squeeze() - - def init_params(self, key: KeyArray) -> Dict: - return { - "shift": jnp.array([1.0]), - "variance": jnp.array([1.0] * self.ndims), - } + K = jnp.power(self.shift + jnp.dot(x * self.variance, y), self.degree) + return K.squeeze() \ No newline at end of file diff --git a/gpjax/kernels/stationary/rbf.py b/gpjax/kernels/stationary/rbf.py index cc3e87b7f..3acb0608c 100644 --- a/gpjax/kernels/stationary/rbf.py +++ b/gpjax/kernels/stationary/rbf.py @@ -13,40 +13,23 @@ # limitations under the License. # ============================================================================== -from typing import Dict, List, Optional - -import jax import jax.numpy as jnp -from jax.random import KeyArray from jaxtyping import Array, Float from ..base import AbstractKernel -from ..computations import ( - DenseKernelComputation, -) from .utils import squared_distance import distrax as dx +from dataclasses import dataclass +from mytree import param_field, Softplus +@dataclass class RBF(AbstractKernel): """The Radial Basis Function (RBF) kernel.""" + lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus) + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) - def __init__( - self, - active_dims: Optional[List[int]] = None, - name: Optional[str] = "Radial basis function kernel", - ) -> None: - super().__init__( - DenseKernelComputation, - active_dims, - spectral_density=dx.Normal(loc=0.0, scale=1.0), - name=name, - ) - self._stationary = True - - def __call__( - self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] - ) -> Float[Array, "1"]: + def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` @@ -55,20 +38,16 @@ def __call__( Args: params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. + x (Float[Array, "D"]): The left hand argument of the kernel function's call. + y (Float[Array, "D"]): The right hand argument of the kernel function's call. Returns: Float[Array, "1"]: The value of :math:`k(x, y)`. """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] - K = params["variance"] * jnp.exp(-0.5 * squared_distance(x, y)) + x = self.slice_input(x) / self.lengthscale + y = self.slice_input(y) / self.lengthscale + K = self.variance * jnp.exp(-0.5 * squared_distance(x, y)) return K.squeeze() - - def init_params(self, key: KeyArray) -> Dict: - params = { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - } - return jax.tree_util.tree_map(lambda x: jnp.atleast_1d(x), params) + + def spectral_density(self) -> dx.Normal: + return dx.Normal(loc=0.0, scale=1.0) diff --git a/tests/test_kernels/test_approximations.py b/tests/test_kernels/test_approximations.py index b48434626..9dbf855fe 100644 --- a/tests/test_kernels/test_approximations.py +++ b/tests/test_kernels/test_approximations.py @@ -1,6 +1,6 @@ import pytest -from gpjax.kernels.approximations import RFF -from gpjax.kernels.stationary import ( +from jaxkern.approximations import RFF +from jaxkern.stationary import ( Matern12, Matern32, Matern52, @@ -9,8 +9,8 @@ PoweredExponential, Periodic, ) -from gpjax.kernels.nonstationary import Polynomial, Linear -from gpjax.kernels.base import AbstractKernel +from jaxkern.nonstationary import Polynomial, Linear +from jaxkern.base import AbstractKernel import jax.random as jr from jax.config import config import jax.numpy as jnp diff --git a/tests/test_kernels/test_base.py b/tests/test_kernels/test_base.py index 13a41deb5..7e3dddce5 100644 --- a/tests/test_kernels/test_base.py +++ b/tests/test_kernels/test_base.py @@ -17,7 +17,7 @@ import jax.random as jr import pytest from jax.config import config -from gpjax.linops import identity +from jaxlinop import identity from gpjax.kernels.base import ( AbstractKernel, diff --git a/tests/test_kernels/test_computation.py b/tests/test_kernels/test_computation.py index 8aeec91b9..0e33c076b 100644 --- a/tests/test_kernels/test_computation.py +++ b/tests/test_kernels/test_computation.py @@ -1,12 +1,11 @@ import jax.numpy as jnp -import jax.random as jr import pytest from gpjax.kernels.computations import ( DiagonalKernelComputation, ConstantDiagonalKernelComputation, ) -from gpjax.kernels.stationary import ( +from jaxkern.stationary import ( RBF, Matern12, Matern32, @@ -34,16 +33,14 @@ ) def test_change_computation(kernel): x = jnp.linspace(-3.0, 3.0, 5).reshape(-1, 1) - key = jr.PRNGKey(123) - params = kernel.init_params(key) # The default computation is DenseKernelComputation - dense_matrix = kernel.gram(params, x).to_dense() + dense_matrix = kernel.gram(x).to_dense() dense_diagonals = jnp.diag(dense_matrix) # Let's now change the computation to DiagonalKernelComputation - kernel.compute_engine = DiagonalKernelComputation - diagonal_matrix = kernel.gram(params, x).to_dense() + kernel = kernel.replace(compute_engine = DiagonalKernelComputation) + diagonal_matrix = kernel.gram(x).to_dense() diag_entries = jnp.diag(diagonal_matrix) # The diagonal entries should be the same as the dense matrix @@ -53,8 +50,8 @@ def test_change_computation(kernel): assert jnp.allclose(diagonal_matrix - jnp.diag(diag_entries), 0.0) # Let's now change the computation to ConstantDiagonalKernelComputation - kernel.compute_engine = ConstantDiagonalKernelComputation - constant_diagonal_matrix = kernel.gram(params, x).to_dense() + kernel = kernel.replace(compute_engine = ConstantDiagonalKernelComputation) + constant_diagonal_matrix = kernel.gram(x).to_dense() constant_entries = jnp.diag(constant_diagonal_matrix) # Assert all the diagonal entries are the same diff --git a/tests/test_kernels/test_non_euclidean.py b/tests/test_kernels/test_non_euclidean.py index abde34d95..7ae915839 100644 --- a/tests/test_kernels/test_non_euclidean.py +++ b/tests/test_kernels/test_non_euclidean.py @@ -17,7 +17,7 @@ import jax.random as jr import networkx as nx from jax.config import config -from gpjax.linops import identity +from jaxlinop import identity from gpjax.kernels.non_euclidean import GraphKernel diff --git a/tests/test_kernels/test_nonstationary.py b/tests/test_kernels/test_nonstationary.py index 20092026d..b3932855e 100644 --- a/tests/test_kernels/test_nonstationary.py +++ b/tests/test_kernels/test_nonstationary.py @@ -17,10 +17,10 @@ import jax.numpy as jnp import jax.random as jr +import jax.tree_util as jtu import pytest from jax.config import config from gpjax.linops import LinearOperator, identity -from jaxutils.parameters import initialise from gpjax.kernels.base import AbstractKernel from gpjax.kernels.nonstationary import Linear, Polynomial @@ -48,11 +48,8 @@ def test_gram(kernel: AbstractKernel, dim: int, n: int) -> None: # Inputs x: x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) - # Default kernel parameters: - params = kernel.init_params(_initialise_key) - # Test gram matrix: - Kxx = kernel.gram(params, x) + Kxx = kernel.gram(x) assert isinstance(Kxx, LinearOperator) assert Kxx.shape == (n, n) @@ -74,11 +71,8 @@ def test_cross_covariance( a = jnp.linspace(-1.0, 1.0, num_a * dim).reshape(num_a, dim) b = jnp.linspace(3.0, 4.0, num_b * dim).reshape(num_b, dim) - # Default kernel parameters: - params = kernel.init_params(_initialise_key) - # Test cross covariance, Kab: - Kab = kernel.cross_covariance(params, a, b) + Kab = kernel.cross_covariance(a, b) assert isinstance(Kab, jnp.ndarray) assert Kab.shape == (num_a, num_b) @@ -97,28 +91,30 @@ def test_pos_def( # Create inputs x: x = jr.uniform(_initialise_key, (n, dim)) - params = {"variance": jnp.array([sigma]), "shift": jnp.array([shift])} + + if isinstance(kern, Polynomial): + kern = kern.replace(shift=shift, variance=sigma) + else: + kern = kern.replace(variance=sigma) # Test gram matrix eigenvalues are positive: - Kxx = kern.gram(params, x) + Kxx = kern.gram(x) Kxx += identity(n) * _jitter eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) assert (eigen_values > 0.0).all() -@pytest.mark.parametrize( - "kernel", - [ - Linear, - Polynomial, - ], -) -def test_dtype(kernel: AbstractKernel) -> None: - parameter_state = initialise(kernel(), _initialise_key) - params, *_ = parameter_state.unpack() - for k, v in params.items(): - assert v.dtype == jnp.float64 - assert isinstance(k, str) +# @pytest.mark.parametrize( +# "kernel", +# [ +# Linear, +# Polynomial, +# ], +# ) +# def test_dtype(kernel: AbstractKernel) -> None: +# params_list = jtu.tree_leaves(kernel()) +# for v in params_list: +# assert v.dtype == jnp.float64 @pytest.mark.parametrize("degree", [1, 2, 3]) @@ -136,19 +132,14 @@ def test_polynomial( # Define kernel kern = Polynomial(degree=degree, active_dims=[i for i in range(dim)]) - # Check name - assert kern.name == f"Polynomial Degree: {degree}" + # # Check name + # assert kern.name == f"Polynomial Degree: {degree}" # Initialise parameters - params = kern.init_params(_initialise_key) - params["shift"] * shift - params["variance"] * variance - - # Check parameter keys - assert list(params.keys()) == ["shift", "variance"] + kern = kern.replace(shift=kern.shift * shift, variance=kern.variance * variance) # Compute gram matrix - Kxx = kern.gram(params, x) + Kxx = kern.gram(x) # Check shapes assert Kxx.shape[0] == x.shape[0] @@ -181,13 +172,9 @@ def test_active_dim(kernel: AbstractKernel) -> None: ad_kern = kernel(active_dims=dp) manual_kern = kernel(active_dims=[i for i in range(perm_length)]) - # Get initial parameters - ad_params = ad_kern.init_params(_initialise_key) - manual_params = manual_kern.init_params(_initialise_key) - # Compute gram matrices - ad_Kxx = ad_kern.gram(ad_params, x) - manual_Kxx = manual_kern.gram(manual_params, slice) + ad_Kxx = ad_kern.gram(x) + manual_Kxx = manual_kern.gram(slice) # Test gram matrices are equal assert jnp.all(ad_Kxx.to_dense() == manual_Kxx.to_dense()) diff --git a/tests/test_kernels/test_stationary.py b/tests/test_kernels/test_stationary.py index 83d89db67..262765cdb 100644 --- a/tests/test_kernels/test_stationary.py +++ b/tests/test_kernels/test_stationary.py @@ -23,7 +23,6 @@ import distrax as dx from jax.config import config from gpjax.linops import LinearOperator, identity -from jaxutils.parameters import initialise from gpjax.kernels.base import AbstractKernel from gpjax.kernels.stationary import ( @@ -31,10 +30,6 @@ Matern12, Matern32, Matern52, - PoweredExponential, - RationalQuadratic, - Periodic, - White, ) from gpjax.kernels.stationary.utils import build_student_t_distribution @@ -48,11 +43,11 @@ "kernel", [ RBF(), - Matern12(), - Matern32(), - Matern52(), - RationalQuadratic(), - White(), + # Matern12(), + # Matern32(), + # Matern52(), + # RationalQuadratic(), + # White(), ], ) @pytest.mark.parametrize("dim", [1, 2, 5]) @@ -65,11 +60,8 @@ def test_gram(kernel: AbstractKernel, dim: int, n: int) -> None: # Inputs x: x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) - # Default kernel parameters: - params = kernel.init_params(_initialise_key) - # Test gram matrix: - Kxx = kernel.gram(params, x) + Kxx = kernel.gram(x) assert isinstance(Kxx, LinearOperator) assert Kxx.shape == (n, n) @@ -78,11 +70,11 @@ def test_gram(kernel: AbstractKernel, dim: int, n: int) -> None: "kernel", [ RBF(), - Matern12(), - Matern32(), - Matern52(), - RationalQuadratic(), - White(), + # Matern12(), + # Matern32(), + # Matern52(), + # RationalQuadratic(), + # White(), ], ) @pytest.mark.parametrize("num_a", [1, 2, 5]) @@ -95,16 +87,22 @@ def test_cross_covariance( a = jnp.linspace(-1.0, 1.0, num_a * dim).reshape(num_a, dim) b = jnp.linspace(3.0, 4.0, num_b * dim).reshape(num_b, dim) - # Default kernel parameters: - params = kernel.init_params(_initialise_key) - # Test cross covariance, Kab: - Kab = kernel.cross_covariance(params, a, b) + Kab = kernel.cross_covariance(a, b) assert isinstance(Kab, jnp.ndarray) assert Kab.shape == (num_a, num_b) -@pytest.mark.parametrize("kernel", [RBF(), Matern12(), Matern32(), Matern52(), White()]) +@pytest.mark.parametrize( + "kernel", + [ + RBF(), + # Matern12(), + # Matern32(), + # Matern52(), + # White(), + ], +) @pytest.mark.parametrize("dim", [1, 2, 5]) def test_call(kernel: AbstractKernel, dim: int) -> None: @@ -112,158 +110,179 @@ def test_call(kernel: AbstractKernel, dim: int) -> None: x = jnp.array([[1.0] * dim]) y = jnp.array([[0.5] * dim]) - # Defualt parameters: - params = kernel.init_params(_initialise_key) - # Test calling gives an autocovariance value of no dimension between the inputs: - kxy = kernel(params, x, y) + kxy = kernel(x, y) assert isinstance(kxy, jax.Array) assert kxy.shape == () -@pytest.mark.parametrize("kern", [RBF, Matern12, Matern32, Matern52]) +@pytest.mark.parametrize( + "kern", + [ + RBF, + # Matern12, + # Matern32, + # Matern52, + ], +) @pytest.mark.parametrize("dim", [1, 2, 5]) @pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) @pytest.mark.parametrize("n", [1, 2, 5]) def test_pos_def( kern: AbstractKernel, dim: int, ell: float, sigma: float, n: int ) -> None: - kern = kern(active_dims=list(range(dim))) - - # Create inputs x: - x = jr.uniform(_initialise_key, (n, dim)) - params = {"lengthscale": jnp.array([ell]), "variance": jnp.array([sigma])} - - # Test gram matrix eigenvalues are positive: - Kxx = kern.gram(params, x) - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert (eigen_values > 0.0).all() - - -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) -@pytest.mark.parametrize("alpha", [0.1, 0.5, 1.0]) -@pytest.mark.parametrize("n", [1, 2, 5]) -def test_pos_def_rq(dim: int, ell: float, sigma: float, alpha: float, n: int) -> None: - kern = RationalQuadratic(active_dims=list(range(dim))) - # Gram constructor static method: - kern.gram - - # Create inputs x: - x = jr.uniform(_initialise_key, (n, dim)) - params = { - "lengthscale": jnp.array([ell]), - "variance": jnp.array([sigma]), - "alpha": jnp.array([alpha]), - } - - # Test gram matrix eigenvalues are positive: - Kxx = kern.gram(params, x) - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert (eigen_values > 0.0).all() - - -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) -@pytest.mark.parametrize("period", [0.1, 0.5, 1.0]) -@pytest.mark.parametrize("n", [1, 2, 5]) -def test_pos_def_periodic( - dim: int, ell: float, sigma: float, period: float, n: int -) -> None: - kern = Periodic(active_dims=list(range(dim))) - # Gram constructor static method: - kern.gram - - # Create inputs x: - x = jr.uniform(_initialise_key, (n, dim)) - params = { - "lengthscale": jnp.array([ell]), - "variance": jnp.array([sigma]), - "period": jnp.array([period]), - } - - # Test gram matrix eigenvalues are positive: - Kxx = kern.gram(params, x) - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert (eigen_values > 0.0).all() - - -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) -@pytest.mark.parametrize("power", [0.1, 0.5, 1.0]) -@pytest.mark.parametrize("n", [1, 2, 5]) -def test_pos_def_power_exp( - dim: int, ell: float, sigma: float, power: float, n: int -) -> None: - kern = PoweredExponential(active_dims=list(range(dim))) - # Gram constructor static method: - kern.gram + kern = kern( + active_dims=list(range(dim)), + lengthscale=jnp.array([ell]), + variance=jnp.array([sigma]), + ) # Create inputs x: x = jr.uniform(_initialise_key, (n, dim)) - params = { - "lengthscale": jnp.array([ell]), - "variance": jnp.array([sigma]), - "power": jnp.array([power]), - } # Test gram matrix eigenvalues are positive: - Kxx = kern.gram(params, x) + Kxx = kern.gram(x) Kxx += identity(n) * _jitter eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) assert (eigen_values > 0.0).all() -@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) -@pytest.mark.parametrize("dim", [None, 1, 2, 5, 10]) -def test_initialisation(kernel: AbstractKernel, dim: int) -> None: - - if dim is None: - kern = kernel() - assert kern.ndims == 1 - - else: - kern = kernel(active_dims=[i for i in range(dim)]) - params = kern.init_params(_initialise_key) - - assert list(params.keys()) == ["lengthscale", "variance"] - assert all(params["lengthscale"] == jnp.array([1.0] * dim)) - assert params["variance"] == jnp.array([1.0]) - - if dim > 1: - assert kern.ard - else: - assert not kern.ard +# @pytest.mark.parametrize("dim", [1, 2, 5]) +# @pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) +# @pytest.mark.parametrize("alpha", [0.1, 0.5, 1.0]) +# @pytest.mark.parametrize("n", [1, 2, 5]) +# def test_pos_def_rq(dim: int, ell: float, sigma: float, alpha: float, n: int) -> None: +# kern = RationalQuadratic(active_dims=list(range(dim))) +# # Gram constructor static method: +# kern.gram + +# # Create inputs x: +# x = jr.uniform(_initialise_key, (n, dim)) +# params = { +# "lengthscale": jnp.array([ell]), +# "variance": jnp.array([sigma]), +# "alpha": jnp.array([alpha]), +# } + +# # Test gram matrix eigenvalues are positive: +# Kxx = kern.gram(params, x) +# Kxx += identity(n) * _jitter +# eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) +# assert (eigen_values > 0.0).all() + + +# @pytest.mark.parametrize("dim", [1, 2, 5]) +# @pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) +# @pytest.mark.parametrize("period", [0.1, 0.5, 1.0]) +# @pytest.mark.parametrize("n", [1, 2, 5]) +# def test_pos_def_periodic( +# dim: int, ell: float, sigma: float, period: float, n: int +# ) -> None: +# kern = Periodic(active_dims=list(range(dim))) +# # Gram constructor static method: +# kern.gram + +# # Create inputs x: +# x = jr.uniform(_initialise_key, (n, dim)) +# params = { +# "lengthscale": jnp.array([ell]), +# "variance": jnp.array([sigma]), +# "period": jnp.array([period]), +# } + +# # Test gram matrix eigenvalues are positive: +# Kxx = kern.gram(params, x) +# Kxx += identity(n) * _jitter +# eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) +# # assert (eigen_values > 0.0).all() + + +# @pytest.mark.parametrize("dim", [1, 2, 5]) +# @pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) +# @pytest.mark.parametrize("power", [0.1, 0.5, 1.0]) +# @pytest.mark.parametrize("n", [1, 2, 5]) +# def test_pos_def_power_exp( +# dim: int, ell: float, sigma: float, power: float, n: int +# ) -> None: +# kern = PoweredExponential(active_dims=list(range(dim))) +# # Gram constructor static method: +# kern.gram + +# # Create inputs x: +# x = jr.uniform(_initialise_key, (n, dim)) +# params = { +# "lengthscale": jnp.array([ell]), +# "variance": jnp.array([sigma]), +# "power": jnp.array([power]), +# } + +# # Test gram matrix eigenvalues are positive: +# Kxx = kern.gram(params, x) +# Kxx += identity(n) * _jitter +# eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) +# assert (eigen_values > 0.0).all() + + +# @pytest.mark.parametrize("kernel", +# [ +# RBF, +# #Matern12, +# #Matern32, +# #Matern52, +# ], +# ) +# @pytest.mark.parametrize("dim", [None, 1, 2, 5, 10]) +# def test_initialisation(kernel: AbstractKernel, dim: int) -> None: + +# if dim is None: +# kern = kernel() +# assert kern.ndims == 1 + +# else: +# kern = kernel(active_dims=[i for i in range(dim)]) +# params = kern.init_params(_initialise_key) + +# assert list(params.keys()) == ["lengthscale", "variance"] +# assert all(params["lengthscale"] == jnp.array([1.0] * dim)) +# assert params["variance"] == jnp.array([1.0]) + +# if dim > 1: +# assert kern.ard +# else: +# assert not kern.ard + + +# @pytest.mark.parametrize( +# "kernel", +# [ +# RBF, +# # Matern12, +# # Matern32, +# # Matern52, +# # RationalQuadratic, +# # Periodic, +# # PoweredExponential, +# ], +# ) +# def test_dtype(kernel: AbstractKernel) -> None: +# parameter_state = initialise(kernel(), _initialise_key) +# params, *_ = parameter_state.unpack() +# for k, v in params.items(): +# assert v.dtype == jnp.float64 +# assert isinstance(k, str) @pytest.mark.parametrize( "kernel", [ RBF, - Matern12, - Matern32, - Matern52, - RationalQuadratic, - Periodic, - PoweredExponential, + # Matern12, + # Matern32, + # Matern52, + # RationalQuadratic, ], ) -def test_dtype(kernel: AbstractKernel) -> None: - parameter_state = initialise(kernel(), _initialise_key) - params, *_ = parameter_state.unpack() - for k, v in params.items(): - assert v.dtype == jnp.float64 - assert isinstance(k, str) - - -@pytest.mark.parametrize( - "kernel", - [RBF, Matern12, Matern32, Matern52, RationalQuadratic], -) def test_active_dim(kernel: AbstractKernel) -> None: dim_list = [0, 1, 2, 3] perm_length = 2 @@ -281,13 +300,9 @@ def test_active_dim(kernel: AbstractKernel) -> None: ad_kern = kernel(active_dims=dp) manual_kern = kernel(active_dims=[i for i in range(perm_length)]) - # Get initial parameters - ad_params = ad_kern.init_params(_initialise_key) - manual_params = manual_kern.init_params(_initialise_key) - # Compute gram matrices - ad_Kxx = ad_kern.gram(ad_params, x) - manual_Kxx = manual_kern.gram(manual_params, slice) + ad_Kxx = ad_kern.gram(x) + manual_Kxx = manual_kern.gram(slice) # Test gram matrices are equal assert jnp.all(ad_Kxx.to_dense() == manual_Kxx.to_dense()) @@ -299,20 +314,20 @@ def test_build_studentt_dist(smoothness: int) -> None: assert isinstance(dist, dx.Distribution) -@pytest.mark.parametrize( - "kern, df", [(Matern12(), 1), (Matern32(), 3), (Matern52(), 5)] -) -def test_matern_spectral_density(kern, df) -> None: - sdensity = kern.spectral_density - assert sdensity.name == "StudentT" - assert sdensity.df == df - assert sdensity.loc == jnp.array(0.0) - assert sdensity.scale == jnp.array(1.0) - - -def test_rbf_spectral_density() -> None: - kern = RBF() - sdensity = kern.spectral_density - assert sdensity.name == "Normal" - assert sdensity.loc == jnp.array(0.0) - assert sdensity.scale == jnp.array(1.0) +# @pytest.mark.parametrize( +# "kern, df", [(Matern12(), 1), (Matern32(), 3), (Matern52(), 5)] +# ) +# def test_matern_spectral_density(kern, df) -> None: +# sdensity = kern.spectral_density +# assert sdensity.name == "StudentT" +# assert sdensity.df == df +# assert sdensity.loc == jnp.array(0.0) +# assert sdensity.scale == jnp.array(1.0) + + +# def test_rbf_spectral_density() -> None: +# kern = RBF() +# sdensity = kern.spectral_density +# assert sdensity.name == "Normal" +# assert sdensity.loc == jnp.array(0.0) +# assert sdensity.scale == jnp.array(1.0) From 1e3591f2e49d0afc8968d968671f5da91b8fcbaf Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Mon, 27 Mar 2023 20:36:57 +0100 Subject: [PATCH 03/44] Fix mytree links --- gpjax/parameters/__init__.py | 5 +++++ gpjax/{ => parameters}/bijectors.py | 0 gpjax/{ => parameters}/module.py | 0 gpjax/{ => parameters}/param.py | 0 tests/test_bijectors.py | 2 +- tests/test_param.py | 2 +- 6 files changed, 7 insertions(+), 2 deletions(-) create mode 100644 gpjax/parameters/__init__.py rename gpjax/{ => parameters}/bijectors.py (100%) rename gpjax/{ => parameters}/module.py (100%) rename gpjax/{ => parameters}/param.py (100%) diff --git a/gpjax/parameters/__init__.py b/gpjax/parameters/__init__.py new file mode 100644 index 000000000..19c6675c9 --- /dev/null +++ b/gpjax/parameters/__init__.py @@ -0,0 +1,5 @@ +from .bijectors import Identity, Softplus +from .module import Module +from .param import param_field + +__all__ = ['Identity', 'Module', 'Softplus', 'param_field'] \ No newline at end of file diff --git a/gpjax/bijectors.py b/gpjax/parameters/bijectors.py similarity index 100% rename from gpjax/bijectors.py rename to gpjax/parameters/bijectors.py diff --git a/gpjax/module.py b/gpjax/parameters/module.py similarity index 100% rename from gpjax/module.py rename to gpjax/parameters/module.py diff --git a/gpjax/param.py b/gpjax/parameters/param.py similarity index 100% rename from gpjax/param.py rename to gpjax/parameters/param.py diff --git a/tests/test_bijectors.py b/tests/test_bijectors.py index 9e6eb5996..14a486a08 100644 --- a/tests/test_bijectors.py +++ b/tests/test_bijectors.py @@ -1,7 +1,7 @@ import jax.numpy as jnp import pytest -from mytree.bijectors import Bijector, Identity, Softplus +from gpjax.parameters.bijectors import Bijector, Identity, Softplus def test_bijector(): diff --git a/tests/test_param.py b/tests/test_param.py index 0dffb163f..c50caae37 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -2,7 +2,7 @@ import pytest -from mytree import Identity, Softplus, param_field +from gpjax.parameters import Identity, Softplus, param_field @pytest.mark.parametrize("bijector", [Identity, Softplus]) From f8612443a1235685b78855cf53e9d68ef13ee673 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Mon, 27 Mar 2023 20:52:10 +0100 Subject: [PATCH 04/44] Tests fixed --- gpjax/__init__.py | 6 +- gpjax/_version.py | 2 +- gpjax/abstractions.py | 9 +- gpjax/gaussian_distribution.py | 13 +- gpjax/gps.py | 15 +- gpjax/kernels/base.py | 18 +- gpjax/kernels/nonstationary/linear.py | 2 +- gpjax/kernels/nonstationary/polynomial.py | 2 +- gpjax/kernels/stationary/rbf.py | 4 +- gpjax/likelihoods.py | 9 +- gpjax/mean_functions.py | 3 +- gpjax/natural_gradients.py | 2 +- gpjax/{parameters.py => params.py} | 0 gpjax/types.py | 2 +- gpjax/utils.py | 2 +- gpjax/variational_families.py | 14 +- gpjax/variational_inference.py | 11 +- tests/test_config.py | 10 +- tests/test_gaussian_distribution.py | 3 +- tests/test_gps.py | 4 +- tests/test_likelihoods.py | 3 +- tests/test_mean_functions.py | 2 +- tests/test_module.py | 3 +- tests/test_param.py | 41 --- tests/test_parameters.py | 308 +++------------------- tests/test_params.py | 273 +++++++++++++++++++ tests/test_types.py | 1 + tests/test_variational_families.py | 4 +- 28 files changed, 373 insertions(+), 393 deletions(-) rename gpjax/{parameters.py => params.py} (100%) delete mode 100644 tests/test_param.py create mode 100644 tests/test_params.py diff --git a/gpjax/__init__.py b/gpjax/__init__.py index a5bc4c95f..dd47bb6a3 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -13,12 +13,14 @@ # limitations under the License. # ============================================================================== +from . import _version from .abstractions import fit, fit_batches, fit_natgrads from .gps import Prior, construct_posterior from .kernels import * from .likelihoods import Bernoulli, Gaussian from .mean_functions import Constant, Zero -from .parameters import constrain, copy_dict_structure, initialise, unconstrain +from .params import constrain, copy_dict_structure, initialise, unconstrain +from .types import Dataset from .variational_families import ( CollapsedVariationalGaussian, ExpectationVariationalGaussian, @@ -26,9 +28,7 @@ VariationalGaussian, WhitenedVariationalGaussian, ) -from .types import Dataset from .variational_inference import CollapsedVI, StochasticVI -from . import _version __version__ = _version.get_versions()["version"] __license__ = "MIT" diff --git a/gpjax/_version.py b/gpjax/_version.py index 0aaf90d44..b5961bba0 100644 --- a/gpjax/_version.py +++ b/gpjax/_version.py @@ -11,12 +11,12 @@ """Git implementation of _version.py.""" import errno +import functools import os import re import subprocess import sys from typing import Callable, Dict -import functools def get_keywords(): diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index ac56b39dc..3682aa950 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -13,22 +13,21 @@ # limitations under the License. # ============================================================================== -from typing import Callable, Dict, Optional, Tuple, Any, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import jax import jax.numpy as jnp import jax.random as jr import optax as ox - -from jax.random import KeyArray from jax import lax from jax.experimental import host_callback +from jax.random import KeyArray from jaxtyping import Array, Float +from jaxutils import Dataset, PyTree from tqdm.auto import tqdm from .natural_gradients import natural_gradients -from .parameters import ParameterState, constrain, trainable_params, unconstrain -from jaxutils import Dataset, PyTree +from .params import ParameterState, constrain, trainable_params, unconstrain from .variational_inference import StochasticVI diff --git a/gpjax/gaussian_distribution.py b/gpjax/gaussian_distribution.py index 6db481d2f..dd27c5246 100644 --- a/gpjax/gaussian_distribution.py +++ b/gpjax/gaussian_distribution.py @@ -13,17 +13,16 @@ # limitations under the License. # ============================================================================== -import jax.numpy as jnp -from .linops import LinearOperator, IdentityLinearOperator - -from jaxtyping import Array, Float -from jax import vmap - -from typing import Tuple, Optional, Any +from typing import Any, Optional, Tuple import distrax as dx +import jax.numpy as jnp import jax.random as jr +from jax import vmap from jax.random import KeyArray +from jaxtyping import Array, Float + +from .linops import IdentityLinearOperator, LinearOperator def _check_loc_scale(loc: Optional[Any], scale: Optional[Any]) -> None: diff --git a/gpjax/gps.py b/gpjax/gps.py index de8442940..15cb826aa 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -16,24 +16,21 @@ from abc import abstractmethod from typing import Any, Callable, Dict, Optional +import deprecation import distrax as dx import jax.numpy as jnp -from jaxtyping import Array, Float from jax.random import KeyArray - -from .linops import identity -from .kernels.base import AbstractKernel -from jaxutils import PyTree +from jaxtyping import Array, Float +from jaxutils import Dataset, PyTree from .config import get_global_config +from .gaussian_distribution import GaussianDistribution from .kernels import AbstractKernel +from .kernels.base import AbstractKernel from .likelihoods import AbstractLikelihood, Conjugate, NonConjugate +from .linops import identity from .mean_functions import AbstractMeanFunction, Zero -from jaxutils import Dataset from .utils import concat_dictionaries -from .gaussian_distribution import GaussianDistribution - -import deprecation class AbstractPrior(PyTree): diff --git a/gpjax/kernels/base.py b/gpjax/kernels/base.py index 725d3e022..f4db57879 100644 --- a/gpjax/kernels/base.py +++ b/gpjax/kernels/base.py @@ -20,7 +20,7 @@ from typing import List, Callable, Union from jaxtyping import Array, Float from functools import partial -from mytree import Mytree, param_field +from ..parameters import Module, param_field from simple_pytree import static_field from dataclasses import dataclass from functools import partial @@ -28,7 +28,7 @@ from .computations import AbstractKernelComputation, DenseKernelComputation @dataclass -class AbstractKernel(Mytree): +class AbstractKernel(Module): """Base kernel class.""" compute_engine: AbstractKernelComputation = static_field(DenseKernelComputation) active_dims: List[int] = static_field(None) @@ -42,7 +42,7 @@ def cross_covariance(self, x: Float[Array, "N D"], y: Float[Array, "M D"]): def gram(self, x: Float[Array, "N D"]): return self.compute_engine(self).gram(x) - + def slice_input(self, x: Float[Array, "N D"]) -> Float[Array, "N S"]: """Select the relevant columns of the supplied matrix to be used within the kernel's evaluation. @@ -81,9 +81,9 @@ def __add__(self, other: Union[AbstractKernel, Float[Array, "1"]]) -> AbstractKe if isinstance(other, AbstractKernel): return SumKernel([self, other]) - + return SumKernel([self, Constant(other)]) - + def __radd__(self, other: Union[AbstractKernel, Float[Array, "1"]]) -> AbstractKernel: """Add two kernels together. Args: @@ -105,7 +105,7 @@ def __mul__(self, other: Union[AbstractKernel, Float[Array, "1"]]) -> AbstractKe """ if isinstance(other, AbstractKernel): return ProductKernel([self, other]) - + return ProductKernel([self, Constant(other)]) @@ -135,11 +135,11 @@ class CombinationKernel(AbstractKernel): """A base class for products or sums of MeanFunctions.""" kernels: List[AbstractKernel] = None operator: Callable = static_field(None) - + def __post_init__(self): #Add kernels to a list, flattening out instances of this class therein, as in GPFlow kernels. kernels_list: List[AbstractKernel] = [] - + for kernel in self.kernels: if not isinstance(kernel, AbstractKernel): raise TypeError("can only combine Kernel instances") # pragma: no cover @@ -166,7 +166,7 @@ def __call__( Float[Array, "1"]: The evaluated kernel function at the supplied inputs. """ return self.operator(jnp.stack([k(x, y) for k in self.kernels])) - + SumKernel = partial(CombinationKernel, operator=jnp.sum) ProductKernel = partial(CombinationKernel, operator=jnp.sum) \ No newline at end of file diff --git a/gpjax/kernels/nonstationary/linear.py b/gpjax/kernels/nonstationary/linear.py index 226e56840..7749e961e 100644 --- a/gpjax/kernels/nonstationary/linear.py +++ b/gpjax/kernels/nonstationary/linear.py @@ -20,7 +20,7 @@ from dataclasses import dataclass from jaxtyping import Array, Float -from mytree import param_field, Softplus +from ...parameters import param_field, Softplus @dataclass class Linear(AbstractKernel): diff --git a/gpjax/kernels/nonstationary/polynomial.py b/gpjax/kernels/nonstationary/polynomial.py index 582460cca..a87cacaa4 100644 --- a/gpjax/kernels/nonstationary/polynomial.py +++ b/gpjax/kernels/nonstationary/polynomial.py @@ -18,7 +18,7 @@ from ..base import AbstractKernel from dataclasses import dataclass from simple_pytree import static_field -from mytree import param_field, Softplus +from ...parameters import param_field, Softplus @dataclass class Polynomial(AbstractKernel): diff --git a/gpjax/kernels/stationary/rbf.py b/gpjax/kernels/stationary/rbf.py index 3acb0608c..2ab9c1d08 100644 --- a/gpjax/kernels/stationary/rbf.py +++ b/gpjax/kernels/stationary/rbf.py @@ -21,7 +21,7 @@ import distrax as dx from dataclasses import dataclass -from mytree import param_field, Softplus +from ...parameters import param_field, Softplus @dataclass class RBF(AbstractKernel): @@ -48,6 +48,6 @@ def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, " y = self.slice_input(y) / self.lengthscale K = self.variance * jnp.exp(-0.5 * squared_distance(x, y)) return K.squeeze() - + def spectral_density(self) -> dx.Normal: return dx.Normal(loc=0.0, scale=1.0) diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index 8a8a3194f..5543c6de7 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -15,17 +15,16 @@ import abc from typing import Any, Callable, Dict, Optional -from .linops.utils import to_dense -from jaxutils import PyTree +import deprecation import distrax as dx import jax.numpy as jnp import jax.scipy as jsp -from jaxtyping import Array, Float - from jax.random import KeyArray +from jaxtyping import Array, Float +from jaxutils import PyTree -import deprecation +from .linops.utils import to_dense class AbstractLikelihood(PyTree): diff --git a/gpjax/mean_functions.py b/gpjax/mean_functions.py index 75ebf43ee..43c72b3d6 100644 --- a/gpjax/mean_functions.py +++ b/gpjax/mean_functions.py @@ -16,13 +16,12 @@ import abc from typing import Dict, Optional +import deprecation import jax.numpy as jnp from jax.random import KeyArray from jaxtyping import Array, Float from jaxutils import PyTree -import deprecation - class AbstractMeanFunction(PyTree): """Abstract mean function that is used to parameterise the Gaussian process.""" diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index baa169390..acdf5bd36 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -24,7 +24,7 @@ from .config import get_global_config from .gps import AbstractPosterior -from .parameters import build_trainables, constrain, trainable_params +from .params import build_trainables, constrain, trainable_params from .variational_families import ( AbstractVariationalFamily, ExpectationVariationalGaussian, diff --git a/gpjax/parameters.py b/gpjax/params.py similarity index 100% rename from gpjax/parameters.py rename to gpjax/params.py diff --git a/gpjax/types.py b/gpjax/types.py index d1e8b110a..dfa9abb7b 100644 --- a/gpjax/types.py +++ b/gpjax/types.py @@ -13,8 +13,8 @@ # limitations under the License. # ============================================================================== -import jaxutils import deprecation +import jaxutils Dataset = deprecation.deprecated( deprecated_in="0.5.5", diff --git a/gpjax/utils.py b/gpjax/utils.py index 27dcb507a..254c76ddc 100644 --- a/gpjax/utils.py +++ b/gpjax/utils.py @@ -13,8 +13,8 @@ # limitations under the License. # ============================================================================== -import jaxutils import deprecation +import jaxutils depreciate = deprecation.deprecated( deprecated_in="0.5.6", diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index 60315f3fc..f25ac340d 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -16,26 +16,20 @@ import abc from typing import Any, Callable, Dict, Optional +import deprecation import distrax as dx import jax.numpy as jnp import jax.scipy as jsp from jax.random import KeyArray from jaxtyping import Array, Float - -from .linops import identity -from jaxutils import PyTree, Dataset -from .linops import ( - DenseLinearOperator, - LowerTriangularLinearOperator -) +from jaxutils import Dataset, PyTree from .config import get_global_config +from .gaussian_distribution import GaussianDistribution from .gps import Prior from .likelihoods import AbstractLikelihood, Gaussian +from .linops import DenseLinearOperator, LowerTriangularLinearOperator, identity from .utils import concat_dictionaries -from .gaussian_distribution import GaussianDistribution - -import deprecation class AbstractVariationalFamily(PyTree): diff --git a/gpjax/variational_inference.py b/gpjax/variational_inference.py index 219db7798..3942079ae 100644 --- a/gpjax/variational_inference.py +++ b/gpjax/variational_inference.py @@ -16,28 +16,25 @@ import abc from typing import Callable, Dict +import deprecation import jax.numpy as jnp import jax.scipy as jsp from jax import vmap -from jaxtyping import Array, Float - -from .linops import identity from jax.random import KeyArray -from jaxutils import PyTree +from jaxtyping import Array, Float +from jaxutils import Dataset, PyTree from .config import get_global_config from .gps import AbstractPosterior from .likelihoods import Gaussian +from .linops import identity from .quadrature import gauss_hermite_quadrature -from jaxutils import Dataset from .utils import concat_dictionaries from .variational_families import ( AbstractVariationalFamily, CollapsedVariationalGaussian, ) -import deprecation - class AbstractVariationalInference(PyTree): """A base class for inference and training of variational families against an exact posterior""" diff --git a/tests/test_config.py b/tests/test_config.py index e084e5761..39be389c3 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -13,17 +13,13 @@ # limitations under the License. # ============================================================================== -import jax import distrax as dx +import jax from jax.config import config from ml_collections import ConfigDict -from gpjax.config import ( - Identity, - add_parameter, - get_global_config, - get_global_config_if_exists, # ignore: unused-import -) +from gpjax.config import get_global_config_if_exists # ignore: unused-import +from gpjax.config import Identity, add_parameter, get_global_config # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) diff --git a/tests/test_gaussian_distribution.py b/tests/test_gaussian_distribution.py index 612153a88..f3f522fe4 100644 --- a/tests/test_gaussian_distribution.py +++ b/tests/test_gaussian_distribution.py @@ -23,11 +23,10 @@ # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) +from gpjax.gaussian_distribution import GaussianDistribution from gpjax.linops.dense_linear_operator import DenseLinearOperator from gpjax.linops.diagonal_linear_operator import DiagonalLinearOperator -from gpjax.gaussian_distribution import GaussianDistribution - _key = jr.PRNGKey(seed=42) from distrax import MultivariateNormalDiag, MultivariateNormalFullCovariance diff --git a/tests/test_gps.py b/tests/test_gps.py index 9bab15d84..df90e0e8c 100644 --- a/tests/test_gps.py +++ b/tests/test_gps.py @@ -15,8 +15,8 @@ import typing as tp -import jax import distrax as dx +import jax import jax.numpy as jnp import jax.random as jr import pytest @@ -24,8 +24,8 @@ from gpjax import Dataset, initialise from gpjax.gps import ( - AbstractPrior, AbstractPosterior, + AbstractPrior, ConjugatePosterior, NonConjugatePosterior, Prior, diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py index d27770b36..6d9d7918c 100644 --- a/tests/test_likelihoods.py +++ b/tests/test_likelihoods.py @@ -20,8 +20,8 @@ import jax.random as jr import numpy as np import pytest -from jax.random import KeyArray from jax.config import config +from jax.random import KeyArray from jaxtyping import Array, Float from gpjax.likelihoods import ( @@ -33,7 +33,6 @@ inv_probit, ) - # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) _initialise_key = jr.PRNGKey(123) diff --git a/tests/test_mean_functions.py b/tests/test_mean_functions.py index 46acdacb1..9d97d82ee 100644 --- a/tests/test_mean_functions.py +++ b/tests/test_mean_functions.py @@ -18,8 +18,8 @@ import jax.numpy as jnp import jax.random as jr import pytest -from jax.random import KeyArray from jax.config import config +from jax.random import KeyArray from jaxtyping import Array, Float from gpjax.mean_functions import AbstractMeanFunction, Constant, Zero diff --git a/tests/test_module.py b/tests/test_module.py index 41042ecc8..dfe445e38 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -9,10 +9,11 @@ from flax import serialization from simple_pytree import Pytree, static_field -from gpjax.module import Module, meta from gpjax.bijectors import Identity, Softplus +from gpjax.module import Module, meta from gpjax.param import param_field + @pytest.mark.parametrize("is_dataclass", [True, False]) def test_init_and_meta_scrambled(is_dataclass): class Tree(Module): diff --git a/tests/test_param.py b/tests/test_param.py deleted file mode 100644 index c50caae37..000000000 --- a/tests/test_param.py +++ /dev/null @@ -1,41 +0,0 @@ -import dataclasses - -import pytest - -from gpjax.parameters import Identity, Softplus, param_field - - -@pytest.mark.parametrize("bijector", [Identity, Softplus]) -@pytest.mark.parametrize("trainable", [True, False]) -def test_param(bijector, trainable): - param_field_ = param_field(bijector=bijector, trainable=trainable) - assert isinstance(param_field_, dataclasses.Field) - assert param_field_.metadata["bijector"] == bijector - assert param_field_.metadata["trainable"] == trainable - - with pytest.raises(ValueError): - param_field( - bijector=bijector, trainable=trainable, metadata={"trainable": trainable} - ) - - with pytest.raises(ValueError): - param_field( - bijector=bijector, trainable=trainable, metadata={"bijector": bijector} - ) - - with pytest.raises(ValueError): - param_field( - bijector=bijector, - trainable=trainable, - metadata={"bijector": Softplus, "trainable": trainable}, - ) - - with pytest.raises(ValueError): - param_field( - bijector=bijector, trainable=trainable, metadata={"pytree_node": True} - ) - - with pytest.raises(ValueError): - param_field( - bijector=bijector, trainable=trainable, metadata={"pytree_node": False} - ) diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 94ed622fe..c50caae37 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -1,273 +1,41 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== +import dataclasses -import typing as tp - -import distrax as dx -import jax.numpy as jnp -import jax.random as jr import pytest -from jax.config import config - -from gpjax.gps import Prior -from gpjax.kernels import RBF -from gpjax.likelihoods import Bernoulli, Gaussian -from gpjax.parameters import ( - build_bijectors, - build_trainables, - constrain, - copy_dict_structure, - evaluate_priors, - initialise, - log_density, - prior_checks, - recursive_complete, - recursive_items, - structure_priors, - unconstrain, -) - -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) - -######################### -# Test base functionality -######################### -@pytest.mark.parametrize("lik", [Gaussian]) -def test_initialise(lik): - key = jr.PRNGKey(123) - posterior = Prior(kernel=RBF()) * lik(num_datapoints=10) - params, _, _ = initialise(posterior, key).unpack() - assert list(sorted(params.keys())) == [ - "kernel", - "likelihood", - "mean_function", - ] - - -def test_non_conjugate_initialise(): - posterior = Prior(kernel=RBF()) * Bernoulli(num_datapoints=10) - params, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() - assert list(sorted(params.keys())) == [ - "kernel", - "latent", - "likelihood", - "mean_function", - ] - - -######################### -# Test priors -######################### -@pytest.mark.parametrize("x", [-1.0, 0.0, 1.0]) -def test_lpd(x): - val = jnp.array(x) - dist = dx.Normal(loc=0.0, scale=1.0) - lpd = log_density(val, dist) - assert lpd is not None - assert log_density(val, None) == 0.0 - - -@pytest.mark.parametrize("lik", [Gaussian, Bernoulli]) -def test_prior_template(lik): - posterior = Prior(kernel=RBF()) * lik(num_datapoints=10) - params, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() - prior_container = copy_dict_structure(params) - for ( - k, - v1, - v2, - ) in recursive_items(params, prior_container): - assert v2 == None - - -@pytest.mark.parametrize("lik", [Gaussian, Bernoulli]) -def test_recursive_complete(lik): - posterior = Prior(kernel=RBF()) * lik(num_datapoints=10) - params, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() - priors = {"kernel": {}} - priors["kernel"]["lengthscale"] = dx.Laplace(loc=0.0, scale=1.0) - container = copy_dict_structure(params) - complete_priors = recursive_complete(container, priors) - for ( - k, - v1, - v2, - ) in recursive_items(params, complete_priors): - if k == "lengthscale": - assert isinstance(v2, dx.Laplace) - else: - assert v2 == None - - -def test_prior_evaluation(): - """ - Test the regular setup that every parameter has a corresponding prior distribution attached to its unconstrained - value. - """ - params = { - "kernel": { - "lengthscale": jnp.array([1.0]), - "variance": jnp.array([1.0]), - }, - "likelihood": {"obs_noise": jnp.array([1.0])}, - } - priors = { - "kernel": { - "lengthscale": dx.Gamma(1.0, 1.0), - "variance": dx.Gamma(2.0, 2.0), - }, - "likelihood": {"obs_noise": dx.Gamma(3.0, 3.0)}, - } - lpd = evaluate_priors(params, priors) - assert pytest.approx(lpd) == -2.0110168 - - -def test_none_prior(): - """ - Test that multiple dispatch is working in the case of no priors. - """ - params = { - "kernel": { - "lengthscale": jnp.array([1.0]), - "variance": jnp.array([1.0]), - }, - "likelihood": {"obs_noise": jnp.array([1.0])}, - } - priors = copy_dict_structure(params) - lpd = evaluate_priors(params, priors) - assert lpd == 0.0 - - -def test_incomplete_priors(): - """ - Test the case where a user specifies priors for some, but not all, parameters. - """ - params = { - "kernel": { - "lengthscale": jnp.array([1.0]), - "variance": jnp.array([1.0]), - }, - "likelihood": {"obs_noise": jnp.array([1.0])}, - } - priors = { - "kernel": { - "lengthscale": dx.Gamma(1.0, 1.0), - "variance": dx.Gamma(2.0, 2.0), - }, - } - container = copy_dict_structure(params) - complete_priors = recursive_complete(container, priors) - lpd = evaluate_priors(params, complete_priors) - assert pytest.approx(lpd) == -1.6137061 - - -@pytest.mark.parametrize("num_datapoints", [1, 10]) -def test_checks(num_datapoints): - incomplete_priors = {"lengthscale": jnp.array([1.0])} - posterior = Prior(kernel=RBF()) * Bernoulli(num_datapoints=num_datapoints) - priors = prior_checks(incomplete_priors) - assert "latent" in priors.keys() - assert "variance" not in priors.keys() - assert isinstance(priors["latent"], dx.Normal) - - -def test_structure_priors(): - posterior = Prior(kernel=RBF()) * Gaussian(num_datapoints=10) - params, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() - priors = { - "kernel": { - "lengthscale": dx.Gamma(1.0, 1.0), - "variance": dx.Gamma(2.0, 2.0), - }, - } - structured_priors = structure_priors(params, priors) - - def recursive_fn(d1, d2, fn: tp.Callable[[tp.Any], tp.Any]): - for key, value in d1.items(): - if type(value) is dict: - yield from recursive_fn(value, d2[key], fn) - else: - yield fn(key, key) - - for v in recursive_fn(params, structured_priors, lambda k1, k2: k1 == k2): - assert v - - -@pytest.mark.parametrize("latent_prior", [dx.Laplace(0.0, 1.0), dx.Laplace(0.0, 1.0)]) -def test_prior_checks(latent_prior): - priors = { - "kernel": {"lengthscale": None, "variance": None}, - "mean_function": {}, - "liklelihood": {"variance": None}, - "latent": None, - } - new_priors = prior_checks(priors) - assert "latent" in new_priors.keys() - assert isinstance(new_priors["latent"], dx.Normal) - - priors = { - "kernel": {"lengthscale": None, "variance": None}, - "mean_function": {}, - "liklelihood": {"variance": None}, - } - new_priors = prior_checks(priors) - assert "latent" in new_priors.keys() - assert isinstance(new_priors["latent"], dx.Normal) - - priors = { - "kernel": {"lengthscale": None, "variance": None}, - "mean_function": {}, - "liklelihood": {"variance": None}, - "latent": latent_prior, - } - with pytest.warns(UserWarning): - new_priors = prior_checks(priors) - assert "latent" in new_priors.keys() - assert isinstance(new_priors["latent"], dx.Laplace) - - -######################### -# Test transforms -######################### -@pytest.mark.parametrize("num_datapoints", [1, 10]) -@pytest.mark.parametrize("likelihood", [Gaussian, Bernoulli]) -def test_output(num_datapoints, likelihood): - posterior = Prior(kernel=RBF()) * likelihood(num_datapoints=num_datapoints) - params, _, bijectors = initialise(posterior, jr.PRNGKey(123)).unpack() - - assert isinstance(bijectors, dict) - for k, v1, v2 in recursive_items(bijectors, bijectors): - assert isinstance(v1.forward, tp.Callable) - assert isinstance(v2.inverse, tp.Callable) - - unconstrained_params = unconstrain(params, bijectors) - assert ( - unconstrained_params["kernel"]["lengthscale"] != params["kernel"]["lengthscale"] - ) - backconstrained_params = constrain(unconstrained_params, bijectors) - for k, v1, v2 in recursive_items(params, unconstrained_params): - assert v1.dtype == v2.dtype - - for k, v1, v2 in recursive_items(params, backconstrained_params): - assert all(v1 == v2) - - augmented_params = params - augmented_params["test_param"] = jnp.array([1.0]) - a_bijectors = build_bijectors(augmented_params) - assert "test_param" in list(a_bijectors.keys()) - assert a_bijectors["test_param"].forward(jnp.array([1.0])) == 1.0 - assert a_bijectors["test_param"].inverse(jnp.array([1.0])) == 1.0 +from gpjax.parameters import Identity, Softplus, param_field + + +@pytest.mark.parametrize("bijector", [Identity, Softplus]) +@pytest.mark.parametrize("trainable", [True, False]) +def test_param(bijector, trainable): + param_field_ = param_field(bijector=bijector, trainable=trainable) + assert isinstance(param_field_, dataclasses.Field) + assert param_field_.metadata["bijector"] == bijector + assert param_field_.metadata["trainable"] == trainable + + with pytest.raises(ValueError): + param_field( + bijector=bijector, trainable=trainable, metadata={"trainable": trainable} + ) + + with pytest.raises(ValueError): + param_field( + bijector=bijector, trainable=trainable, metadata={"bijector": bijector} + ) + + with pytest.raises(ValueError): + param_field( + bijector=bijector, + trainable=trainable, + metadata={"bijector": Softplus, "trainable": trainable}, + ) + + with pytest.raises(ValueError): + param_field( + bijector=bijector, trainable=trainable, metadata={"pytree_node": True} + ) + + with pytest.raises(ValueError): + param_field( + bijector=bijector, trainable=trainable, metadata={"pytree_node": False} + ) diff --git a/tests/test_params.py b/tests/test_params.py new file mode 100644 index 000000000..f9867e4d7 --- /dev/null +++ b/tests/test_params.py @@ -0,0 +1,273 @@ +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import typing as tp + +import distrax as dx +import jax.numpy as jnp +import jax.random as jr +import pytest +from jax.config import config + +from gpjax.gps import Prior +from gpjax.kernels import RBF +from gpjax.likelihoods import Bernoulli, Gaussian +from gpjax.params import ( + build_bijectors, + build_trainables, + constrain, + copy_dict_structure, + evaluate_priors, + initialise, + log_density, + prior_checks, + recursive_complete, + recursive_items, + structure_priors, + unconstrain, +) + +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) + +######################### +# Test base functionality +######################### +@pytest.mark.parametrize("lik", [Gaussian]) +def test_initialise(lik): + key = jr.PRNGKey(123) + posterior = Prior(kernel=RBF()) * lik(num_datapoints=10) + params, _, _ = initialise(posterior, key).unpack() + assert list(sorted(params.keys())) == [ + "kernel", + "likelihood", + "mean_function", + ] + + +def test_non_conjugate_initialise(): + posterior = Prior(kernel=RBF()) * Bernoulli(num_datapoints=10) + params, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() + assert list(sorted(params.keys())) == [ + "kernel", + "latent", + "likelihood", + "mean_function", + ] + + +######################### +# Test priors +######################### +@pytest.mark.parametrize("x", [-1.0, 0.0, 1.0]) +def test_lpd(x): + val = jnp.array(x) + dist = dx.Normal(loc=0.0, scale=1.0) + lpd = log_density(val, dist) + assert lpd is not None + assert log_density(val, None) == 0.0 + + +@pytest.mark.parametrize("lik", [Gaussian, Bernoulli]) +def test_prior_template(lik): + posterior = Prior(kernel=RBF()) * lik(num_datapoints=10) + params, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() + prior_container = copy_dict_structure(params) + for ( + k, + v1, + v2, + ) in recursive_items(params, prior_container): + assert v2 == None + + +@pytest.mark.parametrize("lik", [Gaussian, Bernoulli]) +def test_recursive_complete(lik): + posterior = Prior(kernel=RBF()) * lik(num_datapoints=10) + params, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() + priors = {"kernel": {}} + priors["kernel"]["lengthscale"] = dx.Laplace(loc=0.0, scale=1.0) + container = copy_dict_structure(params) + complete_priors = recursive_complete(container, priors) + for ( + k, + v1, + v2, + ) in recursive_items(params, complete_priors): + if k == "lengthscale": + assert isinstance(v2, dx.Laplace) + else: + assert v2 == None + + +def test_prior_evaluation(): + """ + Test the regular setup that every parameter has a corresponding prior distribution attached to its unconstrained + value. + """ + params = { + "kernel": { + "lengthscale": jnp.array([1.0]), + "variance": jnp.array([1.0]), + }, + "likelihood": {"obs_noise": jnp.array([1.0])}, + } + priors = { + "kernel": { + "lengthscale": dx.Gamma(1.0, 1.0), + "variance": dx.Gamma(2.0, 2.0), + }, + "likelihood": {"obs_noise": dx.Gamma(3.0, 3.0)}, + } + lpd = evaluate_priors(params, priors) + assert pytest.approx(lpd) == -2.0110168 + + +def test_none_prior(): + """ + Test that multiple dispatch is working in the case of no priors. + """ + params = { + "kernel": { + "lengthscale": jnp.array([1.0]), + "variance": jnp.array([1.0]), + }, + "likelihood": {"obs_noise": jnp.array([1.0])}, + } + priors = copy_dict_structure(params) + lpd = evaluate_priors(params, priors) + assert lpd == 0.0 + + +def test_incomplete_priors(): + """ + Test the case where a user specifies priors for some, but not all, parameters. + """ + params = { + "kernel": { + "lengthscale": jnp.array([1.0]), + "variance": jnp.array([1.0]), + }, + "likelihood": {"obs_noise": jnp.array([1.0])}, + } + priors = { + "kernel": { + "lengthscale": dx.Gamma(1.0, 1.0), + "variance": dx.Gamma(2.0, 2.0), + }, + } + container = copy_dict_structure(params) + complete_priors = recursive_complete(container, priors) + lpd = evaluate_priors(params, complete_priors) + assert pytest.approx(lpd) == -1.6137061 + + +@pytest.mark.parametrize("num_datapoints", [1, 10]) +def test_checks(num_datapoints): + incomplete_priors = {"lengthscale": jnp.array([1.0])} + posterior = Prior(kernel=RBF()) * Bernoulli(num_datapoints=num_datapoints) + priors = prior_checks(incomplete_priors) + assert "latent" in priors.keys() + assert "variance" not in priors.keys() + assert isinstance(priors["latent"], dx.Normal) + + +def test_structure_priors(): + posterior = Prior(kernel=RBF()) * Gaussian(num_datapoints=10) + params, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() + priors = { + "kernel": { + "lengthscale": dx.Gamma(1.0, 1.0), + "variance": dx.Gamma(2.0, 2.0), + }, + } + structured_priors = structure_priors(params, priors) + + def recursive_fn(d1, d2, fn: tp.Callable[[tp.Any], tp.Any]): + for key, value in d1.items(): + if type(value) is dict: + yield from recursive_fn(value, d2[key], fn) + else: + yield fn(key, key) + + for v in recursive_fn(params, structured_priors, lambda k1, k2: k1 == k2): + assert v + + +@pytest.mark.parametrize("latent_prior", [dx.Laplace(0.0, 1.0), dx.Laplace(0.0, 1.0)]) +def test_prior_checks(latent_prior): + priors = { + "kernel": {"lengthscale": None, "variance": None}, + "mean_function": {}, + "liklelihood": {"variance": None}, + "latent": None, + } + new_priors = prior_checks(priors) + assert "latent" in new_priors.keys() + assert isinstance(new_priors["latent"], dx.Normal) + + priors = { + "kernel": {"lengthscale": None, "variance": None}, + "mean_function": {}, + "liklelihood": {"variance": None}, + } + new_priors = prior_checks(priors) + assert "latent" in new_priors.keys() + assert isinstance(new_priors["latent"], dx.Normal) + + priors = { + "kernel": {"lengthscale": None, "variance": None}, + "mean_function": {}, + "liklelihood": {"variance": None}, + "latent": latent_prior, + } + with pytest.warns(UserWarning): + new_priors = prior_checks(priors) + assert "latent" in new_priors.keys() + assert isinstance(new_priors["latent"], dx.Laplace) + + +######################### +# Test transforms +######################### +@pytest.mark.parametrize("num_datapoints", [1, 10]) +@pytest.mark.parametrize("likelihood", [Gaussian, Bernoulli]) +def test_output(num_datapoints, likelihood): + posterior = Prior(kernel=RBF()) * likelihood(num_datapoints=num_datapoints) + params, _, bijectors = initialise(posterior, jr.PRNGKey(123)).unpack() + + assert isinstance(bijectors, dict) + for k, v1, v2 in recursive_items(bijectors, bijectors): + assert isinstance(v1.forward, tp.Callable) + assert isinstance(v2.inverse, tp.Callable) + + unconstrained_params = unconstrain(params, bijectors) + assert ( + unconstrained_params["kernel"]["lengthscale"] != params["kernel"]["lengthscale"] + ) + backconstrained_params = constrain(unconstrained_params, bijectors) + for k, v1, v2 in recursive_items(params, unconstrained_params): + assert v1.dtype == v2.dtype + + for k, v1, v2 in recursive_items(params, backconstrained_params): + assert all(v1 == v2) + + augmented_params = params + augmented_params["test_param"] = jnp.array([1.0]) + a_bijectors = build_bijectors(augmented_params) + + assert "test_param" in list(a_bijectors.keys()) + assert a_bijectors["test_param"].forward(jnp.array([1.0])) == 1.0 + assert a_bijectors["test_param"].inverse(jnp.array([1.0])) == 1.0 diff --git a/tests/test_types.py b/tests/test_types.py index f08f66bf0..8b7243c3b 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -15,6 +15,7 @@ import jax.numpy as jnp import pytest + from gpjax.types import Dataset, verify_dataset diff --git a/tests/test_variational_families.py b/tests/test_variational_families.py index c3c507abe..47ea0cac8 100644 --- a/tests/test_variational_families.py +++ b/tests/test_variational_families.py @@ -15,13 +15,13 @@ from typing import Callable, Dict, Tuple -import jax import distrax as dx +import jax import jax.numpy as jnp import jax.random as jr import pytest from jax.config import config -from jaxtyping import Float, Array +from jaxtyping import Array, Float import gpjax as gpx from gpjax.variational_families import ( From cf6443e1a33e386427da279dc422389fddf4874c Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Mon, 27 Mar 2023 20:55:22 +0100 Subject: [PATCH 05/44] Refactor tests --- tests/{ => test_params}/test_bijectors.py | 0 tests/{ => test_params}/test_module.py | 0 tests/{ => test_params}/test_parameters.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename tests/{ => test_params}/test_bijectors.py (100%) rename tests/{ => test_params}/test_module.py (100%) rename tests/{ => test_params}/test_parameters.py (100%) diff --git a/tests/test_bijectors.py b/tests/test_params/test_bijectors.py similarity index 100% rename from tests/test_bijectors.py rename to tests/test_params/test_bijectors.py diff --git a/tests/test_module.py b/tests/test_params/test_module.py similarity index 100% rename from tests/test_module.py rename to tests/test_params/test_module.py diff --git a/tests/test_parameters.py b/tests/test_params/test_parameters.py similarity index 100% rename from tests/test_parameters.py rename to tests/test_params/test_parameters.py From 623be429058c99aa4606b1da1ee4a5c831f16e3c Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Mon, 27 Mar 2023 21:02:41 +0100 Subject: [PATCH 06/44] Reformat --- gpjax/gps.py | 20 +++++------------ gpjax/kernels/base.py | 22 ++++++++++++++----- gpjax/kernels/computations/base.py | 6 ++++- gpjax/kernels/computations/basis_functions.py | 7 ++++-- .../kernels/computations/constant_diagonal.py | 8 +++++-- gpjax/kernels/computations/dense.py | 1 + gpjax/kernels/computations/diagonal.py | 5 +++-- gpjax/kernels/computations/eigen.py | 20 +++++++++++++---- gpjax/kernels/nonstationary/linear.py | 4 +++- gpjax/kernels/nonstationary/polynomial.py | 5 +++-- gpjax/kernels/stationary/rbf.py | 2 ++ gpjax/linops/__init__.py | 2 +- .../constant_diagonal_linear_operator.py | 2 +- gpjax/parameters/__init__.py | 2 +- gpjax/variational_inference.py | 2 +- tests/test_gaussian_distribution.py | 2 +- tests/test_quadrature.py | 2 +- 17 files changed, 71 insertions(+), 41 deletions(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index 15cb826aa..ffdc7b3ea 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -228,9 +228,7 @@ def predict( mean_function = self.mean_function kernel = self.kernel - def predict_fn( - test_inputs: Float[Array, "N D"] - ) -> GaussianDistribution: + def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: # Unpack test inputs t = test_inputs @@ -463,9 +461,7 @@ def predict(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: covariance = Ktt - jnp.matmul(Kxt.T, Sigma_inv_Kxt) covariance += identity(n_test) * jitter - return GaussianDistribution( - jnp.atleast_1d(mean.squeeze()), covariance - ) + return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance) return predict @@ -577,9 +573,7 @@ def mll( ) return constant * ( - marginal_likelihood.log_prob( - jnp.atleast_1d(y.squeeze()) - ).squeeze() + marginal_likelihood.log_prob(jnp.atleast_1d(y.squeeze())).squeeze() ) return mll @@ -627,9 +621,7 @@ def init_params(self, key: KeyArray) -> Dict: self.prior.init_params(key), {"likelihood": self.likelihood.init_params(key)}, ) - parameters["latent"] = jnp.zeros( - shape=(self.likelihood.num_datapoints, 1) - ) + parameters["latent"] = jnp.zeros(shape=(self.likelihood.num_datapoints, 1)) return parameters def predict( @@ -701,9 +693,7 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: covariance = Ktt - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt) covariance += identity(n_test) * jitter - return GaussianDistribution( - jnp.atleast_1d(mean.squeeze()), covariance - ) + return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance) return predict_fn diff --git a/gpjax/kernels/base.py b/gpjax/kernels/base.py index f4db57879..88bc53b50 100644 --- a/gpjax/kernels/base.py +++ b/gpjax/kernels/base.py @@ -27,9 +27,11 @@ from .computations import AbstractKernelComputation, DenseKernelComputation + @dataclass class AbstractKernel(Module): """Base kernel class.""" + compute_engine: AbstractKernelComputation = static_field(DenseKernelComputation) active_dims: List[int] = static_field(None) @@ -70,7 +72,9 @@ def __call__( """ raise NotImplementedError - def __add__(self, other: Union[AbstractKernel, Float[Array, "1"]]) -> AbstractKernel: + def __add__( + self, other: Union[AbstractKernel, Float[Array, "1"]] + ) -> AbstractKernel: """Add two kernels together. Args: other (AbstractKernel): The kernel to be added to the current kernel. @@ -84,7 +88,9 @@ def __add__(self, other: Union[AbstractKernel, Float[Array, "1"]]) -> AbstractKe return SumKernel([self, Constant(other)]) - def __radd__(self, other: Union[AbstractKernel, Float[Array, "1"]]) -> AbstractKernel: + def __radd__( + self, other: Union[AbstractKernel, Float[Array, "1"]] + ) -> AbstractKernel: """Add two kernels together. Args: other (AbstractKernel): The kernel to be added to the current kernel. @@ -94,7 +100,9 @@ def __radd__(self, other: Union[AbstractKernel, Float[Array, "1"]]) -> AbstractK """ return self.__add__(other) - def __mul__(self, other: Union[AbstractKernel, Float[Array, "1"]]) -> AbstractKernel: + def __mul__( + self, other: Union[AbstractKernel, Float[Array, "1"]] + ) -> AbstractKernel: """Multiply two kernels together. Args: @@ -115,6 +123,7 @@ class Constant(AbstractKernel): A constant mean function. This function returns a repeated scalar value for all inputs. The scalar value itself can be treated as a model hyperparameter and learned during training. """ + constant: Float[Array, "1"] = param_field(jnp.array(0.0)) def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: @@ -133,16 +142,17 @@ def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, " @dataclass class CombinationKernel(AbstractKernel): """A base class for products or sums of MeanFunctions.""" + kernels: List[AbstractKernel] = None operator: Callable = static_field(None) def __post_init__(self): - #Add kernels to a list, flattening out instances of this class therein, as in GPFlow kernels. + # Add kernels to a list, flattening out instances of this class therein, as in GPFlow kernels. kernels_list: List[AbstractKernel] = [] for kernel in self.kernels: if not isinstance(kernel, AbstractKernel): - raise TypeError("can only combine Kernel instances") # pragma: no cover + raise TypeError("can only combine Kernel instances") # pragma: no cover if isinstance(kernel, self.__class__): kernels_list.extend(kernel.kernels) @@ -169,4 +179,4 @@ def __call__( SumKernel = partial(CombinationKernel, operator=jnp.sum) -ProductKernel = partial(CombinationKernel, operator=jnp.sum) \ No newline at end of file +ProductKernel = partial(CombinationKernel, operator=jnp.sum) diff --git a/gpjax/kernels/computations/base.py b/gpjax/kernels/computations/base.py index 718a4c65c..af45826c2 100644 --- a/gpjax/kernels/computations/base.py +++ b/gpjax/kernels/computations/base.py @@ -26,9 +26,11 @@ Kernel = Any + @dataclass class AbstractKernelComputation: """Abstract class for kernel computations.""" + kernel: Kernel def gram( @@ -48,7 +50,9 @@ def gram( return DenseLinearOperator(Kxx) @abc.abstractmethod - def cross_covariance(self, x: Float[Array, "N D"], y: Float[Array, "M D"]) -> Float[Array, "N M"]: + def cross_covariance( + self, x: Float[Array, "N D"], y: Float[Array, "M D"] + ) -> Float[Array, "N M"]: """For a given kernel, compute the NxM gram matrix on an a pair of input matrices with shape NxD and MxD. diff --git a/gpjax/kernels/computations/basis_functions.py b/gpjax/kernels/computations/basis_functions.py index e0a057fed..97e12fe1c 100644 --- a/gpjax/kernels/computations/basis_functions.py +++ b/gpjax/kernels/computations/basis_functions.py @@ -9,9 +9,12 @@ @dataclass class BasisFunctionComputation(AbstractKernelComputation): """Compute engine class for finite basis function approximations to a kernel.""" + num_basis_fns = None - def cross_covariance(self, x: Float[Array, "N D"], y: Float[Array, "M D"]) -> Float[Array, "N M"]: + def cross_covariance( + self, x: Float[Array, "N D"], y: Float[Array, "M D"] + ) -> Float[Array, "N M"]: """For a pair of inputs, compute the cross covariance matrix between the inputs. Args: params (Dict): A dictionary of parameters for which the cross-covariance matrix should be constructed with. @@ -55,4 +58,4 @@ def compute_features(self, x: Float[Array, "N D"]) -> Float[Array, "N L"]: scaling_factor = self.kernel.lengthscale z = jnp.matmul(x, (frequencies / scaling_factor).T) z = jnp.concatenate([jnp.cos(z), jnp.sin(z)], axis=-1) - return z \ No newline at end of file + return z diff --git a/gpjax/kernels/computations/constant_diagonal.py b/gpjax/kernels/computations/constant_diagonal.py index 7d9fb2949..58b8087a1 100644 --- a/gpjax/kernels/computations/constant_diagonal.py +++ b/gpjax/kernels/computations/constant_diagonal.py @@ -32,7 +32,9 @@ def gram(self, x: Float[Array, "N D"]) -> ConstantDiagonalLinearOperator: """ value = self.kernel(x[0], x[0]) - return ConstantDiagonalLinearOperator(value=jnp.atleast_1d(value), size=x.shape[0]) + return ConstantDiagonalLinearOperator( + value=jnp.atleast_1d(value), size=x.shape[0] + ) def diagonal(self, inputs: Float[Array, "N D"]) -> DiagonalLinearOperator: """For a given kernel, compute the elementwise diagonal of the @@ -52,7 +54,9 @@ def diagonal(self, inputs: Float[Array, "N D"]) -> DiagonalLinearOperator: return DiagonalLinearOperator(diag=diag) - def cross_covariance(self, x: Float[Array, "N D"], y: Float[Array, "M D"]) -> Float[Array, "N M"]: + def cross_covariance( + self, x: Float[Array, "N D"], y: Float[Array, "M D"] + ) -> Float[Array, "N M"]: """For a given kernel, compute the NxM covariance matrix on a pair of input matrices of shape NxD and MxD. diff --git a/gpjax/kernels/computations/dense.py b/gpjax/kernels/computations/dense.py index c64981feb..d55bb034b 100644 --- a/gpjax/kernels/computations/dense.py +++ b/gpjax/kernels/computations/dense.py @@ -17,6 +17,7 @@ from jaxtyping import Array, Float from .base import AbstractKernelComputation + class DenseKernelComputation(AbstractKernelComputation): """Dense kernel computation class. Operations with the kernel assume a dense gram matrix structure. diff --git a/gpjax/kernels/computations/diagonal.py b/gpjax/kernels/computations/diagonal.py index fc14a257e..999bac468 100644 --- a/gpjax/kernels/computations/diagonal.py +++ b/gpjax/kernels/computations/diagonal.py @@ -22,7 +22,6 @@ class DiagonalKernelComputation(AbstractKernelComputation): - def gram(self, x: Float[Array, "N D"]) -> DiagonalLinearOperator: """For a kernel with diagonal structure, compute the NxN gram matrix on an input matrix of shape NxD. @@ -35,7 +34,9 @@ def gram(self, x: Float[Array, "N D"]) -> DiagonalLinearOperator: """ return DiagonalLinearOperator(diag=vmap(lambda x: self.kernel(x, x))(x)) - def cross_covariance(self, x: Float[Array, "N D"], y: Float[Array, "M D"]) -> Float[Array, "N M"]: + def cross_covariance( + self, x: Float[Array, "N D"], y: Float[Array, "M D"] + ) -> Float[Array, "N M"]: """For a given kernel, compute the NxM covariance matrix on a pair of input matrices of shape NxD and MxD. diff --git a/gpjax/kernels/computations/eigen.py b/gpjax/kernels/computations/eigen.py index e80d2204c..ab3271336 100644 --- a/gpjax/kernels/computations/eigen.py +++ b/gpjax/kernels/computations/eigen.py @@ -20,17 +20,29 @@ from .base import AbstractKernelComputation from dataclasses import dataclass + @dataclass class EigenKernelComputation(AbstractKernelComputation): eigenvalues: Float[Array, "N"] = None eigenvectors: Float[Array, "N N"] = None num_verticies: int = None - def cross_covariance(self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"]) -> Float[Array, "N M"]: + def cross_covariance( + self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] + ) -> Float[Array, "N M"]: + # Extract the graph Laplacian's eigenvalues evals = self.eigenvalues - S = jnp.power(evals + 2 * self.kernel.smoothness / self.kernel.lengthscale / self.kernel.lengthscale, - - self.kernel.smoothness, + # Transform the eigenvalues of the graph Laplacian according to the + # RBF kernel's SPDE form. + S = jnp.power( + evals + + 2 + * self.kernel.smoothness + / self.kernel.lengthscale + / self.kernel.lengthscale, + -self.kernel.smoothness, ) S = jnp.multiply(S, self.num_vertex / jnp.sum(S)) + # Scale the transform eigenvalues by the kernel variance S = jnp.multiply(S, params["variance"]) - return self.kernel(x, y, S=S) \ No newline at end of file + return self.kernel(x, y, S=S) diff --git a/gpjax/kernels/nonstationary/linear.py b/gpjax/kernels/nonstationary/linear.py index 7749e961e..acb5dfb8f 100644 --- a/gpjax/kernels/nonstationary/linear.py +++ b/gpjax/kernels/nonstationary/linear.py @@ -22,9 +22,11 @@ from jaxtyping import Array, Float from ...parameters import param_field, Softplus + @dataclass class Linear(AbstractKernel): """The linear kernel.""" + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) def __call__( @@ -47,4 +49,4 @@ def __call__( x = self.slice_input(x) y = self.slice_input(y) K = self.variance * jnp.matmul(x.T, y) - return K.squeeze() \ No newline at end of file + return K.squeeze() diff --git a/gpjax/kernels/nonstationary/polynomial.py b/gpjax/kernels/nonstationary/polynomial.py index a87cacaa4..e8a3d12d5 100644 --- a/gpjax/kernels/nonstationary/polynomial.py +++ b/gpjax/kernels/nonstationary/polynomial.py @@ -20,11 +20,12 @@ from simple_pytree import static_field from ...parameters import param_field, Softplus + @dataclass class Polynomial(AbstractKernel): """The Polynomial kernel with variable degree.""" - degree: int = static_field(2) + degree: int = static_field(2) shift: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) @@ -44,4 +45,4 @@ def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, " x = self.slice_input(x).squeeze() y = self.slice_input(y).squeeze() K = jnp.power(self.shift + jnp.dot(x * self.variance, y), self.degree) - return K.squeeze() \ No newline at end of file + return K.squeeze() diff --git a/gpjax/kernels/stationary/rbf.py b/gpjax/kernels/stationary/rbf.py index 2ab9c1d08..235556ee6 100644 --- a/gpjax/kernels/stationary/rbf.py +++ b/gpjax/kernels/stationary/rbf.py @@ -23,9 +23,11 @@ from dataclasses import dataclass from ...parameters import param_field, Softplus + @dataclass class RBF(AbstractKernel): """The Radial Basis Function (RBF) kernel.""" + lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus) variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) diff --git a/gpjax/linops/__init__.py b/gpjax/linops/__init__.py index b827066fa..d41888f99 100644 --- a/gpjax/linops/__init__.py +++ b/gpjax/linops/__init__.py @@ -39,4 +39,4 @@ "UpperTriangularLinearOperator", "identity", "to_dense", -] \ No newline at end of file +] diff --git a/gpjax/linops/constant_diagonal_linear_operator.py b/gpjax/linops/constant_diagonal_linear_operator.py index 065304b8c..d3a53d19f 100644 --- a/gpjax/linops/constant_diagonal_linear_operator.py +++ b/gpjax/linops/constant_diagonal_linear_operator.py @@ -182,7 +182,7 @@ def from_root( Returns: ConstantDiagonalLinearOperator: Covariance operator. """ - return ConstantDiagonalLinearOperator(value=root.value**2, size=root.size) + return ConstantDiagonalLinearOperator(value=root.value ** 2, size=root.size) __all__ = [ diff --git a/gpjax/parameters/__init__.py b/gpjax/parameters/__init__.py index 19c6675c9..40aff12bb 100644 --- a/gpjax/parameters/__init__.py +++ b/gpjax/parameters/__init__.py @@ -2,4 +2,4 @@ from .module import Module from .param import param_field -__all__ = ['Identity', 'Module', 'Softplus', 'param_field'] \ No newline at end of file +__all__ = ["Identity", "Module", "Softplus", "param_field"] diff --git a/gpjax/variational_inference.py b/gpjax/variational_inference.py index 3942079ae..8e2166e99 100644 --- a/gpjax/variational_inference.py +++ b/gpjax/variational_inference.py @@ -262,7 +262,7 @@ def elbo_fn(params: Dict) -> Float[Array, "1"]: ) # (y - μx)ᵀ (Iσ² + Q)⁻¹ (y - μx) - quad = (jnp.sum(diff**2) - jnp.sum(L_inv_A_diff**2)) / noise + quad = (jnp.sum(diff ** 2) - jnp.sum(L_inv_A_diff ** 2)) / noise # 2 * log N(y; μx, Iσ² + Q) two_log_prob = -n * jnp.log(2.0 * jnp.pi * noise) - log_det_B - quad diff --git a/tests/test_gaussian_distribution.py b/tests/test_gaussian_distribution.py index f3f522fe4..faa571995 100644 --- a/tests/test_gaussian_distribution.py +++ b/tests/test_gaussian_distribution.py @@ -67,7 +67,7 @@ def test_diag_linear_operator(n: int) -> None: mean = jr.uniform(key_mean, shape=(n,)) diag = jr.uniform(key_diag, shape=(n,)) - dist_diag = GaussianDistribution(loc=mean, scale=DiagonalLinearOperator(diag**2)) + dist_diag = GaussianDistribution(loc=mean, scale=DiagonalLinearOperator(diag ** 2)) distrax_dist = MultivariateNormalDiag(loc=mean, scale_diag=diag) assert approx_equal(dist_diag.mean(), distrax_dist.mean()) diff --git a/tests/test_quadrature.py b/tests/test_quadrature.py index dbe1a03e1..75c91d4af 100644 --- a/tests/test_quadrature.py +++ b/tests/test_quadrature.py @@ -28,7 +28,7 @@ @pytest.mark.parametrize("jit", [True, False]) def test_quadrature(jit): def test(): - fun = lambda x: x**2 + fun = lambda x: x ** 2 mean = jnp.array([[2.0]]) var = jnp.array([[1.0]]) fn_val = gauss_hermite_quadrature(fun, mean, var) From 309d0c4826c7d62a292299eb1e0a81a4841a0712 Mon Sep 17 00:00:00 2001 From: frazane Date: Tue, 28 Mar 2023 17:07:45 +0200 Subject: [PATCH 07/44] Module methods' return type is Self (the subclass) --- gpjax/parameters/module.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/gpjax/parameters/module.py b/gpjax/parameters/module.py index cd0fdd9a3..07f09af40 100644 --- a/gpjax/parameters/module.py +++ b/gpjax/parameters/module.py @@ -4,7 +4,8 @@ import dataclasses from copy import copy, deepcopy -from typing import Any, Callable, Dict, Iterable, Tuple +from typing import Any, Callable, Dict, Iterable, Tuple, List +from typing_extensions import Self import jax import jax.tree_util as jtu @@ -29,7 +30,7 @@ def __init_subclass__(cls, mutable: bool = False): ): cls._pytree__meta[field] = {**value.metadata} - def replace(self, **kwargs: Any) -> Module: + def replace(self, **kwargs: Any) -> Self: """ Replace the values of the fields of the object. @@ -48,7 +49,7 @@ def replace(self, **kwargs: Any) -> Module: pytree.__dict__.update(kwargs) return pytree - def replace_meta(self, **kwargs: Any) -> Module: + def replace_meta(self, **kwargs: Any) -> Self: """ Replace the metadata of the fields. @@ -67,7 +68,7 @@ def replace_meta(self, **kwargs: Any) -> Module: pytree.__dict__.update(_pytree__meta={**pytree._pytree__meta, **kwargs}) return pytree - def update_meta(self, **kwargs: Any) -> Module: + def update_meta(self, **kwargs: Any) -> Self: """ Update the metadata of the fields. The metadata must already exist. @@ -92,15 +93,15 @@ def update_meta(self, **kwargs: Any) -> Module: pytree.__dict__.update(_pytree__meta=new) return pytree - def replace_trainable(self: Module, **kwargs: Dict[str, bool]) -> Module: + def replace_trainable(self: Module, **kwargs: Dict[str, bool]) -> Self: """Replace the trainability status of local nodes of the Module.""" return self.update_meta(**{k: {"trainable": v} for k, v in kwargs.items()}) - def replace_bijector(self: Module, **kwargs: Dict[str, Bijector]) -> Module: + def replace_bijector(self: Module, **kwargs: Dict[str, Bijector]) -> Self: """Replace the bijectors of local nodes of the Module.""" return self.update_meta(**{k: {"bijector": v} for k, v in kwargs.items()}) - def constrain(self) -> Module: + def constrain(self) -> Self: """Transform model parameters to the constrained space according to their defined bijectors. Returns: @@ -113,7 +114,7 @@ def _apply_constrain(meta_leaf): return meta_map(_apply_constrain, self) - def unconstrain(self) -> Module: + def unconstrain(self) -> Self: """Transform model parameters to the unconstrained space according to their defined bijectors. Returns: @@ -126,7 +127,7 @@ def _apply_unconstrain(meta_leaf): return meta_map(_apply_unconstrain, self) - def stop_gradient(self) -> Module: + def stop_gradient(self) -> Self: """Stop gradients flowing through the Module. Returns: From 6859304473b5406a95881815ecc861f1e3cc5f8b Mon Sep 17 00:00:00 2001 From: frazane Date: Tue, 28 Mar 2023 17:31:36 +0200 Subject: [PATCH 08/44] refactor matern12 --- gpjax/kernels/stationary/matern12.py | 40 +++++++++++----------------- 1 file changed, 15 insertions(+), 25 deletions(-) diff --git a/gpjax/kernels/stationary/matern12.py b/gpjax/kernels/stationary/matern12.py index 58bbb5ce1..a93e1f16a 100644 --- a/gpjax/kernels/stationary/matern12.py +++ b/gpjax/kernels/stationary/matern12.py @@ -13,36 +13,29 @@ # limitations under the License. # ============================================================================== +from dataclasses import dataclass from typing import Dict, List, Optional import jax.numpy as jnp from jax.random import KeyArray from jaxtyping import Array, Float +import distrax as dx +from ...parameters import Softplus, param_field from ..base import AbstractKernel -from ..computations import ( - DenseKernelComputation, -) -from .utils import euclidean_distance, build_student_t_distribution +from ..computations import DenseKernelComputation +from .utils import build_student_t_distribution, euclidean_distance +@dataclass class Matern12(AbstractKernel): """The Matérn kernel with smoothness parameter fixed at 0.5.""" - def __init__( - self, - active_dims: Optional[List[int]] = None, - name: Optional[str] = "Matérn 1/2 kernel", - ) -> None: - spectral_density = build_student_t_distribution(nu=1) - super().__init__(DenseKernelComputation, active_dims, spectral_density, name) - self._stationary = True + lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus) + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) def __call__( - self, - params: Dict, - x: Float[Array, "1 D"], - y: Float[Array, "1 D"], + self, x: Float[Array, "1 D"], y: Float[Array, "1 D"] ) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` @@ -51,19 +44,16 @@ def __call__( k(x, y) = \\sigma^2 \\exp \\Bigg( -\\frac{\\lvert x-y \\rvert}{2\\ell^2} \\Bigg) Args: - params (Dict): Parameter set for which the kernel should be evaluated on. x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. y (Float[Array, "1 D"]): The right hand argument of the kernel function's call Returns: Float[Array, "1"]: The value of :math:`k(x, y)` """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] - K = params["variance"] * jnp.exp(-euclidean_distance(x, y)) + x = self.slice_input(x) / self.lengthscale + y = self.slice_input(y) / self.lengthscale + K = self.variance * jnp.exp(-euclidean_distance(x, y)) return K.squeeze() - def init_params(self, key: KeyArray) -> Dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - } + @property + def spectral_density(self) -> dx.Distribution: + return build_student_t_distribution(nu=1) From ace9e5f5be993a0dcb5dfa9953a71858d9933d00 Mon Sep 17 00:00:00 2001 From: frazane Date: Tue, 28 Mar 2023 17:58:34 +0200 Subject: [PATCH 09/44] stationary kernels refactoring --- gpjax/kernels/stationary/matern32.py | 39 ++++++----------- gpjax/kernels/stationary/matern52.py | 35 +++++++--------- gpjax/kernels/stationary/periodic.py | 37 +++++++--------- .../kernels/stationary/powered_exponential.py | 37 ++++++---------- .../kernels/stationary/rational_quadratic.py | 42 +++++++------------ gpjax/kernels/stationary/white.py | 36 ++++------------ 6 files changed, 77 insertions(+), 149 deletions(-) diff --git a/gpjax/kernels/stationary/matern32.py b/gpjax/kernels/stationary/matern32.py index a714261bc..42454bfaf 100644 --- a/gpjax/kernels/stationary/matern32.py +++ b/gpjax/kernels/stationary/matern32.py @@ -13,34 +13,28 @@ # limitations under the License. # ============================================================================== +from dataclasses import dataclass from typing import Dict, List, Optional import jax.numpy as jnp from jax.random import KeyArray from jaxtyping import Array, Float +from ...parameters import Softplus, param_field from ..base import AbstractKernel -from ..computations import ( - DenseKernelComputation, -) -from .utils import euclidean_distance, build_student_t_distribution +from ..computations import DenseKernelComputation +from .utils import build_student_t_distribution, euclidean_distance +@dataclass class Matern32(AbstractKernel): """The Matérn kernel with smoothness parameter fixed at 1.5.""" - def __init__( - self, - active_dims: Optional[List[int]] = None, - name: Optional[str] = "Matern 3/2", - ) -> None: - spectral_density = build_student_t_distribution(nu=3) - super().__init__(DenseKernelComputation, active_dims, spectral_density, name) - self._stationary = True + lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus) + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) def __call__( self, - params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"], ) -> Float[Array, "1"]: @@ -51,25 +45,18 @@ def __call__( k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{3}\\lvert x-y \\rvert}{\\ell^2} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{3}\\lvert x-y\\rvert}{\\ell^2} \\Bigg) Args: - params (Dict): Parameter set for which the kernel should be evaluated on. x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. Returns: Float[Array, "1"]: The value of :math:`k(x, y)`. """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] + x = self.slice_input(x) / self.lengthscale + y = self.slice_input(y) / self.lengthscale tau = euclidean_distance(x, y) - K = ( - params["variance"] - * (1.0 + jnp.sqrt(3.0) * tau) - * jnp.exp(-jnp.sqrt(3.0) * tau) - ) + K = self.variance * (1.0 + jnp.sqrt(3.0) * tau) * jnp.exp(-jnp.sqrt(3.0) * tau) return K.squeeze() - def init_params(self, key: KeyArray) -> Dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - } + @property + def spectral_density(self): + return build_student_t_distribution(nu=3) diff --git a/gpjax/kernels/stationary/matern52.py b/gpjax/kernels/stationary/matern52.py index f7d771baf..a2c8a5f16 100644 --- a/gpjax/kernels/stationary/matern52.py +++ b/gpjax/kernels/stationary/matern52.py @@ -13,32 +13,28 @@ # limitations under the License. # ============================================================================== +from dataclasses import dataclass from typing import Dict, List, Optional import jax.numpy as jnp from jax.random import KeyArray from jaxtyping import Array, Float +from ...parameters import Softplus, param_field from ..base import AbstractKernel -from ..computations import ( - DenseKernelComputation, -) -from .utils import euclidean_distance, build_student_t_distribution +from ..computations import DenseKernelComputation +from .utils import build_student_t_distribution, euclidean_distance +@dataclass class Matern52(AbstractKernel): """The Matérn kernel with smoothness parameter fixed at 2.5.""" - def __init__( - self, - active_dims: Optional[List[int]] = None, - name: Optional[str] = "Matern 5/2", - ) -> None: - spectral_density = build_student_t_distribution(nu=5) - super().__init__(DenseKernelComputation, active_dims, spectral_density, name) + lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus) + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) def __call__( - self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] + self, x: Float[Array, "1 D"], y: Float[Array, "1 D"] ) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` @@ -47,25 +43,22 @@ def __call__( k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{5}\\lvert x-y \\rvert}{\\ell^2} + \\frac{5\\lvert x - y \\rvert^2}{3\\ell^2} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{5}\\lvert x-y\\rvert}{\\ell^2} \\Bigg) Args: - params (Dict): Parameter set for which the kernel should be evaluated on. x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. Returns: Float[Array, "1"]: The value of :math:`k(x, y)`. """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] + x = self.slice_input(x) / self.lengthscale + y = self.slice_input(y) / self.lengthscale tau = euclidean_distance(x, y) K = ( - params["variance"] + self.variance * (1.0 + jnp.sqrt(5.0) * tau + 5.0 / 3.0 * jnp.square(tau)) * jnp.exp(-jnp.sqrt(5.0) * tau) ) return K.squeeze() - def init_params(self, key: KeyArray) -> Dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - } + @property + def spectral_density(self): + return build_student_t_distribution(nu=5) diff --git a/gpjax/kernels/stationary/periodic.py b/gpjax/kernels/stationary/periodic.py index 0d82e5fdb..2b6793b90 100644 --- a/gpjax/kernels/stationary/periodic.py +++ b/gpjax/kernels/stationary/periodic.py @@ -18,54 +18,45 @@ import jax import jax.numpy as jnp from jax.random import KeyArray -from jaxtyping import Array +from jaxtyping import Array, Float from ..base import AbstractKernel from ..computations import ( DenseKernelComputation, ) +from dataclasses import dataclass +from ...parameters import param_field, Softplus +@dataclass class Periodic(AbstractKernel): """The periodic kernel. Key reference is MacKay 1998 - "Introduction to Gaussian processes". """ - def __init__( - self, - active_dims: Optional[List[int]] = None, - name: Optional[str] = "Periodic", - ) -> None: - super().__init__( - DenseKernelComputation, active_dims, spectral_density=None, name=name - ) - self._stationary = True + lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus) + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) + period: Float[Array, "1"] = param_field( + jnp.array([1.0]), bijector=Softplus + ) # NOTE: is bijector needed? - def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: + def __call__(self, x: jax.Array, y: jax.Array) -> Array: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell` and variance :math:`\\sigma` + TODO: write docstring + .. math:: k(x, y) = \\sigma^2 \\exp \\Bigg( -0.5 \\sum_{i=1}^{d} \\Bigg) Args: x (jax.Array): The left hand argument of the kernel function's call. y (jax.Array): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. Returns: Array: The value of :math:`k(x, y)` """ x = self.slice_input(x) y = self.slice_input(y) - sine_squared = ( - jnp.sin(jnp.pi * (x - y) / params["period"]) / params["lengthscale"] - ) ** 2 - K = params["variance"] * jnp.exp(-0.5 * jnp.sum(sine_squared, axis=0)) + sine_squared = (jnp.sin(jnp.pi * (x - y) / self.period) / self.lengthscale) ** 2 + K = self.variance * jnp.exp(-0.5 * jnp.sum(sine_squared, axis=0)) return K.squeeze() - - def init_params(self, key: KeyArray) -> Dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - "period": jnp.array([1.0] * self.ndims), - } diff --git a/gpjax/kernels/stationary/powered_exponential.py b/gpjax/kernels/stationary/powered_exponential.py index 62830ced5..1f5c10153 100644 --- a/gpjax/kernels/stationary/powered_exponential.py +++ b/gpjax/kernels/stationary/powered_exponential.py @@ -13,20 +13,21 @@ # limitations under the License. # ============================================================================== +from dataclasses import dataclass from typing import Dict, List, Optional import jax import jax.numpy as jnp from jax.random import KeyArray -from jaxtyping import Array +from jaxtyping import Array, Float +from ...parameters import Softplus, param_field from ..base import AbstractKernel -from ..computations import ( - DenseKernelComputation, -) +from ..computations import DenseKernelComputation from .utils import euclidean_distance +@dataclass class PoweredExponential(AbstractKernel): """The powered exponential family of kernels. @@ -34,17 +35,11 @@ class PoweredExponential(AbstractKernel): """ - def __init__( - self, - active_dims: Optional[List[int]] = None, - name: Optional[str] = "Powered exponential", - ) -> None: - super().__init__( - DenseKernelComputation, active_dims, spectral_density=None, name=name - ) - self._stationary = True + lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus) + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) + power: Float[Array, "1"] = param_field(jnp.array([1.0])) - def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: + def __call__(self, x: jax.Array, y: jax.Array) -> Array: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell`, :math:`\\sigma` and power :math:`\\kappa`. .. math:: @@ -53,19 +48,11 @@ def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: Args: x (jax.Array): The left hand argument of the kernel function's call. y (jax.Array): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. Returns: Array: The value of :math:`k(x, y)` """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] - K = params["variance"] * jnp.exp(-euclidean_distance(x, y) ** params["power"]) + x = self.slice_input(x) / self.lengthscale + y = self.slice_input(y) / self.lengthscale + K = self.variance * jnp.exp(-euclidean_distance(x, y) ** self.power) return K.squeeze() - - def init_params(self, key: KeyArray) -> Dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - "power": jnp.array([1.0]), - } diff --git a/gpjax/kernels/stationary/rational_quadratic.py b/gpjax/kernels/stationary/rational_quadratic.py index eed0e819e..b5e797406 100644 --- a/gpjax/kernels/stationary/rational_quadratic.py +++ b/gpjax/kernels/stationary/rational_quadratic.py @@ -13,32 +13,28 @@ # limitations under the License. # ============================================================================== +from dataclasses import dataclass from typing import List, Optional import jax import jax.numpy as jnp from jax.random import KeyArray -from jaxtyping import Array +from jaxtyping import Array, Float +from ...parameters import Softplus, param_field from ..base import AbstractKernel -from ..computations import ( - DenseKernelComputation, -) +from ..computations import DenseKernelComputation from .utils import squared_distance +@dataclass class RationalQuadratic(AbstractKernel): - def __init__( - self, - active_dims: Optional[List[int]] = None, - name: Optional[str] = "Rational Quadratic", - ) -> None: - super().__init__( - DenseKernelComputation, active_dims, spectral_density=None, name=name - ) - self._stationary = True - def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: + lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus) + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) + alpha: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) + + def __call__(self, x: jax.Array, y: jax.Array) -> Array: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell` and variance :math:`\\sigma` .. math:: @@ -47,20 +43,12 @@ def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: Args: x (jax.Array): The left hand argument of the kernel function's call. y (jax.Array): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. Returns: Array: The value of :math:`k(x, y)` """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] - K = params["variance"] * ( - 1 + 0.5 * squared_distance(x, y) / params["alpha"] - ) ** (-params["alpha"]) + x = self.slice_input(x) / self.lengthscale + y = self.slice_input(y) / self.lengthscale + K = self.variance * (1 + 0.5 * squared_distance(x, y) / self.alpha) ** ( + -self.alpha + ) return K.squeeze() - - def init_params(self, key: KeyArray) -> dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - "alpha": jnp.array([1.0]), - } diff --git a/gpjax/kernels/stationary/white.py b/gpjax/kernels/stationary/white.py index 0aa04f555..91b48c410 100644 --- a/gpjax/kernels/stationary/white.py +++ b/gpjax/kernels/stationary/white.py @@ -13,30 +13,24 @@ # limitations under the License. # ============================================================================== -from typing import Dict, Optional, List +from dataclasses import dataclass +from typing import Dict, List, Optional import jax.numpy as jnp from jaxtyping import Array, Float +from ...parameters import Softplus, param_field from ..base import AbstractKernel -from ..computations import ( - ConstantDiagonalKernelComputation, - AbstractKernelComputation, -) +from ..computations import AbstractKernelComputation, ConstantDiagonalKernelComputation +@dataclass class White(AbstractKernel): - def __init__( - self, - compute_engine: AbstractKernelComputation = ConstantDiagonalKernelComputation, - active_dims: Optional[List[int]] = None, - name: Optional[str] = "White Noise Kernel", - ) -> None: - super().__init__(compute_engine, active_dims, spectral_density=None, name=name) - self._stationary = True + + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) def __call__( - self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] + self, x: Float[Array, "1 D"], y: Float[Array, "1 D"] ) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with variance :math:`\\sigma` @@ -44,23 +38,11 @@ def __call__( k(x, y) = \\sigma^2 \\delta(x-y) Args: - params (Dict): Parameter set for which the kernel should be evaluated on. x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. Returns: Float[Array, "1"]: The value of :math:`k(x, y)`. """ - K = jnp.all(jnp.equal(x, y)) * params["variance"] + K = jnp.all(jnp.equal(x, y)) * self.variance return K.squeeze() - - def init_params(self, key: Float[Array, "1 D"]) -> Dict: - """Initialise the kernel parameters. - - Args: - key (Float[Array, "1 D"]): The key to initialise the parameters with. - - Returns: - Dict: The initialised parameters. - """ - return {"variance": jnp.array([1.0])} From b609c0b48af9c3e5af37530920ddb8924fe6c730 Mon Sep 17 00:00:00 2001 From: frazane Date: Wed, 29 Mar 2023 16:18:42 +0200 Subject: [PATCH 10/44] tests draft --- tests/test_kernels/test_stationary.py | 454 ++++++++++---------------- 1 file changed, 172 insertions(+), 282 deletions(-) diff --git a/tests/test_kernels/test_stationary.py b/tests/test_kernels/test_stationary.py index 262765cdb..873d8944f 100644 --- a/tests/test_kernels/test_stationary.py +++ b/tests/test_kernels/test_stationary.py @@ -14,11 +14,12 @@ # ============================================================================== -from itertools import permutations +from itertools import permutations, product import jax import jax.numpy as jnp import jax.random as jr +import jax.tree_util as jtu import pytest import distrax as dx from jax.config import config @@ -30,8 +31,14 @@ Matern12, Matern32, Matern52, + White, + Periodic, + PoweredExponential, + RationalQuadratic, ) +from gpjax.kernels.computations import DenseKernelComputation, DiagonalKernelComputation from gpjax.kernels.stationary.utils import build_student_t_distribution +from gpjax.parameters.bijectors import Identity, Softplus # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -39,295 +46,178 @@ _jitter = 1e-6 -@pytest.mark.parametrize( - "kernel", - [ - RBF(), - # Matern12(), - # Matern32(), - # Matern52(), - # RationalQuadratic(), - # White(), - ], -) -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("n", [1, 2, 10]) -def test_gram(kernel: AbstractKernel, dim: int, n: int) -> None: - - # Gram constructor static method: - kernel.gram - - # Inputs x: - x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) - - # Test gram matrix: - Kxx = kernel.gram(x) - assert isinstance(Kxx, LinearOperator) - assert Kxx.shape == (n, n) - - -@pytest.mark.parametrize( - "kernel", - [ - RBF(), - # Matern12(), - # Matern32(), - # Matern52(), - # RationalQuadratic(), - # White(), - ], -) -@pytest.mark.parametrize("num_a", [1, 2, 5]) -@pytest.mark.parametrize("num_b", [1, 2, 5]) -@pytest.mark.parametrize("dim", [1, 2, 5]) -def test_cross_covariance( - kernel: AbstractKernel, num_a: int, num_b: int, dim: int -) -> None: - # Inputs a, b: - a = jnp.linspace(-1.0, 1.0, num_a * dim).reshape(num_a, dim) - b = jnp.linspace(3.0, 4.0, num_b * dim).reshape(num_b, dim) - - # Test cross covariance, Kab: - Kab = kernel.cross_covariance(a, b) - assert isinstance(Kab, jnp.ndarray) - assert Kab.shape == (num_a, num_b) - - -@pytest.mark.parametrize( - "kernel", - [ - RBF(), - # Matern12(), - # Matern32(), - # Matern52(), - # White(), - ], -) -@pytest.mark.parametrize("dim", [1, 2, 5]) -def test_call(kernel: AbstractKernel, dim: int) -> None: - - # Datapoint x and datapoint y: - x = jnp.array([[1.0] * dim]) - y = jnp.array([[0.5] * dim]) - - # Test calling gives an autocovariance value of no dimension between the inputs: - kxy = kernel(x, y) - - assert isinstance(kxy, jax.Array) - assert kxy.shape == () - - -@pytest.mark.parametrize( - "kern", - [ - RBF, - # Matern12, - # Matern32, - # Matern52, - ], -) -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) -@pytest.mark.parametrize("n", [1, 2, 5]) -def test_pos_def( - kern: AbstractKernel, dim: int, ell: float, sigma: float, n: int -) -> None: - kern = kern( - active_dims=list(range(dim)), - lengthscale=jnp.array([ell]), - variance=jnp.array([sigma]), +class BaseTestKernel: + """A base class that contains all tests applied on stationary kernels.""" + + kernel: AbstractKernel + default_compute_engine = type + spectral_density_name: str + + def pytest_generate_tests(self, metafunc): + """This is called automatically by pytest""" + id_func = lambda x: "-".join([f"{k}={v}" for k, v in x.items()]) + funcarglist = metafunc.cls.params.get(metafunc.function.__name__, None) + + if funcarglist is None: + + return + else: + argnames = sorted(funcarglist[0]) + metafunc.parametrize( + argnames, + [[funcargs[name] for name in argnames] for funcargs in funcarglist], + ids=id_func, + ) + + @pytest.mark.parametrize("dim", [None, 1, 3], ids=lambda x: f"dim={x}") + def test_initialization(self, fields: dict, dim: int) -> None: + + fields = {k: jnp.array([v]) for k, v in fields.items()} + + # number of dimensions + if dim is None: + kernel: AbstractKernel = self.kernel(**fields) + assert kernel.ndims == 1 + else: + kernel: AbstractKernel = self.kernel( + active_dims=[i for i in range(dim)], **fields + ) + assert kernel.ndims == dim + + # compute engine + assert kernel.compute_engine == self.default_compute_engine + + # properties + for field, value in fields.items(): + assert getattr(kernel, field) == value + + # pytree + leaves = jtu.tree_leaves(kernel) + assert len(leaves) == len(fields) + + # meta + meta_leaves = kernel._pytree__meta + assert meta_leaves.keys() == fields.keys() + for field in fields: + if field in ["variance", "lengthscale", "period", "alpha"]: + assert meta_leaves[field]["bijector"] == Softplus + if field in ["power"]: + assert meta_leaves[field]["bijector"] == Identity + assert meta_leaves[field]["trainable"] == True + + # call + x = jnp.linspace(0.0, 1.0, 10 * kernel.ndims).reshape(10, kernel.ndims) + kernel(x, x) + + @pytest.mark.parametrize("n", [1, 5], ids=lambda x: f"n={x}") + @pytest.mark.parametrize("dim", [1, 3], ids=lambda x: f"dim={x}") + def test_gram(self, dim: int, n: int) -> None: + kernel: AbstractKernel = self.kernel() + kernel.gram + x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) + Kxx = kernel.gram(x) + assert isinstance(Kxx, LinearOperator) + assert Kxx.shape == (n, n) + assert jnp.all(jnp.linalg.eigvalsh(Kxx.to_dense()) > 0.0) + + @pytest.mark.parametrize("n_a", [1, 2, 5], ids=lambda x: f"n_a={x}") + @pytest.mark.parametrize("n_b", [1, 2, 5], ids=lambda x: f"n_b={x}") + @pytest.mark.parametrize("dim", [1, 2, 5], ids=lambda x: f"dim={x}") + def test_cross_covariance(self, n_a: int, n_b: int, dim: int) -> None: + + kernel: AbstractKernel = self.kernel() + a = jnp.linspace(-1.0, 1.0, n_a * dim).reshape(n_a, dim) + b = jnp.linspace(3.0, 4.0, n_b * dim).reshape(n_b, dim) + Kab = kernel.cross_covariance(a, b) + assert isinstance(Kab, jnp.ndarray) + assert Kab.shape == (n_a, n_b) + + def test_spectral_density(self): + + kernel: AbstractKernel = self.kernel() + + if self.kernel not in [RBF, Matern12, Matern32, Matern52]: + with pytest.raises(AttributeError): + kernel.spectral_density + else: + sdensity = kernel.spectral_density + assert sdensity.name == self.spectral_density_name + assert sdensity.loc == jnp.array(0.0) + assert sdensity.scale == jnp.array(1.0) + + +prod = lambda inp: [ + {"fields": dict(zip(inp.keys(), values))} for values in product(*inp.values()) +] + + +class TestRBF(BaseTestKernel): + kernel = RBF + fields = prod({"lengthscale": [0.1, 1.0], "variance": [0.1, 1.0]}) + params = {"test_initialization": fields} + default_compute_engine = DenseKernelComputation + spectral_density_name = "Normal" + + +class TestMatern12(BaseTestKernel): + kernel = Matern12 + fields = prod({"lengthscale": [0.1, 1.0], "variance": [0.1, 1.0]}) + params = {"test_initialization": fields} + default_compute_engine = DenseKernelComputation + spectral_density_name = "StudentT" + + +class TestMatern32(BaseTestKernel): + kernel = Matern32 + fields = prod({"lengthscale": [0.1, 1.0], "variance": [0.1, 1.0]}) + params = {"test_initialization": fields} + default_compute_engine = DenseKernelComputation + spectral_density_name = "StudentT" + + +class TestMatern52(BaseTestKernel): + kernel = Matern52 + fields = prod({"lengthscale": [0.1, 1.0], "variance": [0.1, 1.0]}) + params = {"test_initialization": fields} + default_compute_engine = DenseKernelComputation + spectral_density_name = "StudentT" + + +class TestWhite(BaseTestKernel): + kernel = White + fields = prod({"variance": [0.1, 1.0]}) + params = {"test_initialization": fields} + default_compute_engine = DenseKernelComputation + + +class TestPeriodic(BaseTestKernel): + kernel = Periodic + fields = prod( + {"lengthscale": [0.1, 1.0], "variance": [0.1, 1.0], "period": [0.1, 1.0]} ) + params = {"test_initialization": fields} + default_compute_engine = DenseKernelComputation - # Create inputs x: - x = jr.uniform(_initialise_key, (n, dim)) - - # Test gram matrix eigenvalues are positive: - Kxx = kern.gram(x) - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert (eigen_values > 0.0).all() - - -# @pytest.mark.parametrize("dim", [1, 2, 5]) -# @pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) -# @pytest.mark.parametrize("alpha", [0.1, 0.5, 1.0]) -# @pytest.mark.parametrize("n", [1, 2, 5]) -# def test_pos_def_rq(dim: int, ell: float, sigma: float, alpha: float, n: int) -> None: -# kern = RationalQuadratic(active_dims=list(range(dim))) -# # Gram constructor static method: -# kern.gram - -# # Create inputs x: -# x = jr.uniform(_initialise_key, (n, dim)) -# params = { -# "lengthscale": jnp.array([ell]), -# "variance": jnp.array([sigma]), -# "alpha": jnp.array([alpha]), -# } - -# # Test gram matrix eigenvalues are positive: -# Kxx = kern.gram(params, x) -# Kxx += identity(n) * _jitter -# eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) -# assert (eigen_values > 0.0).all() - - -# @pytest.mark.parametrize("dim", [1, 2, 5]) -# @pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) -# @pytest.mark.parametrize("period", [0.1, 0.5, 1.0]) -# @pytest.mark.parametrize("n", [1, 2, 5]) -# def test_pos_def_periodic( -# dim: int, ell: float, sigma: float, period: float, n: int -# ) -> None: -# kern = Periodic(active_dims=list(range(dim))) -# # Gram constructor static method: -# kern.gram - -# # Create inputs x: -# x = jr.uniform(_initialise_key, (n, dim)) -# params = { -# "lengthscale": jnp.array([ell]), -# "variance": jnp.array([sigma]), -# "period": jnp.array([period]), -# } - -# # Test gram matrix eigenvalues are positive: -# Kxx = kern.gram(params, x) -# Kxx += identity(n) * _jitter -# eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) -# # assert (eigen_values > 0.0).all() - - -# @pytest.mark.parametrize("dim", [1, 2, 5]) -# @pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) -# @pytest.mark.parametrize("power", [0.1, 0.5, 1.0]) -# @pytest.mark.parametrize("n", [1, 2, 5]) -# def test_pos_def_power_exp( -# dim: int, ell: float, sigma: float, power: float, n: int -# ) -> None: -# kern = PoweredExponential(active_dims=list(range(dim))) -# # Gram constructor static method: -# kern.gram - -# # Create inputs x: -# x = jr.uniform(_initialise_key, (n, dim)) -# params = { -# "lengthscale": jnp.array([ell]), -# "variance": jnp.array([sigma]), -# "power": jnp.array([power]), -# } - -# # Test gram matrix eigenvalues are positive: -# Kxx = kern.gram(params, x) -# Kxx += identity(n) * _jitter -# eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) -# assert (eigen_values > 0.0).all() - - -# @pytest.mark.parametrize("kernel", -# [ -# RBF, -# #Matern12, -# #Matern32, -# #Matern52, -# ], -# ) -# @pytest.mark.parametrize("dim", [None, 1, 2, 5, 10]) -# def test_initialisation(kernel: AbstractKernel, dim: int) -> None: - -# if dim is None: -# kern = kernel() -# assert kern.ndims == 1 - -# else: -# kern = kernel(active_dims=[i for i in range(dim)]) -# params = kern.init_params(_initialise_key) - -# assert list(params.keys()) == ["lengthscale", "variance"] -# assert all(params["lengthscale"] == jnp.array([1.0] * dim)) -# assert params["variance"] == jnp.array([1.0]) - -# if dim > 1: -# assert kern.ard -# else: -# assert not kern.ard - - -# @pytest.mark.parametrize( -# "kernel", -# [ -# RBF, -# # Matern12, -# # Matern32, -# # Matern52, -# # RationalQuadratic, -# # Periodic, -# # PoweredExponential, -# ], -# ) -# def test_dtype(kernel: AbstractKernel) -> None: -# parameter_state = initialise(kernel(), _initialise_key) -# params, *_ = parameter_state.unpack() -# for k, v in params.items(): -# assert v.dtype == jnp.float64 -# assert isinstance(k, str) - - -@pytest.mark.parametrize( - "kernel", - [ - RBF, - # Matern12, - # Matern32, - # Matern52, - # RationalQuadratic, - ], -) -def test_active_dim(kernel: AbstractKernel) -> None: - dim_list = [0, 1, 2, 3] - perm_length = 2 - dim_pairs = list(permutations(dim_list, r=perm_length)) - n_dims = len(dim_list) - - # Generate random inputs - x = jr.normal(_initialise_key, shape=(20, n_dims)) - - for dp in dim_pairs: - # Take slice of x - slice = x[..., dp] - # Define kernels - ad_kern = kernel(active_dims=dp) - manual_kern = kernel(active_dims=[i for i in range(perm_length)]) +class TestPoweredExponential(BaseTestKernel): + kernel = PoweredExponential + fields = prod( + {"lengthscale": [0.1, 1.0], "variance": [0.1, 1.0], "power": [0.1, 2.0]} + ) + params = {"test_initialization": fields} + default_compute_engine = DenseKernelComputation - # Compute gram matrices - ad_Kxx = ad_kern.gram(x) - manual_Kxx = manual_kern.gram(slice) - # Test gram matrices are equal - assert jnp.all(ad_Kxx.to_dense() == manual_Kxx.to_dense()) +class TestRationalQuadratic(BaseTestKernel): + kernel = RationalQuadratic + fields = prod( + {"lengthscale": [0.1, 1.0], "variance": [0.1, 1.0], "alpha": [0.1, 1.0]} + ) + params = {"test_initialization": fields} + default_compute_engine = DenseKernelComputation @pytest.mark.parametrize("smoothness", [1, 2, 3]) def test_build_studentt_dist(smoothness: int) -> None: dist = build_student_t_distribution(smoothness) assert isinstance(dist, dx.Distribution) - - -# @pytest.mark.parametrize( -# "kern, df", [(Matern12(), 1), (Matern32(), 3), (Matern52(), 5)] -# ) -# def test_matern_spectral_density(kern, df) -> None: -# sdensity = kern.spectral_density -# assert sdensity.name == "StudentT" -# assert sdensity.df == df -# assert sdensity.loc == jnp.array(0.0) -# assert sdensity.scale == jnp.array(1.0) - - -# def test_rbf_spectral_density() -> None: -# kern = RBF() -# sdensity = kern.spectral_density -# assert sdensity.name == "Normal" -# assert sdensity.loc == jnp.array(0.0) -# assert sdensity.scale == jnp.array(1.0) From f39d07b5233ede1dcd28ae69f90a212bf0618542 Mon Sep 17 00:00:00 2001 From: frazane Date: Wed, 29 Mar 2023 16:59:04 +0200 Subject: [PATCH 11/44] spectral density as property (RBF) --- gpjax/kernels/stationary/rbf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gpjax/kernels/stationary/rbf.py b/gpjax/kernels/stationary/rbf.py index 235556ee6..d04fcef53 100644 --- a/gpjax/kernels/stationary/rbf.py +++ b/gpjax/kernels/stationary/rbf.py @@ -51,5 +51,6 @@ def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, " K = self.variance * jnp.exp(-0.5 * squared_distance(x, y)) return K.squeeze() + @property def spectral_density(self) -> dx.Normal: return dx.Normal(loc=0.0, scale=1.0) From 982354390e3066b76718cddd8b81c1b914e96bf8 Mon Sep 17 00:00:00 2001 From: frazane Date: Wed, 29 Mar 2023 18:56:07 +0200 Subject: [PATCH 12/44] add jitter in gram test --- tests/test_kernels/test_stationary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_kernels/test_stationary.py b/tests/test_kernels/test_stationary.py index 873d8944f..a05e0d745 100644 --- a/tests/test_kernels/test_stationary.py +++ b/tests/test_kernels/test_stationary.py @@ -118,7 +118,7 @@ def test_gram(self, dim: int, n: int) -> None: Kxx = kernel.gram(x) assert isinstance(Kxx, LinearOperator) assert Kxx.shape == (n, n) - assert jnp.all(jnp.linalg.eigvalsh(Kxx.to_dense()) > 0.0) + assert jnp.all(jnp.linalg.eigvalsh(Kxx.to_dense() + jnp.eye(n) * 1e-6) > 0.0) @pytest.mark.parametrize("n_a", [1, 2, 5], ids=lambda x: f"n_a={x}") @pytest.mark.parametrize("n_b", [1, 2, 5], ids=lambda x: f"n_b={x}") From cc0c22051f6e27241f30acb42e8a7960844fed88 Mon Sep 17 00:00:00 2001 From: frazane Date: Wed, 29 Mar 2023 18:59:03 +0200 Subject: [PATCH 13/44] fix default engine for white kernel --- gpjax/kernels/stationary/white.py | 4 ++++ tests/test_kernels/test_stationary.py | 8 ++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/gpjax/kernels/stationary/white.py b/gpjax/kernels/stationary/white.py index 91b48c410..e39262bdf 100644 --- a/gpjax/kernels/stationary/white.py +++ b/gpjax/kernels/stationary/white.py @@ -18,6 +18,7 @@ import jax.numpy as jnp from jaxtyping import Array, Float +from simple_pytree import static_field from ...parameters import Softplus, param_field from ..base import AbstractKernel @@ -28,6 +29,9 @@ class White(AbstractKernel): variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) + compute_engine: AbstractKernelComputation = static_field( + ConstantDiagonalKernelComputation + ) def __call__( self, x: Float[Array, "1 D"], y: Float[Array, "1 D"] diff --git a/tests/test_kernels/test_stationary.py b/tests/test_kernels/test_stationary.py index a05e0d745..7bdb65874 100644 --- a/tests/test_kernels/test_stationary.py +++ b/tests/test_kernels/test_stationary.py @@ -36,7 +36,11 @@ PoweredExponential, RationalQuadratic, ) -from gpjax.kernels.computations import DenseKernelComputation, DiagonalKernelComputation +from gpjax.kernels.computations import ( + DenseKernelComputation, + DiagonalKernelComputation, + ConstantDiagonalKernelComputation, +) from gpjax.kernels.stationary.utils import build_student_t_distribution from gpjax.parameters.bijectors import Identity, Softplus @@ -187,7 +191,7 @@ class TestWhite(BaseTestKernel): kernel = White fields = prod({"variance": [0.1, 1.0]}) params = {"test_initialization": fields} - default_compute_engine = DenseKernelComputation + default_compute_engine = ConstantDiagonalKernelComputation class TestPeriodic(BaseTestKernel): From 2d06b50b04c242ae31d59f8bc7fbed12c04807a7 Mon Sep 17 00:00:00 2001 From: frazane Date: Wed, 29 Mar 2023 19:06:52 +0200 Subject: [PATCH 14/44] fix jaxtyping hints --- gpjax/kernels/stationary/matern12.py | 8 +++----- gpjax/kernels/stationary/matern32.py | 8 ++++---- gpjax/kernels/stationary/matern52.py | 6 +++--- gpjax/kernels/stationary/periodic.py | 15 +++++++-------- gpjax/kernels/stationary/powered_exponential.py | 8 ++++---- gpjax/kernels/stationary/rational_quadratic.py | 8 ++++---- gpjax/kernels/stationary/utils.py | 12 ++++++------ gpjax/kernels/stationary/white.py | 8 +++----- 8 files changed, 34 insertions(+), 39 deletions(-) diff --git a/gpjax/kernels/stationary/matern12.py b/gpjax/kernels/stationary/matern12.py index a93e1f16a..a991231d6 100644 --- a/gpjax/kernels/stationary/matern12.py +++ b/gpjax/kernels/stationary/matern12.py @@ -34,9 +34,7 @@ class Matern12(AbstractKernel): lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus) variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) - def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"] - ) -> Float[Array, "1"]: + def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` @@ -44,8 +42,8 @@ def __call__( k(x, y) = \\sigma^2 \\exp \\Bigg( -\\frac{\\lvert x-y \\rvert}{2\\ell^2} \\Bigg) Args: - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call + x (Float[Array, "D"]): The left hand argument of the kernel function's call. + y (Float[Array, "D"]): The right hand argument of the kernel function's call Returns: Float[Array, "1"]: The value of :math:`k(x, y)` """ diff --git a/gpjax/kernels/stationary/matern32.py b/gpjax/kernels/stationary/matern32.py index 42454bfaf..9a630392e 100644 --- a/gpjax/kernels/stationary/matern32.py +++ b/gpjax/kernels/stationary/matern32.py @@ -35,8 +35,8 @@ class Matern32(AbstractKernel): def __call__( self, - x: Float[Array, "1 D"], - y: Float[Array, "1 D"], + x: Float[Array, "D"], + y: Float[Array, "D"], ) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` @@ -45,8 +45,8 @@ def __call__( k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{3}\\lvert x-y \\rvert}{\\ell^2} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{3}\\lvert x-y\\rvert}{\\ell^2} \\Bigg) Args: - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. + x (Float[Array, "D"]): The left hand argument of the kernel function's call. + y (Float[Array, "D"]): The right hand argument of the kernel function's call. Returns: Float[Array, "1"]: The value of :math:`k(x, y)`. diff --git a/gpjax/kernels/stationary/matern52.py b/gpjax/kernels/stationary/matern52.py index a2c8a5f16..70078b734 100644 --- a/gpjax/kernels/stationary/matern52.py +++ b/gpjax/kernels/stationary/matern52.py @@ -34,7 +34,7 @@ class Matern52(AbstractKernel): variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"] + self, x: Float[Array, "D"], y: Float[Array, "D"] ) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` @@ -43,8 +43,8 @@ def __call__( k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{5}\\lvert x-y \\rvert}{\\ell^2} + \\frac{5\\lvert x - y \\rvert^2}{3\\ell^2} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{5}\\lvert x-y\\rvert}{\\ell^2} \\Bigg) Args: - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. + x (Float[Array, "D"]): The left hand argument of the kernel function's call. + y (Float[Array, "D"]): The right hand argument of the kernel function's call. Returns: Float[Array, "1"]: The value of :math:`k(x, y)`. diff --git a/gpjax/kernels/stationary/periodic.py b/gpjax/kernels/stationary/periodic.py index 2b6793b90..4d03a414a 100644 --- a/gpjax/kernels/stationary/periodic.py +++ b/gpjax/kernels/stationary/periodic.py @@ -28,6 +28,7 @@ from dataclasses import dataclass from ...parameters import param_field, Softplus + @dataclass class Periodic(AbstractKernel): """The periodic kernel. @@ -37,23 +38,21 @@ class Periodic(AbstractKernel): lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus) variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) - period: Float[Array, "1"] = param_field( - jnp.array([1.0]), bijector=Softplus - ) # NOTE: is bijector needed? + period: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) - def __call__(self, x: jax.Array, y: jax.Array) -> Array: + def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell` and variance :math:`\\sigma` - TODO: write docstring + TODO: update docstring .. math:: k(x, y) = \\sigma^2 \\exp \\Bigg( -0.5 \\sum_{i=1}^{d} \\Bigg) Args: - x (jax.Array): The left hand argument of the kernel function's call. - y (jax.Array): The right hand argument of the kernel function's call + x (Float[Array, "D"]): The left hand argument of the kernel function's call. + y (Float[Array, "D"]): The right hand argument of the kernel function's call Returns: - Array: The value of :math:`k(x, y)` + Float[Array, "1"]: The value of :math:`k(x, y)` """ x = self.slice_input(x) y = self.slice_input(y) diff --git a/gpjax/kernels/stationary/powered_exponential.py b/gpjax/kernels/stationary/powered_exponential.py index 1f5c10153..320d10879 100644 --- a/gpjax/kernels/stationary/powered_exponential.py +++ b/gpjax/kernels/stationary/powered_exponential.py @@ -39,18 +39,18 @@ class PoweredExponential(AbstractKernel): variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) power: Float[Array, "1"] = param_field(jnp.array([1.0])) - def __call__(self, x: jax.Array, y: jax.Array) -> Array: + def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell`, :math:`\\sigma` and power :math:`\\kappa`. .. math:: k(x, y) = \\sigma^2 \\exp \\Bigg( - \\Big( \\frac{\\lVert x - y \\rVert^2}{\\ell^2} \\Big)^\\kappa \\Bigg) Args: - x (jax.Array): The left hand argument of the kernel function's call. - y (jax.Array): The right hand argument of the kernel function's call + x (Float[Array, "D"]): The left hand argument of the kernel function's call. + y (Float[Array, "D"]): The right hand argument of the kernel function's call Returns: - Array: The value of :math:`k(x, y)` + Float[Array, "1"]: The value of :math:`k(x, y)` """ x = self.slice_input(x) / self.lengthscale y = self.slice_input(y) / self.lengthscale diff --git a/gpjax/kernels/stationary/rational_quadratic.py b/gpjax/kernels/stationary/rational_quadratic.py index b5e797406..cf36c24c8 100644 --- a/gpjax/kernels/stationary/rational_quadratic.py +++ b/gpjax/kernels/stationary/rational_quadratic.py @@ -34,17 +34,17 @@ class RationalQuadratic(AbstractKernel): variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) alpha: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) - def __call__(self, x: jax.Array, y: jax.Array) -> Array: + def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell` and variance :math:`\\sigma` .. math:: k(x, y) = \\sigma^2 \\exp \\Bigg( 1 + \\frac{\\lVert x - y \\rVert^2_2}{2 \\alpha \\ell^2} \\Bigg) Args: - x (jax.Array): The left hand argument of the kernel function's call. - y (jax.Array): The right hand argument of the kernel function's call + x (Float[Array, "D"]): The left hand argument of the kernel function's call. + y (Float[Array, "D"]): The right hand argument of the kernel function's call Returns: - Array: The value of :math:`k(x, y)` + Float[Array, "1"]: The value of :math:`k(x, y)` """ x = self.slice_input(x) / self.lengthscale y = self.slice_input(y) / self.lengthscale diff --git a/gpjax/kernels/stationary/utils.py b/gpjax/kernels/stationary/utils.py index f265d5671..fa04f2310 100644 --- a/gpjax/kernels/stationary/utils.py +++ b/gpjax/kernels/stationary/utils.py @@ -35,13 +35,13 @@ def build_student_t_distribution(nu: int) -> dx.Distribution: def squared_distance( - x: Float[Array, "1 D"], y: Float[Array, "1 D"] + x: Float[Array, "D"], y: Float[Array, "D"] ) -> Float[Array, "1"]: """Compute the squared distance between a pair of inputs. Args: - x (Float[Array, "1 D"]): First input. - y (Float[Array, "1 D"]): Second input. + x (Float[Array, "D"]): First input. + y (Float[Array, "D"]): Second input. Returns: Float[Array, "1"]: The squared distance between the inputs. @@ -51,13 +51,13 @@ def squared_distance( def euclidean_distance( - x: Float[Array, "1 D"], y: Float[Array, "1 D"] + x: Float[Array, "D"], y: Float[Array, "D"] ) -> Float[Array, "1"]: """Compute the euclidean distance between a pair of inputs. Args: - x (Float[Array, "1 D"]): First input. - y (Float[Array, "1 D"]): Second input. + x (Float[Array, "D"]): First input. + y (Float[Array, "D"]): Second input. Returns: Float[Array, "1"]: The euclidean distance between the inputs. diff --git a/gpjax/kernels/stationary/white.py b/gpjax/kernels/stationary/white.py index e39262bdf..8b7c6caaf 100644 --- a/gpjax/kernels/stationary/white.py +++ b/gpjax/kernels/stationary/white.py @@ -33,17 +33,15 @@ class White(AbstractKernel): ConstantDiagonalKernelComputation ) - def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"] - ) -> Float[Array, "1"]: + def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with variance :math:`\\sigma` .. math:: k(x, y) = \\sigma^2 \\delta(x-y) Args: - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. + x (Float[Array, "D"]): The left hand argument of the kernel function's call. + y (Float[Array, "D"]): The right hand argument of the kernel function's call. Returns: Float[Array, "1"]: The value of :math:`k(x, y)`. From aee2a12fa5d29d287d90c763e62bf1a4317b3097 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Thu, 30 Mar 2023 19:13:32 +0100 Subject: [PATCH 15/44] Fix bugs on the base. --- gpjax/kernels/base.py | 10 ++-- tests/test_kernels/test_base.py | 103 ++++++++++---------------------- 2 files changed, 36 insertions(+), 77 deletions(-) diff --git a/gpjax/kernels/base.py b/gpjax/kernels/base.py index 88bc53b50..5406310d1 100644 --- a/gpjax/kernels/base.py +++ b/gpjax/kernels/base.py @@ -84,9 +84,9 @@ def __add__( """ if isinstance(other, AbstractKernel): - return SumKernel([self, other]) + return SumKernel(kernels=[self, other]) - return SumKernel([self, Constant(other)]) + return SumKernel(kernels=[self, Constant(other)]) def __radd__( self, other: Union[AbstractKernel, Float[Array, "1"]] @@ -112,9 +112,9 @@ def __mul__( AbstractKernel: A new kernel that is the product of the two kernels. """ if isinstance(other, AbstractKernel): - return ProductKernel([self, other]) + return ProductKernel(kernels=[self, other]) - return ProductKernel([self, Constant(other)]) + return ProductKernel(kernels=[self, Constant(other)]) @dataclass @@ -179,4 +179,4 @@ def __call__( SumKernel = partial(CombinationKernel, operator=jnp.sum) -ProductKernel = partial(CombinationKernel, operator=jnp.sum) +ProductKernel = partial(CombinationKernel, operator=jnp.prod) diff --git a/tests/test_kernels/test_base.py b/tests/test_kernels/test_base.py index 7e3dddce5..203315626 100644 --- a/tests/test_kernels/test_base.py +++ b/tests/test_kernels/test_base.py @@ -14,7 +14,6 @@ # ============================================================================== import jax.numpy as jnp -import jax.random as jr import pytest from jax.config import config from jaxlinop import identity @@ -33,38 +32,35 @@ RationalQuadratic, ) from gpjax.kernels.nonstationary import Polynomial, Linear -from jax.random import KeyArray from jaxtyping import Array, Float -from typing import Dict +from dataclasses import dataclass +from mytree import param_field, Softplus # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) -_initialise_key = jr.PRNGKey(123) -_jitter = 1e-6 def test_abstract_kernel(): - # Test initialising abstract kernel raises TypeError with unimplemented __call__ and _init_params methods: + # Test initialising abstract kernel raises TypeError with unimplemented __call__ method: with pytest.raises(TypeError): AbstractKernel() - # Create a dummy kernel class with __call__ and _init_params methods implemented: + # Create a dummy kernel class with __call__ implemented: + @dataclass class DummyKernel(AbstractKernel): - def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict - ) -> Float[Array, "1"]: - return x * params["test"] * y + test_a: Float[Array, "1"] = jnp.array([1.0]) + test_b: Float[Array, "1"] = param_field(jnp.array([2.0]), bijector=Softplus) - def init_params(self, key: KeyArray) -> Dict: - return {"test": 1.0} + def __call__(self, x: Float[Array, "1 D"], y: Float[Array, "1 D"]) -> Float[Array, "1"]: + return x * self.test_b * y - # Initialise dummy kernel class and test __call__ and _init_params methods: + # Initialise dummy kernel class and test __call__ method: dummy_kernel = DummyKernel() - assert dummy_kernel.init_params(_initialise_key) == {"test": 1.0} - assert ( - dummy_kernel(jnp.array([1.0]), jnp.array([2.0]), {"test": 2.0}) == 4.0 - ) + assert dummy_kernel.test_a == jnp.array([1.0]) + assert dummy_kernel._pytree__meta["test_b"].get("bijector") == Softplus + assert dummy_kernel.test_b == jnp.array([2.0]) + assert (dummy_kernel(jnp.array([1.0]), jnp.array([2.0])) == 4.0) @pytest.mark.parametrize("combination_type", [SumKernel, ProductKernel]) @@ -82,35 +78,29 @@ def test_combination_kernel( x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) # Create list of kernels - kernel_set = [kernel() for _ in range(n_kerns)] + kernels = [kernel() for _ in range(n_kerns)] # Create combination kernel - combination_kernel = combination_type(kernel_set=kernel_set) - - # Initialise default parameters - params = combination_kernel.init_params(_initialise_key) + combination_kernel = combination_type(kernels=kernels) # Check params are a list of dictionaries - assert len(params) == n_kerns - - for p in params: - assert isinstance(p, dict) + assert combination_kernel.kernels == kernels # Check combination kernel set - assert len(combination_kernel.kernel_set) == n_kerns - assert isinstance(combination_kernel.kernel_set, list) - assert isinstance(combination_kernel.kernel_set[0], AbstractKernel) + assert len(combination_kernel.kernels) == n_kerns + assert isinstance(combination_kernel.kernels, list) + assert isinstance(combination_kernel.kernels[0], AbstractKernel) # Compute gram matrix - Kxx = combination_kernel.gram(params, x) + Kxx = combination_kernel.gram(x) # Check shapes assert Kxx.shape[0] == Kxx.shape[1] assert Kxx.shape[1] == n # Check positive definiteness - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + jitter = 1e-6 + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense() + jnp.eye(n) * jitter) assert (eigen_values > 0).all() @@ -126,22 +116,14 @@ def test_sum_kern_value(k1: AbstractKernel, k2: AbstractKernel) -> None: x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) # Create sum kernel - sum_kernel = SumKernel(kernel_set=[k1, k2]) - - # Initialise default parameters - params = sum_kernel.init_params(_initialise_key) + sum_kernel = SumKernel(kernels=[k1, k2]) # Compute gram matrix - Kxx = sum_kernel.gram(params, x) - - # NOW we do the same thing manually and check they are equal: - # Initialise default parameters - k1_params = k1.init_params(_initialise_key) - k2_params = k2.init_params(_initialise_key) + Kxx = sum_kernel.gram(x) # Compute gram matrix - Kxx_k1 = k1.gram(k1_params, x) - Kxx_k2 = k2.gram(k2_params, x) + Kxx_k1 = k1.gram(x) + Kxx_k2 = k2.gram(x) # Check manual and automatic gram matrices are equal assert jnp.all(Kxx.to_dense() == Kxx_k1.to_dense() + Kxx_k2.to_dense()) @@ -180,37 +162,14 @@ def test_prod_kern_value(k1: AbstractKernel, k2: AbstractKernel) -> None: x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) # Create product kernel - prod_kernel = ProductKernel(kernel_set=[k1, k2]) - - # Initialise default parameters - params = prod_kernel.init_params(_initialise_key) + prod_kernel = ProductKernel(kernels=[k1, k2]) # Compute gram matrix - Kxx = prod_kernel.gram(params, x) - - # NOW we do the same thing manually and check they are equal: - - # Initialise default parameters - k1_params = k1.init_params(_initialise_key) - k2_params = k2.init_params(_initialise_key) + Kxx = prod_kernel.gram(x) # Compute gram matrix - Kxx_k1 = k1.gram(k1_params, x) - Kxx_k2 = k2.gram(k2_params, x) + Kxx_k1 = k1.gram(x) + Kxx_k2 = k2.gram(x) # Check manual and automatic gram matrices are equal assert jnp.all(Kxx.to_dense() == Kxx_k1.to_dense() * Kxx_k2.to_dense()) - - -@pytest.mark.parametrize( - "kernel", - [RBF, Matern12, Matern32, Matern52, Polynomial, Linear, RationalQuadratic], -) -def test_combination_kernel_type(kernel: AbstractKernel) -> None: - prod_kern = kernel() * kernel() - assert isinstance(prod_kern, ProductKernel) - assert isinstance(prod_kern, CombinationKernel) - - add_kern = kernel() + kernel() - assert isinstance(add_kern, SumKernel) - assert isinstance(add_kern, CombinationKernel) From a669eb92fb1706493e681afd8e19fdcc22200900 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Tue, 28 Mar 2023 22:16:21 +0100 Subject: [PATCH 16/44] Refactored variational families. --- gpjax/mean_functions.py | 190 +++++------ gpjax/variational_families.py | 531 +++++++++++------------------ tests/test_mean_functions.py | 80 ++--- tests/test_variational_families.py | 227 ++++++------ 4 files changed, 402 insertions(+), 626 deletions(-) diff --git a/gpjax/mean_functions.py b/gpjax/mean_functions.py index 43c72b3d6..f0173b3a7 100644 --- a/gpjax/mean_functions.py +++ b/gpjax/mean_functions.py @@ -13,151 +13,147 @@ # limitations under the License. # ============================================================================== -import abc -from typing import Dict, Optional +from __future__ import annotations -import deprecation +import abc +import dataclasses import jax.numpy as jnp -from jax.random import KeyArray +from beartype.typing import List, Callable, Union from jaxtyping import Array, Float -from jaxutils import PyTree - - -class AbstractMeanFunction(PyTree): - """Abstract mean function that is used to parameterise the Gaussian process.""" +from mytree import Mytree, param_field +from simple_pytree import static_field +from functools import partial - def __init__( - self, output_dim: Optional[int] = 1, name: Optional[str] = "Mean function" - ): - """Initialise the mean function. - Args: - output_dim (Optional[int]): The output dimension of the mean function. Defaults to 1. - name (Optional[str]): The name of the mean function. Defaults to "Mean function". - """ - self.output_dim = output_dim - self.name = name +@dataclasses.dataclass +class AbstractMeanFunction(Mytree): + """Mean function that is used to parameterise the Gaussian process.""" @abc.abstractmethod - def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: + def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]: """Evaluate the mean function at the given points. This method is required for all subclasses. Args: - params (Dict): The parameters of the mean function. - x (Float[Array, "N D"]): The input points at which to evaluate the mean function. + x (Float[Array, "D"]): The point at which to evaluate the mean function. Returns: - Float[Array, "N Q"]: The mean function evaluated point-wise on the inputs. + Float[Array, "1]: The evaluated mean function. """ raise NotImplementedError - - @abc.abstractmethod - def init_params(self, key: KeyArray) -> Dict: - """Return the parameters of the mean function. This method is required for all subclasses. + + def __add__(self, other: Union[AbstractMeanFunction, Float[Array, "1"]]) -> AbstractMeanFunction: + """Add two mean functions. Args: - key (KeyArray): The PRNG key to use for initialising the parameters. + other (AbstractMeanFunction): The other mean function to add. Returns: - Dict: The parameters of the mean function. + AbstractMeanFunction: The sum of the two mean functions. """ - raise NotImplementedError - - @deprecation.deprecated( - deprecated_in="0.5.7", - removed_in="0.6.0", - details="Use the ``init_params`` method for parameter initialisation.", - ) - def _initialise_params(self, key: KeyArray) -> Dict: - """Deprecated method for initialising the GP's parameters. Succeeded by ``init_params``.""" - return self.init_params(key) - -class Zero(AbstractMeanFunction): - """ - A zero mean function. This function returns zero for all inputs. - """ - - def __init__( - self, output_dim: Optional[int] = 1, name: Optional[str] = "Mean function" - ): - """Initialise the zero-mean function. + if isinstance(other, AbstractMeanFunction): + return SumMeanFunction([self, other]) + + return SumMeanFunction([self, Constant(other)]) + + def __radd__(self, other: Union[AbstractMeanFunction, Float[Array, "1"]]) -> AbstractMeanFunction: + """Add two mean functions. Args: - output_dim (Optional[int]): The output dimension of the mean function. Defaults to 1. - name (Optional[str]): The name of the mean function. Defaults to "Mean function". - """ - super().__init__(output_dim, name) + other (AbstractMeanFunction): The other mean function to add. - def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: - """Evaluate the mean function at the given points. + Returns: + AbstractMeanFunction: The sum of the two mean functions. + """ + return self.__add__(other) + + def __mul__(self, other: Union[AbstractMeanFunction, Float[Array, "1"]]) -> AbstractMeanFunction: + """Multiply two mean functions. Args: - params (Dict): The parameters of the mean function. - x (Float[Array, "N D"]): The input points at which to evaluate the mean function. - + other (AbstractMeanFunction): The other mean function to multiply. + Returns: - Float[Array, "N Q"]: A vector of zeros. + AbstractMeanFunction: The product of the two mean functions. """ - out_shape = (x.shape[0], self.output_dim) - return jnp.zeros(shape=out_shape) - - def init_params(self, key: KeyArray) -> Dict: - """The parameters of the mean function. For the zero-mean function, this is an empty dictionary. + if isinstance(other, AbstractMeanFunction): + return ProductMeanFunction([self, other]) + + return ProductMeanFunction([self, Constant(other)]) + + def __rmul__(self, other: Union[AbstractMeanFunction, Float[Array, "1"]]) -> AbstractMeanFunction: + """Multiply two mean functions. Args: - key (KeyArray): The PRNG key to use for initialising the parameters. - + other (AbstractMeanFunction): The other mean function to multiply. + Returns: - Dict: The parameters of the mean function. + AbstractMeanFunction: The product of the two mean functions. """ - return {} + return self.__mul__(other) +@dataclasses.dataclass class Constant(AbstractMeanFunction): """ - A zero mean function. This function returns a repeated scalar value for all inputs. + A constant mean function. This function returns a repeated scalar value for all inputs. The scalar value itself can be treated as a model hyperparameter and learned during training. """ + constant: Float[Array, "1"] = param_field(jnp.array([0.0])) - def __init__( - self, output_dim: Optional[int] = 1, name: Optional[str] = "Mean function" - ): - """Initialise the constant-mean function. - - Args: - output_dim (Optional[int]): The output dimension of the mean function. Defaults to 1. - name (Optional[str]): The name of the mean function. Defaults to "Mean function". - """ - super().__init__(output_dim, name) - - def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: + def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]: """Evaluate the mean function at the given points. Args: - params (Dict): The parameters of the mean function. - x (Float[Array, "N D"]): The input points at which to evaluate the mean function. + x (Float[Array, "D"]): The point at which to evaluate the mean function. Returns: - Float[Array, "N Q"]: A vector of repeated constant values. + Float[Array, "1"]: The evaluated mean function. """ - out_shape = (x.shape[0], self.output_dim) - return jnp.ones(shape=out_shape) * params["constant"] + return jnp.ones((x.shape[0], 1)) * self.constant + - def init_params(self, key: KeyArray) -> Dict: - """The parameters of the mean function. For the constant-mean function, this is a dictionary with a single value. +@dataclasses.dataclass +class CombinationMeanFunction(AbstractMeanFunction): + """A base class for products or sums of AbstractMeanFunctions.""" + items: List[AbstractMeanFunction] + operator: Callable = static_field() + + def __init__( + self, + items: List[AbstractMeanFunction], + operator: Callable, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + #Add items to a list, flattening out instances of this class therein, as in GPFlow kernels. + items_list: List[AbstractMeanFunction] = [] + + for item in items: + if not isinstance(item, AbstractMeanFunction): + raise TypeError("can only combine AbstractMeanFunction instances") # pragma: no cover + + if isinstance(item, self.__class__): + items_list.extend(item.items) + else: + items_list.append(item) + + self.items = items_list + self.operator = operator + + def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]: + """Evaluate combination kernel on a pair of inputs. Args: - key (KeyArray): The PRNG key to use for initialising the parameters. + x (Float[Array, "D"]): The point at which to evaluate the mean function. Returns: - Dict: The parameters of the mean function. + Float[Array, "Q"]: The evaluated mean function. """ - return {"constant": jnp.array([1.0])} - + return self.operator(jnp.stack([m(x) for m in self.items])) + -__all__ = [ - "AbstractMeanFunction", - "Zero", - "Constant", -] +SumMeanFunction = partial(CombinationMeanFunction, operator=partial(jnp.sum, axis=0)) +ProductMeanFunction = partial(CombinationMeanFunction, operator=partial(jnp.sum, axis=0)) +Zero = partial(Constant, constant=jnp.array([0.0])) \ No newline at end of file diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index f25ac340d..ad79e5e8a 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -14,7 +14,7 @@ # ============================================================================== import abc -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict import deprecation import distrax as dx @@ -22,7 +22,15 @@ import jax.scipy as jsp from jax.random import KeyArray from jaxtyping import Array, Float -from jaxutils import Dataset, PyTree + +from .linops import identity +from jaxutils import Dataset +from .linops import ( + DenseLinearOperator, + LowerTriangularLinearOperator +) +from mytree import Mytree, param_field +from simple_pytree import static_field from .config import get_global_config from .gaussian_distribution import GaussianDistribution @@ -30,9 +38,14 @@ from .likelihoods import AbstractLikelihood, Gaussian from .linops import DenseLinearOperator, LowerTriangularLinearOperator, identity from .utils import concat_dictionaries +from .gaussian_distribution import GaussianDistribution + +from dataclasses import dataclass +import tensorflow_probability.substrates.jax.bijectors as tfb -class AbstractVariationalFamily(PyTree): +@dataclass +class AbstractVariationalFamily(Mytree): """ Abstract base class used to represent families of distributions that can be used within variational inference. @@ -52,28 +65,6 @@ def __call__(self, *args: Any, **kwargs: Any) -> GaussianDistribution: """ return self.predict(*args, **kwargs) - @abc.abstractmethod - def init_params(self, key: KeyArray) -> Dict: - """ - The parameters of the distribution. For example, the multivariate - Gaussian would return a mean vector and covariance matrix. - - Args: - key (KeyArray): The PRNG key used to initialise the parameters. - - Returns: - Dict: The parameters of the distribution. - """ - raise NotImplementedError - - @deprecation.deprecated( - deprecated_in="0.5.7", - removed_in="0.6.0", - details="Use the ``init_params`` method for parameter initialisation.", - ) - def _initialise_params(self, key: KeyArray) -> Dict: - """Deprecated method for initialising the GP's parameters. Succeeded by ``init_params``.""" - return self.init_params(key) @abc.abstractmethod def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution: @@ -91,27 +82,19 @@ def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution: raise NotImplementedError +@dataclass class AbstractVariationalGaussian(AbstractVariationalFamily): """The variational Gaussian family of probability distributions.""" + prior: Prior + inducing_inputs: Float[Array, "N D"] - def __init__( - self, - prior: Prior, - inducing_inputs: Float[Array, "N D"], - name: Optional[str] = "Variational Gaussian", - ) -> None: - """ - Args: - prior (Prior): The prior distribution. - inducing_inputs (Float[Array, "N D"]): The inducing inputs. - name (Optional[str]): The name of the variational family. Defaults to "Gaussian". - """ - self.prior = prior - self.inducing_inputs = inducing_inputs - self.num_inducing = self.inducing_inputs.shape[0] - self.name = name + @property + def num_inducing(self) -> int: + """The number of inducing inputs.""" + return self.inducing_inputs.shape[0] +@dataclass class VariationalGaussian(AbstractVariationalGaussian): """The variational Gaussian family of probability distributions. @@ -121,35 +104,19 @@ class VariationalGaussian(AbstractVariationalGaussian): :math:`q(u) = \\mathcal{N}(\\mu, S)`. We parameterise this over :math:`\\mu` and sqrt with S = sqrt sqrtᵀ. """ + variational_mean: Float[Array, "N 1"] = param_field(None) + variational_root_covariance: Float[Array, "N N"] = param_field(None, bijector=tfb.FillScaleTriL(diag_shift=jnp.array(1e-6))) + jitter: Float[Array, "1"] = static_field(1e-6) - def init_params(self, key: KeyArray) -> Dict: - """ - Return the variational mean vector, variational root covariance matrix, - and inducing input vector that parameterise the variational Gaussian - distribution. - - Args: - key (KeyArray): The PRNG key used to initialise the parameters. + def __post_init__(self) -> None: + if self.variational_mean is None: + self.variational_mean = jnp.zeros((self.num_inducing, 1)) - Returns: - Dict: The parameters of the distribution. - """ - m = self.num_inducing + if self.variational_root_covariance is None: + self.variational_root_covariance = jnp.eye(self.num_inducing) - return concat_dictionaries( - self.prior.init_params(key), - { - "variational_family": { - "inducing_inputs": self.inducing_inputs, - "moments": { - "variational_mean": jnp.zeros((m, 1)), - "variational_root_covariance": jnp.eye(m), - }, - } - }, - ) - def prior_kl(self, params: Dict) -> Float[Array, "1"]: + def prior_kl(self) -> Float[Array, "1"]: """ Compute the KL-divergence between our variational approximation and the Gaussian process prior. @@ -158,10 +125,6 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: = KL[ N(μ, S) || N(μz, Kzz) ], where u = f(z) and z are the inducing inputs. - Args: - params (Dict): The parameters at which our variational distribution - and GP prior are to be evaluated. - Returns: Float[Array, "1"]: The KL-divergence between our variational approximation and the GP prior. @@ -170,17 +133,17 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: jitter = get_global_config()["jitter"] # Unpack variational parameters - mu = params["variational_family"]["moments"]["variational_mean"] - sqrt = params["variational_family"]["moments"]["variational_root_covariance"] - z = params["variational_family"]["inducing_inputs"] + mu = self.variational_mean + sqrt = self.variational_root_covariance + z = self.inducing_inputs m = self.num_inducing # Unpack mean function and kernel mean_function = self.prior.mean_function kernel = self.prior.kernel - μz = mean_function(params["mean_function"], z) - Kzz = kernel.gram(params["kernel"], z) + μz = mean_function(z) + Kzz = kernel.gram(z) Kzz += identity(m) * jitter sqrt = LowerTriangularLinearOperator.from_dense(sqrt) @@ -191,9 +154,7 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: return qu.kl_divergence(pu) - def predict( - self, params: Dict - ) -> Callable[[Float[Array, "N D"]], GaussianDistribution]: + def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: """ Compute the predictive distribution of the GP at the test inputs t. @@ -203,67 +164,61 @@ def predict( N[f(t); μt + Ktz Kzz⁻¹ (μ - μz), Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt ]. Args: - params (Dict): The set of parameters that are to be used to - parameterise our variational approximation and GP. + test_inputs (Float[Array, "N D"]): The test inputs at which we wish to make a prediction. Returns: - Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: A - function that accepts a set of test points and will return - the predictive distribution at those points. + GaussianDistribution: The predictive distribution of the low-rank GP at the test inputs. """ jitter = get_global_config()["jitter"] # Unpack variational parameters - mu = params["variational_family"]["moments"]["variational_mean"] - sqrt = params["variational_family"]["moments"]["variational_root_covariance"] - z = params["variational_family"]["inducing_inputs"] + mu = self.variational_mean + sqrt = self.variational_root_covariance + z = self.inducing_inputs m = self.num_inducing # Unpack mean function and kernel mean_function = self.prior.mean_function kernel = self.prior.kernel - Kzz = kernel.gram(params["kernel"], z) + Kzz = kernel.gram(z) Kzz += identity(m) * jitter Lz = Kzz.to_root() - μz = mean_function(params["mean_function"], z) + μz = mean_function(z) - def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: + # Unpack test inputs + t, n_test = test_inputs, test_inputs.shape[0] - # Unpack test inputs - t, n_test = test_inputs, test_inputs.shape[0] + Ktt = kernel.gram(t) + Kzt = kernel.cross_covariance(z, t) + μt = mean_function(t) - Ktt = kernel.gram(params["kernel"], t) - Kzt = kernel.cross_covariance(params["kernel"], z, t) - μt = mean_function(params["mean_function"], t) + # Lz⁻¹ Kzt + Lz_inv_Kzt = Lz.solve(Kzt) - # Lz⁻¹ Kzt - Lz_inv_Kzt = Lz.solve(Kzt) + # Kzz⁻¹ Kzt + Kzz_inv_Kzt = Lz.T.solve(Lz_inv_Kzt) - # Kzz⁻¹ Kzt - Kzz_inv_Kzt = Lz.T.solve(Lz_inv_Kzt) + # Ktz Kzz⁻¹ sqrt + Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T, sqrt) - # Ktz Kzz⁻¹ sqrt - Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T, sqrt) + # μt + Ktz Kzz⁻¹ (μ - μz) + mean = μt + jnp.matmul(Kzz_inv_Kzt.T, mu - μz) - # μt + Ktz Kzz⁻¹ (μ - μz) - mean = μt + jnp.matmul(Kzz_inv_Kzt.T, mu - μz) - - # Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt [recall S = sqrt sqrtᵀ] - covariance = ( - Ktt - - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) - + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T) - ) - covariance += identity(n_test) * jitter - - return GaussianDistribution( - loc=jnp.atleast_1d(mean.squeeze()), scale=covariance - ) + # Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt [recall S = sqrt sqrtᵀ] + covariance = ( + Ktt + - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T) + ) + covariance += identity(n_test) * jitter - return predict_fn + return GaussianDistribution( + loc=jnp.atleast_1d(mean.squeeze()), scale=covariance + ) +@dataclass class WhitenedVariationalGaussian(VariationalGaussian): """ The whitened variational Gaussian family of probability distributions. @@ -272,43 +227,22 @@ class WhitenedVariationalGaussian(VariationalGaussian): are the function values at the inducing inputs z and the distribution over the inducing inputs is q(u) = N(Lz μ + mz, Lz S Lzᵀ). We parameterise this over μ and sqrt with S = sqrt sqrtᵀ. - """ - def __init__( - self, - prior: Prior, - inducing_inputs: Float[Array, "N D"], - name: Optional[str] = "Whitened variational Gaussian", - ) -> None: - """Initialise the whitened variational Gaussian family. - - Args: - prior (Prior): The GP prior. - inducing_inputs (Float[Array, "N D"]): The inducing inputs. - name (Optional[str]): The name of the variational family. - """ - - super().__init__(prior, inducing_inputs, name) - - def prior_kl(self, params: Dict) -> Float[Array, "1"]: + def prior_kl(self) -> Float[Array, "1"]: """Compute the KL-divergence between our variational approximation and the Gaussian process prior. For this variational family, we have KL[q(f(·))||p(·)] = KL[q(u)||p(u)] = KL[N(μ, S)||N(0, I)]. - Args: - params (Dict): The parameters at which our variational distribution - and GP prior are to be evaluated. - Returns: Float[Array, "1"]: The KL-divergence between our variational approximation and the GP prior. """ # Unpack variational parameters - mu = params["variational_family"]["moments"]["variational_mean"] - sqrt = params["variational_family"]["moments"]["variational_root_covariance"] + mu = self.variational_mean + sqrt = self.variational_root_covariance sqrt = LowerTriangularLinearOperator.from_dense(sqrt) S = DenseLinearOperator.from_root(sqrt) @@ -318,9 +252,7 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: pu = GaussianDistribution(loc=jnp.zeros_like(jnp.atleast_1d(mu.squeeze()))) return qu.kl_divergence(pu) - def predict( - self, params: Dict - ) -> Callable[[Float[Array, "N D"]], GaussianDistribution]: + def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: """Compute the predictive distribution of the GP at the test inputs t. This is the integral q(f(t)) = ∫ p(f(t)|u) q(u) du, which can be computed in closed form as @@ -328,60 +260,56 @@ def predict( N[f(t); μt + Ktz Lz⁻ᵀ μ, Ktt - Ktz Kzz⁻¹ Kzt + Ktz Lz⁻ᵀ S Lz⁻¹ Kzt]. Args: - params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. + test_inputs (Float[Array, "N D"]): The test inputs at which we wish to make a prediction. Returns: - Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: A function that accepts a set of test points and will return the predictive distribution at those points. + GaussianDistribution: The predictive distribution of the low-rank GP at the test inputs. """ jitter = get_global_config()["jitter"] # Unpack variational parameters - mu = params["variational_family"]["moments"]["variational_mean"] - sqrt = params["variational_family"]["moments"]["variational_root_covariance"] - z = params["variational_family"]["inducing_inputs"] + mu = self.variational_mean + sqrt = self.variational_root_covariance + z = self.inducing_inputs m = self.num_inducing # Unpack mean function and kernel mean_function = self.prior.mean_function kernel = self.prior.kernel - Kzz = kernel.gram(params["kernel"], z) + Kzz = kernel.gram(z) Kzz += identity(m) * jitter Lz = Kzz.to_root() - def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: + # Unpack test inputs + t, n_test = test_inputs, test_inputs.shape[0] - # Unpack test inputs - t, n_test = test_inputs, test_inputs.shape[0] + Ktt = kernel.gram(t) + Kzt = kernel.cross_covariance(z, t) + μt = mean_function(t) - Ktt = kernel.gram(params["kernel"], t) - Kzt = kernel.cross_covariance(params["kernel"], z, t) - μt = mean_function(params["mean_function"], t) + # Lz⁻¹ Kzt + Lz_inv_Kzt = Lz.solve(Kzt) - # Lz⁻¹ Kzt - Lz_inv_Kzt = Lz.solve(Kzt) - - # Ktz Lz⁻ᵀ sqrt - Ktz_Lz_invT_sqrt = jnp.matmul(Lz_inv_Kzt.T, sqrt) + # Ktz Lz⁻ᵀ sqrt + Ktz_Lz_invT_sqrt = jnp.matmul(Lz_inv_Kzt.T, sqrt) - # μt + Ktz Lz⁻ᵀ μ - mean = μt + jnp.matmul(Lz_inv_Kzt.T, mu) + # μt + Ktz Lz⁻ᵀ μ + mean = μt + jnp.matmul(Lz_inv_Kzt.T, mu) - # Ktt - Ktz Kzz⁻¹ Kzt + Ktz Lz⁻ᵀ S Lz⁻¹ Kzt [recall S = sqrt sqrtᵀ] - covariance = ( - Ktt - - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) - + jnp.matmul(Ktz_Lz_invT_sqrt, Ktz_Lz_invT_sqrt.T) - ) - covariance += identity(n_test) * jitter - - return GaussianDistribution( - loc=jnp.atleast_1d(mean.squeeze()), scale=covariance - ) - - return predict_fn + # Ktt - Ktz Kzz⁻¹ Kzt + Ktz Lz⁻ᵀ S Lz⁻¹ Kzt [recall S = sqrt sqrtᵀ] + covariance = ( + Ktt + - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + + jnp.matmul(Ktz_Lz_invT_sqrt, Ktz_Lz_invT_sqrt.T) + ) + covariance += identity(n_test) * jitter + return GaussianDistribution( + loc=jnp.atleast_1d(mean.squeeze()), scale=covariance + ) +@dataclass class NaturalVariationalGaussian(AbstractVariationalGaussian): """The natural variational Gaussian family of probability distributions. @@ -390,60 +318,33 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian): exponential family, q(u) = exp(θᵀ T(u) - a(θ)), gives rise to the natural parameterisation θ = (θ₁, θ₂) = (S⁻¹μ, -S⁻¹/2), to perform model inference, where T(u) = [u, uuᵀ] are the sufficient statistics. """ + natural_vector: Float[Array, "M 1"] = None + natural_matrix: Float[Array, "M M"] = None - def __init__( - self, - prior: Prior, - inducing_inputs: Float[Array, "N D"], - name: Optional[str] = "Natural variational Gaussian", - ) -> None: - """Initialise the natural variational Gaussian family. - - Args: - prior (Prior): The GP prior. - inducing_inputs (Float[Array, "N D"]): The inducing inputs. - name (Optional[str]): The name of the variational family. - """ - - super().__init__(prior, inducing_inputs, name) + def __post_init__(self): + if self.natural_vector is None: + self.natural_vector = jnp.zeros((self.num_inducing, 1)) + + if self.natural_matrix is None: + self.natural_matrix = -0.5 * jnp.eye(self.num_inducing) - def init_params(self, key: KeyArray) -> Dict: - """Return the natural vector and matrix, inducing inputs, and hyperparameters that parameterise the natural Gaussian distribution.""" - m = self.num_inducing - - return concat_dictionaries( - self.prior.init_params(key), - { - "variational_family": { - "inducing_inputs": self.inducing_inputs, - "moments": { - "natural_vector": jnp.zeros((m, 1)), - "natural_matrix": -0.5 * jnp.eye(m), - }, - } - }, - ) - - def prior_kl(self, params: Dict) -> Float[Array, "1"]: + def prior_kl(self) -> Float[Array, "1"]: """Compute the KL-divergence between our current variational approximation and the Gaussian process prior. For this variational family, we have KL[q(f(·))||p(·)] = KL[q(u)||p(u)] = KL[N(μ, S)||N(mz, Kzz)], with μ and S computed from the natural parameterisation θ = (S⁻¹μ, -S⁻¹/2). - Args: - params (Dict): The parameters at which our variational distribution and GP prior are to be evaluated. - Returns: Float[Array, "1"]: The KL-divergence between our variational approximation and the GP prior. """ jitter = get_global_config()["jitter"] # Unpack variational parameters - natural_vector = params["variational_family"]["moments"]["natural_vector"] - natural_matrix = params["variational_family"]["moments"]["natural_matrix"] - z = params["variational_family"]["inducing_inputs"] + natural_vector = self.natural_vector + natural_matrix = self.natural_matrix + z = self.inducing_inputs m = self.num_inducing # Unpack mean function and kernel @@ -469,8 +370,8 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: # μ = Sθ₁ mu = S @ natural_vector - μz = mean_function(params["mean_function"], z) - Kzz = kernel.gram(params["kernel"], z) + μz = mean_function(z) + Kzz = kernel.gram(z) Kzz += identity(m) * jitter qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S) @@ -478,9 +379,7 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: return qu.kl_divergence(pu) - def predict( - self, params: Dict - ) -> Callable[[Float[Array, "N D"]], GaussianDistribution]: + def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: """Compute the predictive distribution of the GP at the test inputs t. This is the integral q(f(t)) = ∫ p(f(t)|u) q(u) du, which can be computed in closed form as @@ -489,18 +388,15 @@ def predict( with μ and S computed from the natural parameterisation θ = (S⁻¹μ, -S⁻¹/2). - Args: - params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. - Returns: - Callable[[Float[Array, "N D"]], GaussianDistribution]: A function that accepts a set of test points and will return the predictive distribution at those points. + GaussianDistribution: A function that accepts a set of test points and will return the predictive distribution at those points. """ jitter = get_global_config()["jitter"] # Unpack variational parameters - natural_vector = params["variational_family"]["moments"]["natural_vector"] - natural_matrix = params["variational_family"]["moments"]["natural_matrix"] - z = params["variational_family"]["inducing_inputs"] + natural_vector = self.natural_vector + natural_matrix = self.natural_matrix + z = self.inducing_inputs m = self.num_inducing # Unpack mean function and kernel @@ -525,47 +421,44 @@ def predict( # μ = Sθ₁ mu = jnp.matmul(S, natural_vector) - Kzz = kernel.gram(params["kernel"], z) + Kzz = kernel.gram(z) Kzz += identity(m) * jitter Lz = Kzz.to_root() - μz = mean_function(params["mean_function"], z) + μz = mean_function(z) - def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormalTri: + # Unpack test inputs + t, n_test = test_inputs, test_inputs.shape[0] - # Unpack test inputs - t, n_test = test_inputs, test_inputs.shape[0] + Ktt = kernel.gram(t) + Kzt = kernel.cross_covariance(z, t) + μt = mean_function(t) - Ktt = kernel.gram(params["kernel"], t) - Kzt = kernel.cross_covariance(params["kernel"], z, t) - μt = mean_function(params["mean_function"], t) + # Lz⁻¹ Kzt + Lz_inv_Kzt = Lz.solve(Kzt) - # Lz⁻¹ Kzt - Lz_inv_Kzt = Lz.solve(Kzt) + # Kzz⁻¹ Kzt + Kzz_inv_Kzt = Lz.T.solve(Lz_inv_Kzt) - # Kzz⁻¹ Kzt - Kzz_inv_Kzt = Lz.T.solve(Lz_inv_Kzt) + # Ktz Kzz⁻¹ L + Ktz_Kzz_inv_L = jnp.matmul(Kzz_inv_Kzt.T, sqrt) - # Ktz Kzz⁻¹ L - Ktz_Kzz_inv_L = jnp.matmul(Kzz_inv_Kzt.T, sqrt) + # μt + Ktz Kzz⁻¹ (μ - μz) + mean = μt + jnp.matmul(Kzz_inv_Kzt.T, mu - μz) - # μt + Ktz Kzz⁻¹ (μ - μz) - mean = μt + jnp.matmul(Kzz_inv_Kzt.T, mu - μz) - - # Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt [recall S = LLᵀ] - covariance = ( - Ktt - - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) - + jnp.matmul(Ktz_Kzz_inv_L, Ktz_Kzz_inv_L.T) - ) - covariance += identity(n_test) * jitter - - return GaussianDistribution( - loc=jnp.atleast_1d(mean.squeeze()), scale=covariance - ) + # Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt [recall S = LLᵀ] + covariance = ( + Ktt + - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + + jnp.matmul(Ktz_Kzz_inv_L, Ktz_Kzz_inv_L.T) + ) + covariance += identity(n_test) * jitter - return predict_fn + return GaussianDistribution( + loc=jnp.atleast_1d(mean.squeeze()), scale=covariance + ) +@dataclass class ExpectationVariationalGaussian(AbstractVariationalGaussian): """The natural variational Gaussian family of probability distributions. @@ -575,44 +468,17 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian): sufficient statistics T(u) = [u, uuᵀ]. The expectation parameters are given by η = ∫ T(u) q(u) du. This gives a parameterisation, η = (η₁, η₁) = (μ, S + uuᵀ) to perform model inference over. """ + expectation_vector: Float[Array, "M 1"] = None + expectation_matrix: Float[Array, "M M"] = None - def __init__( - self, - prior: Prior, - inducing_inputs: Float[Array, "N D"], - name: Optional[str] = "Expectation variational Gaussian", - ) -> None: - """Initialise the expectation variational Gaussian family. - - Args: - prior (Prior): The GP prior. - inducing_inputs (Float[Array, "N D"]): The inducing inputs. - name (Optional[str]): The name of the variational family. - """ - - super().__init__(prior, inducing_inputs, name) + def __post_init__(self): + if self.expectation_vector is None: + self.expectation_vector = jnp.zeros((self.num_inducing, 1)) + if self.expectation_matrix is None: + self.expectation_matrix = jnp.eye(self.num_inducing) - def init_params(self, key: KeyArray) -> Dict: - """Return the expectation vector and matrix, inducing inputs, and hyperparameters that parameterise the expectation Gaussian distribution.""" - - self.num_inducing = self.inducing_inputs.shape[0] - m = self.num_inducing - - return concat_dictionaries( - self.prior.init_params(key), - { - "variational_family": { - "inducing_inputs": self.inducing_inputs, - "moments": { - "expectation_vector": jnp.zeros((m, 1)), - "expectation_matrix": jnp.eye(m), - }, - } - }, - ) - - def prior_kl(self, params: Dict) -> Float[Array, "1"]: + def prior_kl(self) -> Float[Array, "1"]: """Compute the KL-divergence between our current variational approximation and the Gaussian process prior. For this variational family, we have KL[q(f(·))||p(·)] = KL[q(u)||p(u)] = KL[N(μ, S)||N(mz, Kzz)], @@ -628,13 +494,9 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: jitter = get_global_config()["jitter"] # Unpack variational parameters - expectation_vector = params["variational_family"]["moments"][ - "expectation_vector" - ] - expectation_matrix = params["variational_family"]["moments"][ - "expectation_matrix" - ] - z = params["variational_family"]["inducing_inputs"] + expectation_vector = self.expectation_vector + expectation_matrix = self.expectation_matrix + z = self.inducing_inputs m = self.num_inducing # Unpack mean function and kernel @@ -649,8 +511,8 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: S = DenseLinearOperator(S) S += identity(m) * jitter - μz = mean_function(params["mean_function"], z) - Kzz = kernel.gram(params["kernel"], z) + μz = mean_function(z) + Kzz = kernel.gram(z) Kzz += identity(m) * jitter qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S) @@ -658,9 +520,7 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: return qu.kl_divergence(pu) - def predict( - self, params: Dict - ) -> Callable[[Float[Array, "N D"]], GaussianDistribution]: + def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: """Compute the predictive distribution of the GP at the test inputs t. This is the integral q(f(t)) = ∫ p(f(t)|u) q(u) du, which can be computed in closed form as @@ -669,22 +529,15 @@ def predict( with μ and S computed from the expectation parameterisation η = (μ, S + uuᵀ). - Args: - params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. - Returns: - Callable[[Float[Array, "N D"]], GaussianDistribution]: A function that accepts a set of test points and will return the predictive distribution at those points. + GaussianDistribution: The predictive distribution of the GP at the test inputs t. """ jitter = get_global_config()["jitter"] # Unpack variational parameters - expectation_vector = params["variational_family"]["moments"][ - "expectation_vector" - ] - expectation_matrix = params["variational_family"]["moments"][ - "expectation_matrix" - ] - z = params["variational_family"]["inducing_inputs"] + expectation_vector = self.expectation_vector + expectation_matrix = self.expectation_matrix + z = self.inducing_inputs m = self.num_inducing # Unpack mean function and kernel @@ -702,45 +555,41 @@ def predict( # S = sqrt sqrtᵀ sqrt = S.to_root().to_dense() - Kzz = kernel.gram(params["kernel"], z) + Kzz = kernel.gram(z) Kzz += identity(m) * jitter Lz = Kzz.to_root() - μz = mean_function(params["mean_function"], z) - - def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: + μz = mean_function(z) - # Unpack test inputs - t, n_test = test_inputs, test_inputs.shape[0] + # Unpack test inputs + t, n_test = test_inputs, test_inputs.shape[0] - Ktt = kernel.gram(params["kernel"], t) - Kzt = kernel.cross_covariance(params["kernel"], z, t) - μt = mean_function(params["mean_function"], t) + Ktt = kernel.gram(t) + Kzt = kernel.cross_covariance(z, t) + μt = mean_function(t) - # Lz⁻¹ Kzt - Lz_inv_Kzt = Lz.solve(Kzt) + # Lz⁻¹ Kzt + Lz_inv_Kzt = Lz.solve(Kzt) - # Kzz⁻¹ Kzt - Kzz_inv_Kzt = Lz.T.solve(Lz_inv_Kzt) + # Kzz⁻¹ Kzt + Kzz_inv_Kzt = Lz.T.solve(Lz_inv_Kzt) - # Ktz Kzz⁻¹ sqrt - Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T, sqrt) + # Ktz Kzz⁻¹ sqrt + Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T, sqrt) - # μt + Ktz Kzz⁻¹ (μ - μz) - mean = μt + jnp.matmul(Kzz_inv_Kzt.T, mu - μz) + # μt + Ktz Kzz⁻¹ (μ - μz) + mean = μt + jnp.matmul(Kzz_inv_Kzt.T, mu - μz) - # Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt [recall S = sqrt sqrtᵀ] - covariance = ( - Ktt - - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) - + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T) - ) - covariance += identity(n_test) * jitter - - return GaussianDistribution( - loc=jnp.atleast_1d(mean.squeeze()), scale=covariance - ) + # Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt [recall S = sqrt sqrtᵀ] + covariance = ( + Ktt + - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T) + ) + covariance += identity(n_test) * jitter - return predict_fn + return GaussianDistribution( + loc=jnp.atleast_1d(mean.squeeze()), scale=covariance + ) class CollapsedVariationalGaussian(AbstractVariationalFamily): diff --git a/tests/test_mean_functions.py b/tests/test_mean_functions.py index 9d97d82ee..d677b7840 100644 --- a/tests/test_mean_functions.py +++ b/tests/test_mean_functions.py @@ -1,69 +1,33 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from typing import Dict - -import jax.numpy as jnp -import jax.random as jr import pytest -from jax.config import config -from jax.random import KeyArray -from jaxtyping import Array, Float - -from gpjax.mean_functions import AbstractMeanFunction, Constant, Zero - -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) -_initialise_key = jr.PRNGKey(123) +import jax +import mytree +from gpjax.mean_functions import AbstractMeanFunction, Constant +from jaxtyping import Array, Float -def test_abstract_mean_function() -> None: - # Test that the abstract mean function cannot be instantiated. +def test_abstract() -> None: + # Check abstract mean function cannot be instantiated, as the `__call__` method is not defined. with pytest.raises(TypeError): AbstractMeanFunction() - # Create a dummy mean funcion class with abstract methods implemented. + # Check a "dummy" mean funcion with defined abstract method, `__call__`, can be instantiated. class DummyMeanFunction(AbstractMeanFunction): - def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N 1"]: - return jnp.ones((x.shape[0], 1)) - - def init_params(self, key: KeyArray) -> Dict: - return {} - - # Test that the dummy mean function can be instantiated. - dummy_mean_function = DummyMeanFunction() - assert isinstance(dummy_mean_function, AbstractMeanFunction) - - -@pytest.mark.parametrize("mean_function", [Zero, Constant]) -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("n", [1, 2]) -def test_shape(mean_function: AbstractMeanFunction, n: int, dim: int) -> None: - key = _initialise_key + def __call__(self, x: Float[Array, "D"]) -> Float[Array, "1"]: + return jax.numpy.array([1.0]) - # Create test inputs. - x = jnp.linspace(-1.0, 1.0, num=n * dim).reshape(n, dim) + mf = DummyMeanFunction() + assert isinstance(mf, mytree.Mytree) + assert isinstance(mf, AbstractMeanFunction) + assert (mf(jax.numpy.array([1.0])) == jax.numpy.array([1.0])).all() + assert (mf(jax.numpy.array([2.0, 3.0])) == jax.numpy.array([1.0])).all() - # Initialise mean function. - mf = mean_function(output_dim=dim) - # Initialise parameters. - params = mf.init_params(key) - assert isinstance(params, dict) +@pytest.mark.parametrize("constant", [jax.numpy.array([0.0]), jax.numpy.array([1.0]), jax.numpy.array([3.0])]) +def test_constant(constant: Float[Array, "Q"]) -> None: + mf = Constant(constant = constant) - # Test shape of mean function. - mu = mf(params, x) - assert mu.shape[0] == x.shape[0] - assert mu.shape[1] == dim + assert isinstance(mf, AbstractMeanFunction) + assert (mf(jax.numpy.array([1.0])) == constant).all() + assert (mf(jax.numpy.array([2.0, 3.0])) == constant).all() + assert (jax.vmap(mf)(jax.numpy.array([[1.0], [2.0]])) == jax.numpy.array([constant, constant])).all() + assert (jax.vmap(mf)(jax.numpy.array([[1.0, 2.0], [3.0, 4.0]])) == jax.numpy.array([constant, constant])).all() \ No newline at end of file diff --git a/tests/test_variational_families.py b/tests/test_variational_families.py index 47ea0cac8..b11835b33 100644 --- a/tests/test_variational_families.py +++ b/tests/test_variational_families.py @@ -26,7 +26,7 @@ import gpjax as gpx from gpjax.variational_families import ( AbstractVariationalFamily, - CollapsedVariationalGaussian, + #CollapsedVariationalGaussian, ExpectationVariationalGaussian, NaturalVariationalGaussian, VariationalGaussian, @@ -47,9 +47,6 @@ class DummyVariationalFamily(AbstractVariationalFamily): def predict(self, params: Dict, x: Float[Array, "N D"]) -> dx.Distribution: return dx.MultivariateNormalDiag(loc=x) - def init_params(self, key: jr.PRNGKey) -> dict: - return {} - # Test that the dummy variational family can be instantiated. dummy_variational_family = DummyVariationalFamily() assert isinstance(dummy_variational_family, AbstractVariationalFamily) @@ -89,42 +86,11 @@ def diag_matrix_fn(n_inducing: int) -> Float[Array, "n_inducing n_inducing"]: @pytest.mark.parametrize("n_test", [1, 10]) @pytest.mark.parametrize("n_inducing", [1, 10, 20]) -@pytest.mark.parametrize( - "variational_family, moment_names, shapes, values", - [ - ( - VariationalGaussian, - ["variational_mean", "variational_root_covariance"], - [vector_shape, matrix_shape], - [vector_val(0.0), diag_matrix_val(1.0)], - ), - ( - WhitenedVariationalGaussian, - ["variational_mean", "variational_root_covariance"], - [vector_shape, matrix_shape], - [vector_val(0.0), diag_matrix_val(1.0)], - ), - ( - NaturalVariationalGaussian, - ["natural_vector", "natural_matrix"], - [vector_shape, matrix_shape], - [vector_val(0.0), diag_matrix_val(-0.5)], - ), - ( - ExpectationVariationalGaussian, - ["expectation_vector", "expectation_matrix"], - [vector_shape, matrix_shape], - [vector_val(0.0), diag_matrix_val(1.0)], - ), - ], -) +@pytest.mark.parametrize("variational_family", [VariationalGaussian, WhitenedVariationalGaussian, NaturalVariationalGaussian, ExpectationVariationalGaussian]) def test_variational_gaussians( n_test: int, n_inducing: int, variational_family: AbstractVariationalFamily, - moment_names: Tuple[str, str], - shapes: Tuple, - values: Tuple, ) -> None: # Initialise variational family: @@ -137,39 +103,40 @@ def test_variational_gaussians( assert q.num_inducing == n_inducing assert isinstance(q, AbstractVariationalFamily) - # Test params and keys: - params = q.init_params(jr.PRNGKey(123)) - assert isinstance(params, dict) - - config_params = gpx.config.get_global_config() - - # Test inducing induput parameters: - assert "inducing_inputs" in params["variational_family"].keys() - assert "inducing_inputs" in config_params["transformations"].keys() - - for moment_name, shape, value in zip(moment_names, shapes, values): - - moment_params = params["variational_family"]["moments"] - - assert moment_name in moment_params.keys() - assert moment_name in config_params["transformations"].keys() - - # Test moment shape and values: - moment = moment_params[moment_name] - assert isinstance(moment, jnp.ndarray) - assert moment.shape == shape(n_inducing) - assert (moment == value(n_inducing)).all() + if isinstance(q, VariationalGaussian): + assert q.variational_mean.shape == vector_shape(n_inducing) + assert q.variational_root_covariance.shape == matrix_shape(n_inducing) + assert (q.variational_mean == vector_val(0.0)(n_inducing)).all() + assert (q.variational_root_covariance == diag_matrix_val(1.0)(n_inducing)).all() + + elif isinstance(q, WhitenedVariationalGaussian): + assert q.variational_mean.shape == vector_shape(n_inducing) + assert q.variational_root_covariance.shape == matrix_shape(n_inducing) + assert (q.variational_mean == vector_val(0.0)(n_inducing)).all() + assert (q.variational_root_covariance == diag_matrix_val(1.0)(n_inducing)).all() + + elif isinstance(q, NaturalVariationalGaussian): + assert q.natural_vector.shape == vector_shape(n_inducing) + assert q.natural_matrix.shape == matrix_shape(n_inducing) + assert (q.natural_vector == vector_val(0.0)(n_inducing)).all() + assert (q.natural_matrix == diag_matrix_val(-0.5)(n_inducing)).all() + + elif isinstance(q, ExpectationVariationalGaussian): + assert q.expectation_vector.shape == vector_shape(n_inducing) + assert q.expectation_matrix.shape == matrix_shape(n_inducing) + assert (q.expectation_vector == vector_val(0.0)(n_inducing)).all() + assert (q.expectation_matrix == diag_matrix_val(1.0)(n_inducing)).all() + + # Test KL - params = q.init_params(jr.PRNGKey(123)) - kl = q.prior_kl(params) + kl = q.prior_kl() assert isinstance(kl, jnp.ndarray) + assert kl.shape == () + assert kl >= 0.0 # Test predictions - predictive_dist_fn = q(params) - assert isinstance(predictive_dist_fn, Callable) - - predictive_dist = predictive_dist_fn(test_inputs) + predictive_dist = q(test_inputs) assert isinstance(predictive_dist, dx.Distribution) mu = predictive_dist.mean() @@ -181,69 +148,69 @@ def test_variational_gaussians( assert sigma.shape == (n_test, n_test) -@pytest.mark.parametrize("n_test", [1, 10]) -@pytest.mark.parametrize("n_datapoints", [1, 10]) -@pytest.mark.parametrize("n_inducing", [1, 10, 20]) -@pytest.mark.parametrize("point_dim", [1, 2]) -def test_collapsed_variational_gaussian( - n_test: int, n_inducing: int, n_datapoints: int, point_dim: int -) -> None: - x = jnp.linspace(-5.0, 5.0, n_datapoints).reshape(-1, 1) - y = jnp.sin(x) + jr.normal(key=jr.PRNGKey(123), shape=x.shape) * 0.1 - x = jnp.hstack([x] * point_dim) - D = gpx.Dataset(X=x, y=y) - - prior = gpx.Prior(kernel=gpx.RBF()) - - inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1) - inducing_inputs = jnp.hstack([inducing_inputs] * point_dim) - test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1) - test_inputs = jnp.hstack([test_inputs] * point_dim) - - variational_family = CollapsedVariationalGaussian( - prior=prior, - likelihood=gpx.Gaussian(num_datapoints=D.n), - inducing_inputs=inducing_inputs, - ) - - # We should raise an error for non-Gaussian likelihoods: - with pytest.raises(TypeError): - CollapsedVariationalGaussian( - prior=prior, - likelihood=gpx.Bernoulli(num_datapoints=D.n), - inducing_inputs=inducing_inputs, - ) - - # Test init - assert variational_family.num_inducing == n_inducing - params = gpx.config.get_global_config() - assert "inducing_inputs" in params["transformations"].keys() - assert (variational_family.inducing_inputs == inducing_inputs).all() - - # Test params - params = variational_family.init_params(jr.PRNGKey(123)) - assert isinstance(params, dict) - assert "likelihood" in params.keys() - assert "obs_noise" in params["likelihood"].keys() - assert "inducing_inputs" in params["variational_family"].keys() - assert params["variational_family"]["inducing_inputs"].shape == ( - n_inducing, - point_dim, - ) - assert isinstance(params["variational_family"]["inducing_inputs"], jax.Array) - - # Test predictions - params = variational_family.init_params(jr.PRNGKey(123)) - predictive_dist_fn = variational_family(params, D) - assert isinstance(predictive_dist_fn, Callable) - - predictive_dist = predictive_dist_fn(test_inputs) - assert isinstance(predictive_dist, dx.Distribution) - - mu = predictive_dist.mean() - sigma = predictive_dist.covariance() - - assert isinstance(mu, jnp.ndarray) - assert isinstance(sigma, jnp.ndarray) - assert mu.shape == (n_test,) - assert sigma.shape == (n_test, n_test) +# @pytest.mark.parametrize("n_test", [1, 10]) +# @pytest.mark.parametrize("n_datapoints", [1, 10]) +# @pytest.mark.parametrize("n_inducing", [1, 10, 20]) +# @pytest.mark.parametrize("point_dim", [1, 2]) +# def test_collapsed_variational_gaussian( +# n_test: int, n_inducing: int, n_datapoints: int, point_dim: int +# ) -> None: +# x = jnp.linspace(-5.0, 5.0, n_datapoints).reshape(-1, 1) +# y = jnp.sin(x) + jr.normal(key=jr.PRNGKey(123), shape=x.shape) * 0.1 +# x = jnp.hstack([x] * point_dim) +# D = gpx.Dataset(X=x, y=y) + +# prior = gpx.Prior(kernel=gpx.RBF()) + +# inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1) +# inducing_inputs = jnp.hstack([inducing_inputs] * point_dim) +# test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1) +# test_inputs = jnp.hstack([test_inputs] * point_dim) + +# variational_family = CollapsedVariationalGaussian( +# prior=prior, +# likelihood=gpx.Gaussian(num_datapoints=D.n), +# inducing_inputs=inducing_inputs, +# ) + +# # We should raise an error for non-Gaussian likelihoods: +# with pytest.raises(TypeError): +# CollapsedVariationalGaussian( +# prior=prior, +# likelihood=gpx.Bernoulli(num_datapoints=D.n), +# inducing_inputs=inducing_inputs, +# ) + +# # Test init +# assert variational_family.num_inducing == n_inducing +# params = gpx.config.get_global_config() +# assert "inducing_inputs" in params["transformations"].keys() +# assert (variational_family.inducing_inputs == inducing_inputs).all() + +# # Test params +# params = variational_family.init_params(jr.PRNGKey(123)) +# assert isinstance(params, dict) +# assert "likelihood" in params.keys() +# assert "obs_noise" in params["likelihood"].keys() +# assert "inducing_inputs" in params["variational_family"].keys() +# assert params["variational_family"]["inducing_inputs"].shape == ( +# n_inducing, +# point_dim, +# ) +# assert isinstance(params["variational_family"]["inducing_inputs"], jax.Array) + +# # Test predictions +# params = variational_family.init_params(jr.PRNGKey(123)) +# predictive_dist_fn = variational_family(params, D) +# assert isinstance(predictive_dist_fn, Callable) + +# predictive_dist = predictive_dist_fn(test_inputs) +# assert isinstance(predictive_dist, dx.Distribution) + +# mu = predictive_dist.mean() +# sigma = predictive_dist.covariance() + +# assert isinstance(mu, jnp.ndarray) +# assert isinstance(sigma, jnp.ndarray) +# assert mu.shape == (n_test,) +# assert sigma.shape == (n_test, n_test) From 8051384bd288cc724b731fb0537e7074abc387ce Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Wed, 29 Mar 2023 13:33:00 +0100 Subject: [PATCH 17/44] Update likelihoods and refactor collapsed variational family --- gpjax/likelihoods.py | 116 ++++----------------- gpjax/variational_families.py | 157 +++++++++++------------------ tests/test_variational_families.py | 117 +++++++++------------ 3 files changed, 130 insertions(+), 260 deletions(-) diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index 5543c6de7..4a9c5895f 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -14,31 +14,23 @@ # ============================================================================== import abc -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict +from .linops.utils import to_dense import deprecation import distrax as dx import jax.numpy as jnp import jax.scipy as jsp -from jax.random import KeyArray from jaxtyping import Array, Float -from jaxutils import PyTree - -from .linops.utils import to_dense +from dataclasses import dataclass +from mytree import Mytree -class AbstractLikelihood(PyTree): +@dataclass +class AbstractLikelihood(Mytree): """Abstract base class for likelihoods.""" + num_datapoints: int - def __init__(self, num_datapoints: int, name: Optional[str] = None): - """Initialise the likelihood. - - Args: - num_datapoints (int): The number of datapoints that the likelihood factorises over. - name (Optional[str]): The name of the likelihood. Defaults to None. - """ - self.num_datapoints = num_datapoints - self.name = name def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: """Evaluate the likelihood function at a given predictive distribution. @@ -65,27 +57,6 @@ def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: """ raise NotImplementedError - @abc.abstractmethod - def init_params(self, key: KeyArray) -> Dict: - """Return the parameters of the likelihood function. - - Args: - key (KeyArray): A PRNG key. - - Returns: - Dict: The parameters of the likelihood function. - """ - raise NotImplementedError - - @deprecation.deprecated( - deprecated_in="0.5.7", - removed_in="0.6.0", - details="Use the ``init_params`` method for parameter initialisation.", - ) - def _initialise_params(self, key: KeyArray) -> Dict: - """Deprecated method for initialising the GP's parameters. Succeeded by ``init_params``.""" - return self.init_params(key) - @property @abc.abstractmethod def link_function(self) -> Callable: @@ -97,6 +68,8 @@ def link_function(self) -> Callable: raise NotImplementedError +# I don't think we need these? Either we inherit a conjugate likelihood for the Gaussian, or we check it's Gaussian in the posterior construction. +# I don't like multiple iheritance, and think its an annoying thing to have to remember to do, if the user want to add their own likelihoods! class Conjugate: """An abstract class for conjugate likelihoods with respect to a Gaussian process prior.""" @@ -105,31 +78,10 @@ class NonConjugate: """An abstract class for non-conjugate likelihoods with respect to a Gaussian process prior.""" -# TODO: revamp this with covariance operators. - - +@dataclass class Gaussian(AbstractLikelihood, Conjugate): """Gaussian likelihood object.""" - - def __init__(self, num_datapoints: int, name: Optional[str] = "Gaussian"): - """Initialise the Gaussian likelihood. - - Args: - num_datapoints (int): The number of datapoints that the likelihood factorises over. - name (Optional[str]): The name of the likelihood. Defaults to "Gaussian". - """ - super().__init__(num_datapoints, name) - - def init_params(self, key: KeyArray) -> Dict: - """Return the variance parameter of the likelihood function. - - Args: - key (KeyArray): A PRNG key. - - Returns: - Dict: The parameters of the likelihood function. - """ - return {"obs_noise": jnp.array([1.0])} + obs_noise: float = 1.0 @property def link_function(self) -> Callable[[Dict, Float[Array, "N 1"]], dx.Distribution]: @@ -141,7 +93,7 @@ def link_function(self) -> Callable[[Dict, Float[Array, "N 1"]], dx.Distribution function that maps the predictive distribution to the likelihood function. """ - def link_fn(params: Dict, f: Float[Array, "N 1"]) -> dx.Normal: + def link_fn(f: Float[Array, "N 1"]) -> dx.Normal: """The link function of the Gaussian likelihood. Args: @@ -151,11 +103,11 @@ def link_fn(params: Dict, f: Float[Array, "N 1"]) -> dx.Normal: Returns: dx.Normal: The likelihood function. """ - return dx.Normal(loc=f, scale=params["obs_noise"]) + return dx.Normal(loc=f, scale=self.obs_noise) return link_fn - def predict(self, params: Dict, dist: dx.MultivariateNormalTri) -> dx.Distribution: + def predict(self, dist: dx.MultivariateNormalTri) -> dx.Distribution: """ Evaluate the Gaussian likelihood function at a given predictive distribution. Computationally, this is equivalent to summing the @@ -172,48 +124,26 @@ def predict(self, params: Dict, dist: dx.MultivariateNormalTri) -> dx.Distributi """ n_data = dist.event_shape[0] cov = to_dense(dist.covariance()) - noisy_cov = cov.at[jnp.diag_indices(n_data)].add( - params["likelihood"]["obs_noise"] - ) + noisy_cov = cov.at[jnp.diag_indices(n_data)].add(self.obs_noise) return dx.MultivariateNormalFullCovariance(dist.mean(), noisy_cov) +@dataclass class Bernoulli(AbstractLikelihood, NonConjugate): - def __init__(self, num_datapoints: int, name: Optional[str] = "Bernoulli"): - """Initialise the Bernoulli likelihood. - - Args: - num_datapoints (int): The number of datapoints that the likelihood factorises over. - name (Optional[str]): The name of the likelihood. Defaults to "Bernoulli". - """ - super().__init__(num_datapoints, name) - - def init_params(self, key: KeyArray) -> Dict: - """Initialise the parameter set of a Bernoulli likelihood. - - Args: - key (KeyArray): A PRNG key. - - Returns: - Dict: The parameters of the likelihood function (empty for the Bernoulli likelihood). - """ - return {} @property - def link_function(self) -> Callable[[Dict, Float[Array, "N 1"]], dx.Distribution]: + def link_function(self) -> Callable[[Float[Array, "N 1"]], dx.Distribution]: """Return the probit link function of the Bernoulli likelihood. Returns: Callable[[Dict, Float[Array, "N 1"]], dx.Distribution]: A probit link function that maps the predictive distribution to the likelihood function. """ - - def link_fn(params: Dict, f: Float[Array, "N 1"]) -> dx.Distribution: + def link_fn(f: Float[Array, "N 1"]) -> dx.Distribution: """The probit link function of the Bernoulli likelihood. Args: - params (Dict): The parameters of the likelihood function. f (Float[Array, "N 1"]): Function values. Returns: @@ -226,7 +156,7 @@ def link_fn(params: Dict, f: Float[Array, "N 1"]) -> dx.Distribution: @property def predictive_moment_fn( self, - ) -> Callable[[Dict, Float[Array, "N 1"]], Float[Array, "N 1"]]: + ) -> Callable[[Float[Array, "N 1"]], Float[Array, "N 1"]]: """Instantiate the predictive moment function of the Bernoulli likelihood that is parameterised by a probit link function. @@ -236,26 +166,24 @@ def predictive_moment_fn( """ def moment_fn( - params: Dict, mean: Float[Array, "N 1"], variance: Float[Array, "N 1"], ): """The predictive moment function of the Bernoulli likelihood. Args: - params (Dict): The parameters of the likelihood function. mean (Float[Array, "N 1"]): The mean of the latent function values. variance (Float[Array, "N 1"]): The diagonal variance of the latent function values. Returns: Float[Array, "N 1"]: The pointwise predictive distribution. """ - rv = self.link_function(params, mean / jnp.sqrt(1.0 + variance)) + rv = self.link_function(mean / jnp.sqrt(1.0 + variance)) return rv return moment_fn - def predict(self, params: Dict, dist: dx.Distribution) -> dx.Distribution: + def predict(self, dist: dx.Distribution) -> dx.Distribution: """Evaluate the pointwise predictive distribution, given a Gaussian process posterior and likelihood parameters. @@ -269,7 +197,7 @@ def predict(self, params: Dict, dist: dx.Distribution) -> dx.Distribution: """ variance = jnp.diag(dist.covariance()) mean = dist.mean().ravel() - return self.predictive_moment_fn(params, mean, variance) + return self.predictive_moment_fn(mean, variance) def inv_probit(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index ad79e5e8a..7255e6cc9 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -591,137 +591,96 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: loc=jnp.atleast_1d(mean.squeeze()), scale=covariance ) - -class CollapsedVariationalGaussian(AbstractVariationalFamily): +@dataclass +class CollapsedVariationalGaussian(AbstractVariationalGaussian): """Collapsed variational Gaussian family of probability distributions. The key reference is Titsias, (2009) - Variational Learning of Inducing Variables in Sparse Gaussian Processes.""" + likelihood: AbstractLikelihood - def __init__( - self, - prior: Prior, - likelihood: AbstractLikelihood, - inducing_inputs: Float[Array, "M D"], - name: str = "Collapsed variational Gaussian", - ): - """Initialise the collapsed variational Gaussian family of probability distributions. - - Args: - prior (Prior): The prior distribution that we are approximating. - likelihood (AbstractLikelihood): The likelihood function that we are using to model the data. - inducing_inputs (Float[Array, "M D"]): The inducing inputs that are to be used to parameterise the variational Gaussian distribution. - name (str, optional): The name of the variational family. Defaults to "Collapsed variational Gaussian". - """ - - if not isinstance(likelihood, Gaussian): + def __post_init__(self): + if not isinstance(self.likelihood, Gaussian): raise TypeError("Likelihood must be Gaussian.") - self.prior = prior - self.likelihood = likelihood - self.inducing_inputs = inducing_inputs - self.num_inducing = self.inducing_inputs.shape[0] - self.name = name - - def init_params(self, key: KeyArray) -> Dict: - """Return the variational mean vector, variational root covariance matrix, and inducing input vector that parameterise the variational Gaussian distribution.""" - return concat_dictionaries( - self.prior.init_params(key), - { - "variational_family": {"inducing_inputs": self.inducing_inputs}, - "likelihood": { - "obs_noise": self.likelihood.init_params(key)["obs_noise"] - }, - }, - ) - - def predict( - self, - params: Dict, - train_data: Dataset, - ) -> Callable[[Float[Array, "N D"]], GaussianDistribution]: + def predict(self, test_inputs: Float[Array, "N D"], train_data: Dataset) -> GaussianDistribution: """Compute the predictive distribution of the GP at the test inputs. Args: - params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. train_data (Dataset): The training data that was used to fit the GP. Returns: - Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: A function that accepts a set of test points and will return the predictive distribution at those points. + GaussianDistribution: The predictive distribution of the collapsed variational Gaussian process at the test inputs t. """ jitter = get_global_config()["jitter"] - def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: - - # Unpack test inputs - t, n_test = test_inputs, test_inputs.shape[0] - - # Unpack training data - x, y = train_data.X, train_data.y + # Unpack test inputs + t, n_test = test_inputs, test_inputs.shape[0] - # Unpack variational parameters - noise = params["likelihood"]["obs_noise"] - z = params["variational_family"]["inducing_inputs"] - m = self.num_inducing + # Unpack training data + x, y = train_data.X, train_data.y - # Unpack mean function and kernel - mean_function = self.prior.mean_function - kernel = self.prior.kernel + # Unpack variational parameters + noise = self.likelihood.obs_noise + z = self.inducing_inputs + m = self.num_inducing - Kzx = kernel.cross_covariance(params["kernel"], z, x) - Kzz = kernel.gram(params["kernel"], z) - Kzz += identity(m) * jitter + # Unpack mean function and kernel + mean_function = self.prior.mean_function + kernel = self.prior.kernel - # Lz Lzᵀ = Kzz - Lz = Kzz.to_root() + Kzx = kernel.cross_covariance(z, x) + Kzz = kernel.gram(z) + Kzz += identity(m) * jitter - # Lz⁻¹ Kzx - Lz_inv_Kzx = Lz.solve(Kzx) + # Lz Lzᵀ = Kzz + Lz = Kzz.to_root() - # A = Lz⁻¹ Kzt / σ - A = Lz_inv_Kzx / jnp.sqrt(noise) + # Lz⁻¹ Kzx + Lz_inv_Kzx = Lz.solve(Kzx) - # AAᵀ - AAT = jnp.matmul(A, A.T) + # A = Lz⁻¹ Kzt / σ + A = Lz_inv_Kzx / jnp.sqrt(noise) - # LLᵀ = I + AAᵀ - L = jnp.linalg.cholesky(jnp.eye(m) + AAT) + # AAᵀ + AAT = jnp.matmul(A, A.T) - μx = mean_function(params["mean_function"], x) - diff = y - μx + # LLᵀ = I + AAᵀ + L = jnp.linalg.cholesky(jnp.eye(m) + AAT) - # Lz⁻¹ Kzx (y - μx) - Lz_inv_Kzx_diff = jsp.linalg.cho_solve( - (L, True), jnp.matmul(Lz_inv_Kzx, diff) - ) + μx = mean_function(x) + diff = y - μx - # Kzz⁻¹ Kzx (y - μx) - Kzz_inv_Kzx_diff = Lz.T.solve(Lz_inv_Kzx_diff) + # Lz⁻¹ Kzx (y - μx) + Lz_inv_Kzx_diff = jsp.linalg.cho_solve( + (L, True), jnp.matmul(Lz_inv_Kzx, diff) + ) - Ktt = kernel.gram(params["kernel"], t) - Kzt = kernel.cross_covariance(params["kernel"], z, t) - μt = mean_function(params["mean_function"], t) + # Kzz⁻¹ Kzx (y - μx) + Kzz_inv_Kzx_diff = Lz.T.solve(Lz_inv_Kzx_diff) - # Lz⁻¹ Kzt - Lz_inv_Kzt = Lz.solve(Kzt) + Ktt = kernel.gram(t) + Kzt = kernel.cross_covariance(z, t) + μt = mean_function(t) - # L⁻¹ Lz⁻¹ Kzt - L_inv_Lz_inv_Kzt = jsp.linalg.solve_triangular(L, Lz_inv_Kzt, lower=True) + # Lz⁻¹ Kzt + Lz_inv_Kzt = Lz.solve(Kzt) - # μt + 1/σ² Ktz Kzz⁻¹ Kzx (y - μx) - mean = μt + jnp.matmul(Kzt.T / noise, Kzz_inv_Kzx_diff) + # L⁻¹ Lz⁻¹ Kzt + L_inv_Lz_inv_Kzt = jsp.linalg.solve_triangular(L, Lz_inv_Kzt, lower=True) - # Ktt - Ktz Kzz⁻¹ Kzt + Ktz Lz⁻¹ (I + AAᵀ)⁻¹ Lz⁻¹ Kzt - covariance = ( - Ktt - - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) - + jnp.matmul(L_inv_Lz_inv_Kzt.T, L_inv_Lz_inv_Kzt) - ) - covariance += identity(n_test) * jitter + # μt + 1/σ² Ktz Kzz⁻¹ Kzx (y - μx) + mean = μt + jnp.matmul(Kzt.T / noise, Kzz_inv_Kzx_diff) - return GaussianDistribution( - loc=jnp.atleast_1d(mean.squeeze()), scale=covariance - ) + # Ktt - Ktz Kzz⁻¹ Kzt + Ktz Lz⁻¹ (I + AAᵀ)⁻¹ Lz⁻¹ Kzt + covariance = ( + Ktt + - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + + jnp.matmul(L_inv_Lz_inv_Kzt.T, L_inv_Lz_inv_Kzt) + ) + covariance += identity(n_test) * jitter - return predict_fn + return GaussianDistribution( + loc=jnp.atleast_1d(mean.squeeze()), scale=covariance + ) __all__ = [ diff --git a/tests/test_variational_families.py b/tests/test_variational_families.py index b11835b33..7fa3fa344 100644 --- a/tests/test_variational_families.py +++ b/tests/test_variational_families.py @@ -26,7 +26,7 @@ import gpjax as gpx from gpjax.variational_families import ( AbstractVariationalFamily, - #CollapsedVariationalGaussian, + CollapsedVariationalGaussian, ExpectationVariationalGaussian, NaturalVariationalGaussian, VariationalGaussian, @@ -148,69 +148,52 @@ def test_variational_gaussians( assert sigma.shape == (n_test, n_test) -# @pytest.mark.parametrize("n_test", [1, 10]) -# @pytest.mark.parametrize("n_datapoints", [1, 10]) -# @pytest.mark.parametrize("n_inducing", [1, 10, 20]) -# @pytest.mark.parametrize("point_dim", [1, 2]) -# def test_collapsed_variational_gaussian( -# n_test: int, n_inducing: int, n_datapoints: int, point_dim: int -# ) -> None: -# x = jnp.linspace(-5.0, 5.0, n_datapoints).reshape(-1, 1) -# y = jnp.sin(x) + jr.normal(key=jr.PRNGKey(123), shape=x.shape) * 0.1 -# x = jnp.hstack([x] * point_dim) -# D = gpx.Dataset(X=x, y=y) - -# prior = gpx.Prior(kernel=gpx.RBF()) - -# inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1) -# inducing_inputs = jnp.hstack([inducing_inputs] * point_dim) -# test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1) -# test_inputs = jnp.hstack([test_inputs] * point_dim) - -# variational_family = CollapsedVariationalGaussian( -# prior=prior, -# likelihood=gpx.Gaussian(num_datapoints=D.n), -# inducing_inputs=inducing_inputs, -# ) - -# # We should raise an error for non-Gaussian likelihoods: -# with pytest.raises(TypeError): -# CollapsedVariationalGaussian( -# prior=prior, -# likelihood=gpx.Bernoulli(num_datapoints=D.n), -# inducing_inputs=inducing_inputs, -# ) - -# # Test init -# assert variational_family.num_inducing == n_inducing -# params = gpx.config.get_global_config() -# assert "inducing_inputs" in params["transformations"].keys() -# assert (variational_family.inducing_inputs == inducing_inputs).all() - -# # Test params -# params = variational_family.init_params(jr.PRNGKey(123)) -# assert isinstance(params, dict) -# assert "likelihood" in params.keys() -# assert "obs_noise" in params["likelihood"].keys() -# assert "inducing_inputs" in params["variational_family"].keys() -# assert params["variational_family"]["inducing_inputs"].shape == ( -# n_inducing, -# point_dim, -# ) -# assert isinstance(params["variational_family"]["inducing_inputs"], jax.Array) - -# # Test predictions -# params = variational_family.init_params(jr.PRNGKey(123)) -# predictive_dist_fn = variational_family(params, D) -# assert isinstance(predictive_dist_fn, Callable) - -# predictive_dist = predictive_dist_fn(test_inputs) -# assert isinstance(predictive_dist, dx.Distribution) - -# mu = predictive_dist.mean() -# sigma = predictive_dist.covariance() - -# assert isinstance(mu, jnp.ndarray) -# assert isinstance(sigma, jnp.ndarray) -# assert mu.shape == (n_test,) -# assert sigma.shape == (n_test, n_test) +@pytest.mark.parametrize("n_test", [1, 10]) +@pytest.mark.parametrize("n_datapoints", [1, 10]) +@pytest.mark.parametrize("n_inducing", [1, 10, 20]) +@pytest.mark.parametrize("point_dim", [1, 2]) +def test_collapsed_variational_gaussian( + n_test: int, n_inducing: int, n_datapoints: int, point_dim: int +) -> None: + x = jnp.linspace(-5.0, 5.0, n_datapoints).reshape(-1, 1) + y = jnp.sin(x) + jr.normal(key=jr.PRNGKey(123), shape=x.shape) * 0.1 + x = jnp.hstack([x] * point_dim) + D = gpx.Dataset(X=x, y=y) + + prior = gpx.Prior(kernel=gpx.RBF()) + + inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1) + inducing_inputs = jnp.hstack([inducing_inputs] * point_dim) + test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1) + test_inputs = jnp.hstack([test_inputs] * point_dim) + + variational_family = CollapsedVariationalGaussian( + prior=prior, + likelihood=gpx.Gaussian(num_datapoints=D.n), + inducing_inputs=inducing_inputs, + ) + + # We should raise an error for non-Gaussian likelihoods: + with pytest.raises(TypeError): + CollapsedVariationalGaussian( + prior=prior, + likelihood=gpx.Bernoulli(num_datapoints=D.n), + inducing_inputs=inducing_inputs, + ) + + # Test init + assert variational_family.num_inducing == n_inducing + assert (variational_family.inducing_inputs == inducing_inputs).all() + assert variational_family.likelihood.obs_noise == 1.0 + + # Test predictions + predictive_dist = variational_family(test_inputs, D) + assert isinstance(predictive_dist, dx.Distribution) + + mu = predictive_dist.mean() + sigma = predictive_dist.covariance() + + assert isinstance(mu, jnp.ndarray) + assert isinstance(sigma, jnp.ndarray) + assert mu.shape == (n_test,) + assert sigma.shape == (n_test, n_test) From 4c1c9657e46a7bf60fa3f0f3de4a6c685f242e48 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Wed, 29 Mar 2023 20:40:35 +0100 Subject: [PATCH 18/44] Update likelihoods. --- gpjax/gps.py | 5 ++ gpjax/likelihoods.py | 109 +++++++++----------------------------- tests/test_likelihoods.py | 104 +++++++++--------------------------- 3 files changed, 53 insertions(+), 165 deletions(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index ffdc7b3ea..cf2b0b9a9 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -24,6 +24,11 @@ from jaxutils import Dataset, PyTree from .config import get_global_config +from .kernels import AbstractKernel +from .likelihoods import AbstractLikelihood +from .mean_functions import AbstractMeanFunction, Zero +from jaxutils import Dataset +from .utils import concat_dictionaries from .gaussian_distribution import GaussianDistribution from .kernels import AbstractKernel from .kernels.base import AbstractKernel diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index 4a9c5895f..152d58152 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -14,7 +14,7 @@ # ============================================================================== import abc -from typing import Any, Callable, Dict +from typing import Any from .linops.utils import to_dense import deprecation @@ -22,14 +22,15 @@ import jax.numpy as jnp import jax.scipy as jsp from jaxtyping import Array, Float +from simple_pytree import static_field from dataclasses import dataclass -from mytree import Mytree +from mytree import Mytree, param_field, Softplus @dataclass class AbstractLikelihood(Mytree): """Abstract base class for likelihoods.""" - num_datapoints: int + num_datapoints: int = static_field() def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: @@ -59,53 +60,31 @@ def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: @property @abc.abstractmethod - def link_function(self) -> Callable: + def link_function(self) -> dx.Distribution: """Return the link function of the likelihood function. Returns: - Callable: The link function of the likelihood function. + dx.Distribution: The distribution of observations, y, given values of the Gaussian process, f. """ raise NotImplementedError -# I don't think we need these? Either we inherit a conjugate likelihood for the Gaussian, or we check it's Gaussian in the posterior construction. -# I don't like multiple iheritance, and think its an annoying thing to have to remember to do, if the user want to add their own likelihoods! -class Conjugate: - """An abstract class for conjugate likelihoods with respect to a Gaussian process prior.""" - - -class NonConjugate: - """An abstract class for non-conjugate likelihoods with respect to a Gaussian process prior.""" - - @dataclass -class Gaussian(AbstractLikelihood, Conjugate): +class Gaussian(AbstractLikelihood): """Gaussian likelihood object.""" - obs_noise: float = 1.0 + obs_noise: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) - @property - def link_function(self) -> Callable[[Dict, Float[Array, "N 1"]], dx.Distribution]: - """Return the link function of the Gaussian likelihood. Here, this is - simply the identity function, but we include it for completeness. + def link_function(self, f: Float[Array, "N 1"]) -> dx.Normal: + """The link function of the Gaussian likelihood. + + Args: + params (Dict): The parameters of the likelihood function. + f (Float[Array, "N 1"]): Function values. Returns: - Callable[[Dict, Float[Array, "N 1"]], dx.Distribution]: A link - function that maps the predictive distribution to the likelihood function. + dx.Normal: The likelihood function. """ - - def link_fn(f: Float[Array, "N 1"]) -> dx.Normal: - """The link function of the Gaussian likelihood. - - Args: - params (Dict): The parameters of the likelihood function. - f (Float[Array, "N 1"]): Function values. - - Returns: - dx.Normal: The likelihood function. - """ - return dx.Normal(loc=f, scale=self.obs_noise) - - return link_fn + return dx.Normal(loc=f, scale=self.obs_noise) def predict(self, dist: dx.MultivariateNormalTri) -> dx.Distribution: """ @@ -130,58 +109,18 @@ def predict(self, dist: dx.MultivariateNormalTri) -> dx.Distribution: @dataclass -class Bernoulli(AbstractLikelihood, NonConjugate): +class Bernoulli(AbstractLikelihood): - @property - def link_function(self) -> Callable[[Float[Array, "N 1"]], dx.Distribution]: - """Return the probit link function of the Bernoulli likelihood. + def link_function(self, f: Float[Array, "N 1"]) -> dx.Distribution: + """The probit link function of the Bernoulli likelihood. - Returns: - Callable[[Dict, Float[Array, "N 1"]], dx.Distribution]: A probit link - function that maps the predictive distribution to the likelihood function. - """ - def link_fn(f: Float[Array, "N 1"]) -> dx.Distribution: - """The probit link function of the Bernoulli likelihood. - - Args: - f (Float[Array, "N 1"]): Function values. - - Returns: - dx.Distribution: The likelihood function. - """ - return dx.Bernoulli(probs=inv_probit(f)) - - return link_fn - - @property - def predictive_moment_fn( - self, - ) -> Callable[[Float[Array, "N 1"]], Float[Array, "N 1"]]: - """Instantiate the predictive moment function of the Bernoulli likelihood - that is parameterised by a probit link function. + Args: + f (Float[Array, "N 1"]): Function values. Returns: - Callable: A callable object that accepts a mean and variance term - from which the predictive random variable is computed. + dx.Distribution: The likelihood function. """ - - def moment_fn( - mean: Float[Array, "N 1"], - variance: Float[Array, "N 1"], - ): - """The predictive moment function of the Bernoulli likelihood. - - Args: - mean (Float[Array, "N 1"]): The mean of the latent function values. - variance (Float[Array, "N 1"]): The diagonal variance of the latent function values. - - Returns: - Float[Array, "N 1"]: The pointwise predictive distribution. - """ - rv = self.link_function(mean / jnp.sqrt(1.0 + variance)) - return rv - - return moment_fn + return dx.Bernoulli(probs=inv_probit(f)) def predict(self, dist: dx.Distribution) -> dx.Distribution: """Evaluate the pointwise predictive distribution, given a Gaussian @@ -197,7 +136,7 @@ def predict(self, dist: dx.Distribution) -> dx.Distribution: """ variance = jnp.diag(dist.covariance()) mean = dist.mean().ravel() - return self.predictive_moment_fn(mean, variance) + return self.link_function(mean / jnp.sqrt(1.0 + variance)) def inv_probit(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py index 6d9d7918c..7486f0b15 100644 --- a/tests/test_likelihoods.py +++ b/tests/test_likelihoods.py @@ -13,8 +13,9 @@ # limitations under the License. # ============================================================================== -from typing import Callable, Dict +from typing import Callable +import jax.tree_util as jtu import distrax as dx import jax.numpy as jnp import jax.random as jr @@ -27,21 +28,12 @@ from gpjax.likelihoods import ( AbstractLikelihood, Bernoulli, - Conjugate, Gaussian, - NonConjugate, inv_probit, ) # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) -_initialise_key = jr.PRNGKey(123) - -# Likelihood parameter names to test in initialisation. -true_initialisation = { - "Gaussian": ["obs_noise"], - "Bernoulli": [], -} def test_abstract_likelihood(): @@ -51,17 +43,12 @@ def test_abstract_likelihood(): # Create a dummy likelihood class with abstract methods implemented. class DummyLikelihood(AbstractLikelihood): - def init_params(self, key: KeyArray) -> Dict: - return {} - def predict(self, params: Dict, dist: dx.Distribution) -> dx.Distribution: + def predict(self, dist: dx.Distribution) -> dx.Distribution: return dx.Normal(0.0, 1.0) - def link_function(self) -> Callable: - def link(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: - return dx.MultivariateNormalDiag(loc=x) - - return link + def link_function(self, f: Float[Array, "N 1"]) -> Float[Array, "N 1"]: + return dx.MultivariateNormalDiag(loc=f) # Test that the dummy likelihood can be instantiated. dummy_likelihood = DummyLikelihood(num_datapoints=123) @@ -69,74 +56,46 @@ def link(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: @pytest.mark.parametrize("n", [1, 10]) -@pytest.mark.parametrize("lik", [Gaussian, Bernoulli]) -def test_initialisers(n: int, lik: AbstractLikelihood) -> None: - key = _initialise_key - - # Initialise the likelihood. - likelihood = lik(num_datapoints=n) - - # Get default parameter dictionary. - params = likelihood.init_params(key) +@pytest.mark.parametrize("noise", [0.1, 0.5, 1.0]) +def test_gaussian_init(n: int, noise: float) -> None: - # Check parameter dictionary - assert list(params.keys()) == true_initialisation[likelihood.name] - assert len(list(params.values())) == len(true_initialisation[likelihood.name]) + likelihood = Gaussian(num_datapoints=n, obs_noise=jnp.array([noise])) + + assert likelihood.obs_noise == jnp.array([noise]) + assert likelihood.num_datapoints == n + assert jtu.tree_leaves(likelihood) == [jnp.array([noise])] @pytest.mark.parametrize("n", [1, 10]) -def test_bernoulli_predictive_moment(n: int) -> None: - key = _initialise_key +def test_beroulli_init(n: int) -> None: - # Initialise bernoulli likelihood. likelihood = Bernoulli(num_datapoints=n) - - # Initialise parameters. - params = likelihood.init_params(key) - - # Construct latent function mean and variance values - mean_key, var_key = jr.split(key) - fmean = jr.uniform(mean_key, shape=(n, 1)) - fvar = jnp.exp(jr.normal(var_key, shape=(n, 1))) - - # Test predictive moments. - assert isinstance(likelihood.predictive_moment_fn, Callable) - - y = likelihood.predictive_moment_fn(params, fmean, fvar) - y_mean = y.mean() - y_var = y.variance() - - assert y_mean.shape == (n, 1) - assert y_var.shape == (n, 1) + assert likelihood.num_datapoints == n + assert jtu.tree_leaves(likelihood) == [] @pytest.mark.parametrize("lik", [Gaussian, Bernoulli]) @pytest.mark.parametrize("n", [1, 10]) def test_link_fns(lik: AbstractLikelihood, n: int) -> None: - key = _initialise_key - # Create test inputs. - x = jnp.linspace(-3.0, 3.0).reshape(-1, 1) + # Create function values. + f = jnp.linspace(-3.0, 3.0).reshape(-1, 1) # Initialise likelihood. likelihood = lik(num_datapoints=n) - # Initialise parameters. - params = likelihood.init_params(key) - # Test likelihood link function. assert isinstance(likelihood.link_function, Callable) - assert isinstance(likelihood.link_function(params, x), dx.Distribution) + assert isinstance(likelihood.link_function(f), dx.Distribution) @pytest.mark.parametrize("noise", [0.1, 0.5, 1.0]) @pytest.mark.parametrize("n", [1, 2, 10]) def test_call_gaussian(noise: float, n: int) -> None: - key = _initialise_key + key = jr.PRNGKey(123) # Initialise likelihood and parameters. - likelihood = Gaussian(num_datapoints=n) - params = {"likelihood": {"obs_noise": noise}} + likelihood = Gaussian(num_datapoints=n, obs_noise=jnp.array([noise])) # Construct latent function distribution. latent_mean = jr.uniform(key, shape=(n,)) @@ -145,7 +104,7 @@ def test_call_gaussian(noise: float, n: int) -> None: latent_dist = dx.MultivariateNormalFullCovariance(latent_mean, latent_cov) # Test call method. - pred_dist = likelihood(params, latent_dist) + pred_dist = likelihood(latent_dist) # Check that the distribution is a MultivariateNormalFullCovariance. assert isinstance(pred_dist, dx.MultivariateNormalFullCovariance) @@ -161,11 +120,10 @@ def test_call_gaussian(noise: float, n: int) -> None: @pytest.mark.parametrize("n", [1, 2, 10]) def test_call_bernoulli(n: int) -> None: - key = _initialise_key + key = jr.PRNGKey(123) # Initialise likelihood and parameters. likelihood = Bernoulli(num_datapoints=n) - params = {"likelihood": {}} # Construct latent function distribution. latent_mean = jr.uniform(key, shape=(n,)) @@ -174,7 +132,7 @@ def test_call_bernoulli(n: int) -> None: latent_dist = dx.MultivariateNormalFullCovariance(latent_mean, latent_cov) # Test call method. - pred_dist = likelihood(params, latent_dist) + pred_dist = likelihood(latent_dist) # Check that the distribution is a Bernoulli. assert isinstance(pred_dist, dx.Bernoulli) @@ -183,18 +141,4 @@ def test_call_bernoulli(n: int) -> None: p = inv_probit(latent_mean / jnp.sqrt(1.0 + jnp.diagonal(latent_cov))) assert (pred_dist.mean() == p).all() - assert (pred_dist.variance() == p * (1.0 - p)).all() - - -@pytest.mark.parametrize("lik", [Gaussian, Bernoulli]) -@pytest.mark.parametrize("n", [1, 2, 10]) -def test_conjugacy(lik: AbstractLikelihood, n: int) -> None: - likelihood = lik(num_datapoints=n) - - # Gaussian likelihood is conjugate. - if isinstance(likelihood, Gaussian): - assert isinstance(likelihood, Conjugate) - - # Bernoulli likelihood is non-conjugate. - elif isinstance(likelihood, Bernoulli): - assert isinstance(likelihood, NonConjugate) + assert (pred_dist.variance() == p * (1.0 - p)).all() \ No newline at end of file From cacac4edb474238ea049e3a86c2810bb925c8f78 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Thu, 30 Mar 2023 19:29:15 +0100 Subject: [PATCH 19/44] Add fit.py and test. --- gpjax/abstractions.py | 400 ------------------------------------------ gpjax/fit.py | 231 ++++++++++++++++++++++++ gpjax/progress_bar.py | 126 +++++++++++++ tests/test_fit.py | 57 ++++++ 4 files changed, 414 insertions(+), 400 deletions(-) delete mode 100644 gpjax/abstractions.py create mode 100644 gpjax/fit.py create mode 100644 gpjax/progress_bar.py create mode 100644 tests/test_fit.py diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py deleted file mode 100644 index 3682aa950..000000000 --- a/gpjax/abstractions.py +++ /dev/null @@ -1,400 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from typing import Any, Callable, Dict, Optional, Tuple, Union - -import jax -import jax.numpy as jnp -import jax.random as jr -import optax as ox -from jax import lax -from jax.experimental import host_callback -from jax.random import KeyArray -from jaxtyping import Array, Float -from jaxutils import Dataset, PyTree -from tqdm.auto import tqdm - -from .natural_gradients import natural_gradients -from .params import ParameterState, constrain, trainable_params, unconstrain -from .variational_inference import StochasticVI - - -class InferenceState(PyTree): - """Immutable class for storing optimised parameters and training history.""" - - def __init__(self, params: Dict, history: Float[Array, "num_iters"]): - self._params = params - self._history = history - - @property - def params(self) -> Dict: - """Parameters. - - Returns: - Dict: Parameters. - """ - return self._params - - @property - def history(self) -> Float[Array, "num_iters"]: - """Training history. - - Returns: - Float[Array, "num_iters"]: Training history. - """ - return self._history - - def unpack(self) -> Tuple[Dict, Float[Array, "num_iters"]]: - """Unpack parameters and training history into a tuple. - - Returns: - Tuple[Dict, Float[Array, "num_iters"]]: Tuple of parameters and training history. - """ - return self.params, self.history - - -def fit( - objective: Callable, - parameter_state: ParameterState, - optax_optim: ox.GradientTransformation, - num_iters: Optional[int] = 100, - log_rate: Optional[int] = 10, - verbose: Optional[bool] = True, -) -> InferenceState: - """Abstracted method for fitting a GP model with respect to a supplied objective function. - Optimisers used here should originate from Optax. - - Args: - objective (Callable): The objective function that we are optimising with respect to. - parameter_state (ParameterState): The initial parameter state. - optax_optim (GradientTransformation): The Optax optimiser that is to be used for learning a parameter set. - num_iters (Optional[int]): The number of optimisation steps to run. Defaults to 100. - log_rate (Optional[int]): How frequently the objective function's value should be printed. Defaults to 10. - verbose (Optional[bool]): Whether to print the training loading bar. Defaults to True. - - Returns: - InferenceState: An InferenceState object comprising the optimised parameters and training history respectively. - """ - - params, trainables, bijectors = parameter_state.unpack() - - # Define optimisation loss function on unconstrained space, with a stop gradient rule for trainables that are set to False - def loss(params: Dict) -> Float[Array, "1"]: - params = trainable_params(params, trainables) - params = constrain(params, bijectors) - return objective(params) - - # Transform params to unconstrained space - params = unconstrain(params, bijectors) - - # Initialise optimiser state - opt_state = optax_optim.init(params) - - # Iteration loop numbers to scan over - iter_nums = jnp.arange(num_iters) - - # Optimisation step - def step(carry, iter_num: int): - params, opt_state = carry - loss_val, loss_gradient = jax.value_and_grad(loss)(params) - updates, opt_state = optax_optim.update(loss_gradient, opt_state, params) - params = ox.apply_updates(params, updates) - carry = params, opt_state - return carry, loss_val - - # Display progress bar if verbose is True - if verbose: - step = progress_bar_scan(num_iters, log_rate)(step) - - # Run the optimisation loop - (params, _), history = jax.lax.scan(step, (params, opt_state), iter_nums) - - # Transform final params to constrained space - params = constrain(params, bijectors) - - return InferenceState(params=params, history=history) - - -def fit_batches( - objective: Callable, - parameter_state: ParameterState, - train_data: Dataset, - optax_optim: ox.GradientTransformation, - key: KeyArray, - batch_size: int, - num_iters: Optional[int] = 100, - log_rate: Optional[int] = 10, - verbose: Optional[bool] = True, -) -> InferenceState: - """Abstracted method for fitting a GP model with mini-batches respect to a - supplied objective function. - Optimisers used here should originate from Optax. - - Args: - objective (Callable): The objective function that we are optimising with respect to. - parameter_state (ParameterState): The parameters for which we would like to minimise our objective function with. - train_data (Dataset): The training dataset. - optax_optim (GradientTransformation): The Optax optimiser that is to be used for learning a parameter set. - key (KeyArray): The PRNG key for the mini-batch sampling. - batch_size (int): The batch_size. - num_iters (Optional[int]): The number of optimisation steps to run. Defaults to 100. - log_rate (Optional[int]): How frequently the objective function's value should be printed. Defaults to 10. - verbose (Optional[bool]): Whether to print the training loading bar. Defaults to True. - - Returns: - InferenceState: An InferenceState object comprising the optimised parameters and training history respectively. - """ - - params, trainables, bijectors = parameter_state.unpack() - - # Define optimisation loss function on unconstrained space, with a stop gradient rule for trainables that are set to False - def loss(params: Dict, batch: Dataset) -> Float[Array, "1"]: - params = trainable_params(params, trainables) - params = constrain(params, bijectors) - return objective(params, batch) - - # Transform params to unconstrained space - params = unconstrain(params, bijectors) - - # Initialise optimiser state - opt_state = optax_optim.init(params) - - # Mini-batch random keys and iteration loop numbers to scan over - keys = jr.split(key, num_iters) - iter_nums = jnp.arange(num_iters) - - # Optimisation step - def step(carry, iter_num__and__key): - iter_num, key = iter_num__and__key - params, opt_state = carry - - batch = get_batch(train_data, batch_size, key) - - loss_val, loss_gradient = jax.value_and_grad(loss)(params, batch) - updates, opt_state = optax_optim.update(loss_gradient, opt_state, params) - params = ox.apply_updates(params, updates) - - carry = params, opt_state - return carry, loss_val - - # Display progress bar if verbose is True - if verbose: - step = progress_bar_scan(num_iters, log_rate)(step) - - # Run the optimisation loop - (params, _), history = jax.lax.scan(step, (params, opt_state), (iter_nums, keys)) - - # Transform final params to constrained space - params = constrain(params, bijectors) - - return InferenceState(params=params, history=history) - - -def get_batch(train_data: Dataset, batch_size: int, key: KeyArray) -> Dataset: - """Batch the data into mini-batches. Sampling is done with replacement. - - Args: - train_data (Dataset): The training dataset. - batch_size (int): The batch size. - - Returns: - Dataset: The batched dataset. - """ - x, y, n = train_data.X, train_data.y, train_data.n - - # Subsample data indices with replacement to get the mini-batch - indicies = jr.choice(key, n, (batch_size,), replace=True) - - return Dataset(X=x[indicies], y=y[indicies]) - - -def fit_natgrads( - stochastic_vi: StochasticVI, - parameter_state: ParameterState, - train_data: Dataset, - moment_optim: ox.GradientTransformation, - hyper_optim: ox.GradientTransformation, - key: KeyArray, - batch_size: int, - num_iters: Optional[int] = 100, - log_rate: Optional[int] = 10, - verbose: Optional[bool] = True, -) -> Dict: - """This is a training loop for natural gradients. See Salimbeni et al. - (2018) Natural Gradients in Practice: Non-Conjugate Variational Inference - in Gaussian Process Models - Each iteration comprises a hyperparameter gradient step followed by natural - gradient step to avoid a stale posterior. - - Args: - stochastic_vi (StochasticVI): The stochastic variational inference algorithm to be used for training. - parameter_state (ParameterState): The initial parameter state. - train_data (Dataset): The training dataset. - moment_optim (GradientTransformation): The Optax optimiser for the natural gradient updates on the moments. - hyper_optim (GradientTransformation): The Optax optimiser for gradient updates on the hyperparameters. - key (KeyArray): The PRNG key for the mini-batch sampling. - batch_size(int): The batch_size. - num_iters (Optional[int]): The number of optimisation steps to run. Defaults to 100. - log_rate (Optional[int]): How frequently the objective function's value should be printed. Defaults to 10. - verbose (Optional[bool]): Whether to print the training loading bar. Defaults to True. - - Returns: - InferenceState: A class comprising optimised parameters and training history. - """ - - params, trainables, bijectors = parameter_state.unpack() - - # Transform params to unconstrained space - params = unconstrain(params, bijectors) - - # Initialise optimiser states - hyper_state = hyper_optim.init(params) - moment_state = moment_optim.init(params) - - # Build natural and hyperparameter gradient functions - nat_grads_fn, hyper_grads_fn = natural_gradients( - stochastic_vi, train_data, bijectors, trainables - ) - - # Mini-batch random keys and iteration loop numbers to scan over - keys = jax.random.split(key, num_iters) - iter_nums = jnp.arange(num_iters) - - # Optimisation step - def step(carry, iter_num__and__key): - iter_num, key = iter_num__and__key - params, hyper_state, moment_state = carry - - batch = get_batch(train_data, batch_size, key) - - # Hyper-parameters update: - loss_val, loss_gradient = hyper_grads_fn(params, batch) - updates, hyper_state = hyper_optim.update(loss_gradient, hyper_state, params) - params = ox.apply_updates(params, updates) - - # Natural gradients update: - loss_val, loss_gradient = nat_grads_fn(params, batch) - updates, moment_state = moment_optim.update(loss_gradient, moment_state, params) - params = ox.apply_updates(params, updates) - - carry = params, hyper_state, moment_state - return carry, loss_val - - # Display progress bar if verbose is True - if verbose: - step = progress_bar_scan(num_iters, log_rate)(step) - - # Run the optimisation loop - (params, _, _), history = jax.lax.scan( - step, (params, hyper_state, moment_state), (iter_nums, keys) - ) - - # Transform final params to constrained space - params = constrain(params, bijectors) - - return InferenceState(params=params, history=history) - - -def progress_bar_scan(num_iters: int, log_rate: int) -> Callable: - """Progress bar for Jax.lax scans (adapted from https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/).""" - - tqdm_bars = {} - remainder = num_iters % log_rate - - def _define_tqdm(args: Any, transform: Any) -> None: - """Define a tqdm progress bar.""" - tqdm_bars[0] = tqdm(range(num_iters)) - - def _update_tqdm(args: Any, transform: Any) -> None: - """Update the tqdm progress bar with the latest objective value.""" - loss_val, arg = args - tqdm_bars[0].update(arg) - tqdm_bars[0].set_postfix({"Objective": f"{loss_val: .2f}"}) - - def _close_tqdm(args: Any, transform: Any) -> None: - """Close the tqdm progress bar.""" - tqdm_bars[0].close() - - def _callback(cond: bool, func: Callable, arg: Any) -> None: - """Callback a function for a given argument if a condition is true.""" - dummy_result = 0 - - def _do_callback(_) -> int: - """Perform the callback.""" - return host_callback.id_tap(func, arg, result=dummy_result) - - def _not_callback(_) -> int: - """Do nothing.""" - return dummy_result - - _ = lax.cond(cond, _do_callback, _not_callback, operand=None) - - def _update_progress_bar(loss_val: Float[Array, "1"], iter_num: int) -> None: - """Updates tqdm progress bar of a JAX scan or loop.""" - - # Conditions for iteration number - is_first: bool = iter_num == 0 - is_multiple: bool = (iter_num % log_rate == 0) & ( - iter_num != num_iters - remainder - ) - is_remainder: bool = iter_num == num_iters - remainder - is_last: bool = iter_num == num_iters - 1 - - # Define progress bar, if first iteration - _callback(is_first, _define_tqdm, None) - - # Update progress bar, if multiple of log_rate - _callback(is_multiple, _update_tqdm, (loss_val, log_rate)) - - # Update progress bar, if remainder - _callback(is_remainder, _update_tqdm, (loss_val, remainder)) - - # Close progress bar, if last iteration - _callback(is_last, _close_tqdm, None) - - def _progress_bar_scan(body_fun: Callable) -> Callable: - """Decorator that adds a progress bar to `body_fun` used in `lax.scan`.""" - - def wrapper_progress_bar(carry: Any, x: Union[tuple, int]) -> Any: - - # Get iteration number - if type(x) is tuple: - iter_num, *_ = x - else: - iter_num = x - - # Compute iteration step - result = body_fun(carry, x) - - # Get loss value - *_, loss_val = result - - # Update progress bar - _update_progress_bar(loss_val, iter_num) - - return result - - return wrapper_progress_bar - - return _progress_bar_scan - - -__all__ = [ - "fit", - "fit_natgrads", - "get_batch", - "natural_gradients", - "progress_bar_scan", -] diff --git a/gpjax/fit.py b/gpjax/fit.py new file mode 100644 index 000000000..9388014e9 --- /dev/null +++ b/gpjax/fit.py @@ -0,0 +1,231 @@ +# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Optional, Tuple + +import jax +import jax.random as jr +import optax as ox + +from jax.random import KeyArray +from jax._src.random import _check_prng_key +from jaxtyping import Array, Float +from typing import Any + +from .dataset import Dataset +from .objective import Objective +from .scan import vscan + +Module = Any + + +def fit( + *, + model: Module, + objective: Objective, + train_data: Dataset, + optim: ox.GradientTransformation, + num_iters: Optional[int] = 100, + batch_size: Optional[int] = -1, + key: Optional[KeyArray] = jr.PRNGKey(42), + log_rate: Optional[int] = 10, + verbose: Optional[bool] = True, + unroll: int = 1, +) -> Tuple[Module, Array]: + """Train a Module model with respect to a supplied Objective function. Optimisers used here should originate from Optax. + + Example: + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import optax as ox + >>> import jaxutils as ju + >>> + >>> # (1) Create a dataset: + >>> X = jnp.linspace(0.0, 10.0, 100)[:, None] + >>> y = 2.0 * X + 1.0 + 10 * jr.normal(jr.PRNGKey(0), X.shape) + >>> D = ju.Dataset(X, y) + >>> + >>> # (2) Define your model: + >>> class LinearModel(ju.Module): + ... weight: float = ju.param(ju.Identity) + ... bias: float = ju.param(ju.Identity) + ... + ... def __call__(self, x): + ... return self.weight * x + self.bias + ... + >>> model = LinearModel(weight=1.0, bias=1.0) + >>> + >>> # (3) Define your loss function: + >>> class MeanSqaureError(ju.Objective): + ... def evaluate(self, model: LinearModel, train_data: ju.Dataset) -> float: + ... return jnp.mean((train_data.y - model(train_data.X)) ** 2) + ... + >>> loss = MeanSqaureError() + >>> + >>> # (4) Train! + >>> trained_model, history = ju.fit( + ... model=model, objective=loss, train_data=D, optim=ox.sgd(0.001), num_iters=1000 + ... ) + + Args: + model (Module): The model Module to be optimised. + objective (Objective): The objective function that we are optimising with respect to. + train_data (Dataset): The training data to be used for the optimisation. + optim (GradientTransformation): The Optax optimiser that is to be used for learning a parameter set. + num_iters (Optional[int]): The number of optimisation steps to run. Defaults to 100. + batch_size (Optional[int]): The size of the mini-batch to use. Defaults to -1 (i.e. full batch). + key (Optional[KeyArray]): The random key to use for the optimisation batch selection. Defaults to jr.PRNGKey(42). + log_rate (Optional[int]): How frequently the objective function's value should be printed. Defaults to 10. + verbose (Optional[bool]): Whether to print the training loading bar. Defaults to True. + unroll (int): The number of unrolled steps to use for the optimisation. Defaults to 1. + + Returns: + Tuple[Module, Array]: A Tuple comprising the optimised model and training history respectively. + """ + + # Check inputs. + _check_model(model) + _check_objective(objective) + _check_train_data(train_data) + _check_optim(optim) + _check_num_iters(num_iters) + _check_batch_size(batch_size) + _check_prng_key(key) + _check_log_rate(log_rate) + _check_verbose(verbose) + + # Unconstrained space loss function with stop-gradient rule for non-trainable params. + def loss(model: Module, batch: Dataset) -> Float[Array, "1"]: + model = model.stop_gradients() + return objective(model.constrain(), batch) + + # Unconstrained space model. + model = model.unconstrain() + + # Initialise optimiser state. + state = optim.init(model) + + # Mini-batch random keys to scan over. + iter_keys = jr.split(key, num_iters) + + # Optimisation step. + def step(carry, key): + model, opt_state = carry + + if batch_size != -1: + batch = get_batch(train_data, batch_size, key) + else: + batch = train_data + + loss_val, loss_gradient = jax.value_and_grad(loss)(model, batch) + updates, opt_state = optim.update(loss_gradient, opt_state, model) + model = ox.apply_updates(model, updates) + + carry = model, opt_state + return carry, loss_val + + # Optimisation scan. + scan = vscan if verbose else jax.lax.scan + + # Optimisation loop. + (model, _), history = scan(step, (model, state), (iter_keys), unroll=unroll) + + # Constrained space. + model = model.constrain() + + return model, history + + +def get_batch(train_data: Dataset, batch_size: int, key: KeyArray) -> Dataset: + """Batch the data into mini-batches. Sampling is done with replacement. + + Args: + train_data (Dataset): The training dataset. + batch_size (int): The batch size. + key (KeyArray): The random key to use for the batch selection. + + Returns: + Dataset: The batched dataset. + """ + x, y, n = train_data.X, train_data.y, train_data.n + + # Subsample mini-batch indicies with replacement. + indicies = jr.choice(key, n, (batch_size,), replace=True) + + return Dataset(X=x[indicies], y=y[indicies]) + + +def _check_model(model: Any) -> None: + """Check that the model is of type Module. Check trainables and bijectors tree structure.""" + if not isinstance(model, Module): + raise TypeError("model must be of type jaxutils.Module") + + +def _check_objective(objective: Any) -> None: + """Check that the objective is of type Objective.""" + if not isinstance(objective, Objective): + raise TypeError("objective must be of type jaxutils.Objective") + + +def _check_train_data(train_data: Any) -> None: + """Check that the train_data is of type Dataset.""" + if not isinstance(train_data, Dataset): + raise TypeError("train_data must be of type jaxutils.Dataset") + + +def _check_optim(optim: Any) -> None: + """Check that the optimiser is of type GradientTransformation.""" + if not isinstance(optim, ox.GradientTransformation): + raise TypeError("optax_optim must be of type optax.GradientTransformation") + + +def _check_num_iters(num_iters: Any) -> None: + """Check that the number of iterations is of type int and positive.""" + if not isinstance(num_iters, int): + raise TypeError("num_iters must be of type int") + + if not num_iters > 0: + raise ValueError("num_iters must be positive") + + +def _check_log_rate(log_rate: Any) -> None: + """Check that the log rate is of type int and positive.""" + if not isinstance(log_rate, int): + raise TypeError("log_rate must be of type int") + + if not log_rate > 0: + raise ValueError("log_rate must be positive") + + +def _check_verbose(verbose: Any) -> None: + """Check that the verbose is of type bool.""" + if not isinstance(verbose, bool): + raise TypeError("verbose must be of type bool") + + +def _check_batch_size(batch_size: Any) -> None: + """Check that the batch size is of type int and positive if not minus 1.""" + if not isinstance(batch_size, int): + raise TypeError("batch_size must be of type int") + + if not batch_size == -1: + if not batch_size > 0: + raise ValueError("batch_size must be positive") + + +__all__ = [ + "fit", + "get_batch", +] \ No newline at end of file diff --git a/gpjax/progress_bar.py b/gpjax/progress_bar.py new file mode 100644 index 000000000..68d00e8c7 --- /dev/null +++ b/gpjax/progress_bar.py @@ -0,0 +1,126 @@ +# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Callable, Any, Union + +from jax import lax +from jax.experimental import host_callback +from jaxtyping import Array, Float +from tqdm.auto import tqdm + + +def progress_bar(num_iters: int, log_rate: int) -> Callable: + """Progress bar decorator for the body function of a `jax.lax.scan`. + + !!! example + ```python + + carry = jnp.array(0.0) + iteration_numbers = jnp.arange(100) + + @progress_bar(num_iters=x.shape[0], log_rate=10) + def body_func(carry, x): + return carry, x + + carry, _ = jax.lax.scan(body_func, carry, iteration_numbers) + ``` + + Adapted from the excellent blog post: https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/. + + Might be nice in future to directly create a general purpose `verbose scan` inplace of a for a jax.lax.scan, + that takes the same arguments as a jax.lax.scan, but prints a progress bar. + """ + + tqdm_bars = {} + remainder = num_iters % log_rate + + """Define a tqdm progress bar.""" + tqdm_bars[0] = tqdm(range(num_iters)) + tqdm_bars[0].set_description("Compiling...", refresh=True) + + def _update_tqdm(args: Any, transform: Any) -> None: + """Update the tqdm progress bar with the latest objective value.""" + value, iter_num = args + tqdm_bars[0].set_description(f"Running", refresh=False) + tqdm_bars[0].update(iter_num) + tqdm_bars[0].set_postfix({"Value": f"{value: .2f}"}) + + def _close_tqdm(args: Any, transform: Any) -> None: + """Close the tqdm progress bar.""" + tqdm_bars[0].close() + + def _callback(cond: bool, func: Callable, arg: Any) -> None: + """Callback a function for a given argument if a condition is true.""" + dummy_result = 0 + + def _do_callback(_) -> int: + """Perform the callback.""" + return host_callback.id_tap(func, arg, result=dummy_result) + + def _not_callback(_) -> int: + """Do nothing.""" + return dummy_result + + _ = lax.cond(cond, _do_callback, _not_callback, operand=None) + + def _update_progress_bar(value: Float[Array, "1"], iter_num: int) -> None: + """Update the tqdm progress bar.""" + + # Conditions for iteration number + is_multiple: bool = (iter_num % log_rate == 0) & ( + iter_num != num_iters - remainder + ) + is_remainder: bool = iter_num == num_iters - remainder + is_last: bool = iter_num == num_iters - 1 + + # Update progress bar, if multiple of log_rate + _callback(is_multiple, _update_tqdm, (value, log_rate)) + + # Update progress bar, if remainder + _callback(is_remainder, _update_tqdm, (value, remainder)) + + # Close progress bar, if last iteration + _callback(is_last, _close_tqdm, None) + + def _progress_bar(body_fun: Callable) -> Callable: + """Decorator that adds a progress bar to `body_fun` used in `jax.lax.scan`.""" + + def wrapper_progress_bar(carry: Any, x: Union[tuple, int]) -> Any: + + # Get iteration number + if type(x) is tuple: + iter_num, *_ = x + else: + iter_num = x + + # Compute iteration step + result = body_fun(carry, x) + + # Get value + *_, value = result + + # Update progress bar + _update_progress_bar(value, iter_num) + + return result + + return wrapper_progress_bar + + return _progress_bar + + +__all__ = [ + "progress_bar", +] \ No newline at end of file diff --git a/tests/test_fit.py b/tests/test_fit.py new file mode 100644 index 000000000..ecf989f65 --- /dev/null +++ b/tests/test_fit.py @@ -0,0 +1,57 @@ +# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from jaxutils.dataset import Dataset +from jaxutils.fit import fit +from jaxutils.bijectors import Identity +from jaxutils.module import param, Module +from jaxutils.objective import Objective + +import jax.numpy as jnp +import jax.random as jr +import optax as ox + + +def test_simple_linear_model(): + # (1) Create a dataset: + X = jnp.linspace(0.0, 10.0, 100).reshape(-1, 1) + y = 2.0 * X + 1.0 + 10 * jr.normal(jr.PRNGKey(0), X.shape).reshape(-1, 1) + D = Dataset(X, y) + + # (2) Define your model: + class LinearModel(Module): + weight: float = param(Identity) + bias: float = param(Identity) + + def __call__(self, x): + return self.weight * x + self.bias + + model = LinearModel(weight=1.0, bias=1.0) + + # (3) Define your loss function: + class MeanSqaureError(Objective): + def evaluate(self, model: LinearModel, train_data: Dataset) -> float: + return jnp.mean((train_data.y - model(train_data.X)) ** 2) + + loss = MeanSqaureError() + + # (4) Train! + trained_model, hist = fit( + model=model, objective=loss, train_data=D, optim=ox.sgd(0.001), num_iters=100 + ) + + assert len(hist) == 100 + assert isinstance(trained_model, LinearModel) + assert loss(trained_model, D) < loss(model, D) From e07ffe5d7817284a0b1542c4869800d7d299cd22 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Thu, 30 Mar 2023 19:35:09 +0100 Subject: [PATCH 20/44] Remove types and add dataset. --- gpjax/dataset.py | 101 ++++++++++++++++++++++++++++++++++++++++++ gpjax/types.py | 32 ------------- tests/test_dataset.py | 90 +++++++++++++++++++++++++++++++++++++ 3 files changed, 191 insertions(+), 32 deletions(-) create mode 100644 gpjax/dataset.py delete mode 100644 gpjax/types.py create mode 100644 tests/test_dataset.py diff --git a/gpjax/dataset.py b/gpjax/dataset.py new file mode 100644 index 000000000..b5b9e9052 --- /dev/null +++ b/gpjax/dataset.py @@ -0,0 +1,101 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import annotations + +import jax.numpy as jnp +from jaxtyping import Array, Float +from typing import Optional +from simple_pytree import Pytree +from dataclasses import dataclass + + +@dataclass +class Dataset(Pytree): + """Base class for datasets. + + Attributes: + X (Optional[Float[Array, "N D"]]): Input data. + y (Optional[Float[Array, "N Q"]]): Output data. + """ + + X: Optional[Float[Array, "N D"]] = None + y: Optional[Float[Array, "N Q"]] = None + + def __post_init__(self) -> None: + """Checks that the shapes of X and y are compatible.""" + _check_shape(self.X, self.y) + + def __repr__(self) -> str: + """Returns a string representation of the dataset.""" + repr = ( + f"- Number of observations: {self.n}\n- Input dimension:" + f" {self.in_dim}\n- Output dimension: {self.out_dim}" + ) + return repr + + def is_supervised(self) -> bool: + """Returns `True` if the dataset is supervised.""" + return self.X is not None and self.y is not None + + def is_unsupervised(self) -> bool: + """Returns `True` if the dataset is unsupervised.""" + return self.X is None and self.y is not None + + def __add__(self, other: Dataset) -> Dataset: + """Combine two datasets. Right hand dataset is stacked beneath the left.""" + X = jnp.concatenate((self.X, other.X)) + y = jnp.concatenate((self.y, other.y)) + + return Dataset(X=X, y=y) + + @property + def n(self) -> int: + """Number of observations.""" + return self.X.shape[0] + + @property + def in_dim(self) -> int: + """Dimension of the inputs, X.""" + return self.X.shape[1] + + @property + def out_dim(self) -> int: + """Dimension of the outputs, y.""" + return self.y.shape[1] + + +def _check_shape(X: Float[Array, "N D"], y: Float[Array, "N Q"]) -> None: + """Checks that the shapes of X and y are compatible.""" + if X is not None and y is not None: + if X.shape[0] != y.shape[0]: + raise ValueError( + "Inputs, X, and outputs, y, must have the same number of rows." + f" Got X.shape={X.shape} and y.shape={y.shape}." + ) + + if X is not None and X.ndim != 2: + raise ValueError( + f"Inputs, X, must be a 2-dimensional array. Got X.ndim={X.ndim}." + ) + + if y is not None and y.ndim != 2: + raise ValueError( + f"Outputs, y, must be a 2-dimensional array. Got y.ndim={y.ndim}." + ) + + +__all__ = [ + "Dataset", +] diff --git a/gpjax/types.py b/gpjax/types.py deleted file mode 100644 index dfa9abb7b..000000000 --- a/gpjax/types.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import deprecation -import jaxutils - -Dataset = deprecation.deprecated( - deprecated_in="0.5.5", - removed_in="0.6.0", - details="Use JaxUtils for a Dataset object", -)(jaxutils.Dataset) - -verify_dataset = deprecation.deprecated( - deprecated_in="0.5.5", - removed_in="0.6.0", - details="Use JaxUtils for a Dataset object", -)(jaxutils.verify_dataset) - - -__all__ = ["Dataset" "verify_dataset"] diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 000000000..58a768825 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,90 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import jax.numpy as jnp +import pytest +from jaxutils.dataset import Dataset + + +@pytest.mark.parametrize("n", [1, 10]) +@pytest.mark.parametrize("outd", [1, 2, 10]) +@pytest.mark.parametrize("ind", [1, 2, 10]) +@pytest.mark.parametrize("n2", [1, 10]) +def test_dataset(n: int, outd: int, ind: int, n2: int) -> None: + x = jnp.ones((n, ind)) + y = jnp.ones((n, outd)) + d = Dataset(X=x, y=y) + + assert d.n == n + assert d.in_dim == ind + assert d.out_dim == outd + assert ( + d.__repr__() + == f"- Number of observations: {n}\n- Input dimension: {ind}\n- Output" + f" dimension: {outd}" + ) + + # Test combine datasets. + x2 = 2 * jnp.ones((n2, ind)) + y2 = 2 * jnp.ones((n2, outd)) + d2 = Dataset(X=x2, y=y2) + + d_combined = d + d2 + assert d_combined.n == n + n2 + assert d_combined.in_dim == ind + assert d_combined.out_dim == outd + assert (d_combined.y[:n] == 1.0).all() + assert (d_combined.y[n:] == 2.0).all() + assert (d_combined.X[:n] == 1.0).all() + assert (d_combined.X[n:] == 2.0).all() + + # Test supervised and unsupervised. + assert d.is_supervised() is True + dunsup = Dataset(y=y) + assert dunsup.is_unsupervised() is True + + +@pytest.mark.parametrize("nx, ny", [(1, 2), (2, 1), (10, 5), (5, 10)]) +@pytest.mark.parametrize("outd", [1, 2, 10]) +@pytest.mark.parametrize("ind", [1, 2, 10]) +def test_dataset_assertions(nx: int, ny: int, outd: int, ind: int) -> None: + x = jnp.ones((nx, ind)) + y = jnp.ones((ny, outd)) + + with pytest.raises(ValueError): + Dataset(X=x, y=y) + + +@pytest.mark.parametrize("n", [1, 2, 10]) +@pytest.mark.parametrize("outd", [1, 2, 10]) +@pytest.mark.parametrize("ind", [1, 2, 10]) +def test_2d_inputs(n: int, outd: int, ind: int) -> None: + x = jnp.ones((n, ind)) + y = jnp.ones((n,)) + + with pytest.raises(ValueError): + Dataset(X=x, y=y) + + x = jnp.ones((n,)) + y = jnp.ones((n, outd)) + + with pytest.raises(ValueError): + Dataset(X=x, y=y) + + +def test_y_none() -> None: + x = jnp.ones((10, 1)) + d = Dataset(X=x) + assert d.y is None From 57daf099c19ec071f170b3f47233c11b3091be1a Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Sun, 2 Apr 2023 21:58:05 +0100 Subject: [PATCH 21/44] Commit. --- gpjax/__init__.py | 12 +- gpjax/config.py | 45 -- gpjax/fit.py | 26 +- gpjax/gps.py | 602 +++++-------------------- gpjax/kernels/computations/dense.py | 1 - gpjax/likelihoods.py | 1 - gpjax/module/__init__.py | 4 + gpjax/{parameters => module}/module.py | 0 gpjax/{parameters => module}/param.py | 0 gpjax/natural_gradients.py | 302 ------------- gpjax/objectives.py | 364 +++++++++++++++ gpjax/parameters/__init__.py | 5 - gpjax/parameters/bijectors.py | 34 -- gpjax/params.py | 420 ----------------- gpjax/scan.py | 164 +++++++ gpjax/utils.py | 37 -- gpjax/variational_families.py | 104 ++--- gpjax/variational_inference.py | 283 ------------ tests/test_dataset.py | 2 +- tests/test_fit.py | 22 +- tests/test_gps.py | 181 +++----- tests/test_natural_gradients.py | 263 ----------- tests/test_objectives.py | 178 ++++++++ tests/test_variational_families.py | 22 +- 24 files changed, 967 insertions(+), 2105 deletions(-) delete mode 100644 gpjax/config.py create mode 100644 gpjax/module/__init__.py rename gpjax/{parameters => module}/module.py (100%) rename gpjax/{parameters => module}/param.py (100%) delete mode 100644 gpjax/natural_gradients.py create mode 100644 gpjax/objectives.py delete mode 100644 gpjax/parameters/__init__.py delete mode 100644 gpjax/parameters/bijectors.py delete mode 100644 gpjax/params.py create mode 100644 gpjax/scan.py delete mode 100644 gpjax/utils.py delete mode 100644 gpjax/variational_inference.py delete mode 100644 tests/test_natural_gradients.py create mode 100644 tests/test_objectives.py diff --git a/gpjax/__init__.py b/gpjax/__init__.py index dd47bb6a3..fb3eb83ad 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -14,13 +14,12 @@ # ============================================================================== from . import _version -from .abstractions import fit, fit_batches, fit_natgrads +from .fit import fit from .gps import Prior, construct_posterior from .kernels import * from .likelihoods import Bernoulli, Gaussian from .mean_functions import Constant, Zero -from .params import constrain, copy_dict_structure, initialise, unconstrain -from .types import Dataset +from .dataset import Dataset from .variational_families import ( CollapsedVariationalGaussian, ExpectationVariationalGaussian, @@ -28,7 +27,6 @@ VariationalGaussian, WhitenedVariationalGaussian, ) -from .variational_inference import CollapsedVI, StochasticVI __version__ = _version.get_versions()["version"] __license__ = "MIT" @@ -40,8 +38,6 @@ __all__ = [ "kernels", "fit", - "fit_batches", - "fit_natgrads", "Prior", "construct_posterior", "RBF", @@ -56,10 +52,6 @@ "Gaussian", "Constant", "Zero", - "constrain", - "copy_dict_structure", - "initialise", - "unconstrain", "Dataset", "CollapsedVariationalGaussian", "ExpectationVariationalGaussian", diff --git a/gpjax/config.py b/gpjax/config.py deleted file mode 100644 index 26c592db5..000000000 --- a/gpjax/config.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - - -import deprecation - -depreciate = deprecation.deprecated( - deprecated_in="0.5.6", - removed_in="0.6.0", - details="Use method from jaxutils.config instead.", -) - -from jaxutils import config - -Identity = config.Identity -Softplus = config.Softplus -reset_global_config = depreciate(config.reset_global_config) -get_global_config = depreciate(config.get_global_config) -get_default_config = depreciate(config.get_default_config) -update_x64_sensitive_settings = depreciate(config.update_x64_sensitive_settings) -get_global_config_if_exists = depreciate(config.get_global_config_if_exists) -add_parameter = depreciate(config.add_parameter) - -__all__ = [ - "Identity", - "Softplus", - "reset_global_config", - "get_global_config", - "get_default_config", - "update_x64_sensitive_settings", - "get_global_config_if_exists", - "set_global_config", -] diff --git a/gpjax/fit.py b/gpjax/fit.py index 9388014e9..757f6b463 100644 --- a/gpjax/fit.py +++ b/gpjax/fit.py @@ -25,7 +25,7 @@ from typing import Any from .dataset import Dataset -from .objective import Objective +from .objectives import AbstractObjective from .scan import vscan Module = Any @@ -34,7 +34,7 @@ def fit( *, model: Module, - objective: Objective, + objective: AbstractObjective, train_data: Dataset, optim: ox.GradientTransformation, num_iters: Optional[int] = 100, @@ -96,19 +96,19 @@ def fit( """ # Check inputs. - _check_model(model) - _check_objective(objective) - _check_train_data(train_data) - _check_optim(optim) - _check_num_iters(num_iters) - _check_batch_size(batch_size) - _check_prng_key(key) - _check_log_rate(log_rate) - _check_verbose(verbose) + # _check_model(model) + # _check_objective(objective) + # _check_train_data(train_data) + # _check_optim(optim) + # _check_num_iters(num_iters) + # _check_batch_size(batch_size) + # _check_prng_key(key) + # _check_log_rate(log_rate) + # _check_verbose(verbose) # Unconstrained space loss function with stop-gradient rule for non-trainable params. def loss(model: Module, batch: Dataset) -> Float[Array, "1"]: - model = model.stop_gradients() + model = model.stop_gradient() return objective(model.constrain(), batch) # Unconstrained space model. @@ -175,7 +175,7 @@ def _check_model(model: Any) -> None: def _check_objective(objective: Any) -> None: """Check that the objective is of type Objective.""" - if not isinstance(objective, Objective): + if not isinstance(objective, AbstractObjective): raise TypeError("objective must be of type jaxutils.Objective") diff --git a/gpjax/gps.py b/gpjax/gps.py index cf2b0b9a9..22baa8b8b 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -14,36 +14,31 @@ # ============================================================================== from abc import abstractmethod -from typing import Any, Callable, Dict, Optional +from typing import Any -import deprecation -import distrax as dx import jax.numpy as jnp -from jax.random import KeyArray -from jaxtyping import Array, Float -from jaxutils import Dataset, PyTree - -from .config import get_global_config -from .kernels import AbstractKernel -from .likelihoods import AbstractLikelihood -from .mean_functions import AbstractMeanFunction, Zero -from jaxutils import Dataset -from .utils import concat_dictionaries -from .gaussian_distribution import GaussianDistribution -from .kernels import AbstractKernel from .kernels.base import AbstractKernel -from .likelihoods import AbstractLikelihood, Conjugate, NonConjugate -from .linops import identity -from .mean_functions import AbstractMeanFunction, Zero -from .utils import concat_dictionaries - - -class AbstractPrior(PyTree): - """Abstract Gaussian process prior. +from jaxtyping import Array, Float +from jax.random import KeyArray, PRNGKey, normal - All Gaussian processes priors should inherit from this class.""" +from .dataset import Dataset +from .linops import identity - def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: +from .gaussian_distribution import GaussianDistribution +from .likelihoods import AbstractLikelihood, Gaussian +from .mean_functions import AbstractMeanFunction +from mytree import Mytree +from simple_pytree import static_field +from dataclasses import dataclass + +@dataclass +class AbstractPrior(Mytree): + """Abstract Gaussian process prior.""" + kernel: AbstractKernel + mean_function: AbstractMeanFunction + jitter: float = static_field(1e-6) + + def __call__(self, *args: Any, **kwargs: Any) -> GaussianDistribution: """Evaluate the Gaussian process at the given points. The output of this function is a `Distrax distribution `_ from which the the latent function's mean and covariance can be evaluated and the distribution @@ -58,12 +53,12 @@ def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: **kwargs (Any): The keyword arguments to pass to the GP's `predict` method. Returns: - dx.Distribution: A multivariate normal random variable representation of the Gaussian process. + GaussianDistribution: A multivariate normal random variable representation of the Gaussian process. """ return self.predict(*args, **kwargs) @abstractmethod - def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: + def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution: """Compute the latent function's multivariate normal distribution for a given set of parameters. For any class inheriting the ``AbstractPrior`` class, this method must be implemented. @@ -73,41 +68,15 @@ def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: **kwargs (Any): Keyword arguments to the predict method. Returns: - dx.Distribution: A multivariate normal random variable representation of the Gaussian process. - """ - raise NotImplementedError - - @abstractmethod - def init_params(self, key: KeyArray) -> Dict: - """An initialisation method for the GP's parameters. This method should - be implemented for all classes that inherit the ``AbstractPrior`` class. - Whilst not always necessary, the method accepts a PRNG key to allow - for stochastic initialisation. The method should is most often invoked - through the ``initialise`` function given in GPJax. - - Args: - key (KeyArray): The PRNG key. - - Returns: - Dict: The initialised parameter set. + GaussianDistribution: A multivariate normal random variable representation of the Gaussian process. """ raise NotImplementedError - @deprecation.deprecated( - deprecated_in="0.5.7", - removed_in="0.6.0", - details="Use the ``init_params`` method for parameter initialisation.", - ) - def _initialise_params(self, key: KeyArray) -> Dict: - """Deprecated method for initialising the GP's parameters. Succeeded by ``init_params``.""" - return self.init_params(key) - ####################### # GP Priors ####################### - - +@dataclass class Prior(AbstractPrior): """A Gaussian process prior object. The GP is parameterised by a `mean `_ @@ -131,24 +100,6 @@ class Prior(AbstractPrior): >>> prior = gpx.Prior(kernel = kernel) """ - def __init__( - self, - kernel: AbstractKernel, - mean_function: Optional[AbstractMeanFunction] = Zero(), - name: Optional[str] = "GP prior", - ) -> None: - """Initialise the GP prior. - - Args: - kernel (AbstractKernel): The kernel function used to parameterise the prior. - mean_function (Optional[MeanFunction]): The mean function used to parameterise the - prior. Defaults to zero. - name (Optional[str]): The name of the GP prior. Defaults to "GP prior". - """ - self.kernel = kernel - self.mean_function = mean_function - self.name = name - def __mul__(self, other: AbstractLikelihood): """The product of a prior and likelihood is proportional to the posterior distribution. By computing the product of a GP prior and a @@ -196,9 +147,7 @@ def __rmul__(self, other: AbstractLikelihood): """ return self.__mul__(other) - def predict( - self, params: Dict - ) -> Callable[[Float[Array, "N D"]], GaussianDistribution]: + def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: """Compute the predictive prior distribution for a given set of parameters. The output of this function is a function that computes a Distrax distribution for a given set of inputs. @@ -218,105 +167,67 @@ def predict( >>> prior_predictive(jnp.linspace(0, 1, 100)) Args: - params (Dict): The specific set of parameters for which the mean - function should be defined for. + test_inputs (Float[Array, "N D"]): The inputs at which to evaluate the prior distribution. Returns: - Callable[[Float[Array, "N D"]], GaussianDistribution]: A mean + GaussianDistribution: A mean function that accepts an input array for where the mean function should be evaluated at. The mean function's value at these points is then returned. """ - jitter = get_global_config()["jitter"] - - # Unpack mean function and kernel - mean_function = self.mean_function - kernel = self.kernel + x = test_inputs + mx = self.mean_function(x) + Kxx = self.kernel.gram(x) + Kxx += identity(x.shape[0]) * self.jitter - def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: - - # Unpack test inputs - t = test_inputs - n_test = test_inputs.shape[0] - - μt = mean_function(params["mean_function"], t) - Ktt = kernel.gram(params["kernel"], t) - Ktt += identity(n_test) * jitter - - return GaussianDistribution(jnp.atleast_1d(μt.squeeze()), Ktt) - - return predict_fn - - def init_params(self, key: KeyArray) -> Dict: - """Initialise the GP prior's parameter set. - - Args: - key (KeyArray): The PRNG key. - - Returns: - Dict: The initialised parameter set. - """ - return { - "kernel": self.kernel.init_params(key), - "mean_function": self.mean_function.init_params(key), - } + return GaussianDistribution(jnp.atleast_1d(mx.squeeze()), Kxx) ####################### # GP Posteriors ####################### -class AbstractPosterior(AbstractPrior): +@dataclass +class AbstractPosterior(Mytree): """The base GP posterior object conditioned on an observed dataset. All posterior objects should inherit from this class.""" + prior: AbstractPrior + likelihood: AbstractLikelihood - def __init__( - self, - prior: AbstractPrior, - likelihood: AbstractLikelihood, - name: Optional[str] = "GP posterior", - ) -> None: - """Initialise the GP posterior object. - - Args: - prior (Prior): The prior distribution of the GP. - likelihood (AbstractLikelihood): The likelihood distribution of the observed dataset. - name (Optional[str]): The name of the GP posterior. Defaults to "GP posterior". - """ - self.prior = prior - self.likelihood = likelihood - self.name = name + def __call__(self, *args: Any, **kwargs: Any) -> GaussianDistribution: + """Evaluate the Gaussian process at the given points. The output of this function + is a `Distrax distribution `_ from which the + the latent function's mean and covariance can be evaluated and the distribution + can be sampled. - @abstractmethod - def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution: - """Compute the predictive posterior distribution of the latent function - for a given set of parameters. For any class inheriting the - ``AbstractPosterior`` class, this method must be implemented. + Under the hood, ``__call__`` is calling the objects ``predict`` method. For this + reasons, classes inheriting the ``AbstractPrior`` class, should not overwrite the + ``__call__`` method and should instead define a ``predict`` method. Args: - *args (Any): Arguments to the predict method. **kwargs (Any): - Keyword arguments to the predict method. + *args (Any): The arguments to pass to the GP's `predict` method. + **kwargs (Any): The keyword arguments to pass to the GP's `predict` method. Returns: - GaussianDistribution: A multivariate normal random variable - representation of the Gaussian process. + GaussianDistribution: A multivariate normal random variable representation of the Gaussian process. """ - raise NotImplementedError + return self.predict(*args, **kwargs) - def init_params(self, key: KeyArray) -> Dict: - """Initialise the parameter set of a GP posterior. + @abstractmethod + def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution: + """Compute the latent function's multivariate normal distribution for a + given set of parameters. For any class inheriting the ``AbstractPrior`` class, + this method must be implemented. Args: - key (KeyArray): The PRNG key. + *args (Any): Arguments to the predict method. + **kwargs (Any): Keyword arguments to the predict method. Returns: - Dict: The initialised parameter set. + GaussianDistribution: A multivariate normal random variable representation of the Gaussian process. """ - return concat_dictionaries( - self.prior.init_params(key), - {"likelihood": self.likelihood.init_params(key)}, - ) - + raise NotImplementedError +@dataclass class ConjugatePosterior(AbstractPosterior): """A Gaussian process posterior distribution when the constituent likelihood function is a Gaussian distribution. In such cases, the latent function values @@ -349,29 +260,11 @@ class ConjugatePosterior(AbstractPosterior): >>> >>> posterior = prior * likelihood """ - - def __init__( - self, - prior: AbstractPrior, - likelihood: AbstractLikelihood, - name: Optional[str] = "GP posterior", - ) -> None: - """Initialise the conjugate GP posterior object. - - Args: - prior (Prior): The prior distribution of the GP. - likelihood (AbstractLikelihood): The likelihood distribution of the observed dataset. - name (Optional[str]): The name of the GP posterior. Defaults to "GP posterior". - """ - self.prior = prior - self.likelihood = likelihood - self.name = name - def predict( self, - params: Dict, + test_inputs: Float[Array, "N D"], train_data: Dataset, - ) -> Callable[[Float[Array, "N D"]], GaussianDistribution]: + ) -> GaussianDistribution: """Conditional on a training data set, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the @@ -413,177 +306,44 @@ def predict( input and output data used for training dataset. Returns: - Callable[[Float[Array, "N D"]], GaussianDistribution]: A + GaussianDistribution: A function that accepts an input array and returns the predictive distribution as a ``GaussianDistribution``. """ - jitter = get_global_config()["jitter"] - # Unpack training data x, y, n = train_data.X, train_data.y, train_data.n - # Unpack mean function and kernel - mean_function = self.prior.mean_function - kernel = self.prior.kernel + # Unpack test inputs + t, n_test = test_inputs, test_inputs.shape[0] # Observation noise σ² - obs_noise = params["likelihood"]["obs_noise"] - μx = mean_function(params["mean_function"], x) + obs_noise = self.likelihood.obs_noise + mx = self.prior.mean_function(x) # Precompute Gram matrix, Kxx, at training inputs, x - Kxx = kernel.gram(params["kernel"], x) - Kxx += identity(n) * jitter + Kxx = self.prior.kernel.gram(x) + (identity(n) * self.prior.jitter) # Σ = Kxx + Iσ² Sigma = Kxx + identity(n) * obs_noise + + μt = self.prior.mean_function(t) + Ktt = self.prior.kernel.gram(t) + Kxt = self.prior.kernel.cross_covariance(x, t) - def predict(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: - """Compute the predictive distribution at a set of test inputs. - - Args: - test_inputs (Float[Array, "N D"]): A Jax array of test inputs. - - Returns: - A ``GaussianDistribution`` object that represents the - predictive distribution. - """ + # Σ⁻¹ Kxt + Sigma_inv_Kxt = Sigma.solve(Kxt) - # Unpack test inputs - t = test_inputs - n_test = test_inputs.shape[0] + # μt + Ktx (Kxx + Iσ²)⁻¹ (y - μx) + mean = μt + jnp.matmul(Sigma_inv_Kxt.T, y - mx) - μt = mean_function(params["mean_function"], t) - Ktt = kernel.gram(params["kernel"], t) - Kxt = kernel.cross_covariance(params["kernel"], x, t) + # Ktt - Ktx (Kxx + Iσ²)⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently. + covariance = Ktt - jnp.matmul(Kxt.T, Sigma_inv_Kxt) + covariance += identity(n_test) * self.prior.jitter - # Σ⁻¹ Kxt - Sigma_inv_Kxt = Sigma.solve(Kxt) - - # μt + Ktx (Kxx + Iσ²)⁻¹ (y - μx) - mean = μt + jnp.matmul(Sigma_inv_Kxt.T, y - μx) - - # Ktt - Ktx (Kxx + Iσ²)⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently. - covariance = Ktt - jnp.matmul(Kxt.T, Sigma_inv_Kxt) - covariance += identity(n_test) * jitter - - return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance) - - return predict - - def marginal_log_likelihood( - self, - train_data: Dataset, - negative: bool = False, - ) -> Callable[[Dict], Float[Array, "1"]]: - """Compute the marginal log-likelihood function of the Gaussian process. - The returned function can then be used for gradient based optimisation - of the model's parameters or for model comparison. The implementation - given here enables exact estimation of the Gaussian process' latent - function values. - - For a training dataset :math:`\\{x_n, y_n\\}_{n=1}^N`, set of test - inputs :math:`\\mathbf{x}^{\\star}` the corresponding latent function - evaluations are given by :math:`\\mathbf{f} = f(\\mathbf{x})` - and :math:`\\mathbf{f}^{\\star} = f(\\mathbf{x}^{\\star})`, the marginal - log-likelihood is given by: - - .. math:: - - \\log p(\\mathbf{y}) & = \\int p(\\mathbf{y}\\mid\\mathbf{f})p(\\mathbf{f}, \\mathbf{f}^{\\star}\\mathrm{d}\\mathbf{f}^{\\star}\\\\ - &=0.5\\left(-\\mathbf{y}^{\\top}\\left(k(\\mathbf{x}, \\mathbf{x}') +\\sigma^2\\mathbf{I}_N \\right)^{-1}\\mathbf{y}-\\log\\lvert k(\\mathbf{x}, \\mathbf{x}') + \\sigma^2\\mathbf{I}_N\\rvert - n\\log 2\\pi \\right). - - Example: - - For a given ``ConjugatePosterior`` object, the following code snippet shows - how the marginal log-likelihood can be evaluated. - - >>> import gpjax as gpx - >>> - >>> xtrain = jnp.linspace(0, 1).reshape(-1, 1) - >>> ytrain = jnp.sin(xtrain) - >>> D = gpx.Dataset(X=xtrain, y=ytrain) - >>> - >>> params = gpx.initialise(posterior) - >>> mll = posterior.marginal_log_likelihood(train_data = D) - >>> mll(params) - - Our goal is to maximise the marginal log-likelihood. Therefore, when - optimising the model's parameters with respect to the parameters, we - use the negative marginal log-likelihood. This can be realised through - - >>> mll = posterior.marginal_log_likelihood(train_data = D, negative=True) - - Further, prior distributions can be passed into the marginal log-likelihood - - >>> mll = posterior.marginal_log_likelihood(train_data = D) - - For optimal performance, the marginal log-likelihood should be ``jax.jit`` - compiled. - - >>> mll = jit(posterior.marginal_log_likelihood(train_data = D)) - - Args: - train_data (Dataset): The training dataset used to compute the - marginal log-likelihood. - negative (Optional[bool]): Whether or not the returned function - should be negative. For optimisation, the negative is useful - as minimisation of the negative marginal log-likelihood is - equivalent to maximisation of the marginal log-likelihood. - Defaults to False. - - Returns: - Callable[[Dict], Float[Array, "1"]]: A functional representation - of the marginal log-likelihood that can be evaluated at a - given parameter set. - """ - jitter = get_global_config()["jitter"] - - # Unpack training data - x, y, n = train_data.X, train_data.y, train_data.n - - # Unpack mean function and kernel - mean_function = self.prior.mean_function - kernel = self.prior.kernel - - # The sign of the marginal log-likelihood depends on whether we are maximising or minimising - constant = jnp.array(-1.0) if negative else jnp.array(1.0) - - def mll( - params: Dict, - ): - """Compute the marginal log-likelihood of the Gaussian process. - - Args: - params (Dict): The model's parameters. - - Returns: - Float[Array, "1"]: The marginal log-likelihood. - """ - - # Observation noise σ² - obs_noise = params["likelihood"]["obs_noise"] - μx = mean_function(params["mean_function"], x) - - # TODO: This implementation does not take advantage of the covariance operator structure. - # Future work concerns implementation of a custom Gaussian distribution / measure object that accepts a covariance operator. - - # Σ = (Kxx + Iσ²) = LLᵀ - Kxx = kernel.gram(params["kernel"], x) - Kxx += identity(n) * jitter - Sigma = Kxx + identity(n) * obs_noise - - # p(y | x, θ), where θ are the model hyperparameters: - marginal_likelihood = GaussianDistribution( - jnp.atleast_1d(μx.squeeze()), Sigma - ) - - return constant * ( - marginal_likelihood.log_prob(jnp.atleast_1d(y.squeeze())).squeeze() - ) - - return mll + return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance) +@dataclass class NonConjugatePosterior(AbstractPosterior): """ A Gaussian process posterior object for models where the likelihood is @@ -595,45 +355,15 @@ class NonConjugatePosterior(AbstractPosterior): variational inference, or Laplace approximations can then be used to sample from, or optimise an approximation to, the posterior distribution. """ + latent: Float[Array, "N 1"] = None + key: KeyArray = PRNGKey(42) - def __init__( - self, - prior: AbstractPrior, - likelihood: AbstractLikelihood, - name: Optional[str] = "GP posterior", - ) -> None: - """Initialise a non-conjugate Gaussian process posterior object. - - Args: - prior (AbstractPrior): The Gaussian process prior distribution. - likelihood (AbstractLikelihood): The likelihood function that represents the data. - name (Optional[str]): The name of the posterior object. Defaults to "GP posterior". - """ - self.prior = prior - self.likelihood = likelihood - self.name = name - - def init_params(self, key: KeyArray) -> Dict: - """Initialise the parameter set of a non-conjugate GP posterior. + def __post_init__(self): + if self.latent is None: + self.latent = normal(self.key, shape=(self.likelihood.num_datapoints, 1)) + - Args: - key (KeyArray): A PRNG key used to initialise the parameters. - - Returns: - Dict: A dictionary containing the default parameter set. - """ - parameters = concat_dictionaries( - self.prior.init_params(key), - {"likelihood": self.likelihood.init_params(key)}, - ) - parameters["latent"] = jnp.zeros(shape=(self.likelihood.num_datapoints, 1)) - return parameters - - def predict( - self, - params: Dict, - train_data: Dataset, - ) -> Callable[[Float[Array, "N D"]], dx.Distribution]: + def predict(self, test_inputs: Float[Array, "N D"], train_data: Dataset) -> GaussianDistribution: """ Conditional on a set of training data, compute the GP's posterior predictive distribution for a given set of parameters. The returned @@ -643,18 +373,14 @@ def predict( transformed through the likelihood function's inverse link function. Args: - params (Dict): A dictionary of parameters that should be used to - compute the posterior. train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset. Returns: - Callable[[Array], dx.Distribution]: A function that accepts an + GaussianDistribution: A function that accepts an input array and returns the predictive distribution as a ``dx.Distribution``. """ - jitter = get_global_config()["jitter"] - # Unpack training data x, n = train_data.X, train_data.n @@ -663,135 +389,35 @@ def predict( kernel = self.prior.kernel # Precompute lower triangular of Gram matrix, Lx, at training inputs, x - Kxx = kernel.gram(params["kernel"], x) - Kxx += identity(n) * jitter + Kxx = kernel.gram(x) + Kxx += identity(n) * self.prior.jitter Lx = Kxx.to_root() - def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: - """Predictive distribution of the latent function for a given set of test inputs. - - Args: - test_inputs (Float[Array, "N D"]): A set of test inputs. - - Returns: - dx.Distribution: The predictive distribution of the latent function. - """ - - # Unpack test inputs - t, n_test = test_inputs, test_inputs.shape[0] - - # Compute terms of the posterior predictive distribution - Ktx = kernel.cross_covariance(params["kernel"], t, x) - Ktt = kernel.gram(params["kernel"], t) + identity(n_test) * jitter - μt = mean_function(params["mean_function"], t) - - # Lx⁻¹ Kxt - Lx_inv_Kxt = Lx.solve(Ktx.T) - - # Whitened function values, wx, corresponding to the inputs, x - wx = params["latent"] - - # μt + Ktx Lx⁻¹ wx - mean = μt + jnp.matmul(Lx_inv_Kxt.T, wx) - - # Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently. - covariance = Ktt - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt) - covariance += identity(n_test) * jitter - - return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance) - - return predict_fn - - def marginal_log_likelihood( - self, - train_data: Dataset, - negative: bool = False, - ) -> Callable[[Dict], Float[Array, "1"]]: - """ - Compute the marginal log-likelihood function of the Gaussian process. - The returned function can then be used for gradient based optimisation - of the model's parameters or for model comparison. The implementation - given here is general and will work for any likelihood support by GPJax. - - Unlike the marginal_log_likelihood function of the ConjugatePosterior - object, the marginal_log_likelihood function of the - ``NonConjugatePosterior`` object does not provide an exact marginal - log-likelihood function. Instead, the ``NonConjugatePosterior`` object - represents the posterior distributions as a function of the model's - hyperparameters and the latent function. Markov chain Monte Carlo, - variational inference, or Laplace approximations can then be used to - sample from, or optimise an approximation to, the posterior - distribution. - - Args: - train_data (Dataset): The training dataset used to compute the - marginal log-likelihood. - negative (Optional[bool]): Whether or not the returned function - should be negative. For optimisation, the negative is useful as - minimisation of the negative marginal log-likelihood is equivalent - to maximisation of the marginal log-likelihood. Defaults to False. - - Returns: - Callable[[Dict], Float[Array, "1"]]: A functional representation - of the marginal log-likelihood that can be evaluated at a given - parameter set. - """ - jitter = get_global_config()["jitter"] + # Unpack test inputs + t, n_test = test_inputs, test_inputs.shape[0] - # Unpack dataset - x, y, n = train_data.X, train_data.y, train_data.n - - # Unpack mean function and kernel - mean_function = self.prior.mean_function - kernel = self.prior.kernel + # Compute terms of the posterior predictive distribution + Ktx = kernel.cross_covariance(t, x) + Ktt = kernel.gram(t) + identity(n_test) * self.prior.jitter + μt = mean_function(t) - # Link function of the likelihood - link_function = self.likelihood.link_function + # Lx⁻¹ Kxt + Lx_inv_Kxt = Lx.solve(Ktx.T) - # The sign of the marginal log-likelihood depends on whether we are maximising or minimising - constant = jnp.array(-1.0) if negative else jnp.array(1.0) + # Whitened function values, wx, corresponding to the inputs, x + wx = self.latent - def mll(params: Dict): - """Compute the marginal log-likelihood of the model. + # μt + Ktx Lx⁻¹ wx + mean = μt + jnp.matmul(Lx_inv_Kxt.T, wx) - Args: - params (Dict): A dictionary of parameters that should be used - to compute the marginal log-likelihood. + # Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently. + covariance = Ktt - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt) + covariance += identity(n_test) * self.prior.jitter - Returns: - Float[Array, "1"]: The marginal log-likelihood of the model. - """ + return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance) - # Compute lower triangular of the kernel Gram matrix - Kxx = kernel.gram(params["kernel"], x) - Kxx += identity(n) * jitter - Lx = Kxx.to_root() - # Compute the prior mean function - μx = mean_function(params["mean_function"], x) - - # Whitened function values, wx, corresponding to the inputs, x - wx = params["latent"] - - # f(x) = μx + Lx wx - fx = μx + Lx @ wx - - # p(y | f(x), θ), where θ are the model hyperparameters - likelihood = link_function(params, fx) - - # Whitened latent function values prior, p(wx | θ) = N(0, I) - latent_prior = dx.Normal(loc=0.0, scale=1.0) - - return constant * ( - likelihood.log_prob(y).sum() + latent_prior.log_prob(wx).sum() - ) - - return mll - - -def construct_posterior( - prior: Prior, likelihood: AbstractLikelihood -) -> AbstractPosterior: +def construct_posterior(prior: Prior, likelihood: AbstractLikelihood) -> AbstractPosterior: """Utility function for constructing a posterior object from a prior and likelihood. The function will automatically select the correct posterior object based on the likelihood. @@ -806,18 +432,10 @@ def construct_posterior( Gaussian, then a ``ConjugatePosterior`` will be returned. Otherwise, a ``NonConjugatePosterior`` will be returned. """ - if isinstance(likelihood, Conjugate): - PosteriorGP = ConjugatePosterior - - elif isinstance(likelihood, NonConjugate): - PosteriorGP = NonConjugatePosterior - - else: - raise NotImplementedError( - f"No posterior implemented for {likelihood.name} likelihood" - ) + if isinstance(likelihood, Gaussian): + return ConjugatePosterior(prior=prior, likelihood=likelihood) - return PosteriorGP(prior=prior, likelihood=likelihood) + return NonConjugatePosterior(prior=prior, likelihood=likelihood) __all__ = [ @@ -827,4 +445,4 @@ def construct_posterior( "ConjugatePosterior", "NonConjugatePosterior", "construct_posterior", -] +] \ No newline at end of file diff --git a/gpjax/kernels/computations/dense.py b/gpjax/kernels/computations/dense.py index d55bb034b..c64981feb 100644 --- a/gpjax/kernels/computations/dense.py +++ b/gpjax/kernels/computations/dense.py @@ -17,7 +17,6 @@ from jaxtyping import Array, Float from .base import AbstractKernelComputation - class DenseKernelComputation(AbstractKernelComputation): """Dense kernel computation class. Operations with the kernel assume a dense gram matrix structure. diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index 152d58152..478720fe5 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -17,7 +17,6 @@ from typing import Any from .linops.utils import to_dense -import deprecation import distrax as dx import jax.numpy as jnp import jax.scipy as jsp diff --git a/gpjax/module/__init__.py b/gpjax/module/__init__.py new file mode 100644 index 000000000..140ea2d8f --- /dev/null +++ b/gpjax/module/__init__.py @@ -0,0 +1,4 @@ +from .module import Module +from .param import param_field + +__all__ = ["Module", "param_field"] diff --git a/gpjax/parameters/module.py b/gpjax/module/module.py similarity index 100% rename from gpjax/parameters/module.py rename to gpjax/module/module.py diff --git a/gpjax/parameters/param.py b/gpjax/module/param.py similarity index 100% rename from gpjax/parameters/param.py rename to gpjax/module/param.py diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py deleted file mode 100644 index acdf5bd36..000000000 --- a/gpjax/natural_gradients.py +++ /dev/null @@ -1,302 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from copy import deepcopy -from typing import Callable, Dict, Tuple - -import jax.numpy as jnp -import jax.scipy as jsp -from jax import value_and_grad -from jaxtyping import Array, Float -from jaxutils import Dataset - -from .config import get_global_config -from .gps import AbstractPosterior -from .params import build_trainables, constrain, trainable_params -from .variational_families import ( - AbstractVariationalFamily, - ExpectationVariationalGaussian, - NaturalVariationalGaussian, -) -from .variational_inference import StochasticVI - - -def natural_to_expectation(params: Dict) -> Dict: - """ - Translate natural parameters to expectation parameters. - - In particular, in terms of the Gaussian mean μ and covariance matrix μ for - the Gaussian variational family, - - - the natural parameterisation is θ = (S⁻¹μ, -S⁻¹/2) - - the expectation parameters are η = (μ, S + μ μᵀ). - - This function solves these equations in terms of μ and S to convert θ to η. - - Writing θ = (θ₁, θ₂), we have that S⁻¹ = -2θ₂ . Taking the cholesky - decomposition of the inverse covariance, S⁻¹ = LLᵀ and defining C = L⁻¹, we - have S = CᵀC and μ = Sθ₁ = CᵀC θ₁. - - Now from here, using μ and S found from θ, we compute η as η₁ = μ, and η₂ = S + μ μᵀ. - - Args: - params: A dictionary of variational Gaussian parameters under the natural - parameterisation. - - Returns: - Dict: A dictionary of Gaussian moments under the expectation parameterisation. - """ - - natural_matrix = params["variational_family"]["moments"]["natural_matrix"] - natural_vector = params["variational_family"]["moments"]["natural_vector"] - m = natural_vector.shape[0] - - # S⁻¹ = -2θ₂ - S_inv = -2 * natural_matrix - jitter = get_global_config()["jitter"] - S_inv += jnp.eye(m) * jitter - - # S⁻¹ = LLᵀ - L = jnp.linalg.cholesky(S_inv) - - # C = L⁻¹I - C = jsp.linalg.solve_triangular(L, jnp.eye(m), lower=True) - - # S = CᵀC - S = jnp.matmul(C.T, C) - - # μ = Sθ₁ - mu = jnp.matmul(S, natural_vector) - - # η₁ = μ - expectation_vector = mu - - # η₂ = S + μ μᵀ - expectation_matrix = S + jnp.matmul(mu, mu.T) - - params["variational_family"]["moments"] = { - "expectation_vector": expectation_vector, - "expectation_matrix": expectation_matrix, - } - - return params - - -def _expectation_elbo( - posterior: AbstractPosterior, - variational_family: AbstractVariationalFamily, - train_data: Dataset, -) -> Callable[[Dict, Dataset], float]: - """ - Construct evidence lower bound (ELBO) for variational Gaussian under the - expectation parameterisation. - - Args: - posterior: An instance of AbstractPosterior. - variational_family: An instance of AbstractVariationalFamily. - - Returns: - Callable[[Dict, Dataset], float]: A function that computes the ELBO. - """ - expectation_vartiational_gaussian = ExpectationVariationalGaussian( - prior=variational_family.prior, - inducing_inputs=variational_family.inducing_inputs, - ) - svgp = StochasticVI( - posterior=posterior, variational_family=expectation_vartiational_gaussian - ) - - return svgp.elbo(train_data, negative=True) - - -def _rename_expectation_to_natural(params: Dict) -> Dict: - """ - This function renames the gradient components (that have expectation - parameterisation keys) to match the natural parameterisation PyTree. - - Args: - params (Dict): A dictionary of variational Gaussian parameters - under the expectation parameterisation moment names. - - Returns: - Dict: A dictionary of variational Gaussian parameters under the - natural parameterisation moment names. - """ - params["variational_family"]["moments"] = { - "natural_vector": params["variational_family"]["moments"]["expectation_vector"], - "natural_matrix": params["variational_family"]["moments"]["expectation_matrix"], - } - - return params - - -def _rename_natural_to_expectation(params: Dict) -> Dict: - """ - This function renames the gradient components (that have natural - parameterisation keys) to match the expectation parameterisation PyTree. - - Args: - params (Dict): A dictionary of variational Gaussian parameters - under the natural parameterisation moment names. - - Returns: - Dict: A dictionary of variational Gaussian parameters under - the expectation parameterisation moment names. - """ - params["variational_family"]["moments"] = { - "expectation_vector": params["variational_family"]["moments"]["natural_vector"], - "expectation_matrix": params["variational_family"]["moments"]["natural_matrix"], - } - - return params - - -def _get_moment_trainables(trainables: Dict) -> Dict: - """ - This function takes a trainables dictionary, and sets non-moment parameter - training to false for gradient stopping. - - Args: - trainables (Dict): A dictionary of trainables. - - Returns: - Dict: A dictionary of trainables with non-moment parameters set to False. - """ - expectation_trainables = _rename_natural_to_expectation(deepcopy(trainables)) - moment_trainables = build_trainables(expectation_trainables, False) - moment_trainables["variational_family"]["moments"] = expectation_trainables[ - "variational_family" - ]["moments"] - - return moment_trainables - - -def _get_hyperparameter_trainables(trainables: Dict) -> Dict: - """ - This function takes a trainables dictionary, and sets moment parameter - training to false for gradient stopping. - - Args: - trainables (Dict): A dictionary of trainables. - - Returns: - Dict: A dictionary of trainables with moment parameters set to False. - """ - hyper_trainables = deepcopy(trainables) - hyper_trainables["variational_family"]["moments"] = build_trainables( - trainables["variational_family"]["moments"], False - ) - - return hyper_trainables - - -def natural_gradients( - stochastic_vi: StochasticVI, - train_data: Dataset, - bijectors: Dict, - trainables: Dict, -) -> Tuple[Callable[[Dict, Dataset], Dict]]: - """ - Computes the gradient with respect to the natural parameters. Currently only - implemented for the natural variational Gaussian family. - - Args: - posterior: An instance of AbstractPosterior. - variational_family: An instance of AbstractVariationalFamily. - train_data: A Dataset. - bijectors: A dictionary of bijectors. - - Returns: - Tuple[Callable[[Dict, Dataset], Dict]]: Functions that compute natural - gradients and hyperparameter gradients respectively. - """ - posterior = stochastic_vi.posterior - variational_family = stochastic_vi.variational_family - - # The ELBO under the user chosen parameterisation xi. - xi_elbo = stochastic_vi.elbo(train_data, negative=True) - - # The ELBO under the expectation parameterisation, L(η). - expectation_elbo = _expectation_elbo(posterior, variational_family, train_data) - - # Trainable dictionaries for alternating gradient updates. - moment_trainables = _get_moment_trainables(trainables) - hyper_trainables = _get_hyperparameter_trainables(trainables) - - if isinstance(variational_family, NaturalVariationalGaussian): - - def nat_grads_fn(params: Dict, batch: Dataset) -> Dict: - """ - Computes the natural gradients of the ELBO. - - Args: - params (Dict): A dictionary of parameters. - batch (Dataset): A Dataset. - - Returns: - Dict: A dictionary of natural gradients. - """ - # Transform parameters to constrained space. - params = constrain(params, bijectors) - - # Convert natural parameterisation θ to the expectation parameterisation η. - expectation_params = natural_to_expectation(params) - - # Compute gradient ∂L/∂η: - def loss_fn(params: Dict, batch: Dataset) -> Float[Array, "1"]: - # Stop gradients for non-trainable and non-moment parameters. - params = trainable_params(params, moment_trainables) - - return expectation_elbo(params, batch) - - value, dL_dexp = value_and_grad(loss_fn)(expectation_params, batch) - - nat_grad = _rename_expectation_to_natural(dL_dexp) - - return value, nat_grad - - else: - raise NotImplementedError - - def hyper_grads_fn(params: Dict, batch: Dataset) -> Dict: - """ - Computes the hyperparameter gradients of the ELBO. - - Args: - params (Dict): A dictionary of parameters. - batch (Dataset): A Dataset. - - Returns: - Dict: A dictionary of hyperparameter gradients. - """ - - def loss_fn(params: Dict, batch: Dataset) -> Float[Array, "1"]: - # Stop gradients for non-trainable and moment parameters. - params = constrain(params, bijectors) - params = trainable_params(params, hyper_trainables) - - return xi_elbo(params, batch) - - value, dL_dhyper = value_and_grad(loss_fn)(params, batch) - - return value, dL_dhyper - - return nat_grads_fn, hyper_grads_fn - - -__all__ = [ - "natural_to_expectation", - "natural_gradients", -] diff --git a/gpjax/objectives.py b/gpjax/objectives.py new file mode 100644 index 000000000..e71755e84 --- /dev/null +++ b/gpjax/objectives.py @@ -0,0 +1,364 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .gps import ConjugatePosterior, NonConjugatePosterior + from .variational_families import AbstractVariationalFamily + +from abc import abstractmethod + +import distrax as dx +from jax import vmap +import jax.numpy as jnp +import jax.scipy as jsp +from .linops import identity +from jaxtyping import Array, Float + +from .dataset import Dataset +from .gaussian_distribution import GaussianDistribution +from .quadrature import gauss_hermite_quadrature + +from mytree import Mytree +from dataclasses import dataclass +from simple_pytree import static_field + +import jax.tree_util as jtu + +@dataclass +class AbstractObjective(Mytree): + """Abstract base class for objectives.""" + negative: bool = static_field(False) + constant: float = static_field(init=False, repr=False) + + def __post_init__(self) -> None: + self.constant = jnp.array(-1.0) if self.negative else jnp.array(1.0) + + def __hash__(self): + return hash(tuple(jtu.tree_leaves(self))) # Probably put this on the Module! + + @abstractmethod + def __call__(self, *args, **kwargs) -> Float[Array, "1"]: + raise NotImplementedError + + +class ConjugateMLL(AbstractObjective): + + def __call__(self, posterior: ConjugatePosterior, train_data: Dataset) -> Float[Array, "1"]: + """Compute the marginal log-likelihood function of the Gaussian process. + The returned function can then be used for gradient based optimisation + of the model's parameters or for model comparison. The implementation + given here enables exact estimation of the Gaussian process' latent + function values. + + For a training dataset :math:`\\{x_n, y_n\\}_{n=1}^N`, set of test + inputs :math:`\\mathbf{x}^{\\star}` the corresponding latent function + evaluations are given by :math:`\\mathbf{f} = f(\\mathbf{x})` + and :math:`\\mathbf{f}^{\\star} = f(\\mathbf{x}^{\\star})`, the marginal + log-likelihood is given by: + + .. math:: + + \\log p(\\mathbf{y}) & = \\int p(\\mathbf{y}\\mid\\mathbf{f})p(\\mathbf{f}, \\mathbf{f}^{\\star}\\mathrm{d}\\mathbf{f}^{\\star}\\\\ + &=0.5\\left(-\\mathbf{y}^{\\top}\\left(k(\\mathbf{x}, \\mathbf{x}') +\\sigma^2\\mathbf{I}_N \\right)^{-1}\\mathbf{y}-\\log\\lvert k(\\mathbf{x}, \\mathbf{x}') + \\sigma^2\\mathbf{I}_N\\rvert - n\\log 2\\pi \\right). + + Example: + + For a given ``ConjugatePosterior`` object, the following code snippet shows + how the marginal log-likelihood can be evaluated. + + >>> import gpjax as gpx + >>> + >>> xtrain = jnp.linspace(0, 1).reshape(-1, 1) + >>> ytrain = jnp.sin(xtrain) + >>> D = gpx.Dataset(X=xtrain, y=ytrain) + >>> + >>> params = gpx.initialise(posterior) + >>> mll = posterior.marginal_log_likelihood(train_data = D) + >>> mll(params) + + Our goal is to maximise the marginal log-likelihood. Therefore, when + optimising the model's parameters with respect to the parameters, we + use the negative marginal log-likelihood. This can be realised through + + >>> mll = posterior.marginal_log_likelihood(train_data = D, negative=True) + + Further, prior distributions can be passed into the marginal log-likelihood + + >>> mll = posterior.marginal_log_likelihood(train_data = D) + + For optimal performance, the marginal log-likelihood should be ``jax.jit`` + compiled. + + >>> mll = jit(posterior.marginal_log_likelihood(train_data = D)) + + Args: + train_data (Dataset): The training dataset used to compute the + marginal log-likelihood. + negative (Optional[bool]): Whether or not the returned function + should be negative. For optimisation, the negative is useful + as minimisation of the negative marginal log-likelihood is + equivalent to maximisation of the marginal log-likelihood. + Defaults to False. + + Returns: + Callable[[Parameters], Float[Array, "1"]]: A functional representation + of the marginal log-likelihood that can be evaluated at a + given parameter set. + """ + + x, y, n = train_data.X, train_data.y, train_data.n + + # Observation noise σ² + obs_noise = posterior.likelihood.obs_noise + mx = posterior.prior.mean_function(x) + + # Σ = (Kxx + Iσ²) = LLᵀ + Kxx = posterior.prior.kernel.gram(x) + Kxx += identity(n) * posterior.prior.jitter + Sigma = Kxx + identity(n) * obs_noise + + # p(y | x, θ), where θ are the model hyperparameters: + mll = GaussianDistribution(jnp.atleast_1d(mx.squeeze()), Sigma) + + return self.constant * (mll.log_prob(jnp.atleast_1d(y.squeeze())).squeeze()) + + +class NonConjugateMLL(AbstractObjective): + + def __call__(self, posterior: NonConjugatePosterior, data: Dataset) -> Float[Array, "1"]: + """ + Compute the marginal log-likelihood function of the Gaussian process. + The returned function can then be used for gradient based optimisation + of the model's parameters or for model comparison. The implementation + given here is general and will work for any likelihood support by GPJax. + + Unlike the marginal_log_likelihood function of the ConjugatePosterior + object, the marginal_log_likelihood function of the + ``NonConjugatePosterior`` object does not provide an exact marginal + log-likelihood function. Instead, the ``NonConjugatePosterior`` object + represents the posterior distributions as a function of the model's + hyperparameters and the latent function. Markov chain Monte Carlo, + variational inference, or Laplace approximations can then be used to + sample from, or optimise an approximation to, the posterior + distribution. + + Args: + train_data (Dataset): The training dataset used to compute the + marginal log-likelihood. + negative (Optional[bool]): Whether or not the returned function + should be negative. For optimisation, the negative is useful as + minimisation of the negative marginal log-likelihood is equivalent + to maximisation of the marginal log-likelihood. Defaults to False. + + Returns: + Callable[[Parameters], Float[Array, "1"]]: A functional representation + of the marginal log-likelihood that can be evaluated at a given + parameter set. + """ + # Unpack the training data + x, y, n = data.X, data.y, data.n + Kxx = posterior.prior.kernel.gram(x) + Kxx += identity(n) * posterior.prior.jitter + Lx = Kxx.to_root() + + # Compute the prior mean function + mx = posterior.prior.mean_function(x) + + # Whitened function values, wx, corresponding to the inputs, x + wx = posterior.latent + + # f(x) = mx + Lx wx + fx = mx + Lx @ wx + + # p(y | f(x), θ), where θ are the model hyperparameters + likelihood = posterior.likelihood.link_function(fx) + + # Whitened latent function values prior, p(wx | θ) = N(0, I) + latent_prior = dx.Normal(loc=0.0, scale=1.0) + + return self.constant * ( + likelihood.log_prob(y).sum() + latent_prior.log_prob(wx).sum() + ) + + +class ELBO(AbstractObjective): + + def __call__(self, variational_family: AbstractVariationalFamily, train_data: Dataset) -> Float[Array, "1"]: + """Compute the evidence lower bound under this model. In short, this requires + evaluating the expectation of the model's log-likelihood under the variational + approximation. To this, we sum the KL divergence from the variational posterior + to the prior. When batching occurs, the result is scaled by the batch size + relative to the full dataset size. + + Args: + params (Parameters): The set of parameters that induce our variational + approximation. + train_data (Dataset): The training data for which we should maximise the + ELBO with respect to. + negative (bool, optional): Whether or not the resultant elbo function should + be negative. For gradient descent where we optimise our objective + function this argument should be true as minimisation of the negative + corresponds to maximisation of the ELBO. Defaults to False. + + Returns: + Callable[[Parameters, Dataset], Array]: A callable function that accepts a + current parameter estimate and batch of data for which gradients should + be computed. + """ + + # KL[q(f(·)) || p(f(·))] + kl = variational_family.prior_kl() + + # ∫[log(p(y|f(·))) q(f(·))] df(·) + var_exp = variational_expectation(variational_family, train_data) + + # For batch size b, we compute n/b * Σᵢ[ ∫log(p(y|f(xᵢ))) q(f(xᵢ)) df(xᵢ)] - KL[q(f(·)) || p(f(·))] + return self.constant * ( + jnp.sum(var_exp) * variational_family.posterior.likelihood.num_datapoints / train_data.n - kl + ) + + +LogPosteriorDensity = NonConjugateMLL + + +def variational_expectation( + variational_family: AbstractVariationalFamily, + train_data: Dataset, +) -> Float[Array, "N 1"]: + """Compute the expectation of our model's log-likelihood under our variational + distribution. Batching can be done here to speed up computation. + + Args: + variational_family (AbstractVariationalFamily): The variational family that we are using to approximate the posterior. + train_data (Dataset): The batch for which the expectation should be computed for. + + Returns: + Array: The expectation of the model's log-likelihood under our variational + distribution. + """ + + # Unpack training batch + x, y = train_data.X, train_data.y + + # Variational distribution q(f(·)) = N(f(·); μ(·), Σ(·, ·)) + q = variational_family + + # Compute variational mean, μ(x), and variance, √diag(Σ(x, x)), at the training + # inputs, x + def q_moments(x): + qx = q(x) + return qx.mean(), qx.variance() + + mean, variance = vmap(q_moments)(x[:, None]) + + # log(p(y|f(x))) + link_function = variational_family.posterior.likelihood.link_function + log_prob = vmap(lambda f, y: link_function(f).log_prob(y)) + + # ≈ ∫[log(p(y|f(x))) q(f(x))] df(x) + expectation = gauss_hermite_quadrature(log_prob, mean, jnp.sqrt(variance), y=y) + + return expectation + + +class CollapsedELBO(AbstractObjective): + """Collapsed variational inference for a sparse Gaussian process regression model. + The key reference is Titsias, (2009) - Variational Learning of Inducing Variables + in Sparse Gaussian Processes. + """ + + def __call__(self, variational_family: AbstractVariationalFamily, train_data: Dataset) -> Float[Array, "1"]: + """Compute the evidence lower bound under this model. In short, this requires + evaluating the expectation of the model's log-likelihood under the variational + approximation. To this, we sum the KL divergence from the variational posterior + to the prior. When batching occurs, the result is scaled by the batch size + relative to the full dataset size. + + Args: + train_data (Dataset): The training data for which we should maximise the + ELBO with respect to. + negative (bool, optional): Whether or not the resultant elbo function should + be negative. For gradient descent where we optimise our objective + function this argument should be true as minimisation of the negative + corresponds to maximisation of the ELBO. Defaults to False. + + Returns: + Callable[[Parameters, Dataset], Array]: A callable function that accepts a + current parameter estimate for which gradients should be computed. + """ + + # Unpack training data + x, y, n = train_data.X, train_data.y, train_data.n + + # Unpack mean function and kernel + mean_function = variational_family.posterior.prior.mean_function + kernel = variational_family.posterior.prior.kernel + + m = variational_family.num_inducing + + noise = variational_family.posterior.likelihood.obs_noise + z = variational_family.inducing_inputs + Kzz = kernel.gram(z) + Kzz += identity(m) * variational_family.jitter + Kzx = kernel.cross_covariance(z, x) + Kxx_diag = vmap(kernel, in_axes=(0, 0))(x, x) + μx = mean_function(x) + + Lz = Kzz.to_root() + + # Notation and derivation: + # + # Let Q = KxzKzz⁻¹Kzx, we must compute the log normal pdf: + # + # log N(y; μx, σ²I + Q) = -nπ - n/2 log|σ²I + Q| + # - 1/2 (y - μx)ᵀ (σ²I + Q)⁻¹ (y - μx). + # + # The log determinant |σ²I + Q| is computed via applying the matrix determinant + # lemma + # + # |σ²I + Q| = log|σ²I| + log|I + Lz⁻¹ Kzx (σ²I)⁻¹ Kxz Lz⁻¹| = log(σ²) + log|B|, + # + # with B = I + AAᵀ and A = Lz⁻¹ Kzx / σ. + # + # Similarly we apply matrix inversion lemma to invert σ²I + Q + # + # (σ²I + Q)⁻¹ = (Iσ²)⁻¹ - (Iσ²)⁻¹ Kxz Lz⁻ᵀ (I + Lz⁻¹ Kzx (Iσ²)⁻¹ Kxz Lz⁻ᵀ )⁻¹ Lz⁻¹ Kzx (Iσ²)⁻¹ + # = (Iσ²)⁻¹ - (Iσ²)⁻¹ σAᵀ (I + σA (Iσ²)⁻¹ σAᵀ)⁻¹ σA (Iσ²)⁻¹ + # = I/σ² - Aᵀ B⁻¹ A/σ², + # + # giving the quadratic term as + # + # (y - μx)ᵀ (σ²I + Q)⁻¹ (y - μx) = [(y - μx)ᵀ(y - µx) - (y - μx)ᵀ Aᵀ B⁻¹ A (y - μx)]/σ², + # + # with A and B defined as above. + + A = Lz.solve(Kzx) / jnp.sqrt(noise) + + # AAᵀ + AAT = jnp.matmul(A, A.T) + + # B = I + AAᵀ + B = jnp.eye(m) + AAT + + # LLᵀ = I + AAᵀ + L = jnp.linalg.cholesky(B) + + # log|B| = 2 trace(log|L|) = 2 Σᵢ log Lᵢᵢ [since |B| = |LLᵀ| = |L|² => log|B| = 2 log|L|, and |L| = Πᵢ Lᵢᵢ] + log_det_B = 2.0 * jnp.sum(jnp.log(jnp.diagonal(L))) + + diff = y - μx + + # L⁻¹ A (y - μx) + L_inv_A_diff = jsp.linalg.solve_triangular(L, jnp.matmul(A, diff), lower=True) + + # (y - μx)ᵀ (Iσ² + Q)⁻¹ (y - μx) + quad = (jnp.sum(diff**2) - jnp.sum(L_inv_A_diff**2)) / noise + + # 2 * log N(y; μx, Iσ² + Q) + two_log_prob = -n * jnp.log(2.0 * jnp.pi * noise) - log_det_B - quad + + # 1/σ² tr(Kxx - Q) [Trace law tr(AB) = tr(BA) => tr(KxzKzz⁻¹Kzx) = tr(KxzLz⁻ᵀLz⁻¹Kzx) = tr(Lz⁻¹Kzx KxzLz⁻ᵀ) = trace(σ²AAᵀ)] + two_trace = jnp.sum(Kxx_diag) / noise - jnp.trace(AAT) + + # log N(y; μx, Iσ² + KxzKzz⁻¹Kzx) - 1/2σ² tr(Kxx - KxzKzz⁻¹Kzx) + return self.constant * (two_log_prob - two_trace).squeeze() / 2.0 \ No newline at end of file diff --git a/gpjax/parameters/__init__.py b/gpjax/parameters/__init__.py deleted file mode 100644 index 40aff12bb..000000000 --- a/gpjax/parameters/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .bijectors import Identity, Softplus -from .module import Module -from .param import param_field - -__all__ = ["Identity", "Module", "Softplus", "param_field"] diff --git a/gpjax/parameters/bijectors.py b/gpjax/parameters/bijectors.py deleted file mode 100644 index 5a257629e..000000000 --- a/gpjax/parameters/bijectors.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import annotations - -__all__ = ["Bijector", "Identity", "Softplus"] - -from dataclasses import dataclass -from typing import Callable - -import jax.numpy as jnp -from simple_pytree import Pytree, static_field - - -@dataclass -class Bijector(Pytree): - """ - Create a bijector. - - Args: - forward(Callable): The forward transformation. - inverse(Callable): The inverse transformation. - - Returns: - Bijector: A bijector. - """ - - forward: Callable = static_field() - inverse: Callable = static_field() - - -Identity = Bijector(forward=lambda x: x, inverse=lambda x: x) - -Softplus = Bijector( - forward=lambda x: jnp.log(1.0 + jnp.exp(x)), - inverse=lambda x: jnp.log(jnp.exp(x) - 1.0), -) diff --git a/gpjax/params.py b/gpjax/params.py deleted file mode 100644 index 02a438c98..000000000 --- a/gpjax/params.py +++ /dev/null @@ -1,420 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import warnings -from copy import deepcopy -from typing import Dict, Tuple -from warnings import warn - -import distrax as dx -import jax -import jax.numpy as jnp -import jax.random as jr -from jax.random import KeyArray -from jaxtyping import Array, Float -from jaxutils import PyTree - -from .config import Identity, get_global_config -from .utils import merge_dictionaries - - -################################ -# Base operations -################################ -class ParameterState(PyTree): - """ - The state of the model. This includes the parameter set, which parameters - are to be trained and bijectors that allow parameters to be constrained and - unconstrained. - """ - - def __init__(self, params: Dict, trainables: Dict, bijectors: Dict) -> None: - self.params = params - self.trainables = trainables - self.bijectors = bijectors - - def unpack(self): - """Unpack the state into a tuple of parameters, trainables and bijectors. - - Returns: - Tuple[Dict, Dict, Dict]: The parameters, trainables and bijectors. - """ - return self.params, self.trainables, self.bijectors - - -def initialise(model, key: KeyArray = None, **kwargs) -> ParameterState: - """ - Initialise the stateful parameters of any GPJax object. This function also - returns the trainability status of each parameter and set of bijectors that - allow parameters to be constrained and unconstrained. - - Args: - model: The GPJax object that is to be initialised. - key (KeyArray, optional): The random key that is to be used for - initialisation. Defaults to None. - - Returns: - ParameterState: The state of the model. This includes the parameter - set, which parameters are to be trained and bijectors that allow - parameters to be constrained and unconstrained. - """ - - if key is None: - warn("No PRNGKey specified. Defaulting to seed 123.", UserWarning, stacklevel=2) - key = jr.PRNGKey(123) - - # Initialise the parameters. - if hasattr(model, "init_params"): - params = model.init_params(key) - - elif hasattr(model, "_initialise_params"): - warn( - "`_initialise_params` is deprecated. Please use `init_params` instead.", - DeprecationWarning, - stacklevel=2, - ) - params = model._initialise_params(key) - - else: - raise AttributeError("No `init_params` or `_initialise_params` method found.") - - if kwargs: - _validate_kwargs(kwargs, params) - for k, v in kwargs.items(): - params[k] = merge_dictionaries(params[k], v) - - bijectors = build_bijectors(params) - trainables = build_trainables(params) - - return ParameterState( - params=params, - trainables=trainables, - bijectors=bijectors, - ) - - -def _validate_kwargs(kwargs, params): - for k, v in kwargs.items(): - if k not in params.keys(): - raise ValueError(f"Parameter {k} is not a valid parameter.") - - -def recursive_items(d1: Dict, d2: Dict): - """ - Recursive loop over pair of dictionaries whereby the value of a given key in - either dictionary can be itself a dictionary. - - Args: - d1 (_type_): _description_ - d2 (_type_): _description_ - - Yields: - _type_: _description_ - """ - for key, value in d1.items(): - if type(value) is dict: - yield from recursive_items(value, d2[key]) - else: - yield (key, value, d2[key]) - - -def recursive_complete(d1: Dict, d2: Dict) -> Dict: - """ - Recursive loop over pair of dictionaries whereby the value of a given key in - either dictionary can be itself a dictionary. If the value of the key in the - second dictionary is None, the value of the key in the first dictionary is - used. - - Args: - d1 (Dict): The reference dictionary. - d2 (Dict): The potentially incomplete dictionary. - - Returns: - Dict: A completed form of the second dictionary. - """ - for key, value in d1.items(): - if type(value) is dict: - if key in d2.keys(): - recursive_complete(value, d2[key]) - else: - if key in d2.keys(): - d1[key] = d2[key] - return d1 - - -################################ -# Parameter transformation -################################ -def build_bijectors(params: Dict) -> Dict: - """ - For each parameter, build the bijection pair that allows the parameter to be - constrained and unconstrained. - - Args: - params (Dict): _description_ - - Returns: - Dict: A dictionary that maps each parameter to a bijection. - """ - bijectors = copy_dict_structure(params) - config = get_global_config() - transform_set = config["transformations"] - - def recursive_bijectors_list(ps, bs): - return [recursive_bijectors(ps[i], bs[i]) for i in range(len(bs))] - - def recursive_bijectors(ps, bs) -> Tuple[Dict, Dict]: - if type(ps) is list: - bs = recursive_bijectors_list(ps, bs) - - else: - for key, value in ps.items(): - if type(value) is dict: - recursive_bijectors(value, bs[key]) - elif type(value) is list: - bs[key] = recursive_bijectors_list(value, bs[key]) - else: - if key in transform_set.keys(): - transform_type = transform_set[key] - bijector = transform_set[transform_type] - else: - bijector = Identity - warnings.warn( - f"Parameter {key} has no transform. Defaulting to identity transform." - ) - bs[key] = bijector - return bs - - return recursive_bijectors(params, bijectors) - - -def constrain(params: Dict, bijectors: Dict) -> Dict: - """ - Transform the parameters to the constrained space for corresponding - bijectors. - - Args: - params (Dict): The parameters that are to be transformed. - bijectors (Dict): The bijectors that are to be used for - transformation. - - Returns: - Dict: A transformed parameter set. The dictionary is equal in - structure to the input params dictionary. - """ - map = lambda param, trans: trans.forward(param) - - return jax.tree_util.tree_map(map, params, bijectors) - - -def unconstrain(params: Dict, bijectors: Dict) -> Dict: - """Transform the parameters to the unconstrained space for corresponding - bijectors. - - Args: - params (Dict): The parameters that are to be transformed. - bijectors (Dict): The corresponding dictionary of transforms that - should be applied to the parameter set. - - Returns: - Dict: A transformed parameter set. The dictionary is equal in - structure to the input params dictionary. - """ - - map = lambda param, trans: trans.inverse(param) - - return jax.tree_util.tree_map(map, params, bijectors) - - -################################ -# Priors -################################ -def log_density( - param: Float[Array, "D"], density: dx.Distribution -) -> Float[Array, "1"]: - """Compute the log density of a parameter given a distribution. - - Args: - param (Float[Array, "D"]): The parameter that is to be evaluated. - density (dx.Distribution): The distribution that is to be evaluated. - - Returns: - Float[Array, "1"]: The log density of the parameter. - """ - if type(density) == type(None): - log_prob = jnp.array(0.0) - else: - log_prob = jnp.sum(density.log_prob(param)) - return log_prob - - -def copy_dict_structure(params: Dict) -> Dict: - """Copy the structure of a dictionary. - - Args: - params (Dict): The dictionary that is to be copied. - - Returns: - Dict: A copy of the input dictionary. - """ - # Copy dictionary structure - prior_container = deepcopy(params) - # Set all values to zero - prior_container = jax.tree_util.tree_map(lambda _: None, prior_container) - return prior_container - - -def structure_priors(params: Dict, priors: Dict) -> Dict: - """First create a dictionary with equal structure to the parameters. - Then, for each supplied prior, overwrite the None value if it exists. - - Args: - params (Dict): [description] - priors (Dict): [description] - - Returns: - Dict: [description] - """ - prior_container = copy_dict_structure(params) - # Where a prior has been supplied, override the None value by the prior distribution. - complete_prior = recursive_complete(prior_container, priors) - return complete_prior - - -def evaluate_priors(params: Dict, priors: Dict) -> Dict: - """ - Recursive loop over pair of dictionaries that correspond to a parameter's - current value and the parameter's respective prior distribution. For - parameters where a prior distribution is specified, the log-prior density is - evaluated at the parameter's current value. - - Args: params (Dict): Dictionary containing the current set of parameter - estimates. priors (Dict): Dictionary specifying the parameters' prior - distributions. - - Returns: - Dict: The log-prior density, summed over all parameters. - """ - lpd = jnp.array(0.0) - if priors is not None: - for name, param, prior in recursive_items(params, priors): - lpd += log_density(param, prior) - return lpd - - -def prior_checks(priors: Dict) -> Dict: - """ - Run checks on the parameters' prior distributions. This checks that for - Gaussian processes that are constructed with non-conjugate likelihoods, the - prior distribution on the function's latent values is a unit Gaussian. - - Args: - priors (Dict): Dictionary specifying the parameters' prior distributions. - - Returns: - Dict: Dictionary specifying the parameters' prior distributions. - """ - if "latent" in priors.keys(): - latent_prior = priors["latent"] - if latent_prior is not None: - if not isinstance(latent_prior, dx.Normal): - warnings.warn( - f"A {type(latent_prior)} distribution prior has been placed on" - " the latent function. It is strongly advised that a" - " unit Gaussian prior is used." - ) - else: - warnings.warn("Placing unit Gaussian prior on latent function.") - priors["latent"] = dx.Normal(loc=0.0, scale=1.0) - else: - priors["latent"] = dx.Normal(loc=0.0, scale=1.0) - - return priors - - -def build_trainables(params: Dict, status: bool = True) -> Dict: - """ - Construct a dictionary of trainable statuses for each parameter. By default, - every parameter within the model is trainable. - - Args: - params (Dict): The parameter set for which trainable statuses should be - derived from. - status (bool): The status of each parameter. Default is True. - - Returns: - Dict: A dictionary of boolean trainability statuses. The dictionary is - equal in structure to the input params dictionary. - """ - # Copy dictionary structure - prior_container = deepcopy(params) - # Set all values to zero - prior_container = jax.tree_util.tree_map(lambda _: status, prior_container) - return prior_container - - -def _stop_grad(param: Dict, trainable: Dict) -> Dict: - """ - When taking a gradient, we want to stop the gradient from flowing through a - parameter if it is not trainable. This is achieved using the model's - dictionary of parameters and the corresponding trainability status. - - Args: - param (Dict): The parameter set for which trainable statuses should be - derived from. - trainable (Dict): A boolean value denoting the training status the `param`. - - Returns: - Dict: The gradient is stopped for non-trainable parameters. - """ - return jax.lax.cond(trainable, lambda x: x, jax.lax.stop_gradient, param) - - -def trainable_params(params: Dict, trainables: Dict) -> Dict: - """ - Stop the gradients flowing through parameters whose trainable status is - False. - - Args: - params (Dict): The parameter set for which trainable statuses should - be derived from. - trainables (Dict): A dictionary of boolean trainability statuses. The - dictionary is equal in structure to the input params dictionary. - - Returns: - Dict: A dictionary parameters. The dictionary is equal in structure to - the input params dictionary. - """ - return jax.tree_util.tree_map( - lambda param, trainable: _stop_grad(param, trainable), params, trainables - ) - - -__all__ = [ - "ParameterState", - "initialise", - "recursive_items", - "recursive_complete", - "build_bijectors", - "constrain", - "unconstrain", - "log_density", - "copy_dict_structure", - "structure_priors", - "evaluate_priors", - "prior_checks", - "build_trainables", - "trainable_params", -] diff --git a/gpjax/scan.py b/gpjax/scan.py new file mode 100644 index 000000000..211964850 --- /dev/null +++ b/gpjax/scan.py @@ -0,0 +1,164 @@ +# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Callable, List, Optional, Tuple, TypeVar, Any +from jax import lax +from jax.experimental import host_callback as hcb +from tqdm.auto import trange + +import jax.tree_util as jtu +import jax +import jax.numpy as jnp + +Carry = TypeVar("Carry") +X = TypeVar("X") +Y = TypeVar("Y") + + +def _callback(cond: bool, func: Callable, *args: Any) -> None: + """Callback a function for a given argument if a condition is true. + + Args: + cond (bool): The condition. + func (Callable): The function to call. + *args (Any): The arguments to pass to the function. + """ + + # lax.cond requires a result, so we use a dummy result. + _dummy_result = 0 + + def _do_callback(_) -> int: + """Perform the callback.""" + return hcb.id_tap(func, *args, result=_dummy_result) + + def _not_callback(_) -> int: + """Do nothing.""" + return _dummy_result + + _ = lax.cond(cond, _do_callback, _not_callback, operand=None) + + +def vscan( + f: Callable[[Carry, X], Tuple[Carry, Y]], + init: Carry, + xs: X, + length: Optional[int] = None, + reverse: Optional[bool] = False, + unroll: Optional[int] = 1, + log_rate: Optional[int] = 10, + log_value: Optional[bool] = True, +) -> Tuple[Carry, List[Y]]: + """Scan with verbose output. + + This is based on code from the excellent blog post: + https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/. + + Example: + >>> def f(carry, x): + ... return carry + x, carry + x + >>> init = 0 + >>> xs = jnp.arange(10) + >>> vscan(f, init, xs) + (45, DeviceArray([ 0, 1, 3, 6, 10, 15, 21, 28, 36, 45], dtype=int32)) + + Args: + f (Callable[[Carry, X], Tuple[Carry, Y]]): A function that takes in a carry and + an input and returns a tuple of a new carry and an output. + init (Carry): The initial carry. + xs (X): The inputs. + length (Optional[int]): The length of the inputs. If None, then the length of + the inputs is inferred. + reverse (bool): Whether to scan in reverse. + unroll (int): The number of iterations to unroll. + log_rate (int): The rate at which to log the progress bar. + log_value (bool): Whether to log the value of the objective function. + + Returns: + Tuple[Carry, List[Y]]: A tuple of the final carry and the outputs. + """ + + # TODO: Scope out lower level API for jax.lax.scan, to avoid the need for finding + # the length of the inputs / check inputs. + # TODO: Scope out lower level API for tqdm, for more control over the progress bar. + # Need to check this. + _xs_flat = jtu.tree_leaves(xs) + _length = length if length is not None else len(_xs_flat[0]) + _iter_nums = jnp.arange(_length) + _remainder = _length % log_rate + + _progress_bar = trange(_length) + _progress_bar.set_description("Compiling...", refresh=True) + + def _set_running(args: Any, transform: Any) -> None: + """Set the tqdm progress bar to running.""" + _progress_bar.set_description("Running", refresh=False) + + def _update_tqdm(args: Any, transform: Any) -> None: + """Update the tqdm progress bar with the latest objective value.""" + _value, _iter_num = args + _progress_bar.update(_iter_num) + + if log_value and _value is not None: + _progress_bar.set_postfix({"Value": f"{_value: .2f}"}) + + def _close_tqdm(args: Any, transform: Any) -> None: + """Close the tqdm progress bar.""" + _progress_bar.close() + + def _body_fun(carry: Carry, iter_num_and_x: Tuple[int, X]) -> Tuple[Carry, Y]: + + # Unpack iter_num and x. + iter_num, x = iter_num_and_x + + # Compute body function. + carry, y = f(carry, x) + + # Conditions for iteration number. + _is_first: bool = iter_num == 0 + _is_multiple: bool = (iter_num % log_rate == 0) & ( + iter_num != _length - _remainder + ) + _is_remainder: bool = iter_num == _length - _remainder + _is_last: bool = iter_num == _length - 1 + + # Update progress bar, if first of log_rate. + _callback(_is_first, _set_running, (y, log_rate)) + + # Update progress bar, if multiple of log_rate. + _callback(_is_multiple, _update_tqdm, (y, log_rate)) + + # Update progress bar, if remainder. + _callback(_is_remainder, _update_tqdm, (y, _remainder)) + + # Close progress bar, if last iteration. + _callback(_is_last, _close_tqdm, (y, None)) + + return carry, y + + carry, ys = jax.lax.scan( + _body_fun, + init, + (_iter_nums, xs), + length=length, + reverse=reverse, + unroll=unroll, + ) + + return carry, ys + + +__all__ = [ + "vscan", +] diff --git a/gpjax/utils.py b/gpjax/utils.py deleted file mode 100644 index 254c76ddc..000000000 --- a/gpjax/utils.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import deprecation -import jaxutils - -depreciate = deprecation.deprecated( - deprecated_in="0.5.6", - removed_in="0.6.0", - details="Use method from jaxutils.config instead.", -) - - -concat_dictionaries = depreciate(jaxutils.dict.concat_dictionaries) -merge_dictionaries = depreciate(jaxutils.dict.merge_dictionaries) -sort_dictionary = depreciate(jaxutils.dict.sort_dictionary) -dict_array_coercion = depreciate(jaxutils.dict.dict_array_coercion) - - -__all__ = [ - "concat_dictionaries", - "merge_dictionaries", - "sort_dictionary", - "dict_array_coercion", -] diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index 7255e6cc9..a8208f1a9 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -14,30 +14,19 @@ # ============================================================================== import abc -from typing import Any, Callable, Dict - -import deprecation -import distrax as dx +from typing import Any import jax.numpy as jnp import jax.scipy as jsp -from jax.random import KeyArray from jaxtyping import Array, Float -from .linops import identity -from jaxutils import Dataset -from .linops import ( - DenseLinearOperator, - LowerTriangularLinearOperator -) from mytree import Mytree, param_field from simple_pytree import static_field -from .config import get_global_config +from .dataset import Dataset from .gaussian_distribution import GaussianDistribution -from .gps import Prior -from .likelihoods import AbstractLikelihood, Gaussian +from .gps import AbstractPosterior +from .likelihoods import Gaussian from .linops import DenseLinearOperator, LowerTriangularLinearOperator, identity -from .utils import concat_dictionaries from .gaussian_distribution import GaussianDistribution from dataclasses import dataclass @@ -50,6 +39,7 @@ class AbstractVariationalFamily(Mytree): Abstract base class used to represent families of distributions that can be used within variational inference. """ + posterior: AbstractPosterior def __call__(self, *args: Any, **kwargs: Any) -> GaussianDistribution: """For a given set of parameters, compute the latent function's prediction @@ -85,8 +75,8 @@ def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution: @dataclass class AbstractVariationalGaussian(AbstractVariationalFamily): """The variational Gaussian family of probability distributions.""" - prior: Prior inducing_inputs: Float[Array, "N D"] + jitter: Float[Array, "1"] = static_field(1e-6) @property def num_inducing(self) -> int: @@ -106,7 +96,7 @@ class VariationalGaussian(AbstractVariationalGaussian): """ variational_mean: Float[Array, "N 1"] = param_field(None) variational_root_covariance: Float[Array, "N N"] = param_field(None, bijector=tfb.FillScaleTriL(diag_shift=jnp.array(1e-6))) - jitter: Float[Array, "1"] = static_field(1e-6) + def __post_init__(self) -> None: if self.variational_mean is None: @@ -130,8 +120,6 @@ def prior_kl(self) -> Float[Array, "1"]: approximation and the GP prior. """ - jitter = get_global_config()["jitter"] - # Unpack variational parameters mu = self.variational_mean sqrt = self.variational_root_covariance @@ -139,12 +127,12 @@ def prior_kl(self) -> Float[Array, "1"]: m = self.num_inducing # Unpack mean function and kernel - mean_function = self.prior.mean_function - kernel = self.prior.kernel + mean_function = self.posterior.prior.mean_function + kernel = self.posterior.prior.kernel μz = mean_function(z) Kzz = kernel.gram(z) - Kzz += identity(m) * jitter + Kzz += identity(m) * self.jitter sqrt = LowerTriangularLinearOperator.from_dense(sqrt) S = DenseLinearOperator.from_root(sqrt) @@ -169,7 +157,6 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: Returns: GaussianDistribution: The predictive distribution of the low-rank GP at the test inputs. """ - jitter = get_global_config()["jitter"] # Unpack variational parameters mu = self.variational_mean @@ -178,11 +165,11 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: m = self.num_inducing # Unpack mean function and kernel - mean_function = self.prior.mean_function - kernel = self.prior.kernel + mean_function = self.posterior.prior.mean_function + kernel = self.posterior.prior.kernel Kzz = kernel.gram(z) - Kzz += identity(m) * jitter + Kzz += identity(m) * self.jitter Lz = Kzz.to_root() μz = mean_function(z) @@ -211,7 +198,7 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T) ) - covariance += identity(n_test) * jitter + covariance += identity(n_test) * self.jitter return GaussianDistribution( loc=jnp.atleast_1d(mean.squeeze()), scale=covariance @@ -265,7 +252,6 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: Returns: GaussianDistribution: The predictive distribution of the low-rank GP at the test inputs. """ - jitter = get_global_config()["jitter"] # Unpack variational parameters mu = self.variational_mean @@ -274,11 +260,11 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: m = self.num_inducing # Unpack mean function and kernel - mean_function = self.prior.mean_function - kernel = self.prior.kernel + mean_function = self.posterior.prior.mean_function + kernel = self.posterior.prior.kernel Kzz = kernel.gram(z) - Kzz += identity(m) * jitter + Kzz += identity(m) * self.jitter Lz = Kzz.to_root() # Unpack test inputs @@ -303,7 +289,7 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + jnp.matmul(Ktz_Lz_invT_sqrt, Ktz_Lz_invT_sqrt.T) ) - covariance += identity(n_test) * jitter + covariance += identity(n_test) * self.jitter return GaussianDistribution( loc=jnp.atleast_1d(mean.squeeze()), scale=covariance @@ -339,7 +325,6 @@ def prior_kl(self) -> Float[Array, "1"]: Returns: Float[Array, "1"]: The KL-divergence between our variational approximation and the GP prior. """ - jitter = get_global_config()["jitter"] # Unpack variational parameters natural_vector = self.natural_vector @@ -348,12 +333,12 @@ def prior_kl(self) -> Float[Array, "1"]: m = self.num_inducing # Unpack mean function and kernel - mean_function = self.prior.mean_function - kernel = self.prior.kernel + mean_function = self.posterior.prior.mean_function + kernel = self.posterior.prior.kernel # S⁻¹ = -2θ₂ S_inv = -2 * natural_matrix - S_inv += jnp.eye(m) * jitter + S_inv += jnp.eye(m) * self.jitter # Compute L⁻¹, where LLᵀ = S, via a trick found in the NumPyro source code and https://nbviewer.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril: sqrt_inv = jnp.swapaxes( @@ -372,7 +357,7 @@ def prior_kl(self) -> Float[Array, "1"]: μz = mean_function(z) Kzz = kernel.gram(z) - Kzz += identity(m) * jitter + Kzz += identity(m) * self.jitter qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S) pu = GaussianDistribution(loc=jnp.atleast_1d(μz.squeeze()), scale=Kzz) @@ -391,7 +376,6 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: Returns: GaussianDistribution: A function that accepts a set of test points and will return the predictive distribution at those points. """ - jitter = get_global_config()["jitter"] # Unpack variational parameters natural_vector = self.natural_vector @@ -400,12 +384,12 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: m = self.num_inducing # Unpack mean function and kernel - mean_function = self.prior.mean_function - kernel = self.prior.kernel + mean_function = self.posterior.prior.mean_function + kernel = self.posterior.prior.kernel # S⁻¹ = -2θ₂ S_inv = -2 * natural_matrix - S_inv += jnp.eye(m) * jitter + S_inv += jnp.eye(m) * self.jitter # Compute L⁻¹, where LLᵀ = S, via a trick found in the NumPyro source code and https://nbviewer.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril: sqrt_inv = jnp.swapaxes( @@ -422,7 +406,7 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: mu = jnp.matmul(S, natural_vector) Kzz = kernel.gram(z) - Kzz += identity(m) * jitter + Kzz += identity(m) * self.jitter Lz = Kzz.to_root() μz = mean_function(z) @@ -451,7 +435,7 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + jnp.matmul(Ktz_Kzz_inv_L, Ktz_Kzz_inv_L.T) ) - covariance += identity(n_test) * jitter + covariance += identity(n_test) * self.jitter return GaussianDistribution( loc=jnp.atleast_1d(mean.squeeze()), scale=covariance @@ -491,7 +475,6 @@ def prior_kl(self) -> Float[Array, "1"]: Returns: Float[Array, "1"]: The KL-divergence between our variational approximation and the GP prior. """ - jitter = get_global_config()["jitter"] # Unpack variational parameters expectation_vector = self.expectation_vector @@ -500,8 +483,8 @@ def prior_kl(self) -> Float[Array, "1"]: m = self.num_inducing # Unpack mean function and kernel - mean_function = self.prior.mean_function - kernel = self.prior.kernel + mean_function = self.posterior.prior.mean_function + kernel = self.posterior.prior.kernel # μ = η₁ mu = expectation_vector @@ -509,11 +492,11 @@ def prior_kl(self) -> Float[Array, "1"]: # S = η₂ - η₁ η₁ᵀ S = expectation_matrix - jnp.outer(mu, mu) S = DenseLinearOperator(S) - S += identity(m) * jitter + S += identity(m) * self.jitter μz = mean_function(z) Kzz = kernel.gram(z) - Kzz += identity(m) * jitter + Kzz += identity(m) * self.jitter qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S) pu = GaussianDistribution(loc=jnp.atleast_1d(μz.squeeze()), scale=Kzz) @@ -532,7 +515,6 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: Returns: GaussianDistribution: The predictive distribution of the GP at the test inputs t. """ - jitter = get_global_config()["jitter"] # Unpack variational parameters expectation_vector = self.expectation_vector @@ -541,8 +523,8 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: m = self.num_inducing # Unpack mean function and kernel - mean_function = self.prior.mean_function - kernel = self.prior.kernel + mean_function = self.posterior.prior.mean_function + kernel = self.posterior.prior.kernel # μ = η₁ mu = expectation_vector @@ -550,13 +532,13 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: # S = η₂ - η₁ η₁ᵀ S = expectation_matrix - jnp.matmul(mu, mu.T) S = DenseLinearOperator(S) - S += identity(m) * jitter + S += identity(m) * self.jitter # S = sqrt sqrtᵀ sqrt = S.to_root().to_dense() Kzz = kernel.gram(z) - Kzz += identity(m) * jitter + Kzz += identity(m) * self.jitter Lz = Kzz.to_root() μz = mean_function(z) @@ -585,7 +567,7 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T) ) - covariance += identity(n_test) * jitter + covariance += identity(n_test) * self.jitter return GaussianDistribution( loc=jnp.atleast_1d(mean.squeeze()), scale=covariance @@ -595,10 +577,9 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: class CollapsedVariationalGaussian(AbstractVariationalGaussian): """Collapsed variational Gaussian family of probability distributions. The key reference is Titsias, (2009) - Variational Learning of Inducing Variables in Sparse Gaussian Processes.""" - likelihood: AbstractLikelihood def __post_init__(self): - if not isinstance(self.likelihood, Gaussian): + if not isinstance(self.posterior.likelihood, Gaussian): raise TypeError("Likelihood must be Gaussian.") def predict(self, test_inputs: Float[Array, "N D"], train_data: Dataset) -> GaussianDistribution: @@ -610,7 +591,6 @@ def predict(self, test_inputs: Float[Array, "N D"], train_data: Dataset) -> Gaus Returns: GaussianDistribution: The predictive distribution of the collapsed variational Gaussian process at the test inputs t. """ - jitter = get_global_config()["jitter"] # Unpack test inputs t, n_test = test_inputs, test_inputs.shape[0] @@ -619,17 +599,17 @@ def predict(self, test_inputs: Float[Array, "N D"], train_data: Dataset) -> Gaus x, y = train_data.X, train_data.y # Unpack variational parameters - noise = self.likelihood.obs_noise + noise = self.posterior.likelihood.obs_noise z = self.inducing_inputs m = self.num_inducing # Unpack mean function and kernel - mean_function = self.prior.mean_function - kernel = self.prior.kernel + mean_function = self.posterior.prior.mean_function + kernel = self.posterior.prior.kernel Kzx = kernel.cross_covariance(z, x) Kzz = kernel.gram(z) - Kzz += identity(m) * jitter + Kzz += identity(m) * self.jitter # Lz Lzᵀ = Kzz Lz = Kzz.to_root() @@ -676,7 +656,7 @@ def predict(self, test_inputs: Float[Array, "N D"], train_data: Dataset) -> Gaus - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + jnp.matmul(L_inv_Lz_inv_Kzt.T, L_inv_Lz_inv_Kzt) ) - covariance += identity(n_test) * jitter + covariance += identity(n_test) * self.jitter return GaussianDistribution( loc=jnp.atleast_1d(mean.squeeze()), scale=covariance diff --git a/gpjax/variational_inference.py b/gpjax/variational_inference.py deleted file mode 100644 index 8e2166e99..000000000 --- a/gpjax/variational_inference.py +++ /dev/null @@ -1,283 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import abc -from typing import Callable, Dict - -import deprecation -import jax.numpy as jnp -import jax.scipy as jsp -from jax import vmap -from jax.random import KeyArray -from jaxtyping import Array, Float -from jaxutils import Dataset, PyTree - -from .config import get_global_config -from .gps import AbstractPosterior -from .likelihoods import Gaussian -from .linops import identity -from .quadrature import gauss_hermite_quadrature -from .utils import concat_dictionaries -from .variational_families import ( - AbstractVariationalFamily, - CollapsedVariationalGaussian, -) - - -class AbstractVariationalInference(PyTree): - """A base class for inference and training of variational families against an exact posterior""" - - def __init__( - self, - posterior: AbstractPosterior, - variational_family: AbstractVariationalFamily, - ) -> None: - """Initialise the variational inference module. - - Args: - posterior (AbstractPosterior): The exact posterior distribution. - variational_family (AbstractVariationalFamily): The variational family to be trained. - """ - self.posterior = posterior - self.prior = self.posterior.prior - self.likelihood = self.posterior.likelihood - self.variational_family = variational_family - - def init_params(self, key: KeyArray) -> Dict: - """Construct the parameter set used within the variational scheme adopted.""" - hyperparams = concat_dictionaries( - {"likelihood": self.posterior.likelihood.init_params(key)}, - self.variational_family.init_params(key), - ) - return hyperparams - - @deprecation.deprecated( - deprecated_in="0.5.7", - removed_in="0.6.0", - details="Use the ``init_params`` method for parameter initialisation.", - ) - def _initialise_params(self, key: KeyArray) -> Dict: - """Deprecated method for initialising the GP's parameters. Succeeded by ``init_params``.""" - return self.init_params(key) - - @abc.abstractmethod - def elbo( - self, - train_data: Dataset, - ) -> Callable[[Dict], Float[Array, "1"]]: - """Placeholder method for computing the evidence lower bound function (ELBO), given a training dataset and a set of transformations that map each parameter onto the entire real line. - - Args: - train_data (Dataset): The training dataset for which the ELBO is to be computed. - - Returns: - Callable[[Array], Array]: A function that computes the ELBO given a set of parameters. - """ - raise NotImplementedError - - -class StochasticVI(AbstractVariationalInference): - """Stochastic Variational inference training module. The key reference is Hensman et. al., (2013) - Gaussian processes for big data.""" - - def elbo( - self, train_data: Dataset, negative: bool = False - ) -> Callable[[Float[Array, "N D"]], Float[Array, "1"]]: - """Compute the evidence lower bound under this model. In short, this requires evaluating the expectation of the model's log-likelihood under the variational approximation. To this, we sum the KL divergence from the variational posterior to the prior. When batching occurs, the result is scaled by the batch size relative to the full dataset size. - - Args: - train_data (Dataset): The training data for which we should maximise the ELBO with respect to. - negative (bool, optional): Whether or not the resultant elbo function should be negative. For gradient descent where we minimise our objective function this argument should be true as minimisation of the negative corresponds to maximisation of the ELBO. Defaults to False. - - Returns: - Callable[[Dict, Dataset], Array]: A callable function that accepts a current parameter estimate and batch of data for which gradients should be computed. - """ - - # Constant for whether or not to negate the elbo for optimisation purposes - constant = jnp.array(-1.0) if negative else jnp.array(1.0) - - def elbo_fn(params: Dict, batch: Dataset) -> Float[Array, "1"]: - # KL[q(f(·)) || p(f(·))] - kl = self.variational_family.prior_kl(params) - - # ∫[log(p(y|f(·))) q(f(·))] df(·) - var_exp = self.variational_expectation(params, batch) - - # For batch size b, we compute n/b * Σᵢ[ ∫log(p(y|f(xᵢ))) q(f(xᵢ)) df(xᵢ)] - KL[q(f(·)) || p(f(·))] - return constant * (jnp.sum(var_exp) * train_data.n / batch.n - kl) - - return elbo_fn - - def variational_expectation( - self, params: Dict, batch: Dataset - ) -> Float[Array, "N 1"]: - """Compute the expectation of our model's log-likelihood under our variational distribution. Batching can be done here to speed up computation. - - Args: - params (Dict): The set of parameters that induce our variational approximation. - batch (Dataset): The data batch for which the expectation should be computed for. - - Returns: - Array: The expectation of the model's log-likelihood under our variational distribution. - """ - - # Unpack training batch - x, y = batch.X, batch.y - - # Variational distribution q(f(·)) = N(f(·); μ(·), Σ(·, ·)) - q = self.variational_family(params) - - # Compute variational mean, μ(x), and variance, √diag(Σ(x, x)), at training inputs, x - def q_moments(x): - qx = q(x) - return qx.mean(), qx.variance() - - mean, variance = vmap(q_moments)(x[:, None]) - - # log(p(y|f(x))) - link_function = self.likelihood.link_function - log_prob = vmap(lambda f, y: link_function(params["likelihood"], f).log_prob(y)) - - # ≈ ∫[log(p(y|f(x))) q(f(x))] df(x) - expectation = gauss_hermite_quadrature(log_prob, mean, jnp.sqrt(variance), y=y) - - return expectation - - -class CollapsedVI(AbstractVariationalInference): - """Collapsed variational inference for a sparse Gaussian process regression model. - The key reference is Titsias, (2009) - Variational Learning of Inducing Variables in Sparse Gaussian Processes.""" - - def __init__( - self, - posterior: AbstractPosterior, - variational_family: AbstractVariationalFamily, - ) -> None: - """Initialise the variational inference module. - - Args: - posterior (AbstractPosterior): The exact posterior distribution. - variational_family (AbstractVariationalFamily): The variational family to be trained. - """ - - if not isinstance(posterior.likelihood, Gaussian): - raise TypeError("Likelihood must be Gaussian.") - - if not isinstance(variational_family, CollapsedVariationalGaussian): - raise TypeError("Variational family must be CollapsedVariationalGaussian.") - - super().__init__(posterior, variational_family) - - def elbo( - self, train_data: Dataset, negative: bool = False - ) -> Callable[[Dict], Float[Array, "1"]]: - """Compute the evidence lower bound under this model. In short, this requires evaluating the expectation of the model's log-likelihood under the variational approximation. To this, we sum the KL divergence from the variational posterior to the prior. When batching occurs, the result is scaled by the batch size relative to the full dataset size. - - Args: - train_data (Dataset): The training data for which we should maximise the ELBO with respect to. - negative (bool, optional): Whether or not the resultant elbo function should be negative. For gradient descent where we minimise our objective function this argument should be true as minimisation of the negative corresponds to maximisation of the ELBO. Defaults to False. - - Returns: - Callable[[Dict, Dataset], Array]: A callable function that accepts a current parameter estimate for which gradients should be computed. - """ - - # Unpack training data - x, y, n = train_data.X, train_data.y, train_data.n - - # Unpack mean function and kernel - mean_function = self.prior.mean_function - kernel = self.prior.kernel - - m = self.variational_family.num_inducing - jitter = get_global_config()["jitter"] - - # Constant for whether or not to negate the elbo for optimisation purposes - constant = jnp.array(-1.0) if negative else jnp.array(1.0) - - def elbo_fn(params: Dict) -> Float[Array, "1"]: - noise = params["likelihood"]["obs_noise"] - z = params["variational_family"]["inducing_inputs"] - Kzz = kernel.gram(params["kernel"], z) - Kzz += identity(m) * jitter - Kzx = kernel.cross_covariance(params["kernel"], z, x) - Kxx_diag = vmap(kernel, in_axes=(None, 0, 0))(params["kernel"], x, x) - μx = mean_function(params["mean_function"], x) - - Lz = Kzz.to_root() - - # Notation and derivation: - # - # Let Q = KxzKzz⁻¹Kzx, we must compute the log normal pdf: - # - # log N(y; μx, σ²I + Q) = -nπ - n/2 log|σ²I + Q| - 1/2 (y - μx)ᵀ (σ²I + Q)⁻¹ (y - μx). - # - # The log determinant |σ²I + Q| is computed via applying the matrix determinant lemma - # - # |σ²I + Q| = log|σ²I| + log|I + Lz⁻¹ Kzx (σ²I)⁻¹ Kxz Lz⁻¹| = log(σ²) + log|B|, - # - # with B = I + AAᵀ and A = Lz⁻¹ Kzx / σ. - # - # Similarly we apply matrix inversion lemma to invert σ²I + Q - # - # (σ²I + Q)⁻¹ = (Iσ²)⁻¹ - (Iσ²)⁻¹ Kxz Lz⁻ᵀ (I + Lz⁻¹ Kzx (Iσ²)⁻¹ Kxz Lz⁻ᵀ )⁻¹ Lz⁻¹ Kzx (Iσ²)⁻¹ - # = (Iσ²)⁻¹ - (Iσ²)⁻¹ σAᵀ (I + σA (Iσ²)⁻¹ σAᵀ)⁻¹ σA (Iσ²)⁻¹ - # = I/σ² - Aᵀ B⁻¹ A/σ², - # - # giving the quadratic term as - # - # (y - μx)ᵀ (σ²I + Q)⁻¹ (y - μx) = [(y - μx)ᵀ(y - µx) - (y - μx)ᵀ Aᵀ B⁻¹ A (y - μx)]/σ², - # - # with A and B defined as above. - - A = Lz.solve(Kzx) / jnp.sqrt(noise) - - # AAᵀ - AAT = jnp.matmul(A, A.T) - - # B = I + AAᵀ - B = jnp.eye(m) + AAT - - # LLᵀ = I + AAᵀ - L = jnp.linalg.cholesky(B) - - # log|B| = 2 trace(log|L|) = 2 Σᵢ log Lᵢᵢ [since |B| = |LLᵀ| = |L|² => log|B| = 2 log|L|, and |L| = Πᵢ Lᵢᵢ] - log_det_B = 2.0 * jnp.sum(jnp.log(jnp.diagonal(L))) - - diff = y - μx - - # L⁻¹ A (y - μx) - L_inv_A_diff = jsp.linalg.solve_triangular( - L, jnp.matmul(A, diff), lower=True - ) - - # (y - μx)ᵀ (Iσ² + Q)⁻¹ (y - μx) - quad = (jnp.sum(diff ** 2) - jnp.sum(L_inv_A_diff ** 2)) / noise - - # 2 * log N(y; μx, Iσ² + Q) - two_log_prob = -n * jnp.log(2.0 * jnp.pi * noise) - log_det_B - quad - - # 1/σ² tr(Kxx - Q) [Trace law tr(AB) = tr(BA) => tr(KxzKzz⁻¹Kzx) = tr(KxzLz⁻ᵀLz⁻¹Kzx) = tr(Lz⁻¹Kzx KxzLz⁻ᵀ) = trace(σ²AAᵀ)] - two_trace = jnp.sum(Kxx_diag) / noise - jnp.trace(AAT) - - # log N(y; μx, Iσ² + KxzKzz⁻¹Kzx) - 1/2σ² tr(Kxx - KxzKzz⁻¹Kzx) - return constant * (two_log_prob - two_trace).squeeze() / 2.0 - - return elbo_fn - - -__all__ = [ - "AbstractVariationalInference", - "StochasticVI", - "CollapsedVI", -] diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 58a768825..25efb81d1 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -15,7 +15,7 @@ import jax.numpy as jnp import pytest -from jaxutils.dataset import Dataset +from gpjax.dataset import Dataset @pytest.mark.parametrize("n", [1, 10]) diff --git a/tests/test_fit.py b/tests/test_fit.py index ecf989f65..7ab04b314 100644 --- a/tests/test_fit.py +++ b/tests/test_fit.py @@ -13,11 +13,13 @@ # limitations under the License. # ============================================================================== -from jaxutils.dataset import Dataset -from jaxutils.fit import fit -from jaxutils.bijectors import Identity -from jaxutils.module import param, Module -from jaxutils.objective import Objective +from dataclasses import dataclass + +from gpjax.dataset import Dataset +from gpjax.fit import fit +from gpjax.parameters.bijectors import Identity +from gpjax.parameters import param_field, Module +from gpjax.objectives import AbstractObjective import jax.numpy as jnp import jax.random as jr @@ -31,9 +33,10 @@ def test_simple_linear_model(): D = Dataset(X, y) # (2) Define your model: + @dataclass class LinearModel(Module): - weight: float = param(Identity) - bias: float = param(Identity) + weight: float = param_field(bijector=Identity) + bias: float = param_field(bijector=Identity) def __call__(self, x): return self.weight * x + self.bias @@ -41,8 +44,9 @@ def __call__(self, x): model = LinearModel(weight=1.0, bias=1.0) # (3) Define your loss function: - class MeanSqaureError(Objective): - def evaluate(self, model: LinearModel, train_data: Dataset) -> float: + @dataclass + class MeanSqaureError(AbstractObjective): + def __call__(self, model: LinearModel, train_data: Dataset) -> float: return jnp.mean((train_data.y - model(train_data.X)) ** 2) loss = MeanSqaureError() diff --git a/tests/test_gps.py b/tests/test_gps.py index df90e0e8c..0ac5e7835 100644 --- a/tests/test_gps.py +++ b/tests/test_gps.py @@ -13,27 +13,25 @@ # limitations under the License. # ============================================================================== -import typing as tp - import distrax as dx -import jax import jax.numpy as jnp import jax.random as jr import pytest from jax.config import config -from gpjax import Dataset, initialise + from gpjax.gps import ( - AbstractPosterior, AbstractPrior, + AbstractPosterior, ConjugatePosterior, NonConjugatePosterior, Prior, construct_posterior, ) -from gpjax.kernels import RBF, Matern12, Matern32, Matern52 +from gpjax.kernels import RBF +from gpjax.mean_functions import Constant from gpjax.likelihoods import Bernoulli, Gaussian -from gpjax.parameters import ParameterState +from gpjax.dataset import Dataset # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -42,16 +40,13 @@ @pytest.mark.parametrize("num_datapoints", [1, 10]) def test_prior(num_datapoints): - p = Prior(kernel=RBF()) - parameter_state = initialise(p, jr.PRNGKey(123)) - params, _, _ = parameter_state.unpack() + p = Prior(mean_function=Constant(), kernel=RBF()) + assert isinstance(p, Prior) assert isinstance(p, AbstractPrior) - prior_rv_fn = p(params) - assert isinstance(prior_rv_fn, tp.Callable) x = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) - predictive_dist = prior_rv_fn(x) + predictive_dist = p(x) assert isinstance(predictive_dist, dx.Distribution) mu = predictive_dist.mean() sigma = predictive_dist.covariance() @@ -60,7 +55,8 @@ def test_prior(num_datapoints): @pytest.mark.parametrize("num_datapoints", [1, 2, 10]) -def test_conjugate_posterior(num_datapoints): +@pytest.mark.parametrize("jit_compile", [True, False]) +def test_conjugate_posterior(num_datapoints, jit_compile): key = jr.PRNGKey(123) x = jnp.sort( jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 1)), @@ -68,33 +64,19 @@ def test_conjugate_posterior(num_datapoints): ) y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 D = Dataset(X=x, y=y) + # Initialisation - p = Prior(kernel=RBF()) + p = Prior(mean_function=Constant(), kernel=RBF()) lik = Gaussian(num_datapoints=num_datapoints) post = p * lik assert isinstance(post, ConjugatePosterior) - assert isinstance(post, AbstractPrior) - assert isinstance(p, AbstractPrior) post2 = lik * p assert isinstance(post2, ConjugatePosterior) - assert isinstance(post2, AbstractPrior) - - parameter_state = initialise(post, key) - params, *_ = parameter_state.unpack() - - # Marginal likelihood - mll = post.marginal_log_likelihood(train_data=D) - objective_val = mll(params) - assert isinstance(objective_val, jax.Array) - assert objective_val.shape == () # Prediction - predictive_dist_fn = post(params, D) - assert isinstance(predictive_dist_fn, tp.Callable) - x = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) - predictive_dist = predictive_dist_fn(x) + predictive_dist = post(x, D) assert isinstance(predictive_dist, dx.Distribution) mu = predictive_dist.mean() @@ -102,10 +84,21 @@ def test_conjugate_posterior(num_datapoints): assert mu.shape == (num_datapoints,) assert sigma.shape == (num_datapoints, num_datapoints) + # # Loss function + # loss_fn = post.loss_function() + # assert isinstance(loss_fn, AbstractObjective) + # assert isinstance(loss_fn, ConjugateMLL) + # if jit_compile: + # loss_fn = jax.jit(loss_fn) + # objective_val = loss_fn(params=params, data=D) + # assert isinstance(objective_val, jax.Array) + # assert objective_val.shape == () + @pytest.mark.parametrize("num_datapoints", [1, 2, 10]) @pytest.mark.parametrize("likel", NonConjugateLikelihoods) -def test_nonconjugate_posterior(num_datapoints, likel): +@pytest.mark.parametrize("jit_compile", [True, False]) +def test_nonconjugate_posterior(num_datapoints, likel, jit_compile): key = jr.PRNGKey(123) x = jnp.sort( jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 1)), @@ -114,98 +107,56 @@ def test_nonconjugate_posterior(num_datapoints, likel): y = 0.5 * jnp.sign(jnp.cos(3 * x + jr.normal(key, shape=x.shape) * 0.05)) + 0.5 D = Dataset(X=x, y=y) # Initialisation - p = Prior(kernel=RBF()) + p = Prior(mean_function=Constant(), kernel=RBF()) lik = likel(num_datapoints=num_datapoints) post = p * lik assert isinstance(post, NonConjugatePosterior) - assert isinstance(post, AbstractPrior) - assert isinstance(p, AbstractPrior) - - parameter_state = initialise(post, key) - params, _, _ = parameter_state.unpack() - assert isinstance(parameter_state, ParameterState) - - # Marginal likelihood - mll = post.marginal_log_likelihood(train_data=D) - objective_val = mll(params) - assert isinstance(objective_val, jax.Array) - assert objective_val.shape == () + assert (post.latent == jnp.zeros((num_datapoints, 1))).all() # Prediction - predictive_dist_fn = post(params, D) - assert isinstance(predictive_dist_fn, tp.Callable) - x = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) - predictive_dist = predictive_dist_fn(x) + predictive_dist = post(x, D) assert isinstance(predictive_dist, dx.Distribution) mu = predictive_dist.mean() sigma = predictive_dist.covariance() assert mu.shape == (num_datapoints,) assert sigma.shape == (num_datapoints, num_datapoints) - - -@pytest.mark.parametrize("num_datapoints", [1, 10]) -@pytest.mark.parametrize("lik", [Bernoulli, Gaussian]) -def test_param_construction(num_datapoints, lik): - p = Prior(kernel=RBF()) * lik(num_datapoints=num_datapoints) - parameter_state = initialise(p, jr.PRNGKey(123)) - params, _, _ = parameter_state.unpack() - - if isinstance(lik, Bernoulli): - assert sorted(list(params.keys())) == [ - "kernel", - "latent_fn", - "likelihood", - "mean_function", - ] - elif isinstance(lik, Gaussian): - assert sorted(list(params.keys())) == [ - "kernel", - "likelihood", - "mean_function", - ] - - -@pytest.mark.parametrize("lik", [Bernoulli, Gaussian]) -def test_abstract_posterior(lik): - pr = Prior(kernel=RBF()) - likelihood = lik(num_datapoints=10) - - with pytest.raises(TypeError): - _ = AbstractPosterior(pr, likelihood) - - class DummyPosterior(AbstractPosterior): - def predict(self): - pass - - dummy_post = DummyPosterior(pr, likelihood) - assert isinstance(dummy_post, AbstractPosterior) - assert dummy_post.likelihood == likelihood - assert dummy_post.prior == pr - - -@pytest.mark.parametrize("lik", [Bernoulli, Gaussian]) -def test_posterior_construct(lik): - pr = Prior(kernel=RBF()) - likelihood = lik(num_datapoints=10) - p1 = pr * likelihood - p2 = construct_posterior(prior=pr, likelihood=likelihood) - assert type(p1) == type(p2) - - -@pytest.mark.parametrize("kernel", [RBF(), Matern12(), Matern32(), Matern52()]) -def test_initialisation_override(kernel): - key = jr.PRNGKey(123) - override_params = {"lengthscale": jnp.array([0.5]), "variance": jnp.array([0.1])} - p = Prior(kernel=kernel) * Gaussian(num_datapoints=10) - parameter_state = initialise(p, key, kernel=override_params) - ds = parameter_state.unpack() - for d in ds: - assert "lengthscale" in d["kernel"].keys() - assert "variance" in d["kernel"].keys() - assert ds[0]["kernel"]["lengthscale"] == jnp.array([0.5]) - assert ds[0]["kernel"]["variance"] == jnp.array([0.1]) - - with pytest.raises(ValueError): - parameter_state = initialise(p, key, keernel=override_params) + + + # # Loss function + # loss_fn = post.loss_function() + # assert isinstance(loss_fn, AbstractObjective) + # assert isinstance(loss_fn, NonConjugateMLL) + # if jit_compile: + # loss_fn = jax.jit(loss_fn) + # objective_val = loss_fn(params=params, data=D) + # assert isinstance(objective_val, jax.Array) + # assert objective_val.shape == () + + +# @pytest.mark.parametrize("lik", [Bernoulli, Gaussian]) +# def test_abstract_posterior(lik): +# pr = Prior(kernel=RBF()) +# likelihood = lik(num_datapoints=10) + +# with pytest.raises(TypeError): +# _ = AbstractPosterior(pr, likelihood) + +# class DummyPosterior(AbstractPosterior): +# def predict(self): +# pass + +# dummy_post = DummyPosterior(pr, likelihood) +# assert isinstance(dummy_post, AbstractPosterior) +# assert dummy_post.likelihood == likelihood +# assert dummy_post.prior == pr + + +# @pytest.mark.parametrize("lik", [Bernoulli, Gaussian]) +# def test_posterior_construct(lik): +# pr = Prior(kernel=RBF()) +# likelihood = lik(num_datapoints=10) +# p1 = pr * likelihood +# p2 = construct_posterior(prior=pr, likelihood=likelihood) +# assert type(p1) == type(p2) \ No newline at end of file diff --git a/tests/test_natural_gradients.py b/tests/test_natural_gradients.py deleted file mode 100644 index 847a3148b..000000000 --- a/tests/test_natural_gradients.py +++ /dev/null @@ -1,263 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import typing as tp - -import jax -import jax.numpy as jnp -import jax.random as jr -import pytest -from jax.config import config - -import gpjax as gpx -from gpjax.abstractions import get_batch -from gpjax.natural_gradients import ( - _expectation_elbo, - _rename_expectation_to_natural, - _rename_natural_to_expectation, - natural_gradients, - natural_to_expectation, -) -from gpjax.parameters import recursive_items - -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) -key = jr.PRNGKey(123) - - -def get_data_and_gp(n_datapoints): - x = jnp.linspace(-5.0, 5.0, n_datapoints).reshape(-1, 1) - y = jnp.sin(x) + jr.normal(key=jr.PRNGKey(123), shape=x.shape) * 0.1 - D = gpx.Dataset(X=x, y=y) - - p = gpx.Prior(kernel=gpx.RBF()) - lik = gpx.Gaussian(num_datapoints=n_datapoints) - post = p * lik - return D, post, p - - -@pytest.mark.parametrize("dim", [1, 2, 3]) -def test_natural_to_expectation(dim): - - _, posterior, prior = get_data_and_gp(10) - - z = jnp.linspace(-5.0, 5.0, 5 * dim).reshape(-1, dim) - expectation_variational_family = ( - gpx.variational_families.ExpectationVariationalGaussian( - prior=prior, inducing_inputs=z - ) - ) - - natural_variational_family = gpx.variational_families.NaturalVariationalGaussian( - prior=prior, inducing_inputs=z - ) - - natural_svgp = gpx.StochasticVI( - posterior=posterior, variational_family=natural_variational_family - ) - expectation_svgp = gpx.StochasticVI( - posterior=posterior, variational_family=expectation_variational_family - ) - - key = jr.PRNGKey(123) - natural_params, *_ = gpx.initialise(natural_svgp, key).unpack() - expectation_params, *_ = gpx.initialise(expectation_svgp, key).unpack() - - expectation_params_test = natural_to_expectation(natural_params) - - assert ( - "expectation_vector" - in expectation_params_test["variational_family"]["moments"].keys() - ) - assert ( - "expectation_matrix" - in expectation_params_test["variational_family"]["moments"].keys() - ) - assert ( - expectation_params_test["variational_family"]["moments"][ - "expectation_vector" - ].shape - == expectation_params["variational_family"]["moments"][ - "expectation_vector" - ].shape - ) - assert ( - expectation_params_test["variational_family"]["moments"][ - "expectation_matrix" - ].shape - == expectation_params["variational_family"]["moments"][ - "expectation_matrix" - ].shape - ) - - -from copy import deepcopy - - -def test_renaming(): - - _, posterior, prior = get_data_and_gp(10) - - z = jnp.linspace(-5.0, 5.0, 5).reshape(-1, 1) - expectation_variational_family = ( - gpx.variational_families.ExpectationVariationalGaussian( - prior=prior, inducing_inputs=z - ) - ) - - natural_variational_family = gpx.variational_families.NaturalVariationalGaussian( - prior=prior, inducing_inputs=z - ) - - natural_svgp = gpx.StochasticVI( - posterior=posterior, variational_family=natural_variational_family - ) - expectation_svgp = gpx.StochasticVI( - posterior=posterior, variational_family=expectation_variational_family - ) - - key = jr.PRNGKey(123) - natural_params, *_ = gpx.initialise(natural_svgp, key).unpack() - expectation_params, *_ = gpx.initialise(expectation_svgp, key).unpack() - - _nat = deepcopy(natural_params) - _exp = deepcopy(expectation_params) - - rename_expectation_to_natural = _rename_expectation_to_natural(_exp) - rename_natural_to_expectation = _rename_natural_to_expectation(_nat) - - # Check correct names are in the dictionaries: - assert ( - "expectation_vector" - in rename_natural_to_expectation["variational_family"]["moments"].keys() - ) - assert ( - "expectation_matrix" - in rename_natural_to_expectation["variational_family"]["moments"].keys() - ) - assert ( - "natural_vector" - not in rename_natural_to_expectation["variational_family"]["moments"].keys() - ) - assert ( - "natural_matrix" - not in rename_natural_to_expectation["variational_family"]["moments"].keys() - ) - - assert ( - "natural_vector" - in rename_expectation_to_natural["variational_family"]["moments"].keys() - ) - assert ( - "natural_matrix" - in rename_expectation_to_natural["variational_family"]["moments"].keys() - ) - assert ( - "expectation_vector" - not in rename_expectation_to_natural["variational_family"]["moments"].keys() - ) - assert ( - "expectation_matrix" - not in rename_expectation_to_natural["variational_family"]["moments"].keys() - ) - - # Check the values are unchanged: - for v1, v2 in zip( - rename_natural_to_expectation["variational_family"]["moments"].values(), - natural_params["variational_family"]["moments"].values(), - ): - assert jnp.all(v1 == v2) - - for v1, v2 in zip( - rename_expectation_to_natural["variational_family"]["moments"].values(), - expectation_params["variational_family"]["moments"].values(), - ): - assert jnp.all(v1 == v2) - - -@pytest.mark.parametrize("jit_fns", [True, False]) -def test_expectation_elbo(jit_fns): - """ - Tests the expectation ELBO. - """ - D, posterior, prior = get_data_and_gp(10) - - z = jnp.linspace(-5.0, 5.0, 5).reshape(-1, 1) - variational_family = gpx.variational_families.ExpectationVariationalGaussian( - prior=prior, inducing_inputs=z - ) - - svgp = gpx.StochasticVI(posterior=posterior, variational_family=variational_family) - - params, _, _ = gpx.initialise(svgp, jr.PRNGKey(123)).unpack() - - expectation_elbo = _expectation_elbo(posterior, variational_family, D) - - if jit_fns: - elbo_fn = jax.jit(expectation_elbo) - else: - elbo_fn = expectation_elbo - - assert isinstance(elbo_fn, tp.Callable) - elbo_value = elbo_fn(params, D) - assert isinstance(elbo_value, jnp.ndarray) - - # Test gradients - grads = jax.grad(elbo_fn, argnums=0)(params, D) - assert isinstance(grads, tp.Dict) - assert len(grads) == len(params) - - -def test_natural_gradients(): - """ - Tests the natural gradient and hyperparameter gradients. - """ - D, p, prior = get_data_and_gp(10) - - z = jnp.linspace(-5.0, 5.0, 5).reshape(-1, 1) - prior = gpx.Prior(kernel=gpx.RBF()) - q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z) - - svgp = gpx.StochasticVI(posterior=p, variational_family=q) - - params, trainables, bijectors = gpx.initialise(svgp, jr.PRNGKey(123)).unpack() - - batch = get_batch(D, batch_size=10, key=jr.PRNGKey(42)) - - nat_grads_fn, hyper_grads_fn = natural_gradients(svgp, D, bijectors, trainables) - - assert isinstance(nat_grads_fn, tp.Callable) - assert isinstance(hyper_grads_fn, tp.Callable) - - val, nat_grads = nat_grads_fn(params, batch) - val, hyper_grads = hyper_grads_fn(params, batch) - - assert isinstance(val, jnp.ndarray) - assert isinstance(nat_grads, tp.Dict) - assert isinstance(hyper_grads, tp.Dict) - - # Need to check moments are zero in hyper_grads: - assert jnp.array( - [ - (v == 0.0).all() - for v in hyper_grads["variational_family"]["moments"].values() - ] - ).all() - - # Check non-moments are zero in nat_grads: - d = jax.tree_map(lambda x: (x == 0.0).all(), nat_grads) - d["variational_family"]["moments"] = True - - assert jnp.array([v1 == True for k, v1, v2 in recursive_items(d, d)]).all() diff --git a/tests/test_objectives.py b/tests/test_objectives.py new file mode 100644 index 000000000..81bcb561f --- /dev/null +++ b/tests/test_objectives.py @@ -0,0 +1,178 @@ +from gpjax.objectives import ( + AbstractObjective, + LogPosteriorDensity, + ConjugateMLL, + NonConjugateMLL, + CollapsedELBO, + ELBO, +) +import gpjax as gpx +import pytest +import jax.random as jr +import jax.numpy as jnp +from gpjax import Prior, Gaussian, Bernoulli +import jax +from gpjax.dataset import Dataset + + +def test_abstract_objective(): + with pytest.raises(TypeError): + AbstractObjective() + + +def build_data(num_datapoints: int, num_dims: int, key, binary: bool): + x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, num_dims)) + if binary: + y = ( + 0.5 + * jnp.sign( + jnp.cos( + 3 * x[:, 1].reshape(-1, 1) + + jr.normal(key, shape=(num_datapoints, 1)) * 0.05 + ) + ) + + 0.5 + ) + else: + y = ( + jnp.sin(x[:, 1]).reshape(-1, 1) + + jr.normal(key=key, shape=(num_datapoints, 1)) * 0.1 + ) + D = Dataset(X=x, y=y) + return D + + +@pytest.mark.parametrize("num_datapoints", [1, 2, 10]) +@pytest.mark.parametrize("num_dims", [1, 2, 3]) +@pytest.mark.parametrize("negative", [False, True]) +@pytest.mark.parametrize("jit_compile", [False, True]) +@pytest.mark.parametrize("key_val", [123, 42]) +def test_conjugate_mll( + num_datapoints: int, num_dims: int, negative: bool, jit_compile: bool, key_val: int +): + key = jr.PRNGKey(key_val) + D = build_data(num_datapoints, num_dims, key, binary=False) + + # Build model + p = Prior(kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant()) + lik = Gaussian(num_datapoints=num_datapoints) + post = p * lik + + mll = ConjugateMLL(negative=negative) + assert isinstance(mll, AbstractObjective) + + if jit_compile: + mll = jax.jit(mll) + + evaluation = mll(post, D) + assert isinstance(evaluation, jax.Array) + assert evaluation.shape == () + + +@pytest.mark.parametrize("num_datapoints", [1, 2, 10]) +@pytest.mark.parametrize("num_dims", [1, 2, 3]) +@pytest.mark.parametrize("negative", [False, True]) +@pytest.mark.parametrize("jit_compile", [False, True]) +@pytest.mark.parametrize("key_val", [123, 42]) +def test_non_conjugate_mll( + num_datapoints: int, num_dims: int, negative: bool, jit_compile: bool, key_val: int +): + key = jr.PRNGKey(key_val) + D = build_data(num_datapoints, num_dims, key, binary=True) + + # Build model + p = Prior(kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant()) + lik = Bernoulli(num_datapoints=num_datapoints) + post = p * lik + + mll = NonConjugateMLL(negative=negative) + assert isinstance(mll, AbstractObjective) + if jit_compile: + mll = jax.jit(mll) + + evaluation = mll(post, D) + assert isinstance(evaluation, jax.Array) + assert evaluation.shape == () + + mll2 = LogPosteriorDensity(negative=negative) + + if jit_compile: + mll2 = jax.jit(mll2) + assert mll2(post, D) == evaluation + + +@pytest.mark.parametrize("num_datapoints", [10, 20]) +@pytest.mark.parametrize("num_dims", [1, 2, 3]) +@pytest.mark.parametrize("negative", [False, True]) +@pytest.mark.parametrize("jit_compile", [False, True]) +@pytest.mark.parametrize("key_val", [123, 42]) +def test_collapsed_elbo( + num_datapoints: int, num_dims: int, negative: bool, jit_compile: bool, key_val: int +): + key = jr.PRNGKey(key_val) + D = build_data(num_datapoints, num_dims, key, binary=False) + z = jr.uniform( + key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints // 2, num_dims) + ) + + p = Prior(kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant()) + lik = Gaussian(num_datapoints=num_datapoints) + q = gpx.CollapsedVariationalGaussian(posterior= p * lik, inducing_inputs=z) + + negative_elbo = CollapsedELBO(negative=negative) + + assert isinstance(negative_elbo, AbstractObjective) + + if jit_compile: + negative_elbo = jax.jit(negative_elbo) + + evaluation = negative_elbo(q, D) + assert isinstance(evaluation, jax.Array) + assert evaluation.shape == () + + # bern_post = p * Bernoulli(num_datapoints=num_datapoints) + # with pytest.raises(TypeError): + # gpx.CollapsedELBO(posterior=bern_post, variational_family=q, negative=negative) + + +@pytest.mark.parametrize("num_datapoints", [1, 2, 10]) +@pytest.mark.parametrize("num_dims", [1, 2, 3]) +@pytest.mark.parametrize("negative", [False, True]) +@pytest.mark.parametrize("jit_compile", [False, True]) +@pytest.mark.parametrize("key_val", [123, 42]) +@pytest.mark.parametrize("binary", [True, False]) +def test_elbo( + num_datapoints: int, + num_dims: int, + negative: bool, + jit_compile: bool, + key_val: int, + binary: bool, +): + key = jr.PRNGKey(key_val) + D = build_data(num_datapoints, num_dims, key, binary=binary) + z = jr.uniform( + key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints // 2, num_dims) + ) + + p = Prior(kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant()) + if binary: + lik = Bernoulli(num_datapoints=num_datapoints) + else: + lik = Gaussian(num_datapoints=num_datapoints) + post = p * lik + + q = gpx.VariationalGaussian(posterior= post, inducing_inputs=z) + + negative_elbo = ELBO( + negative=negative, + ) + + assert isinstance(negative_elbo, AbstractObjective) + + if jit_compile: + negative_elbo = jax.jit(negative_elbo) + + evaluation = negative_elbo(q, D) + assert isinstance(evaluation, jax.Array) + assert evaluation.shape == () diff --git a/tests/test_variational_families.py b/tests/test_variational_families.py index 7fa3fa344..1451b0810 100644 --- a/tests/test_variational_families.py +++ b/tests/test_variational_families.py @@ -13,10 +13,9 @@ # limitations under the License. # ============================================================================== -from typing import Callable, Dict, Tuple +from typing import Callable, Tuple import distrax as dx -import jax import jax.numpy as jnp import jax.random as jr import pytest @@ -44,11 +43,11 @@ def test_abstract_variational_family(): # Create a dummy variational family class with abstract methods implemented. class DummyVariationalFamily(AbstractVariationalFamily): - def predict(self, params: Dict, x: Float[Array, "N D"]) -> dx.Distribution: + def predict(self, x: Float[Array, "N D"]) -> dx.Distribution: return dx.MultivariateNormalDiag(loc=x) # Test that the dummy variational family can be instantiated. - dummy_variational_family = DummyVariationalFamily() + dummy_variational_family = DummyVariationalFamily(posterior=None) assert isinstance(dummy_variational_family, AbstractVariationalFamily) @@ -94,10 +93,11 @@ def test_variational_gaussians( ) -> None: # Initialise variational family: - prior = gpx.Prior(kernel=gpx.RBF()) + prior = gpx.Prior(kernel=gpx.RBF(), mean_function=gpx.Constant()) + likelihood = gpx.Gaussian(123) inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1) test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1) - q = variational_family(prior=prior, inducing_inputs=inducing_inputs) + q = variational_family(posterior = prior*likelihood, inducing_inputs=inducing_inputs) # Test init: assert q.num_inducing == n_inducing @@ -160,7 +160,7 @@ def test_collapsed_variational_gaussian( x = jnp.hstack([x] * point_dim) D = gpx.Dataset(X=x, y=y) - prior = gpx.Prior(kernel=gpx.RBF()) + prior = gpx.Prior(kernel=gpx.RBF(), mean_function=gpx.Constant()) inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1) inducing_inputs = jnp.hstack([inducing_inputs] * point_dim) @@ -168,23 +168,21 @@ def test_collapsed_variational_gaussian( test_inputs = jnp.hstack([test_inputs] * point_dim) variational_family = CollapsedVariationalGaussian( - prior=prior, - likelihood=gpx.Gaussian(num_datapoints=D.n), + posterior=prior*gpx.Gaussian(num_datapoints=D.n), inducing_inputs=inducing_inputs, ) # We should raise an error for non-Gaussian likelihoods: with pytest.raises(TypeError): CollapsedVariationalGaussian( - prior=prior, - likelihood=gpx.Bernoulli(num_datapoints=D.n), + posterior= prior * gpx.Bernoulli(num_datapoints=D.n), inducing_inputs=inducing_inputs, ) # Test init assert variational_family.num_inducing == n_inducing assert (variational_family.inducing_inputs == inducing_inputs).all() - assert variational_family.likelihood.obs_noise == 1.0 + assert variational_family.posterior.likelihood.obs_noise == 1.0 # Test predictions predictive_dist = variational_family(test_inputs, D) From 3c6aaa0ad9c563cabd7978b843d60a31d46b05ab Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Sun, 2 Apr 2023 22:52:36 +0100 Subject: [PATCH 22/44] Use tfb bijectors, update base. --- gpjax/{module => base}/__init__.py | 0 gpjax/{module => base}/module.py | 9 +- gpjax/{module => base}/param.py | 5 +- gpjax/kernels/base.py | 3 +- gpjax/kernels/nonstationary/linear.py | 11 +- gpjax/kernels/nonstationary/polynomial.py | 12 +- gpjax/kernels/stationary/matern12.py | 18 +- gpjax/kernels/stationary/matern32.py | 17 +- gpjax/kernels/stationary/matern52.py | 17 +- gpjax/kernels/stationary/periodic.py | 21 +- .../kernels/stationary/powered_exponential.py | 16 +- .../kernels/stationary/rational_quadratic.py | 18 +- gpjax/kernels/stationary/rbf.py | 17 +- gpjax/kernels/stationary/white.py | 11 +- tests/test_abstractions.py | 355 +++++++++--------- tests/test_base/__init__.py | 0 .../{test_params => test_base}/test_module.py | 192 +++++----- tests/test_base/test_params.py | 43 +++ tests/test_config.py | 72 ---- tests/test_fit.py | 20 +- tests/test_kernels/test_stationary.py | 18 +- tests/test_params/test_bijectors.py | 22 -- tests/test_params/test_parameters.py | 41 -- 23 files changed, 427 insertions(+), 511 deletions(-) rename gpjax/{module => base}/__init__.py (100%) rename gpjax/{module => base}/module.py (96%) rename gpjax/{module => base}/param.py (91%) create mode 100644 tests/test_base/__init__.py rename tests/{test_params => test_base}/test_module.py (74%) create mode 100644 tests/test_base/test_params.py delete mode 100644 tests/test_config.py delete mode 100644 tests/test_params/test_bijectors.py delete mode 100644 tests/test_params/test_parameters.py diff --git a/gpjax/module/__init__.py b/gpjax/base/__init__.py similarity index 100% rename from gpjax/module/__init__.py rename to gpjax/base/__init__.py diff --git a/gpjax/module/module.py b/gpjax/base/module.py similarity index 96% rename from gpjax/module/module.py rename to gpjax/base/module.py index 07f09af40..c0b491a46 100644 --- a/gpjax/module/module.py +++ b/gpjax/base/module.py @@ -12,8 +12,7 @@ from jax._src.tree_util import _registry from simple_pytree import Pytree, static_field -from .bijectors import Bijector, Identity - +import tensorflow_probability.substrates.jax.bijectors as tfb class Module(Pytree): _pytree__meta: Dict[str, Any] = static_field() @@ -97,7 +96,7 @@ def replace_trainable(self: Module, **kwargs: Dict[str, bool]) -> Self: """Replace the trainability status of local nodes of the Module.""" return self.update_meta(**{k: {"trainable": v} for k, v in kwargs.items()}) - def replace_bijector(self: Module, **kwargs: Dict[str, Bijector]) -> Self: + def replace_bijector(self: Module, **kwargs: Dict[str, tfb.Bijector]) -> Self: """Replace the bijectors of local nodes of the Module.""" return self.update_meta(**{k: {"bijector": v} for k, v in kwargs.items()}) @@ -110,7 +109,7 @@ def constrain(self) -> Self: def _apply_constrain(meta_leaf): meta, leaf = meta_leaf - return meta.get("bijector", Identity).forward(leaf) + return meta.get("bijector", tfb.Identity()).forward(leaf) return meta_map(_apply_constrain, self) @@ -123,7 +122,7 @@ def unconstrain(self) -> Self: def _apply_unconstrain(meta_leaf): meta, leaf = meta_leaf - return meta.get("bijector", Identity).inverse(leaf) + return meta.get("bijector", tfb.Identity()).inverse(leaf) return meta_map(_apply_unconstrain, self) diff --git a/gpjax/module/param.py b/gpjax/base/param.py similarity index 91% rename from gpjax/module/param.py rename to gpjax/base/param.py index 745f95488..543a9fd9e 100644 --- a/gpjax/module/param.py +++ b/gpjax/base/param.py @@ -5,13 +5,12 @@ import dataclasses from typing import Any, Mapping, Optional -from .bijectors import Bijector, Identity - +import tensorflow_probability.substrates.jax.bijectors as tfb def param_field( default: Any = dataclasses.MISSING, *, - bijector: Bijector = Identity, + bijector: tfb.Bijector = tfb.Identity(), trainable: bool = True, default_factory: Any = dataclasses.MISSING, init: bool = True, diff --git a/gpjax/kernels/base.py b/gpjax/kernels/base.py index 5406310d1..0a1aee259 100644 --- a/gpjax/kernels/base.py +++ b/gpjax/kernels/base.py @@ -20,11 +20,10 @@ from typing import List, Callable, Union from jaxtyping import Array, Float from functools import partial -from ..parameters import Module, param_field from simple_pytree import static_field from dataclasses import dataclass -from functools import partial +from ..base import Module, param_field from .computations import AbstractKernelComputation, DenseKernelComputation diff --git a/gpjax/kernels/nonstationary/linear.py b/gpjax/kernels/nonstationary/linear.py index acb5dfb8f..66ada13fb 100644 --- a/gpjax/kernels/nonstationary/linear.py +++ b/gpjax/kernels/nonstationary/linear.py @@ -14,20 +14,21 @@ # ============================================================================== import jax.numpy as jnp -from jaxtyping import Array - -from ..base import AbstractKernel +import tensorflow_probability.substrates.jax.bijectors as tfb +from jaxtyping import Array from dataclasses import dataclass from jaxtyping import Array, Float -from ...parameters import param_field, Softplus + +from ..base import AbstractKernel +from ...base import param_field @dataclass class Linear(AbstractKernel): """The linear kernel.""" - variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) def __call__( self, diff --git a/gpjax/kernels/nonstationary/polynomial.py b/gpjax/kernels/nonstationary/polynomial.py index e8a3d12d5..4a25db188 100644 --- a/gpjax/kernels/nonstationary/polynomial.py +++ b/gpjax/kernels/nonstationary/polynomial.py @@ -14,11 +14,15 @@ # ============================================================================== import jax.numpy as jnp +import tensorflow_probability.substrates.jax.bijectors as tfb + from jaxtyping import Array, Float -from ..base import AbstractKernel from dataclasses import dataclass from simple_pytree import static_field -from ...parameters import param_field, Softplus + +from ..base import AbstractKernel +from ...base import param_field + @dataclass @@ -26,8 +30,8 @@ class Polynomial(AbstractKernel): """The Polynomial kernel with variable degree.""" degree: int = static_field(2) - shift: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) - variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) + shift: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with shift parameter :math:`\\alpha` and variance :math:`\\sigma^2` through diff --git a/gpjax/kernels/stationary/matern12.py b/gpjax/kernels/stationary/matern12.py index a991231d6..23ecb881b 100644 --- a/gpjax/kernels/stationary/matern12.py +++ b/gpjax/kernels/stationary/matern12.py @@ -13,17 +13,15 @@ # limitations under the License. # ============================================================================== -from dataclasses import dataclass -from typing import Dict, List, Optional - import jax.numpy as jnp -from jax.random import KeyArray +import tensorflow_probability.substrates.jax.bijectors as tfb +import tensorflow_probability.substrates.jax.distributions as tfd + from jaxtyping import Array, Float -import distrax as dx +from dataclasses import dataclass -from ...parameters import Softplus, param_field +from ...base import param_field from ..base import AbstractKernel -from ..computations import DenseKernelComputation from .utils import build_student_t_distribution, euclidean_distance @@ -31,8 +29,8 @@ class Matern12(AbstractKernel): """The Matérn kernel with smoothness parameter fixed at 0.5.""" - lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus) - variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) + lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with @@ -53,5 +51,5 @@ def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, " return K.squeeze() @property - def spectral_density(self) -> dx.Distribution: + def spectral_density(self) -> tfd.Distribution: return build_student_t_distribution(nu=1) diff --git a/gpjax/kernels/stationary/matern32.py b/gpjax/kernels/stationary/matern32.py index 9a630392e..ca2068e81 100644 --- a/gpjax/kernels/stationary/matern32.py +++ b/gpjax/kernels/stationary/matern32.py @@ -13,16 +13,15 @@ # limitations under the License. # ============================================================================== -from dataclasses import dataclass -from typing import Dict, List, Optional - import jax.numpy as jnp -from jax.random import KeyArray +import tensorflow_probability.substrates.jax.bijectors as tfb +import tensorflow_probability.substrates.jax.distributions as tfd + from jaxtyping import Array, Float +from dataclasses import dataclass -from ...parameters import Softplus, param_field +from ...base import param_field from ..base import AbstractKernel -from ..computations import DenseKernelComputation from .utils import build_student_t_distribution, euclidean_distance @@ -30,8 +29,8 @@ class Matern32(AbstractKernel): """The Matérn kernel with smoothness parameter fixed at 1.5.""" - lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus) - variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) + lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) def __call__( self, @@ -58,5 +57,5 @@ def __call__( return K.squeeze() @property - def spectral_density(self): + def spectral_density(self) -> tfd.Distribution: return build_student_t_distribution(nu=3) diff --git a/gpjax/kernels/stationary/matern52.py b/gpjax/kernels/stationary/matern52.py index 70078b734..4bae5ca99 100644 --- a/gpjax/kernels/stationary/matern52.py +++ b/gpjax/kernels/stationary/matern52.py @@ -13,16 +13,15 @@ # limitations under the License. # ============================================================================== -from dataclasses import dataclass -from typing import Dict, List, Optional - import jax.numpy as jnp -from jax.random import KeyArray +import tensorflow_probability.substrates.jax.bijectors as tfb +import tensorflow_probability.substrates.jax.distributions as tfd + from jaxtyping import Array, Float +from dataclasses import dataclass -from ...parameters import Softplus, param_field +from ...base import param_field from ..base import AbstractKernel -from ..computations import DenseKernelComputation from .utils import build_student_t_distribution, euclidean_distance @@ -30,8 +29,8 @@ class Matern52(AbstractKernel): """The Matérn kernel with smoothness parameter fixed at 2.5.""" - lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus) - variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) + lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) def __call__( self, x: Float[Array, "D"], y: Float[Array, "D"] @@ -60,5 +59,5 @@ def __call__( return K.squeeze() @property - def spectral_density(self): + def spectral_density(self) -> tfd.Distribution: return build_student_t_distribution(nu=5) diff --git a/gpjax/kernels/stationary/periodic.py b/gpjax/kernels/stationary/periodic.py index 4d03a414a..b9fccf4ec 100644 --- a/gpjax/kernels/stationary/periodic.py +++ b/gpjax/kernels/stationary/periodic.py @@ -13,20 +13,15 @@ # limitations under the License. # ============================================================================== -from typing import Dict, List, Optional - -import jax import jax.numpy as jnp -from jax.random import KeyArray +import tensorflow_probability.substrates.jax.bijectors as tfb +import tensorflow_probability.substrates.jax.distributions as tfd + from jaxtyping import Array, Float +from dataclasses import dataclass +from ...base import param_field from ..base import AbstractKernel -from ..computations import ( - DenseKernelComputation, -) - -from dataclasses import dataclass -from ...parameters import param_field, Softplus @dataclass @@ -36,9 +31,9 @@ class Periodic(AbstractKernel): Key reference is MacKay 1998 - "Introduction to Gaussian processes". """ - lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus) - variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) - period: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) + lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + period: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell` and variance :math:`\\sigma` diff --git a/gpjax/kernels/stationary/powered_exponential.py b/gpjax/kernels/stationary/powered_exponential.py index 320d10879..725f17f3e 100644 --- a/gpjax/kernels/stationary/powered_exponential.py +++ b/gpjax/kernels/stationary/powered_exponential.py @@ -13,17 +13,15 @@ # limitations under the License. # ============================================================================== -from dataclasses import dataclass -from typing import Dict, List, Optional - -import jax import jax.numpy as jnp -from jax.random import KeyArray +import tensorflow_probability.substrates.jax.bijectors as tfb +import tensorflow_probability.substrates.jax.distributions as tfd + from jaxtyping import Array, Float +from dataclasses import dataclass -from ...parameters import Softplus, param_field +from ...base import param_field from ..base import AbstractKernel -from ..computations import DenseKernelComputation from .utils import euclidean_distance @@ -35,8 +33,8 @@ class PoweredExponential(AbstractKernel): """ - lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus) - variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) + lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) power: Float[Array, "1"] = param_field(jnp.array([1.0])) def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: diff --git a/gpjax/kernels/stationary/rational_quadratic.py b/gpjax/kernels/stationary/rational_quadratic.py index cf36c24c8..f3887c71b 100644 --- a/gpjax/kernels/stationary/rational_quadratic.py +++ b/gpjax/kernels/stationary/rational_quadratic.py @@ -13,26 +13,24 @@ # limitations under the License. # ============================================================================== -from dataclasses import dataclass -from typing import List, Optional - -import jax import jax.numpy as jnp -from jax.random import KeyArray +import tensorflow_probability.substrates.jax.bijectors as tfb +import tensorflow_probability.substrates.jax.distributions as tfd + from jaxtyping import Array, Float +from dataclasses import dataclass -from ...parameters import Softplus, param_field +from ...base import param_field from ..base import AbstractKernel -from ..computations import DenseKernelComputation from .utils import squared_distance @dataclass class RationalQuadratic(AbstractKernel): - lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus) - variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) - alpha: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) + lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + alpha: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell` and variance :math:`\\sigma` diff --git a/gpjax/kernels/stationary/rbf.py b/gpjax/kernels/stationary/rbf.py index d04fcef53..ae78cb158 100644 --- a/gpjax/kernels/stationary/rbf.py +++ b/gpjax/kernels/stationary/rbf.py @@ -14,22 +14,23 @@ # ============================================================================== import jax.numpy as jnp +import tensorflow_probability.substrates.jax.bijectors as tfb +import tensorflow_probability.substrates.jax.distributions as tfd + from jaxtyping import Array, Float +from dataclasses import dataclass +from ...base import param_field from ..base import AbstractKernel from .utils import squared_distance -import distrax as dx - -from dataclasses import dataclass -from ...parameters import param_field, Softplus @dataclass class RBF(AbstractKernel): """The Radial Basis Function (RBF) kernel.""" - lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus) - variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) + lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with @@ -52,5 +53,5 @@ def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, " return K.squeeze() @property - def spectral_density(self) -> dx.Normal: - return dx.Normal(loc=0.0, scale=1.0) + def spectral_density(self) -> tfd.Normal: + return tfd.Normal(loc=0.0, scale=1.0) diff --git a/gpjax/kernels/stationary/white.py b/gpjax/kernels/stationary/white.py index 8b7c6caaf..83144513c 100644 --- a/gpjax/kernels/stationary/white.py +++ b/gpjax/kernels/stationary/white.py @@ -13,14 +13,15 @@ # limitations under the License. # ============================================================================== -from dataclasses import dataclass -from typing import Dict, List, Optional - import jax.numpy as jnp +import tensorflow_probability.substrates.jax.bijectors as tfb +import tensorflow_probability.substrates.jax.distributions as tfd + from jaxtyping import Array, Float +from dataclasses import dataclass from simple_pytree import static_field -from ...parameters import Softplus, param_field +from ...base import param_field from ..base import AbstractKernel from ..computations import AbstractKernelComputation, ConstantDiagonalKernelComputation @@ -28,7 +29,7 @@ @dataclass class White(AbstractKernel): - variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) compute_engine: AbstractKernelComputation = static_field( ConstantDiagonalKernelComputation ) diff --git a/tests/test_abstractions.py b/tests/test_abstractions.py index 4b7fce300..7105a54de 100644 --- a/tests/test_abstractions.py +++ b/tests/test_abstractions.py @@ -1,178 +1,177 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import jax.numpy as jnp -import jax.random as jr -import optax -import pytest -from jax.config import config - -import gpjax as gpx -from gpjax import RBF, Dataset, Gaussian, Prior, initialise -from gpjax.abstractions import InferenceState, fit, fit_batches, fit_natgrads, get_batch -from gpjax.parameters import ParameterState, build_bijectors - -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) - - -@pytest.mark.parametrize("num_iters", [1, 5]) -@pytest.mark.parametrize("n", [1, 20]) -@pytest.mark.parametrize("verbose", [True, False]) -def test_fit(num_iters, n, verbose): - key = jr.PRNGKey(123) - x = jnp.sort(jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(n, 1)), axis=0) - y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 - D = Dataset(X=x, y=y) - p = Prior(kernel=RBF()) * Gaussian(num_datapoints=n) - parameter_state = initialise(p, key) - mll = p.marginal_log_likelihood(D, negative=True) - pre_mll_val = mll(parameter_state.params) - optimiser = optax.adam(learning_rate=0.1) - inference_state = fit(mll, parameter_state, optimiser, num_iters, verbose=verbose) - optimised_params, history = inference_state.unpack() - assert isinstance(inference_state, InferenceState) - assert isinstance(optimised_params, dict) - assert mll(optimised_params) < pre_mll_val - assert isinstance(history, jnp.ndarray) - assert history.shape[0] == num_iters - - -def test_stop_grads(): - params = {"x": jnp.array(3.0), "y": jnp.array(4.0)} - trainables = {"x": True, "y": False} - bijectors = build_bijectors(params) - loss_fn = lambda params: params["x"] ** 2 + params["y"] ** 2 - optimiser = optax.adam(learning_rate=0.1) - parameter_state = ParameterState( - params=params, trainables=trainables, bijectors=bijectors - ) - inference_state = fit(loss_fn, parameter_state, optimiser, num_iters=1) - learned_params = inference_state.params - assert isinstance(inference_state, InferenceState) - assert learned_params["y"] == params["y"] - assert learned_params["x"] != params["x"] - - -@pytest.mark.parametrize("num_iters", [1, 5]) -@pytest.mark.parametrize("nb", [1, 20, 50]) -@pytest.mark.parametrize("ndata", [50]) -@pytest.mark.parametrize("verbose", [True, False]) -def test_batch_fitting(num_iters, nb, ndata, verbose): - key = jr.PRNGKey(123) - x = jnp.sort(jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(ndata, 1)), axis=0) - y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 - D = Dataset(X=x, y=y) - prior = Prior(kernel=RBF()) - likelihood = Gaussian(num_datapoints=ndata) - p = prior * likelihood - z = jnp.linspace(-2.0, 2.0, 10).reshape(-1, 1) - - q = gpx.VariationalGaussian(prior=prior, inducing_inputs=z) - - svgp = gpx.StochasticVI(posterior=p, variational_family=q) - parameter_state = initialise(svgp, key) - objective = svgp.elbo(D, negative=True) - - pre_mll_val = objective(parameter_state.params, D) - - D = Dataset(X=x, y=y) - - optimiser = optax.adam(learning_rate=0.1) - key = jr.PRNGKey(42) - inference_state = fit_batches( - objective, parameter_state, D, optimiser, key, nb, num_iters, verbose=verbose - ) - optimised_params, history = inference_state.unpack() - assert isinstance(inference_state, InferenceState) - assert isinstance(optimised_params, dict) - assert objective(optimised_params, D) < pre_mll_val - assert isinstance(history, jnp.ndarray) - assert history.shape[0] == num_iters - - -@pytest.mark.parametrize("num_iters", [1, 5]) -@pytest.mark.parametrize("nb", [1, 20, 50]) -@pytest.mark.parametrize("ndata", [50]) -@pytest.mark.parametrize("verbose", [True, False]) -def test_natural_gradients(ndata, nb, num_iters, verbose): - key = jr.PRNGKey(123) - x = jnp.sort(jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(ndata, 1)), axis=0) - y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 - D = Dataset(X=x, y=y) - prior = Prior(kernel=RBF()) - likelihood = Gaussian(num_datapoints=ndata) - p = prior * likelihood - z = jnp.linspace(-2.0, 2.0, 10).reshape(-1, 1) - - q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z) - - svgp = gpx.StochasticVI(posterior=p, variational_family=q) - training_state = initialise(svgp, key) - - D = Dataset(X=x, y=y) - - hyper_optimiser = optax.adam(learning_rate=0.1) - moment_optimiser = optax.sgd(learning_rate=1.0) - - objective = svgp.elbo(D, negative=True) - parameter_state = initialise(svgp, key) - pre_mll_val = objective(parameter_state.params, D) - - key = jr.PRNGKey(42) - inference_state = fit_natgrads( - svgp, - training_state, - D, - moment_optimiser, - hyper_optimiser, - key, - nb, - num_iters, - verbose=verbose, - ) - optimised_params, history = inference_state.unpack() - assert isinstance(inference_state, InferenceState) - assert isinstance(optimised_params, dict) - assert objective(optimised_params, D) < pre_mll_val - assert isinstance(history, jnp.ndarray) - assert history.shape[0] == num_iters - - -@pytest.mark.parametrize("batch_size", [1, 2, 50]) -@pytest.mark.parametrize("ndim", [1, 2, 3]) -@pytest.mark.parametrize("ndata", [50]) -@pytest.mark.parametrize("key", [jr.PRNGKey(123)]) -def test_get_batch(ndata, ndim, batch_size, key): - x = jnp.sort( - jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(ndata, ndim)), axis=0 - ) - y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 - D = Dataset(X=x, y=y) - - B = get_batch(D, batch_size, key) - - assert B.n == batch_size - assert B.X.shape[1:] == x.shape[1:] - assert B.y.shape[1:] == y.shape[1:] - - # test no caching of batches: - key, subkey = jr.split(key) - Bnew = get_batch(D, batch_size, subkey) - assert Bnew.n == batch_size - assert Bnew.X.shape[1:] == x.shape[1:] - assert Bnew.y.shape[1:] == y.shape[1:] - assert (Bnew.X != B.X).all() - assert (Bnew.y != B.y).all() +# # Copyright 2022 The GPJax Contributors. All Rights Reserved. +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. +# # ============================================================================== + +# import jax.numpy as jnp +# import jax.random as jr +# import optax +# import pytest +# from jax.config import config + +# import gpjax as gpx +# from gpjax import RBF, Dataset, Gaussian, Prior, initialise +# from gpjax.abstractions import InferenceState, fit, fit_batches, fit_natgrads, get_batch + +# # Enable Float64 for more stable matrix inversions. +# config.update("jax_enable_x64", True) + + +# @pytest.mark.parametrize("num_iters", [1, 5]) +# @pytest.mark.parametrize("n", [1, 20]) +# @pytest.mark.parametrize("verbose", [True, False]) +# def test_fit(num_iters, n, verbose): +# key = jr.PRNGKey(123) +# x = jnp.sort(jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(n, 1)), axis=0) +# y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 +# D = Dataset(X=x, y=y) +# p = Prior(kernel=RBF()) * Gaussian(num_datapoints=n) +# parameter_state = initialise(p, key) +# mll = p.marginal_log_likelihood(D, negative=True) +# pre_mll_val = mll(parameter_state.params) +# optimiser = optax.adam(learning_rate=0.1) +# inference_state = fit(mll, parameter_state, optimiser, num_iters, verbose=verbose) +# optimised_params, history = inference_state.unpack() +# assert isinstance(inference_state, InferenceState) +# assert isinstance(optimised_params, dict) +# assert mll(optimised_params) < pre_mll_val +# assert isinstance(history, jnp.ndarray) +# assert history.shape[0] == num_iters + + +# def test_stop_grads(): +# params = {"x": jnp.array(3.0), "y": jnp.array(4.0)} +# trainables = {"x": True, "y": False} +# bijectors = build_bijectors(params) +# loss_fn = lambda params: params["x"] ** 2 + params["y"] ** 2 +# optimiser = optax.adam(learning_rate=0.1) +# parameter_state = ParameterState( +# params=params, trainables=trainables, bijectors=bijectors +# ) +# inference_state = fit(loss_fn, parameter_state, optimiser, num_iters=1) +# learned_params = inference_state.params +# assert isinstance(inference_state, InferenceState) +# assert learned_params["y"] == params["y"] +# assert learned_params["x"] != params["x"] + + +# @pytest.mark.parametrize("num_iters", [1, 5]) +# @pytest.mark.parametrize("nb", [1, 20, 50]) +# @pytest.mark.parametrize("ndata", [50]) +# @pytest.mark.parametrize("verbose", [True, False]) +# def test_batch_fitting(num_iters, nb, ndata, verbose): +# key = jr.PRNGKey(123) +# x = jnp.sort(jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(ndata, 1)), axis=0) +# y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 +# D = Dataset(X=x, y=y) +# prior = Prior(kernel=RBF()) +# likelihood = Gaussian(num_datapoints=ndata) +# p = prior * likelihood +# z = jnp.linspace(-2.0, 2.0, 10).reshape(-1, 1) + +# q = gpx.VariationalGaussian(prior=prior, inducing_inputs=z) + +# svgp = gpx.StochasticVI(posterior=p, variational_family=q) +# parameter_state = initialise(svgp, key) +# objective = svgp.elbo(D, negative=True) + +# pre_mll_val = objective(parameter_state.params, D) + +# D = Dataset(X=x, y=y) + +# optimiser = optax.adam(learning_rate=0.1) +# key = jr.PRNGKey(42) +# inference_state = fit_batches( +# objective, parameter_state, D, optimiser, key, nb, num_iters, verbose=verbose +# ) +# optimised_params, history = inference_state.unpack() +# assert isinstance(inference_state, InferenceState) +# assert isinstance(optimised_params, dict) +# assert objective(optimised_params, D) < pre_mll_val +# assert isinstance(history, jnp.ndarray) +# assert history.shape[0] == num_iters + + +# @pytest.mark.parametrize("num_iters", [1, 5]) +# @pytest.mark.parametrize("nb", [1, 20, 50]) +# @pytest.mark.parametrize("ndata", [50]) +# @pytest.mark.parametrize("verbose", [True, False]) +# def test_natural_gradients(ndata, nb, num_iters, verbose): +# key = jr.PRNGKey(123) +# x = jnp.sort(jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(ndata, 1)), axis=0) +# y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 +# D = Dataset(X=x, y=y) +# prior = Prior(kernel=RBF()) +# likelihood = Gaussian(num_datapoints=ndata) +# p = prior * likelihood +# z = jnp.linspace(-2.0, 2.0, 10).reshape(-1, 1) + +# q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z) + +# svgp = gpx.StochasticVI(posterior=p, variational_family=q) +# training_state = initialise(svgp, key) + +# D = Dataset(X=x, y=y) + +# hyper_optimiser = optax.adam(learning_rate=0.1) +# moment_optimiser = optax.sgd(learning_rate=1.0) + +# objective = svgp.elbo(D, negative=True) +# parameter_state = initialise(svgp, key) +# pre_mll_val = objective(parameter_state.params, D) + +# key = jr.PRNGKey(42) +# inference_state = fit_natgrads( +# svgp, +# training_state, +# D, +# moment_optimiser, +# hyper_optimiser, +# key, +# nb, +# num_iters, +# verbose=verbose, +# ) +# optimised_params, history = inference_state.unpack() +# assert isinstance(inference_state, InferenceState) +# assert isinstance(optimised_params, dict) +# assert objective(optimised_params, D) < pre_mll_val +# assert isinstance(history, jnp.ndarray) +# assert history.shape[0] == num_iters + + +# @pytest.mark.parametrize("batch_size", [1, 2, 50]) +# @pytest.mark.parametrize("ndim", [1, 2, 3]) +# @pytest.mark.parametrize("ndata", [50]) +# @pytest.mark.parametrize("key", [jr.PRNGKey(123)]) +# def test_get_batch(ndata, ndim, batch_size, key): +# x = jnp.sort( +# jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(ndata, ndim)), axis=0 +# ) +# y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 +# D = Dataset(X=x, y=y) + +# B = get_batch(D, batch_size, key) + +# assert B.n == batch_size +# assert B.X.shape[1:] == x.shape[1:] +# assert B.y.shape[1:] == y.shape[1:] + +# # test no caching of batches: +# key, subkey = jr.split(key) +# Bnew = get_batch(D, batch_size, subkey) +# assert Bnew.n == batch_size +# assert Bnew.X.shape[1:] == x.shape[1:] +# assert Bnew.y.shape[1:] == y.shape[1:] +# assert (Bnew.X != B.X).all() +# assert (Bnew.y != B.y).all() diff --git a/tests/test_base/__init__.py b/tests/test_base/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_params/test_module.py b/tests/test_base/test_module.py similarity index 74% rename from tests/test_params/test_module.py rename to tests/test_base/test_module.py index dfe445e38..1e33d6747 100644 --- a/tests/test_params/test_module.py +++ b/tests/test_base/test_module.py @@ -1,3 +1,21 @@ +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import jax.numpy as jnp +import tensorflow_probability.substrates.jax.bijectors as tfb + import dataclasses from dataclasses import dataclass, field from typing import Any, Generic, Iterable, TypeVar @@ -9,9 +27,9 @@ from flax import serialization from simple_pytree import Pytree, static_field -from gpjax.bijectors import Identity, Softplus -from gpjax.module import Module, meta -from gpjax.param import param_field + +from gpjax.base.module import Module, meta +from gpjax.base.param import param_field @pytest.mark.parametrize("is_dataclass", [True, False]) @@ -121,7 +139,7 @@ def __init__(self, a, b, c): @pytest.mark.parametrize("is_dataclass", [True, False]) def test_simple_linear_model(is_dataclass): class SimpleModel(Module): - weight: float = param_field(bijector=Softplus, trainable=False) + weight: float = param_field(bijector=tfb.Softplus(), trainable=False) bias: float def __init__(self, weight, bias): @@ -144,16 +162,16 @@ def __call__(self, test_point): meta_model = meta(model) - assert meta_model.weight["bijector"] == Softplus + assert isinstance(meta_model.weight["bijector"], tfb.Softplus) assert meta_model.weight["trainable"] == False assert meta_model.bias == {} constrained_model = model.constrain() - assert constrained_model.weight == Softplus.forward(1.0) - assert constrained_model.bias == Identity.forward(2.0) + assert constrained_model.weight == tfb.Softplus().forward(1.0) + assert constrained_model.bias == tfb.Identity().forward(2.0) meta_constrained_model = meta(constrained_model) - assert meta_constrained_model.weight["bijector"] == Softplus + assert isinstance(meta_constrained_model.weight["bijector"], tfb.Softplus) assert meta_constrained_model.weight["trainable"] == False assert meta_constrained_model.bias == {} @@ -162,7 +180,7 @@ def __call__(self, test_point): assert unconstrained_model.bias == 2.0 meta_unconstrained_model = meta(unconstrained_model) - assert meta_unconstrained_model.weight["bijector"] == Softplus + assert isinstance(meta_unconstrained_model.weight["bijector"], tfb.Softplus) assert meta_unconstrained_model.weight["trainable"] == False assert meta_unconstrained_model.bias == {} @@ -200,9 +218,9 @@ def loss_fn(model): @pytest.mark.parametrize("is_dataclass", [True, False]) def test_nested_Module_structure(is_dataclass): class SubTree(Module): - c: float = param_field(bijector=Identity) - d: float = param_field(bijector=Softplus) - e: float = param_field(bijector=Softplus) + c: float = param_field(bijector=tfb.Identity()) + d: float = param_field(bijector=tfb.Softplus()) + e: float = param_field(bijector=tfb.Softplus()) def __init__(self, c, d, e): self.c = c @@ -210,9 +228,9 @@ def __init__(self, c, d, e): self.e = e class Tree(Module): - a: float = param_field(bijector=Identity) + a: float = param_field(bijector=tfb.Identity()) sub_tree: SubTree - b: float = param_field(bijector=Softplus) + b: float = param_field(bijector=tfb.Softplus()) def __init__(self, a, sub_tree, b): self.a = a @@ -245,15 +263,15 @@ def __init__(self, a, sub_tree, b): assert isinstance(meta_tree, Module) assert isinstance(meta_tree, Pytree) - assert meta_tree.a["bijector"] == Identity + assert isinstance(meta_tree.a["bijector"], tfb.Identity) assert meta_tree.a["trainable"] == True - assert meta_tree.b["bijector"] == Softplus + assert isinstance(meta_tree.b["bijector"], tfb.Softplus) assert meta_tree.b["trainable"] == True - assert meta_tree.sub_tree.c["bijector"] == Identity + assert isinstance(meta_tree.sub_tree.c["bijector"], tfb.Identity) assert meta_tree.sub_tree.c["trainable"] == True - assert meta_tree.sub_tree.d["bijector"] == Softplus + assert isinstance(meta_tree.sub_tree.d["bijector"], tfb.Softplus) assert meta_tree.sub_tree.d["trainable"] == True - assert meta_tree.sub_tree.e["bijector"] == Softplus + assert isinstance(meta_tree.sub_tree.e["bijector"], tfb.Softplus) assert meta_tree.sub_tree.e["trainable"] == True # Test constrain and unconstrain @@ -262,26 +280,26 @@ def __init__(self, a, sub_tree, b): assert isinstance(constrained, Module) assert isinstance(constrained, Pytree) - assert constrained.a == Identity.forward(1.0) - assert constrained.b == Softplus.forward(5.0) - assert constrained.sub_tree.c == Identity.forward(2.0) - assert constrained.sub_tree.d == Softplus.forward(3.0) - assert constrained.sub_tree.e == Softplus.forward(4.0) + assert constrained.a == tfb.Identity().forward(1.0) + assert constrained.b == tfb.Softplus().forward(5.0) + assert constrained.sub_tree.c == tfb.Identity().forward(2.0) + assert constrained.sub_tree.d == tfb.Softplus().forward(3.0) + assert constrained.sub_tree.e == tfb.Softplus().forward(4.0) meta_constrained = meta(constrained) assert isinstance(meta_constrained, Module) assert isinstance(meta_constrained, Pytree) - assert meta_constrained.a["bijector"] == Identity + assert isinstance(meta_constrained.a["bijector"], tfb.Identity) assert meta_constrained.a["trainable"] == True - assert meta_constrained.b["bijector"] == Softplus + assert isinstance(meta_constrained.b["bijector"], tfb.Softplus) assert meta_constrained.b["trainable"] == True - assert meta_constrained.sub_tree.c["bijector"] == Identity + assert isinstance(meta_constrained.sub_tree.c["bijector"], tfb.Identity) assert meta_constrained.sub_tree.c["trainable"] == True - assert meta_constrained.sub_tree.d["bijector"] == Softplus + assert isinstance(meta_constrained.sub_tree.d["bijector"], tfb.Softplus) assert meta_constrained.sub_tree.d["trainable"] == True - assert meta_constrained.sub_tree.e["bijector"] == Softplus + assert isinstance(meta_constrained.sub_tree.e["bijector"], tfb.Softplus) assert meta_constrained.sub_tree.e["trainable"] == True # Test constrain and unconstrain @@ -290,34 +308,34 @@ def __init__(self, a, sub_tree, b): assert isinstance(unconstrained, Module) assert isinstance(unconstrained, Pytree) - assert unconstrained.a == Identity.inverse(1.0) - assert unconstrained.b == Softplus.inverse(5.0) - assert unconstrained.sub_tree.c == Identity.inverse(2.0) - assert unconstrained.sub_tree.d == Softplus.inverse(3.0) - assert unconstrained.sub_tree.e == Softplus.inverse(4.0) + assert unconstrained.a == tfb.Identity().inverse(1.0) + assert unconstrained.b == tfb.Softplus().inverse(5.0) + assert unconstrained.sub_tree.c == tfb.Identity().inverse(2.0) + assert unconstrained.sub_tree.d == tfb.Softplus().inverse(3.0) + assert unconstrained.sub_tree.e == tfb.Softplus().inverse(4.0) meta_unconstrained = meta(unconstrained) assert isinstance(meta_unconstrained, Module) assert isinstance(meta_unconstrained, Pytree) - assert meta_unconstrained.a["bijector"] == Identity + assert isinstance(meta_unconstrained.a["bijector"], tfb.Identity) assert meta_unconstrained.a["trainable"] == True - assert meta_unconstrained.b["bijector"] == Softplus + assert isinstance(meta_unconstrained.b["bijector"], tfb.Softplus) assert meta_unconstrained.b["trainable"] == True - assert meta_unconstrained.sub_tree.c["bijector"] == Identity + assert isinstance(meta_unconstrained.sub_tree.c["bijector"], tfb.Identity) assert meta_unconstrained.sub_tree.c["trainable"] == True - assert meta_unconstrained.sub_tree.d["bijector"] == Softplus + assert isinstance(meta_unconstrained.sub_tree.d["bijector"], tfb.Softplus) assert meta_unconstrained.sub_tree.d["trainable"] == True - assert meta_unconstrained.sub_tree.e["bijector"] == Softplus + assert isinstance(meta_unconstrained.sub_tree.e["bijector"], tfb.Softplus) assert meta_unconstrained.sub_tree.e["trainable"] == True # Test updating metadata - new_subtree = tree.sub_tree.replace_bijector(c=Softplus, e=Identity) + new_subtree = tree.sub_tree.replace_bijector(c=tfb.Softplus(), e=tfb.Identity()) new_subtree = new_subtree.replace_trainable(c=False, e=False) - new_tree = tree.replace_bijector(b=Identity) + new_tree = tree.replace_bijector(b=tfb.Identity()) new_tree = new_tree.replace_trainable(b=False) new_tree = new_tree.replace(sub_tree=new_subtree) @@ -335,15 +353,15 @@ def __init__(self, a, sub_tree, b): assert isinstance(meta_new_tree, Module) assert isinstance(meta_new_tree, Pytree) - assert meta_new_tree.a["bijector"] == Identity + assert isinstance(meta_new_tree.a["bijector"], tfb.Identity) assert meta_new_tree.a["trainable"] == True - assert meta_new_tree.b["bijector"] == Identity + assert isinstance(meta_new_tree.b["bijector"], tfb.Identity) assert meta_new_tree.b["trainable"] == False - assert meta_new_tree.sub_tree.c["bijector"] == Softplus + assert isinstance(meta_new_tree.sub_tree.c["bijector"], tfb.Softplus) assert meta_new_tree.sub_tree.c["trainable"] == False - assert meta_new_tree.sub_tree.d["bijector"] == Softplus + assert isinstance(meta_new_tree.sub_tree.d["bijector"], tfb.Softplus) assert meta_new_tree.sub_tree.d["trainable"] == True - assert meta_new_tree.sub_tree.e["bijector"] == Identity + assert isinstance(meta_new_tree.sub_tree.e["bijector"], tfb.Identity) assert meta_new_tree.sub_tree.e["trainable"] == False # Test stop gradients @@ -370,9 +388,9 @@ def loss(tree): @pytest.mark.parametrize("iterable", [list, tuple]) def test_iterable_attribute(is_dataclass, iterable): class SubTree(Module): - a: int = param_field(bijector=Identity, default=1) - b: int = param_field(bijector=Softplus, default=2) - c: int = param_field(bijector=Identity, default=3, trainable=False) + a: int = param_field(bijector=tfb.Identity(), default=1) + b: int = param_field(bijector=tfb.Softplus(), default=2) + c: int = param_field(bijector=tfb.Identity(), default=3, trainable=False) def __init__(self, a=1.0, b=2.0, c=3.0): self.a = a @@ -411,25 +429,25 @@ def __init__(self, trees): assert isinstance(meta_tree, Module) assert isinstance(meta_tree, Pytree) - assert meta_tree.trees[0].a["bijector"] == Identity + assert isinstance(meta_tree.trees[0].a["bijector"], tfb.Identity) assert meta_tree.trees[0].a["trainable"] == True - assert meta_tree.trees[0].b["bijector"] == Softplus + assert isinstance(meta_tree.trees[0].b["bijector"], tfb.Softplus) assert meta_tree.trees[0].b["trainable"] == True - assert meta_tree.trees[0].c["bijector"] == Identity + assert isinstance(meta_tree.trees[0].c["bijector"], tfb.Identity) assert meta_tree.trees[0].c["trainable"] == False - assert meta_tree.trees[1].a["bijector"] == Identity + assert isinstance(meta_tree.trees[1].a["bijector"], tfb.Identity) assert meta_tree.trees[1].a["trainable"] == True - assert meta_tree.trees[1].b["bijector"] == Softplus + assert isinstance(meta_tree.trees[1].b["bijector"], tfb.Softplus) assert meta_tree.trees[1].b["trainable"] == True - assert meta_tree.trees[1].c["bijector"] == Identity + assert isinstance(meta_tree.trees[1].c["bijector"], tfb.Identity) assert meta_tree.trees[1].c["trainable"] == False - assert meta_tree.trees[2].a["bijector"] == Identity + assert isinstance(meta_tree.trees[2].a["bijector"], tfb.Identity) assert meta_tree.trees[2].a["trainable"] == True - assert meta_tree.trees[2].b["bijector"] == Softplus + assert isinstance(meta_tree.trees[2].b["bijector"], tfb.Softplus) assert meta_tree.trees[2].b["trainable"] == True - assert meta_tree.trees[2].c["bijector"] == Identity + assert isinstance(meta_tree.trees[2].c["bijector"], tfb.Identity) assert meta_tree.trees[2].c["trainable"] == False # Test constrain and unconstrain @@ -446,29 +464,29 @@ def __init__(self, trees): assert isinstance(unconstrained_tree, Module) assert isinstance(unconstrained_tree, Pytree) - assert constrained_tree.trees[0].a == Identity.forward(1.0) - assert constrained_tree.trees[0].b == Softplus.forward(2.0) - assert constrained_tree.trees[0].c == Identity.forward(3.0) + assert constrained_tree.trees[0].a == tfb.Identity().forward(1.0) + assert constrained_tree.trees[0].b == tfb.Softplus().forward(2.0) + assert constrained_tree.trees[0].c == tfb.Identity().forward(3.0) - assert constrained_tree.trees[1].a == Identity.forward(1.0) - assert constrained_tree.trees[1].b == Softplus.forward(2.0) - assert constrained_tree.trees[1].c == Identity.forward(3.0) + assert constrained_tree.trees[1].a == tfb.Identity().forward(1.0) + assert constrained_tree.trees[1].b == tfb.Softplus().forward(2.0) + assert constrained_tree.trees[1].c == tfb.Identity().forward(3.0) - assert constrained_tree.trees[2].a == Identity.forward(1.0) - assert constrained_tree.trees[2].b == Softplus.forward(2.0) - assert constrained_tree.trees[2].c == Identity.forward(3.0) + assert constrained_tree.trees[2].a == tfb.Identity().forward(1.0) + assert constrained_tree.trees[2].b == tfb.Softplus().forward(2.0) + assert constrained_tree.trees[2].c == tfb.Identity().forward(3.0) - assert unconstrained_tree.trees[0].a == Identity.inverse(1.0) - assert unconstrained_tree.trees[0].b == Softplus.inverse(2.0) - assert unconstrained_tree.trees[0].c == Identity.inverse(3.0) + assert unconstrained_tree.trees[0].a == tfb.Identity().inverse(1.0) + assert unconstrained_tree.trees[0].b == tfb.Softplus().inverse(2.0) + assert unconstrained_tree.trees[0].c == tfb.Identity().inverse(3.0) - assert unconstrained_tree.trees[1].a == Identity.inverse(1.0) - assert unconstrained_tree.trees[1].b == Softplus.inverse(2.0) - assert unconstrained_tree.trees[1].c == Identity.inverse(3.0) + assert unconstrained_tree.trees[1].a == tfb.Identity().inverse(1.0) + assert unconstrained_tree.trees[1].b == tfb.Softplus().inverse(2.0) + assert unconstrained_tree.trees[1].c == tfb.Identity().inverse(3.0) - assert unconstrained_tree.trees[2].a == Identity.inverse(1.0) - assert unconstrained_tree.trees[2].b == Softplus.inverse(2.0) - assert unconstrained_tree.trees[2].c == Identity.inverse(3.0) + assert unconstrained_tree.trees[2].a == tfb.Identity().inverse(1.0) + assert unconstrained_tree.trees[2].b == tfb.Softplus().inverse(2.0) + assert unconstrained_tree.trees[2].c == tfb.Identity().inverse(3.0) # The following tests are adapted from equinox 🏴‍☠️ @@ -477,14 +495,14 @@ def __init__(self, trees): def test_Module_not_enough_attributes(): @dataclass class Tree1(Module): - weight: Any = param_field(bijector=Identity) + weight: Any = param_field(bijector=tfb.Identity()) with pytest.raises(TypeError): Tree1() @dataclass class Tree2(Module): - weight: Any = param_field(bijector=Identity) + weight: Any = param_field(bijector=tfb.Identity()) def __init__(self): return None @@ -496,7 +514,7 @@ def __init__(self): def test_Module_too_many_attributes(): @dataclass class Tree1(Module): - weight: Any = param_field(bijector=Identity) + weight: Any = param_field(bijector=tfb.Identity()) with pytest.raises(TypeError): Tree1(1, 2) @@ -505,7 +523,7 @@ class Tree1(Module): def test_Module_setattr_after_init(): @dataclass class Tree(Module): - weight: Any = param_field(bijector=Identity) + weight: Any = param_field(bijector=tfb.Identity()) m = Tree(1) with pytest.raises(AttributeError): @@ -518,11 +536,11 @@ def test_inheritance(): @dataclass class Tree(Module): - weight: Any = param_field(bijector=Identity) + weight: Any = param_field(bijector=tfb.Identity()) @dataclass class Tree2(Tree): - weight2: Any = param_field(bijector=Identity) + weight2: Any = param_field(bijector=tfb.Identity()) m = Tree2(1, 2) assert m.weight == 1 @@ -540,7 +558,7 @@ class Tree2(Tree): @dataclass class Tree3(Tree): - weight3: Any = param_field(bijector=Identity) + weight3: Any = param_field(bijector=tfb.Identity()) def __init__(self, *, weight3, **kwargs): self.weight3 = weight3 @@ -554,11 +572,11 @@ def __init__(self, *, weight3, **kwargs): @dataclass class Tree4(Module): - weight4: Any = param_field(bijector=Identity) + weight4: Any = param_field(bijector=tfb.Identity()) @dataclass class Tree5(Tree4): - weight5: Any = param_field(bijector=Identity) + weight5: Any = param_field(bijector=tfb.Identity()) with pytest.raises(TypeError): m = Tree5(value4=1, weight5=2) @@ -574,7 +592,7 @@ class Tree6(Tree4): @dataclass class Tree7(Tree4): - weight7: Any = param_field(bijector=Identity) + weight7: Any = param_field(bijector=tfb.Identity()) def __init__(self, value7, **kwargs): self.weight7 = value7 @@ -588,7 +606,7 @@ def __init__(self, value7, **kwargs): def test_static_field(): @dataclass class Tree(Module): - field1: int = param_field(bijector=Identity) + field1: int = param_field(bijector=tfb.Identity()) field2: int = static_field() field3: int = static_field(default=3) diff --git a/tests/test_base/test_params.py b/tests/test_base/test_params.py new file mode 100644 index 000000000..17e3bca48 --- /dev/null +++ b/tests/test_base/test_params.py @@ -0,0 +1,43 @@ +import dataclasses + +import pytest + +import tensorflow_probability.substrates.jax.bijectors as tfb + +from gpjax.base import param_field + + +@pytest.mark.parametrize("bijector", [tfb.Identity, tfb.Softplus]) +@pytest.mark.parametrize("trainable", [True, False]) +def test_param(bijector, trainable): + param_field_ = param_field(bijector=bijector(), trainable=trainable) + assert isinstance(param_field_, dataclasses.Field) + assert isinstance(param_field_.metadata["bijector"], bijector) + assert param_field_.metadata["trainable"] == trainable + + with pytest.raises(ValueError): + param_field( + bijector=bijector(), trainable=trainable, metadata={"trainable": trainable} + ) + + with pytest.raises(ValueError): + param_field( + bijector=bijector(), trainable=trainable, metadata={"bijector": bijector()} + ) + + with pytest.raises(ValueError): + param_field( + bijector=bijector(), + trainable=trainable, + metadata={"bijector": tfb.Softplus(), "trainable": trainable}, + ) + + with pytest.raises(ValueError): + param_field( + bijector=bijector(), trainable=trainable, metadata={"pytree_node": True} + ) + + with pytest.raises(ValueError): + param_field( + bijector=bijector(), trainable=trainable, metadata={"pytree_node": False} + ) diff --git a/tests/test_config.py b/tests/test_config.py deleted file mode 100644 index 39be389c3..000000000 --- a/tests/test_config.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import distrax as dx -import jax -from jax.config import config -from ml_collections import ConfigDict - -from gpjax.config import get_global_config_if_exists # ignore: unused-import -from gpjax.config import Identity, add_parameter, get_global_config - -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) - -# TODO: Fix this test. -# This test needs to be run first to ensure that the global config is not set on library import. -# def test_config_on_library_import(): -# config = get_global_config_if_exists() -# assert config is None - - -def test_add_parameter(): - add_parameter("test_parameter", Identity) - config = get_global_config() - assert "test_parameter" in config.transformations - assert "test_parameter_transform" in config.transformations - assert config.transformations["test_parameter"] == "test_parameter_transform" - assert isinstance(config.transformations["test_parameter_transform"], dx.Bijector) - - -def test_add_parameter(): - config = get_global_config() - add_parameter("test_parameter", Identity) - config = get_global_config() - assert "test_parameter" in config.transformations - assert "test_parameter_transform" in config.transformations - assert config.transformations["test_parameter"] == "test_parameter_transform" - assert isinstance(config.transformations["test_parameter_transform"], dx.Bijector) - - -def test_get_global_config(): - config = get_global_config() - assert isinstance(config, ConfigDict) - assert isinstance(config.transformations, ConfigDict) - - -def test_x64_based_config_update(): - cached_jax_precision = jax.config.x64_enabled - - jax.config.update("jax_enable_x64", True) - config = get_global_config() - assert config.x64_state is True - - jax.config.update("jax_enable_x64", False) - config = get_global_config() - assert config.x64_state is False - - # Reset the JAX precision to the original value. - jax.config.update("jax_enable_x64", cached_jax_precision) - get_global_config() diff --git a/tests/test_fit.py b/tests/test_fit.py index 7ab04b314..6cab1c8d0 100644 --- a/tests/test_fit.py +++ b/tests/test_fit.py @@ -13,17 +13,19 @@ # limitations under the License. # ============================================================================== -from dataclasses import dataclass - -from gpjax.dataset import Dataset -from gpjax.fit import fit -from gpjax.parameters.bijectors import Identity -from gpjax.parameters import param_field, Module -from gpjax.objectives import AbstractObjective import jax.numpy as jnp import jax.random as jr import optax as ox +import tensorflow_probability.substrates.jax.bijectors as tfb +import tensorflow_probability.substrates.jax.distributions as tfd + +from dataclasses import dataclass + +from gpjax.base import param_field, Module +from gpjax.objectives import AbstractObjective +from gpjax.dataset import Dataset +from gpjax.fit import fit def test_simple_linear_model(): @@ -35,8 +37,8 @@ def test_simple_linear_model(): # (2) Define your model: @dataclass class LinearModel(Module): - weight: float = param_field(bijector=Identity) - bias: float = param_field(bijector=Identity) + weight: float = param_field(bijector=tfb.Identity()) + bias: float = param_field(bijector=tfb.Identity()) def __call__(self, x): return self.weight * x + self.bias diff --git a/tests/test_kernels/test_stationary.py b/tests/test_kernels/test_stationary.py index 7bdb65874..20dc8f010 100644 --- a/tests/test_kernels/test_stationary.py +++ b/tests/test_kernels/test_stationary.py @@ -14,16 +14,17 @@ # ============================================================================== -from itertools import permutations, product +from itertools import product -import jax import jax.numpy as jnp -import jax.random as jr import jax.tree_util as jtu +import tensorflow_probability.substrates.jax.bijectors as tfb +import tensorflow_probability.substrates.jax.distributions as tfd + import pytest import distrax as dx from jax.config import config -from gpjax.linops import LinearOperator, identity +from gpjax.linops import LinearOperator from gpjax.kernels.base import AbstractKernel from gpjax.kernels.stationary import ( @@ -38,16 +39,13 @@ ) from gpjax.kernels.computations import ( DenseKernelComputation, - DiagonalKernelComputation, ConstantDiagonalKernelComputation, ) + from gpjax.kernels.stationary.utils import build_student_t_distribution -from gpjax.parameters.bijectors import Identity, Softplus # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) -_initialise_key = jr.PRNGKey(123) -_jitter = 1e-6 class BaseTestKernel: @@ -104,9 +102,9 @@ def test_initialization(self, fields: dict, dim: int) -> None: assert meta_leaves.keys() == fields.keys() for field in fields: if field in ["variance", "lengthscale", "period", "alpha"]: - assert meta_leaves[field]["bijector"] == Softplus + assert isinstance(meta_leaves[field]["bijector"], tfb.Softplus) if field in ["power"]: - assert meta_leaves[field]["bijector"] == Identity + assert isinstance(meta_leaves[field]["bijector"], tfb.Identity) assert meta_leaves[field]["trainable"] == True # call diff --git a/tests/test_params/test_bijectors.py b/tests/test_params/test_bijectors.py deleted file mode 100644 index 14a486a08..000000000 --- a/tests/test_params/test_bijectors.py +++ /dev/null @@ -1,22 +0,0 @@ -import jax.numpy as jnp -import pytest - -from gpjax.parameters.bijectors import Bijector, Identity, Softplus - - -def test_bijector(): - bij = Bijector(forward=lambda x: jnp.exp(x), inverse=lambda x: jnp.log(x)) - assert bij.forward(1.0) == pytest.approx(jnp.exp(1.0)) - assert bij.inverse(jnp.exp(1.0)) == pytest.approx(1.0) - - -def test_identity(): - bij = Identity - assert bij.forward(1.0) == pytest.approx(1.0) - assert bij.inverse(1.0) == pytest.approx(1.0) - - -def test_softplus(): - bij = Softplus - assert bij.forward(1.0) == pytest.approx(jnp.log(1.0 + jnp.exp(1.0))) - assert bij.inverse(jnp.log(1.0 + jnp.exp(1.0))) == pytest.approx(1.0) diff --git a/tests/test_params/test_parameters.py b/tests/test_params/test_parameters.py deleted file mode 100644 index c50caae37..000000000 --- a/tests/test_params/test_parameters.py +++ /dev/null @@ -1,41 +0,0 @@ -import dataclasses - -import pytest - -from gpjax.parameters import Identity, Softplus, param_field - - -@pytest.mark.parametrize("bijector", [Identity, Softplus]) -@pytest.mark.parametrize("trainable", [True, False]) -def test_param(bijector, trainable): - param_field_ = param_field(bijector=bijector, trainable=trainable) - assert isinstance(param_field_, dataclasses.Field) - assert param_field_.metadata["bijector"] == bijector - assert param_field_.metadata["trainable"] == trainable - - with pytest.raises(ValueError): - param_field( - bijector=bijector, trainable=trainable, metadata={"trainable": trainable} - ) - - with pytest.raises(ValueError): - param_field( - bijector=bijector, trainable=trainable, metadata={"bijector": bijector} - ) - - with pytest.raises(ValueError): - param_field( - bijector=bijector, - trainable=trainable, - metadata={"bijector": Softplus, "trainable": trainable}, - ) - - with pytest.raises(ValueError): - param_field( - bijector=bijector, trainable=trainable, metadata={"pytree_node": True} - ) - - with pytest.raises(ValueError): - param_field( - bijector=bijector, trainable=trainable, metadata={"pytree_node": False} - ) From 2ec87277f84492a1f1082010579c103cdbad7f28 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Sun, 2 Apr 2023 23:21:59 +0100 Subject: [PATCH 23/44] Minimal passing tests except for eigen and basis work Note: - The tests need rewriting, with additional checks for the param_field leaves and for static_fields. - Docstrings need doing, etc. - The code is rough, and needs to be cleaned up, and the structure needs to be revised in places. --- gpjax/base/__init__.py | 19 +- gpjax/base/module.py | 15 ++ gpjax/base/param.py | 15 ++ gpjax/gps.py | 16 +- gpjax/likelihoods.py | 7 +- gpjax/mean_functions.py | 20 +- gpjax/objectives.py | 7 +- gpjax/variational_families.py | 4 +- setup.py | 1 - tests/test_base/test_params.py | 17 +- tests/test_gps.py | 3 +- tests/test_kernels/test_approximations.py | 284 +++++++++++----------- tests/test_kernels/test_base.py | 10 +- tests/test_kernels/test_computation.py | 2 +- tests/test_kernels/test_non_euclidean.py | 110 ++++----- tests/test_mean_functions.py | 2 - tests/test_params.py | 273 --------------------- tests/test_types.py | 90 ------- tests/test_utilities.py | 67 ----- tests/test_variational_inference.py | 159 ------------ 20 files changed, 296 insertions(+), 825 deletions(-) delete mode 100644 tests/test_params.py delete mode 100644 tests/test_types.py delete mode 100644 tests/test_utilities.py delete mode 100644 tests/test_variational_inference.py diff --git a/gpjax/base/__init__.py b/gpjax/base/__init__.py index 140ea2d8f..2c872ec45 100644 --- a/gpjax/base/__init__.py +++ b/gpjax/base/__init__.py @@ -1,4 +1,19 @@ -from .module import Module +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from .module import Module, meta_leaves, meta_flatten, meta_map, meta from .param import param_field -__all__ = ["Module", "param_field"] +__all__ = ["Module", "meta_leaves", "meta_flatten", "meta_map", "meta", "param_field"] diff --git a/gpjax/base/module.py b/gpjax/base/module.py index c0b491a46..dd8b97a18 100644 --- a/gpjax/base/module.py +++ b/gpjax/base/module.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from __future__ import annotations __all__ = ["Module", "meta_leaves", "meta_flatten", "meta_map", "meta"] diff --git a/gpjax/base/param.py b/gpjax/base/param.py index 543a9fd9e..5d4c25e9d 100644 --- a/gpjax/base/param.py +++ b/gpjax/base/param.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from __future__ import annotations __all__ = ["param_field"] diff --git a/gpjax/gps.py b/gpjax/gps.py index 22baa8b8b..a90a9f72d 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -21,18 +21,18 @@ from jaxtyping import Array, Float from jax.random import KeyArray, PRNGKey, normal +from simple_pytree import static_field +from dataclasses import dataclass + +from .base import Module, param_field from .dataset import Dataset from .linops import identity - from .gaussian_distribution import GaussianDistribution from .likelihoods import AbstractLikelihood, Gaussian from .mean_functions import AbstractMeanFunction -from mytree import Mytree -from simple_pytree import static_field -from dataclasses import dataclass @dataclass -class AbstractPrior(Mytree): +class AbstractPrior(Module): """Abstract Gaussian process prior.""" kernel: AbstractKernel mean_function: AbstractMeanFunction @@ -187,7 +187,7 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: # GP Posteriors ####################### @dataclass -class AbstractPosterior(Mytree): +class AbstractPosterior(Module): """The base GP posterior object conditioned on an observed dataset. All posterior objects should inherit from this class.""" prior: AbstractPrior @@ -355,8 +355,8 @@ class NonConjugatePosterior(AbstractPosterior): variational inference, or Laplace approximations can then be used to sample from, or optimise an approximation to, the posterior distribution. """ - latent: Float[Array, "N 1"] = None - key: KeyArray = PRNGKey(42) + latent: Float[Array, "N 1"] = param_field(None) + key: KeyArray = static_field(PRNGKey(42)) def __post_init__(self): if self.latent is None: diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index 478720fe5..3369105b0 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -18,16 +18,17 @@ from .linops.utils import to_dense import distrax as dx +import tensorflow_probability.substrates.jax.bijectors as tfb import jax.numpy as jnp import jax.scipy as jsp from jaxtyping import Array, Float from simple_pytree import static_field from dataclasses import dataclass -from mytree import Mytree, param_field, Softplus +from .base import Module, param_field @dataclass -class AbstractLikelihood(Mytree): +class AbstractLikelihood(Module): """Abstract base class for likelihoods.""" num_datapoints: int = static_field() @@ -71,7 +72,7 @@ def link_function(self) -> dx.Distribution: @dataclass class Gaussian(AbstractLikelihood): """Gaussian likelihood object.""" - obs_noise: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) + obs_noise: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) def link_function(self, f: Float[Array, "N 1"]) -> dx.Normal: """The link function of the Gaussian likelihood. diff --git a/gpjax/mean_functions.py b/gpjax/mean_functions.py index f0173b3a7..dde0302c0 100644 --- a/gpjax/mean_functions.py +++ b/gpjax/mean_functions.py @@ -20,13 +20,15 @@ import jax.numpy as jnp from beartype.typing import List, Callable, Union from jaxtyping import Array, Float -from mytree import Mytree, param_field + + +from .base import Module, param_field from simple_pytree import static_field from functools import partial @dataclasses.dataclass -class AbstractMeanFunction(Mytree): +class AbstractMeanFunction(Module): """Mean function that is used to parameterise the Gaussian process.""" @abc.abstractmethod @@ -116,30 +118,30 @@ def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]: @dataclasses.dataclass class CombinationMeanFunction(AbstractMeanFunction): """A base class for products or sums of AbstractMeanFunctions.""" - items: List[AbstractMeanFunction] + means: List[AbstractMeanFunction] operator: Callable = static_field() def __init__( self, - items: List[AbstractMeanFunction], + means: List[AbstractMeanFunction], operator: Callable, **kwargs, ) -> None: super().__init__(**kwargs) - #Add items to a list, flattening out instances of this class therein, as in GPFlow kernels. + #Add means to a list, flattening out instances of this class therein, as in GPFlow kernels. items_list: List[AbstractMeanFunction] = [] - for item in items: + for item in means: if not isinstance(item, AbstractMeanFunction): raise TypeError("can only combine AbstractMeanFunction instances") # pragma: no cover if isinstance(item, self.__class__): - items_list.extend(item.items) + items_list.extend(item.means) else: items_list.append(item) - self.items = items_list + self.means = items_list self.operator = operator def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]: @@ -151,7 +153,7 @@ def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]: Returns: Float[Array, "Q"]: The evaluated mean function. """ - return self.operator(jnp.stack([m(x) for m in self.items])) + return self.operator(jnp.stack([m(x) for m in self.means])) SumMeanFunction = partial(CombinationMeanFunction, operator=partial(jnp.sum, axis=0)) diff --git a/gpjax/objectives.py b/gpjax/objectives.py index e71755e84..a3166aa50 100644 --- a/gpjax/objectives.py +++ b/gpjax/objectives.py @@ -11,21 +11,22 @@ from jax import vmap import jax.numpy as jnp import jax.scipy as jsp -from .linops import identity + from jaxtyping import Array, Float +from .base import Module +from .linops import identity from .dataset import Dataset from .gaussian_distribution import GaussianDistribution from .quadrature import gauss_hermite_quadrature -from mytree import Mytree from dataclasses import dataclass from simple_pytree import static_field import jax.tree_util as jtu @dataclass -class AbstractObjective(Mytree): +class AbstractObjective(Module): """Abstract base class for objectives.""" negative: bool = static_field(False) constant: float = static_field(init=False, repr=False) diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index a8208f1a9..0f431391e 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -19,9 +19,9 @@ import jax.scipy as jsp from jaxtyping import Array, Float -from mytree import Mytree, param_field from simple_pytree import static_field +from .base import Module, param_field from .dataset import Dataset from .gaussian_distribution import GaussianDistribution from .gps import AbstractPosterior @@ -34,7 +34,7 @@ import tensorflow_probability.substrates.jax.bijectors as tfb @dataclass -class AbstractVariationalFamily(Mytree): +class AbstractVariationalFamily(Module): """ Abstract base class used to represent families of distributions that can be used within variational inference. diff --git a/setup.py b/setup.py index 9968a4fc2..01808385b 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,6 @@ def get_versions(): "jax>=0.4.1", "jaxlib>=0.4.1", "optax", - "jaxutils>=0.0.6", "distrax>=0.1.2", "tqdm>=4.0.0", "ml-collections==0.1.0", diff --git a/tests/test_base/test_params.py b/tests/test_base/test_params.py index 17e3bca48..7b3a797af 100644 --- a/tests/test_base/test_params.py +++ b/tests/test_base/test_params.py @@ -1,9 +1,22 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import dataclasses import pytest - import tensorflow_probability.substrates.jax.bijectors as tfb - from gpjax.base import param_field diff --git a/tests/test_gps.py b/tests/test_gps.py index 0ac5e7835..589d1c52e 100644 --- a/tests/test_gps.py +++ b/tests/test_gps.py @@ -19,7 +19,6 @@ import pytest from jax.config import config - from gpjax.gps import ( AbstractPrior, AbstractPosterior, @@ -111,7 +110,7 @@ def test_nonconjugate_posterior(num_datapoints, likel, jit_compile): lik = likel(num_datapoints=num_datapoints) post = p * lik assert isinstance(post, NonConjugatePosterior) - assert (post.latent == jnp.zeros((num_datapoints, 1))).all() + assert (post.latent == jr.normal(post.key,(num_datapoints, 1))).all() # Prediction x = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) diff --git a/tests/test_kernels/test_approximations.py b/tests/test_kernels/test_approximations.py index 9dbf855fe..1631145a9 100644 --- a/tests/test_kernels/test_approximations.py +++ b/tests/test_kernels/test_approximations.py @@ -1,149 +1,149 @@ -import pytest -from jaxkern.approximations import RFF -from jaxkern.stationary import ( - Matern12, - Matern32, - Matern52, - RBF, - RationalQuadratic, - PoweredExponential, - Periodic, -) -from jaxkern.nonstationary import Polynomial, Linear -from jaxkern.base import AbstractKernel -import jax.random as jr -from jax.config import config -import jax.numpy as jnp -from gpjax.linops import DenseLinearOperator -from typing import Tuple -import jax - -config.update("jax_enable_x64", True) -_jitter = 1e-5 - - -@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) -@pytest.mark.parametrize("num_basis_fns", [2, 10, 20]) -@pytest.mark.parametrize("n_dims", [1, 2, 5]) -def test_frequency_sampler(kernel: AbstractKernel, num_basis_fns: int, n_dims: int): - key = jr.PRNGKey(123) - base_kernel = kernel(active_dims=list(range(n_dims))) - approximate = RFF(base_kernel, num_basis_fns) - - params = approximate.init_params(key) - assert params["frequencies"].shape == (num_basis_fns, n_dims) - - -@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) -@pytest.mark.parametrize("num_basis_fns", [2, 10, 20]) -@pytest.mark.parametrize("n_dims", [1, 2, 5]) -@pytest.mark.parametrize("n_data", [50, 100]) -def test_gram(kernel: AbstractKernel, num_basis_fns: int, n_dims: int, n_data: int): - key = jr.PRNGKey(123) - x = jr.uniform(key, shape=(n_data, 1), minval=-3.0, maxval=3.0).reshape(-1, 1) - if n_dims > 1: - x = jnp.hstack([x] * n_dims) - base_kernel = kernel(active_dims=list(range(n_dims))) - approximate = RFF(base_kernel, num_basis_fns) - - params = approximate.init_params(key) - - linop = approximate.gram(params, x) - - # Check the return type - assert isinstance(linop, DenseLinearOperator) - - Kxx = linop.to_dense() + jnp.eye(n_data) * _jitter - - # Check that the shape is correct - assert Kxx.shape == (n_data, n_data) - - # Check that the Gram matrix is PSD - evals, _ = jnp.linalg.eigh(Kxx) - assert jnp.all(evals > 0) - - -@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) -@pytest.mark.parametrize("num_basis_fns", [2, 10, 20]) -@pytest.mark.parametrize("n_dims", [1, 2, 5]) -@pytest.mark.parametrize("n_datas", [(50, 100), (100, 50)]) -def test_cross_covariance( - kernel: AbstractKernel, - num_basis_fns: int, - n_dims: int, - n_datas: Tuple[int, int], -): - nd1, nd2 = n_datas - key = jr.PRNGKey(123) - x1 = jr.uniform(key, shape=(nd1, 1), minval=-3.0, maxval=3.0) - if n_dims > 1: - x1 = jnp.hstack([x1] * n_dims) - x2 = jr.uniform(key, shape=(nd2, 1), minval=-3.0, maxval=3.0) - if n_dims > 1: - x2 = jnp.hstack([x2] * n_dims) - - base_kernel = kernel(active_dims=list(range(n_dims))) - approximate = RFF(base_kernel, num_basis_fns) - - params = approximate.init_params(key) - - Kxx = approximate.cross_covariance(params, x1, x2) - - # Check the return type - assert isinstance(Kxx, jax.Array) - - # Check that the shape is correct - assert Kxx.shape == (nd1, nd2) - - -@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) -@pytest.mark.parametrize("n_dim", [1, 2, 5]) -def test_improvement(kernel, n_dim): - n_data = 100 - key = jr.PRNGKey(123) - - x = jr.uniform(key, minval=-3.0, maxval=3.0, shape=(n_data, n_dim)) - base_kernel = kernel(active_dims=list(range(n_dim))) - exact_params = base_kernel.init_params(key) - exact_linop = base_kernel.gram(exact_params, x).to_dense() - - crude_approximation = RFF(base_kernel, num_basis_fns=10) - c_params = crude_approximation.init_params(key) - c_linop = crude_approximation.gram(c_params, x).to_dense() - - better_approximation = RFF(base_kernel, num_basis_fns=50) - b_params = better_approximation.init_params(key) - b_linop = better_approximation.gram(b_params, x).to_dense() +# import pytest +# from jaxkern.approximations import RFF +# from jaxkern.stationary import ( +# Matern12, +# Matern32, +# Matern52, +# RBF, +# RationalQuadratic, +# PoweredExponential, +# Periodic, +# ) +# from jaxkern.nonstationary import Polynomial, Linear +# from jaxkern.base import AbstractKernel +# import jax.random as jr +# from jax.config import config +# import jax.numpy as jnp +# from gpjax.linops import DenseLinearOperator +# from typing import Tuple +# import jax + +# config.update("jax_enable_x64", True) +# _jitter = 1e-5 + + +# @pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) +# @pytest.mark.parametrize("num_basis_fns", [2, 10, 20]) +# @pytest.mark.parametrize("n_dims", [1, 2, 5]) +# def test_frequency_sampler(kernel: AbstractKernel, num_basis_fns: int, n_dims: int): +# key = jr.PRNGKey(123) +# base_kernel = kernel(active_dims=list(range(n_dims))) +# approximate = RFF(base_kernel, num_basis_fns) + +# params = approximate.init_params(key) +# assert params["frequencies"].shape == (num_basis_fns, n_dims) + + +# @pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) +# @pytest.mark.parametrize("num_basis_fns", [2, 10, 20]) +# @pytest.mark.parametrize("n_dims", [1, 2, 5]) +# @pytest.mark.parametrize("n_data", [50, 100]) +# def test_gram(kernel: AbstractKernel, num_basis_fns: int, n_dims: int, n_data: int): +# key = jr.PRNGKey(123) +# x = jr.uniform(key, shape=(n_data, 1), minval=-3.0, maxval=3.0).reshape(-1, 1) +# if n_dims > 1: +# x = jnp.hstack([x] * n_dims) +# base_kernel = kernel(active_dims=list(range(n_dims))) +# approximate = RFF(base_kernel, num_basis_fns) + +# params = approximate.init_params(key) + +# linop = approximate.gram(params, x) + +# # Check the return type +# assert isinstance(linop, DenseLinearOperator) + +# Kxx = linop.to_dense() + jnp.eye(n_data) * _jitter + +# # Check that the shape is correct +# assert Kxx.shape == (n_data, n_data) + +# # Check that the Gram matrix is PSD +# evals, _ = jnp.linalg.eigh(Kxx) +# assert jnp.all(evals > 0) + + +# @pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) +# @pytest.mark.parametrize("num_basis_fns", [2, 10, 20]) +# @pytest.mark.parametrize("n_dims", [1, 2, 5]) +# @pytest.mark.parametrize("n_datas", [(50, 100), (100, 50)]) +# def test_cross_covariance( +# kernel: AbstractKernel, +# num_basis_fns: int, +# n_dims: int, +# n_datas: Tuple[int, int], +# ): +# nd1, nd2 = n_datas +# key = jr.PRNGKey(123) +# x1 = jr.uniform(key, shape=(nd1, 1), minval=-3.0, maxval=3.0) +# if n_dims > 1: +# x1 = jnp.hstack([x1] * n_dims) +# x2 = jr.uniform(key, shape=(nd2, 1), minval=-3.0, maxval=3.0) +# if n_dims > 1: +# x2 = jnp.hstack([x2] * n_dims) + +# base_kernel = kernel(active_dims=list(range(n_dims))) +# approximate = RFF(base_kernel, num_basis_fns) + +# params = approximate.init_params(key) + +# Kxx = approximate.cross_covariance(params, x1, x2) + +# # Check the return type +# assert isinstance(Kxx, jax.Array) + +# # Check that the shape is correct +# assert Kxx.shape == (nd1, nd2) + + +# @pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) +# @pytest.mark.parametrize("n_dim", [1, 2, 5]) +# def test_improvement(kernel, n_dim): +# n_data = 100 +# key = jr.PRNGKey(123) + +# x = jr.uniform(key, minval=-3.0, maxval=3.0, shape=(n_data, n_dim)) +# base_kernel = kernel(active_dims=list(range(n_dim))) +# exact_params = base_kernel.init_params(key) +# exact_linop = base_kernel.gram(exact_params, x).to_dense() + +# crude_approximation = RFF(base_kernel, num_basis_fns=10) +# c_params = crude_approximation.init_params(key) +# c_linop = crude_approximation.gram(c_params, x).to_dense() + +# better_approximation = RFF(base_kernel, num_basis_fns=50) +# b_params = better_approximation.init_params(key) +# b_linop = better_approximation.gram(b_params, x).to_dense() - c_delta = jnp.linalg.norm(exact_linop - c_linop, ord="fro") - b_delta = jnp.linalg.norm(exact_linop - b_linop, ord="fro") +# c_delta = jnp.linalg.norm(exact_linop - c_linop, ord="fro") +# b_delta = jnp.linalg.norm(exact_linop - b_linop, ord="fro") - # The frobenius norm of the difference between the exact and approximate - # should improve as we increase the number of basis functions - assert c_delta > b_delta +# # The frobenius norm of the difference between the exact and approximate +# # should improve as we increase the number of basis functions +# assert c_delta > b_delta -@pytest.mark.parametrize("kernel", [RBF(), Matern12(), Matern32(), Matern52()]) -def test_exactness(kernel): - n_data = 100 - key = jr.PRNGKey(123) - - x = jr.uniform(key, minval=-3.0, maxval=3.0, shape=(n_data, 1)) - exact_params = kernel.init_params(key) - exact_linop = kernel.gram(exact_params, x).to_dense() +# @pytest.mark.parametrize("kernel", [RBF(), Matern12(), Matern32(), Matern52()]) +# def test_exactness(kernel): +# n_data = 100 +# key = jr.PRNGKey(123) + +# x = jr.uniform(key, minval=-3.0, maxval=3.0, shape=(n_data, 1)) +# exact_params = kernel.init_params(key) +# exact_linop = kernel.gram(exact_params, x).to_dense() - better_approximation = RFF(kernel, num_basis_fns=500) - b_params = better_approximation.init_params(key) - b_linop = better_approximation.gram(b_params, x).to_dense() - - max_delta = jnp.max(exact_linop - b_linop) - assert max_delta < 0.1 +# better_approximation = RFF(kernel, num_basis_fns=500) +# b_params = better_approximation.init_params(key) +# b_linop = better_approximation.gram(b_params, x).to_dense() + +# max_delta = jnp.max(exact_linop - b_linop) +# assert max_delta < 0.1 -@pytest.mark.parametrize( - "kernel", - [RationalQuadratic, PoweredExponential, Polynomial, Linear, Periodic], -) -def test_value_error(kernel): - with pytest.raises(ValueError): - RFF(kernel(), num_basis_fns=10) +# @pytest.mark.parametrize( +# "kernel", +# [RationalQuadratic, PoweredExponential, Polynomial, Linear, Periodic], +# ) +# def test_value_error(kernel): +# with pytest.raises(ValueError): +# RFF(kernel(), num_basis_fns=10) diff --git a/tests/test_kernels/test_base.py b/tests/test_kernels/test_base.py index 203315626..40b348931 100644 --- a/tests/test_kernels/test_base.py +++ b/tests/test_kernels/test_base.py @@ -16,7 +16,6 @@ import jax.numpy as jnp import pytest from jax.config import config -from jaxlinop import identity from gpjax.kernels.base import ( AbstractKernel, @@ -34,7 +33,10 @@ from gpjax.kernels.nonstationary import Polynomial, Linear from jaxtyping import Array, Float from dataclasses import dataclass -from mytree import param_field, Softplus +from gpjax.base import param_field + +import tensorflow_probability.substrates.jax.bijectors as tfb + # Enable Float64 for more stable matrix inversions. @@ -50,7 +52,7 @@ def test_abstract_kernel(): @dataclass class DummyKernel(AbstractKernel): test_a: Float[Array, "1"] = jnp.array([1.0]) - test_b: Float[Array, "1"] = param_field(jnp.array([2.0]), bijector=Softplus) + test_b: Float[Array, "1"] = param_field(jnp.array([2.0]), bijector=tfb.Softplus()) def __call__(self, x: Float[Array, "1 D"], y: Float[Array, "1 D"]) -> Float[Array, "1"]: return x * self.test_b * y @@ -58,7 +60,7 @@ def __call__(self, x: Float[Array, "1 D"], y: Float[Array, "1 D"]) -> Float[Arra # Initialise dummy kernel class and test __call__ method: dummy_kernel = DummyKernel() assert dummy_kernel.test_a == jnp.array([1.0]) - assert dummy_kernel._pytree__meta["test_b"].get("bijector") == Softplus + assert isinstance(dummy_kernel._pytree__meta["test_b"].get("bijector"), tfb.Softplus) assert dummy_kernel.test_b == jnp.array([2.0]) assert (dummy_kernel(jnp.array([1.0]), jnp.array([2.0])) == 4.0) diff --git a/tests/test_kernels/test_computation.py b/tests/test_kernels/test_computation.py index 0e33c076b..38ba5202a 100644 --- a/tests/test_kernels/test_computation.py +++ b/tests/test_kernels/test_computation.py @@ -5,7 +5,7 @@ DiagonalKernelComputation, ConstantDiagonalKernelComputation, ) -from jaxkern.stationary import ( +from gpjax.kernels.stationary import ( RBF, Matern12, Matern32, diff --git a/tests/test_kernels/test_non_euclidean.py b/tests/test_kernels/test_non_euclidean.py index 7ae915839..266115689 100644 --- a/tests/test_kernels/test_non_euclidean.py +++ b/tests/test_kernels/test_non_euclidean.py @@ -1,66 +1,66 @@ -# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== +# # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. +# # ============================================================================== -import jax.numpy as jnp -import jax.random as jr -import networkx as nx -from jax.config import config -from jaxlinop import identity +# import jax.numpy as jnp +# import jax.random as jr +# import networkx as nx +# from jax.config import config +# from jaxlinop import identity -from gpjax.kernels.non_euclidean import GraphKernel +# from gpjax.kernels.non_euclidean import GraphKernel -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) -_initialise_key = jr.PRNGKey(123) -_jitter = 1e-6 +# # Enable Float64 for more stable matrix inversions. +# config.update("jax_enable_x64", True) +# _initialise_key = jr.PRNGKey(123) +# _jitter = 1e-6 -def test_graph_kernel(): - # Create a random graph, G, and verice labels, x, - n_verticies = 20 - n_edges = 40 - G = nx.gnm_random_graph(n_verticies, n_edges, seed=123) - x = jnp.arange(n_verticies).reshape(-1, 1) +# def test_graph_kernel(): +# # Create a random graph, G, and verice labels, x, +# n_verticies = 20 +# n_edges = 40 +# G = nx.gnm_random_graph(n_verticies, n_edges, seed=123) +# x = jnp.arange(n_verticies).reshape(-1, 1) - # Compute graph laplacian - L = nx.laplacian_matrix(G).toarray() + jnp.eye(n_verticies) * 1e-12 +# # Compute graph laplacian +# L = nx.laplacian_matrix(G).toarray() + jnp.eye(n_verticies) * 1e-12 - # Create graph kernel - kern = GraphKernel(laplacian=L) - assert isinstance(kern, GraphKernel) - assert kern.num_vertex == n_verticies - assert kern.evals.shape == (n_verticies, 1) - assert kern.evecs.shape == (n_verticies, n_verticies) +# # Create graph kernel +# kern = GraphKernel(laplacian=L) +# assert isinstance(kern, GraphKernel) +# assert kern.num_vertex == n_verticies +# assert kern.evals.shape == (n_verticies, 1) +# assert kern.evecs.shape == (n_verticies, n_verticies) - # Unpack kernel computation - kern.gram +# # Unpack kernel computation +# kern.gram - # Initialise default parameters - params = kern.init_params(_initialise_key) - assert isinstance(params, dict) - assert list(sorted(list(params.keys()))) == [ - "lengthscale", - "smoothness", - "variance", - ] +# # Initialise default parameters +# params = kern.init_params(_initialise_key) +# assert isinstance(params, dict) +# assert list(sorted(list(params.keys()))) == [ +# "lengthscale", +# "smoothness", +# "variance", +# ] - # Compute gram matrix - Kxx = kern.gram(params, x) - assert Kxx.shape == (n_verticies, n_verticies) +# # Compute gram matrix +# Kxx = kern.gram(params, x) +# assert Kxx.shape == (n_verticies, n_verticies) - # Check positive definiteness - Kxx += identity(n_verticies) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert all(eigen_values > 0) +# # Check positive definiteness +# Kxx += identity(n_verticies) * _jitter +# eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) +# assert all(eigen_values > 0) diff --git a/tests/test_mean_functions.py b/tests/test_mean_functions.py index d677b7840..7b7cfb42c 100644 --- a/tests/test_mean_functions.py +++ b/tests/test_mean_functions.py @@ -1,6 +1,5 @@ import pytest import jax -import mytree from gpjax.mean_functions import AbstractMeanFunction, Constant from jaxtyping import Array, Float @@ -16,7 +15,6 @@ def __call__(self, x: Float[Array, "D"]) -> Float[Array, "1"]: return jax.numpy.array([1.0]) mf = DummyMeanFunction() - assert isinstance(mf, mytree.Mytree) assert isinstance(mf, AbstractMeanFunction) assert (mf(jax.numpy.array([1.0])) == jax.numpy.array([1.0])).all() assert (mf(jax.numpy.array([2.0, 3.0])) == jax.numpy.array([1.0])).all() diff --git a/tests/test_params.py b/tests/test_params.py deleted file mode 100644 index f9867e4d7..000000000 --- a/tests/test_params.py +++ /dev/null @@ -1,273 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import typing as tp - -import distrax as dx -import jax.numpy as jnp -import jax.random as jr -import pytest -from jax.config import config - -from gpjax.gps import Prior -from gpjax.kernels import RBF -from gpjax.likelihoods import Bernoulli, Gaussian -from gpjax.params import ( - build_bijectors, - build_trainables, - constrain, - copy_dict_structure, - evaluate_priors, - initialise, - log_density, - prior_checks, - recursive_complete, - recursive_items, - structure_priors, - unconstrain, -) - -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) - -######################### -# Test base functionality -######################### -@pytest.mark.parametrize("lik", [Gaussian]) -def test_initialise(lik): - key = jr.PRNGKey(123) - posterior = Prior(kernel=RBF()) * lik(num_datapoints=10) - params, _, _ = initialise(posterior, key).unpack() - assert list(sorted(params.keys())) == [ - "kernel", - "likelihood", - "mean_function", - ] - - -def test_non_conjugate_initialise(): - posterior = Prior(kernel=RBF()) * Bernoulli(num_datapoints=10) - params, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() - assert list(sorted(params.keys())) == [ - "kernel", - "latent", - "likelihood", - "mean_function", - ] - - -######################### -# Test priors -######################### -@pytest.mark.parametrize("x", [-1.0, 0.0, 1.0]) -def test_lpd(x): - val = jnp.array(x) - dist = dx.Normal(loc=0.0, scale=1.0) - lpd = log_density(val, dist) - assert lpd is not None - assert log_density(val, None) == 0.0 - - -@pytest.mark.parametrize("lik", [Gaussian, Bernoulli]) -def test_prior_template(lik): - posterior = Prior(kernel=RBF()) * lik(num_datapoints=10) - params, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() - prior_container = copy_dict_structure(params) - for ( - k, - v1, - v2, - ) in recursive_items(params, prior_container): - assert v2 == None - - -@pytest.mark.parametrize("lik", [Gaussian, Bernoulli]) -def test_recursive_complete(lik): - posterior = Prior(kernel=RBF()) * lik(num_datapoints=10) - params, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() - priors = {"kernel": {}} - priors["kernel"]["lengthscale"] = dx.Laplace(loc=0.0, scale=1.0) - container = copy_dict_structure(params) - complete_priors = recursive_complete(container, priors) - for ( - k, - v1, - v2, - ) in recursive_items(params, complete_priors): - if k == "lengthscale": - assert isinstance(v2, dx.Laplace) - else: - assert v2 == None - - -def test_prior_evaluation(): - """ - Test the regular setup that every parameter has a corresponding prior distribution attached to its unconstrained - value. - """ - params = { - "kernel": { - "lengthscale": jnp.array([1.0]), - "variance": jnp.array([1.0]), - }, - "likelihood": {"obs_noise": jnp.array([1.0])}, - } - priors = { - "kernel": { - "lengthscale": dx.Gamma(1.0, 1.0), - "variance": dx.Gamma(2.0, 2.0), - }, - "likelihood": {"obs_noise": dx.Gamma(3.0, 3.0)}, - } - lpd = evaluate_priors(params, priors) - assert pytest.approx(lpd) == -2.0110168 - - -def test_none_prior(): - """ - Test that multiple dispatch is working in the case of no priors. - """ - params = { - "kernel": { - "lengthscale": jnp.array([1.0]), - "variance": jnp.array([1.0]), - }, - "likelihood": {"obs_noise": jnp.array([1.0])}, - } - priors = copy_dict_structure(params) - lpd = evaluate_priors(params, priors) - assert lpd == 0.0 - - -def test_incomplete_priors(): - """ - Test the case where a user specifies priors for some, but not all, parameters. - """ - params = { - "kernel": { - "lengthscale": jnp.array([1.0]), - "variance": jnp.array([1.0]), - }, - "likelihood": {"obs_noise": jnp.array([1.0])}, - } - priors = { - "kernel": { - "lengthscale": dx.Gamma(1.0, 1.0), - "variance": dx.Gamma(2.0, 2.0), - }, - } - container = copy_dict_structure(params) - complete_priors = recursive_complete(container, priors) - lpd = evaluate_priors(params, complete_priors) - assert pytest.approx(lpd) == -1.6137061 - - -@pytest.mark.parametrize("num_datapoints", [1, 10]) -def test_checks(num_datapoints): - incomplete_priors = {"lengthscale": jnp.array([1.0])} - posterior = Prior(kernel=RBF()) * Bernoulli(num_datapoints=num_datapoints) - priors = prior_checks(incomplete_priors) - assert "latent" in priors.keys() - assert "variance" not in priors.keys() - assert isinstance(priors["latent"], dx.Normal) - - -def test_structure_priors(): - posterior = Prior(kernel=RBF()) * Gaussian(num_datapoints=10) - params, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() - priors = { - "kernel": { - "lengthscale": dx.Gamma(1.0, 1.0), - "variance": dx.Gamma(2.0, 2.0), - }, - } - structured_priors = structure_priors(params, priors) - - def recursive_fn(d1, d2, fn: tp.Callable[[tp.Any], tp.Any]): - for key, value in d1.items(): - if type(value) is dict: - yield from recursive_fn(value, d2[key], fn) - else: - yield fn(key, key) - - for v in recursive_fn(params, structured_priors, lambda k1, k2: k1 == k2): - assert v - - -@pytest.mark.parametrize("latent_prior", [dx.Laplace(0.0, 1.0), dx.Laplace(0.0, 1.0)]) -def test_prior_checks(latent_prior): - priors = { - "kernel": {"lengthscale": None, "variance": None}, - "mean_function": {}, - "liklelihood": {"variance": None}, - "latent": None, - } - new_priors = prior_checks(priors) - assert "latent" in new_priors.keys() - assert isinstance(new_priors["latent"], dx.Normal) - - priors = { - "kernel": {"lengthscale": None, "variance": None}, - "mean_function": {}, - "liklelihood": {"variance": None}, - } - new_priors = prior_checks(priors) - assert "latent" in new_priors.keys() - assert isinstance(new_priors["latent"], dx.Normal) - - priors = { - "kernel": {"lengthscale": None, "variance": None}, - "mean_function": {}, - "liklelihood": {"variance": None}, - "latent": latent_prior, - } - with pytest.warns(UserWarning): - new_priors = prior_checks(priors) - assert "latent" in new_priors.keys() - assert isinstance(new_priors["latent"], dx.Laplace) - - -######################### -# Test transforms -######################### -@pytest.mark.parametrize("num_datapoints", [1, 10]) -@pytest.mark.parametrize("likelihood", [Gaussian, Bernoulli]) -def test_output(num_datapoints, likelihood): - posterior = Prior(kernel=RBF()) * likelihood(num_datapoints=num_datapoints) - params, _, bijectors = initialise(posterior, jr.PRNGKey(123)).unpack() - - assert isinstance(bijectors, dict) - for k, v1, v2 in recursive_items(bijectors, bijectors): - assert isinstance(v1.forward, tp.Callable) - assert isinstance(v2.inverse, tp.Callable) - - unconstrained_params = unconstrain(params, bijectors) - assert ( - unconstrained_params["kernel"]["lengthscale"] != params["kernel"]["lengthscale"] - ) - backconstrained_params = constrain(unconstrained_params, bijectors) - for k, v1, v2 in recursive_items(params, unconstrained_params): - assert v1.dtype == v2.dtype - - for k, v1, v2 in recursive_items(params, backconstrained_params): - assert all(v1 == v2) - - augmented_params = params - augmented_params["test_param"] = jnp.array([1.0]) - a_bijectors = build_bijectors(augmented_params) - - assert "test_param" in list(a_bijectors.keys()) - assert a_bijectors["test_param"].forward(jnp.array([1.0])) == 1.0 - assert a_bijectors["test_param"].inverse(jnp.array([1.0])) == 1.0 diff --git a/tests/test_types.py b/tests/test_types.py deleted file mode 100644 index 8b7243c3b..000000000 --- a/tests/test_types.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import jax.numpy as jnp -import pytest - -from gpjax.types import Dataset, verify_dataset - - -@pytest.mark.parametrize("n", [1, 10]) -@pytest.mark.parametrize("outd", [1, 2, 10]) -@pytest.mark.parametrize("ind", [1, 2, 10]) -@pytest.mark.parametrize("n2", [1, 10]) -def test_dataset(n: int, outd: int, ind: int, n2: int) -> None: - x = jnp.ones((n, ind)) - y = jnp.ones((n, outd)) - d = Dataset(X=x, y=y) - - verify_dataset(d) - assert d.n == n - assert d.in_dim == ind - assert d.out_dim == outd - - assert d.__repr__() == f"- Number of datapoints: {n}\n- Dimension: {ind}" - - # Test combine datasets. - x2 = 2 * jnp.ones((n2, ind)) - y2 = 2 * jnp.ones((n2, outd)) - d2 = Dataset(X=x2, y=y2) - - d_combined = d + d2 - assert d_combined.n == n + n2 - assert d_combined.in_dim == ind - assert d_combined.out_dim == outd - assert (d_combined.y[:n] == 1.0).all() - assert (d_combined.y[n:] == 2.0).all() - assert (d_combined.X[:n] == 1.0).all() - assert (d_combined.X[n:] == 2.0).all() - - # Test supervised and unsupervised. - assert d.is_supervised() is True - dunsup = Dataset(y=y) - assert dunsup.is_unsupervised() is True - - -@pytest.mark.parametrize("nx, ny", [(1, 2), (2, 1), (10, 5), (5, 10)]) -@pytest.mark.parametrize("outd", [1, 2, 10]) -@pytest.mark.parametrize("ind", [1, 2, 10]) -def test_dataset_assertions(nx: int, ny: int, outd: int, ind: int) -> None: - x = jnp.ones((nx, ind)) - y = jnp.ones((ny, outd)) - - with pytest.raises(ValueError): - ds = Dataset(X=x, y=y) - - -@pytest.mark.parametrize("n", [1, 2, 10]) -@pytest.mark.parametrize("outd", [1, 2, 10]) -@pytest.mark.parametrize("ind", [1, 2, 10]) -def test_2d_inputs(n: int, outd: int, ind: int) -> None: - x = jnp.ones((n, ind)) - y = jnp.ones((n,)) - - with pytest.raises(ValueError): - ds = Dataset(X=x, y=y) - - x = jnp.ones((n,)) - y = jnp.ones((n, outd)) - - with pytest.raises(ValueError): - ds = Dataset(X=x, y=y) - - -def test_y_none() -> None: - x = jnp.ones((10, 1)) - d = Dataset(X=x) - verify_dataset(d) - assert d.y is None diff --git a/tests/test_utilities.py b/tests/test_utilities.py deleted file mode 100644 index 21fbd961e..000000000 --- a/tests/test_utilities.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import jax.numpy as jnp -import pytest -from jax.config import config - -from gpjax.utils import ( - concat_dictionaries, - dict_array_coercion, - merge_dictionaries, - sort_dictionary, -) - -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) - - -def test_concat_dict(): - d1 = {"a": 1, "b": 2} - d2 = {"c": 3, "d": 4} - d = concat_dictionaries(d1, d2) - assert list(d.keys()) == ["a", "b", "c", "d"] - assert list(d.values()) == [1, 2, 3, 4] - - -def test_merge_dicts(): - d1 = {"a": 1, "b": 2} - d2 = {"b": 3} - d = merge_dictionaries(d1, d2) - assert list(d.keys()) == ["a", "b"] - assert list(d.values()) == [1, 3] - - -def test_sort_dict(): - unsorted = {"b": 1, "a": 2} - sorted_dict = sort_dictionary(unsorted) - assert list(sorted_dict.keys()) == ["a", "b"] - assert list(sorted_dict.values()) == [2, 1] - - -@pytest.mark.parametrize("d", [1, 2, 10]) -def test_array_coercion(d): - params = { - "kernel": { - "lengthscale": jnp.array([1.0] * d), - "variance": jnp.array([1.0]), - }, - "likelihood": {"obs_noise": jnp.array([1.0])}, - "mean_function": {}, - } - dict_to_array, array_to_dict = dict_array_coercion(params) - assert array_to_dict(dict_to_array(params)) == params - assert isinstance(dict_to_array(params), list) - assert isinstance(array_to_dict(dict_to_array(params)), dict) diff --git a/tests/test_variational_inference.py b/tests/test_variational_inference.py deleted file mode 100644 index 1e7eb9eba..000000000 --- a/tests/test_variational_inference.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import typing as tp - -import jax -import jax.numpy as jnp -import jax.random as jr -import pytest -from jax.config import config - -import gpjax as gpx -from gpjax.variational_families import ( - CollapsedVariationalGaussian, - ExpectationVariationalGaussian, - NaturalVariationalGaussian, - VariationalGaussian, - WhitenedVariationalGaussian, -) - -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) - - -def test_abstract_variational_inference(): - prior = gpx.Prior(kernel=gpx.RBF()) - lik = gpx.Gaussian(num_datapoints=20) - post = prior * lik - n_inducing_points = 10 - inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing_points).reshape(-1, 1) - vartiational_family = gpx.VariationalGaussian( - prior=prior, inducing_inputs=inducing_inputs - ) - - with pytest.raises(TypeError): - gpx.variational_inference.AbstractVariationalInference( - posterior=post, vartiational_family=vartiational_family - ) - - -def get_data_and_gp(n_datapoints, point_dim): - x = jnp.linspace(-5.0, 5.0, n_datapoints).reshape(-1, 1) - y = jnp.sin(x) + jr.normal(key=jr.PRNGKey(123), shape=x.shape) * 0.1 - x = jnp.hstack([x] * point_dim) - D = gpx.Dataset(X=x, y=y) - - p = gpx.Prior(kernel=gpx.RBF()) - lik = gpx.Gaussian(num_datapoints=n_datapoints) - post = p * lik - return D, post, p - - -@pytest.mark.parametrize("n_datapoints, n_inducing_points", [(10, 2), (100, 10)]) -@pytest.mark.parametrize("jit_fns", [False, True]) -@pytest.mark.parametrize("point_dim", [1, 2, 3]) -@pytest.mark.parametrize( - "variational_family", - [ - VariationalGaussian, - WhitenedVariationalGaussian, - NaturalVariationalGaussian, - ExpectationVariationalGaussian, - ], -) -def test_stochastic_vi( - n_datapoints, n_inducing_points, jit_fns, point_dim, variational_family -): - D, post, prior = get_data_and_gp(n_datapoints, point_dim) - inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing_points).reshape(-1, 1) - inducing_inputs = jnp.hstack([inducing_inputs] * point_dim) - - q = variational_family(prior=prior, inducing_inputs=inducing_inputs) - - svgp = gpx.StochasticVI(posterior=post, variational_family=q) - assert svgp.posterior.prior == post.prior - assert svgp.posterior.likelihood == post.likelihood - - params, _, _ = gpx.initialise(svgp, jr.PRNGKey(123)).unpack() - - assert svgp.prior == post.prior - assert svgp.likelihood == post.likelihood - - if jit_fns: - elbo_fn = jax.jit(svgp.elbo(D)) - else: - elbo_fn = svgp.elbo(D) - assert isinstance(elbo_fn, tp.Callable) - elbo_value = elbo_fn(params, D) - assert isinstance(elbo_value, jnp.ndarray) - - # Test gradients - grads = jax.grad(elbo_fn, argnums=0)(params, D) - assert isinstance(grads, tp.Dict) - assert len(grads) == len(params) - - -@pytest.mark.parametrize("n_datapoints, n_inducing_points", [(10, 2), (100, 10)]) -@pytest.mark.parametrize("jit_fns", [False, True]) -@pytest.mark.parametrize("point_dim", [1, 2]) -def test_collapsed_vi(n_datapoints, n_inducing_points, jit_fns, point_dim): - D, post, prior = get_data_and_gp(n_datapoints, point_dim) - likelihood = gpx.Gaussian(num_datapoints=n_datapoints) - - inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing_points).reshape(-1, 1) - inducing_inputs = jnp.hstack([inducing_inputs] * point_dim) - - q = CollapsedVariationalGaussian( - prior=prior, likelihood=likelihood, inducing_inputs=inducing_inputs - ) - - sgpr = gpx.variational_inference.CollapsedVI(posterior=post, variational_family=q) - assert sgpr.posterior.prior == post.prior - assert sgpr.posterior.likelihood == post.likelihood - - params, _, _ = gpx.initialise(sgpr, jr.PRNGKey(123)).unpack() - - assert sgpr.prior == post.prior - assert sgpr.likelihood == post.likelihood - - if jit_fns: - elbo_fn = jax.jit(sgpr.elbo(D)) - else: - elbo_fn = sgpr.elbo(D) - assert isinstance(elbo_fn, tp.Callable) - elbo_value = elbo_fn(params) - assert isinstance(elbo_value, jnp.ndarray) - - # Test gradients - grads = jax.grad(elbo_fn)(params) - assert isinstance(grads, tp.Dict) - assert len(grads) == len(params) - - # We should raise an error for non-Collapsed variational families: - with pytest.raises(TypeError): - q = gpx.variational_families.VariationalGaussian( - prior=prior, inducing_inputs=inducing_inputs - ) - gpx.variational_inference.CollapsedVI(posterior=post, variational_family=q) - - # We should raise an error for non-Gaussian likelihoods: - with pytest.raises(TypeError): - q = gpx.variational_families.CollapsedVariationalGaussian( - prior=prior, likelihood=likelihood, inducing_inputs=inducing_inputs - ) - gpx.variational_inference.CollapsedVI( - posterior=prior * gpx.Bernoulli(num_datapoints=D.n), variational_family=q - ) From cd1cce7f80bf410037fa3f2e1d0db340be57fca9 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Mon, 3 Apr 2023 13:33:43 +0100 Subject: [PATCH 24/44] Improve dataset tests. --- gpjax/dataset.py | 11 ++- tests/test_dataset.py | 152 +++++++++++++++++++++++++++++------------- 2 files changed, 116 insertions(+), 47 deletions(-) diff --git a/gpjax/dataset.py b/gpjax/dataset.py index b5b9e9052..9b496c1e6 100644 --- a/gpjax/dataset.py +++ b/gpjax/dataset.py @@ -55,8 +55,15 @@ def is_unsupervised(self) -> bool: def __add__(self, other: Dataset) -> Dataset: """Combine two datasets. Right hand dataset is stacked beneath the left.""" - X = jnp.concatenate((self.X, other.X)) - y = jnp.concatenate((self.y, other.y)) + + X = None + y = None + + if self.X is not None and other.X is not None: + X = jnp.concatenate((self.X, other.X)) + + if self.y is not None and other.y is not None: + y = jnp.concatenate((self.y, other.y)) return Dataset(X=X, y=y) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 25efb81d1..4a22e98f3 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -15,76 +15,138 @@ import jax.numpy as jnp import pytest +import jax.tree_util as jtu + +from dataclasses import is_dataclass from gpjax.dataset import Dataset +@pytest.mark.parametrize("n", [1, 2, 10]) +@pytest.mark.parametrize("out_dim", [1, 2, 10]) +@pytest.mark.parametrize("in_dim", [1, 2, 10]) +def test_dataset_init(n: int, in_dim: int, out_dim: int) -> None: + + # Create dataset + x = jnp.ones((n, in_dim)) + y = jnp.ones((n, out_dim)) + D = Dataset(X=x, y=y) + + # Test dataset shapes + assert D.n == n + assert D.in_dim == in_dim + assert D.out_dim == out_dim + + # Test representation + assert ( + D.__repr__() + == f"- Number of observations: {n}\n- Input dimension: {in_dim}\n- Output" + f" dimension: {out_dim}" + ) + + # Ensure dataclass + assert is_dataclass(D) + + # Test supervised and unsupervised + assert Dataset(X=x, y=y).is_supervised() is True + assert Dataset(y=y).is_unsupervised() is True + + # Check tree flatten + assert jtu.tree_leaves(D) == [x, y] + -@pytest.mark.parametrize("n", [1, 10]) -@pytest.mark.parametrize("outd", [1, 2, 10]) -@pytest.mark.parametrize("ind", [1, 2, 10]) -@pytest.mark.parametrize("n2", [1, 10]) -def test_dataset(n: int, outd: int, ind: int, n2: int) -> None: - x = jnp.ones((n, ind)) - y = jnp.ones((n, outd)) - d = Dataset(X=x, y=y) +@pytest.mark.parametrize("n1", [1, 2, 10]) +@pytest.mark.parametrize("n2", [1, 2, 10]) +@pytest.mark.parametrize("out_dim", [1, 2, 10]) +@pytest.mark.parametrize("in_dim", [1, 2, 10]) +def test_dataset_add(n1: int, n2: int, in_dim: int, out_dim: int) -> None: + + # Create first dataset + x1 = jnp.ones((n1, in_dim)) + y1 = jnp.ones((n1, out_dim)) + D1 = Dataset(X=x1, y=y1) - assert d.n == n - assert d.in_dim == ind - assert d.out_dim == outd + # Create second dataset + x2 = 2 * jnp.ones((n2, in_dim)) + y2 = 2 * jnp.ones((n2, out_dim)) + D2 = Dataset(X=x2, y=y2) + + # Add datasets + D = D1 + D2 + + # Test shapes + assert D.n == n1 + n2 + assert D.in_dim == in_dim + assert D.out_dim == out_dim + + # Test representation assert ( - d.__repr__() - == f"- Number of observations: {n}\n- Input dimension: {ind}\n- Output" - f" dimension: {outd}" + D.__repr__() + == f"- Number of observations: {n1 + n2}\n- Input dimension: {in_dim}\n- Output" + f" dimension: {out_dim}" ) - # Test combine datasets. - x2 = 2 * jnp.ones((n2, ind)) - y2 = 2 * jnp.ones((n2, outd)) - d2 = Dataset(X=x2, y=y2) + # Ensure dataclass + assert is_dataclass(D) - d_combined = d + d2 - assert d_combined.n == n + n2 - assert d_combined.in_dim == ind - assert d_combined.out_dim == outd - assert (d_combined.y[:n] == 1.0).all() - assert (d_combined.y[n:] == 2.0).all() - assert (d_combined.X[:n] == 1.0).all() - assert (d_combined.X[n:] == 2.0).all() + # Test supervised and unsupervised + assert (Dataset(X=x1, y=y1) + Dataset(X=x2, y=y2)).is_supervised() is True + assert (Dataset(y=y1) + Dataset(y=y2)).is_unsupervised() is True - # Test supervised and unsupervised. - assert d.is_supervised() is True - dunsup = Dataset(y=y) - assert dunsup.is_unsupervised() is True + # Check tree flatten + x = jnp.concatenate((x1, x2)) + y = jnp.concatenate((y1, y2)) + (jtu.tree_leaves(D)[0] == x).all() + (jtu.tree_leaves(D)[1] == y).all() @pytest.mark.parametrize("nx, ny", [(1, 2), (2, 1), (10, 5), (5, 10)]) -@pytest.mark.parametrize("outd", [1, 2, 10]) -@pytest.mark.parametrize("ind", [1, 2, 10]) -def test_dataset_assertions(nx: int, ny: int, outd: int, ind: int) -> None: - x = jnp.ones((nx, ind)) - y = jnp.ones((ny, outd)) - +@pytest.mark.parametrize("out_dim", [1, 2, 10]) +@pytest.mark.parametrize("in_dim", [1, 2, 10]) +def test_dataset_incorrect_lengths(nx: int, ny: int, out_dim: int, in_dim: int) -> None: + + # Create input and output pairs of different lengths + x = jnp.ones((nx, in_dim)) + y = jnp.ones((ny, out_dim)) + + # Ensure error is raised upon dataset creation with pytest.raises(ValueError): Dataset(X=x, y=y) @pytest.mark.parametrize("n", [1, 2, 10]) -@pytest.mark.parametrize("outd", [1, 2, 10]) -@pytest.mark.parametrize("ind", [1, 2, 10]) -def test_2d_inputs(n: int, outd: int, ind: int) -> None: - x = jnp.ones((n, ind)) +@pytest.mark.parametrize("out_dim", [1, 2, 10]) +@pytest.mark.parametrize("in_dim", [1, 2, 10]) +def test_2d_inputs(n: int, out_dim: int, in_dim: int) -> None: + + # Create dataset where output dimension is incorrectly not 2D + x = jnp.ones((n, in_dim)) y = jnp.ones((n,)) + # Ensure error is raised upon dataset creation with pytest.raises(ValueError): Dataset(X=x, y=y) + # Create dataset where input dimension is incorrectly not 2D x = jnp.ones((n,)) - y = jnp.ones((n, outd)) + y = jnp.ones((n, out_dim)) + # Ensure error is raised upon dataset creation with pytest.raises(ValueError): Dataset(X=x, y=y) -def test_y_none() -> None: - x = jnp.ones((10, 1)) - d = Dataset(X=x) - assert d.y is None +@pytest.mark.parametrize("n", [1, 2, 10]) +@pytest.mark.parametrize("in_dim", [1, 2, 10]) +def test_y_none(n: int, in_dim: int) -> None: + + # Create a dataset with no output + x = jnp.ones((n, in_dim)) + D = Dataset(X=x) + + # Ensure is dataclass + assert is_dataclass(D) + + # Ensure output is None + assert D.y is None + + # Check tree flatten + assert jtu.tree_leaves(D) == [x] \ No newline at end of file From aa356e516209dbf517b50630d4b11d2f072df001 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Mon, 3 Apr 2023 15:40:00 +0100 Subject: [PATCH 25/44] Update fit testing. --- gpjax/fit.py | 37 ++++---- gpjax/scan.py | 5 -- tests/test_abstractions.py | 177 ------------------------------------- tests/test_fit.py | 148 +++++++++++++++++++++++++++++-- 4 files changed, 159 insertions(+), 208 deletions(-) delete mode 100644 tests/test_abstractions.py diff --git a/gpjax/fit.py b/gpjax/fit.py index 757f6b463..b4dddda2f 100644 --- a/gpjax/fit.py +++ b/gpjax/fit.py @@ -24,12 +24,11 @@ from jaxtyping import Array, Float from typing import Any +from .base import Module from .dataset import Dataset from .objectives import AbstractObjective from .scan import vscan -Module = Any - def fit( *, @@ -50,17 +49,17 @@ def fit( >>> import jax.numpy as jnp >>> import jax.random as jr >>> import optax as ox - >>> import jaxutils as ju + >>> import gpjax as gpx >>> >>> # (1) Create a dataset: >>> X = jnp.linspace(0.0, 10.0, 100)[:, None] >>> y = 2.0 * X + 1.0 + 10 * jr.normal(jr.PRNGKey(0), X.shape) - >>> D = ju.Dataset(X, y) + >>> D = gpx.Dataset(X, y) >>> >>> # (2) Define your model: - >>> class LinearModel(ju.Module): - ... weight: float = ju.param(ju.Identity) - ... bias: float = ju.param(ju.Identity) + >>> class LinearModel(gpx.Module): + ... weight: float = gpx.param_field() + ... bias: float = gpx.param_field() ... ... def __call__(self, x): ... return self.weight * x + self.bias @@ -68,14 +67,14 @@ def fit( >>> model = LinearModel(weight=1.0, bias=1.0) >>> >>> # (3) Define your loss function: - >>> class MeanSqaureError(ju.Objective): - ... def evaluate(self, model: LinearModel, train_data: ju.Dataset) -> float: + >>> class MeanSqaureError(gpx.AbstractObjective): + ... def evaluate(self, model: LinearModel, train_data: gpx.Dataset) -> float: ... return jnp.mean((train_data.y - model(train_data.X)) ** 2) ... >>> loss = MeanSqaureError() >>> >>> # (4) Train! - >>> trained_model, history = ju.fit( + >>> trained_model, history = gpx.fit( ... model=model, objective=loss, train_data=D, optim=ox.sgd(0.001), num_iters=1000 ... ) @@ -96,15 +95,15 @@ def fit( """ # Check inputs. - # _check_model(model) - # _check_objective(objective) - # _check_train_data(train_data) - # _check_optim(optim) - # _check_num_iters(num_iters) - # _check_batch_size(batch_size) - # _check_prng_key(key) - # _check_log_rate(log_rate) - # _check_verbose(verbose) + _check_model(model) + _check_objective(objective) + _check_train_data(train_data) + _check_optim(optim) + _check_num_iters(num_iters) + _check_batch_size(batch_size) + _check_prng_key(key) + _check_log_rate(log_rate) + _check_verbose(verbose) # Unconstrained space loss function with stop-gradient rule for non-trainable params. def loss(model: Module, batch: Dataset) -> Float[Array, "1"]: diff --git a/gpjax/scan.py b/gpjax/scan.py index 211964850..e5e80f4b0 100644 --- a/gpjax/scan.py +++ b/gpjax/scan.py @@ -88,11 +88,6 @@ def vscan( Returns: Tuple[Carry, List[Y]]: A tuple of the final carry and the outputs. """ - - # TODO: Scope out lower level API for jax.lax.scan, to avoid the need for finding - # the length of the inputs / check inputs. - # TODO: Scope out lower level API for tqdm, for more control over the progress bar. - # Need to check this. _xs_flat = jtu.tree_leaves(xs) _length = length if length is not None else len(_xs_flat[0]) _iter_nums = jnp.arange(_length) diff --git a/tests/test_abstractions.py b/tests/test_abstractions.py deleted file mode 100644 index 7105a54de..000000000 --- a/tests/test_abstractions.py +++ /dev/null @@ -1,177 +0,0 @@ -# # Copyright 2022 The GPJax Contributors. All Rights Reserved. -# # -# # Licensed under the Apache License, Version 2.0 (the "License"); -# # you may not use this file except in compliance with the License. -# # You may obtain a copy of the License at -# # -# # http://www.apache.org/licenses/LICENSE-2.0 -# # -# # Unless required by applicable law or agreed to in writing, software -# # distributed under the License is distributed on an "AS IS" BASIS, -# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# # See the License for the specific language governing permissions and -# # limitations under the License. -# # ============================================================================== - -# import jax.numpy as jnp -# import jax.random as jr -# import optax -# import pytest -# from jax.config import config - -# import gpjax as gpx -# from gpjax import RBF, Dataset, Gaussian, Prior, initialise -# from gpjax.abstractions import InferenceState, fit, fit_batches, fit_natgrads, get_batch - -# # Enable Float64 for more stable matrix inversions. -# config.update("jax_enable_x64", True) - - -# @pytest.mark.parametrize("num_iters", [1, 5]) -# @pytest.mark.parametrize("n", [1, 20]) -# @pytest.mark.parametrize("verbose", [True, False]) -# def test_fit(num_iters, n, verbose): -# key = jr.PRNGKey(123) -# x = jnp.sort(jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(n, 1)), axis=0) -# y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 -# D = Dataset(X=x, y=y) -# p = Prior(kernel=RBF()) * Gaussian(num_datapoints=n) -# parameter_state = initialise(p, key) -# mll = p.marginal_log_likelihood(D, negative=True) -# pre_mll_val = mll(parameter_state.params) -# optimiser = optax.adam(learning_rate=0.1) -# inference_state = fit(mll, parameter_state, optimiser, num_iters, verbose=verbose) -# optimised_params, history = inference_state.unpack() -# assert isinstance(inference_state, InferenceState) -# assert isinstance(optimised_params, dict) -# assert mll(optimised_params) < pre_mll_val -# assert isinstance(history, jnp.ndarray) -# assert history.shape[0] == num_iters - - -# def test_stop_grads(): -# params = {"x": jnp.array(3.0), "y": jnp.array(4.0)} -# trainables = {"x": True, "y": False} -# bijectors = build_bijectors(params) -# loss_fn = lambda params: params["x"] ** 2 + params["y"] ** 2 -# optimiser = optax.adam(learning_rate=0.1) -# parameter_state = ParameterState( -# params=params, trainables=trainables, bijectors=bijectors -# ) -# inference_state = fit(loss_fn, parameter_state, optimiser, num_iters=1) -# learned_params = inference_state.params -# assert isinstance(inference_state, InferenceState) -# assert learned_params["y"] == params["y"] -# assert learned_params["x"] != params["x"] - - -# @pytest.mark.parametrize("num_iters", [1, 5]) -# @pytest.mark.parametrize("nb", [1, 20, 50]) -# @pytest.mark.parametrize("ndata", [50]) -# @pytest.mark.parametrize("verbose", [True, False]) -# def test_batch_fitting(num_iters, nb, ndata, verbose): -# key = jr.PRNGKey(123) -# x = jnp.sort(jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(ndata, 1)), axis=0) -# y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 -# D = Dataset(X=x, y=y) -# prior = Prior(kernel=RBF()) -# likelihood = Gaussian(num_datapoints=ndata) -# p = prior * likelihood -# z = jnp.linspace(-2.0, 2.0, 10).reshape(-1, 1) - -# q = gpx.VariationalGaussian(prior=prior, inducing_inputs=z) - -# svgp = gpx.StochasticVI(posterior=p, variational_family=q) -# parameter_state = initialise(svgp, key) -# objective = svgp.elbo(D, negative=True) - -# pre_mll_val = objective(parameter_state.params, D) - -# D = Dataset(X=x, y=y) - -# optimiser = optax.adam(learning_rate=0.1) -# key = jr.PRNGKey(42) -# inference_state = fit_batches( -# objective, parameter_state, D, optimiser, key, nb, num_iters, verbose=verbose -# ) -# optimised_params, history = inference_state.unpack() -# assert isinstance(inference_state, InferenceState) -# assert isinstance(optimised_params, dict) -# assert objective(optimised_params, D) < pre_mll_val -# assert isinstance(history, jnp.ndarray) -# assert history.shape[0] == num_iters - - -# @pytest.mark.parametrize("num_iters", [1, 5]) -# @pytest.mark.parametrize("nb", [1, 20, 50]) -# @pytest.mark.parametrize("ndata", [50]) -# @pytest.mark.parametrize("verbose", [True, False]) -# def test_natural_gradients(ndata, nb, num_iters, verbose): -# key = jr.PRNGKey(123) -# x = jnp.sort(jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(ndata, 1)), axis=0) -# y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 -# D = Dataset(X=x, y=y) -# prior = Prior(kernel=RBF()) -# likelihood = Gaussian(num_datapoints=ndata) -# p = prior * likelihood -# z = jnp.linspace(-2.0, 2.0, 10).reshape(-1, 1) - -# q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z) - -# svgp = gpx.StochasticVI(posterior=p, variational_family=q) -# training_state = initialise(svgp, key) - -# D = Dataset(X=x, y=y) - -# hyper_optimiser = optax.adam(learning_rate=0.1) -# moment_optimiser = optax.sgd(learning_rate=1.0) - -# objective = svgp.elbo(D, negative=True) -# parameter_state = initialise(svgp, key) -# pre_mll_val = objective(parameter_state.params, D) - -# key = jr.PRNGKey(42) -# inference_state = fit_natgrads( -# svgp, -# training_state, -# D, -# moment_optimiser, -# hyper_optimiser, -# key, -# nb, -# num_iters, -# verbose=verbose, -# ) -# optimised_params, history = inference_state.unpack() -# assert isinstance(inference_state, InferenceState) -# assert isinstance(optimised_params, dict) -# assert objective(optimised_params, D) < pre_mll_val -# assert isinstance(history, jnp.ndarray) -# assert history.shape[0] == num_iters - - -# @pytest.mark.parametrize("batch_size", [1, 2, 50]) -# @pytest.mark.parametrize("ndim", [1, 2, 3]) -# @pytest.mark.parametrize("ndata", [50]) -# @pytest.mark.parametrize("key", [jr.PRNGKey(123)]) -# def test_get_batch(ndata, ndim, batch_size, key): -# x = jnp.sort( -# jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(ndata, ndim)), axis=0 -# ) -# y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 -# D = Dataset(X=x, y=y) - -# B = get_batch(D, batch_size, key) - -# assert B.n == batch_size -# assert B.X.shape[1:] == x.shape[1:] -# assert B.y.shape[1:] == y.shape[1:] - -# # test no caching of batches: -# key, subkey = jr.split(key) -# Bnew = get_batch(D, batch_size, subkey) -# assert Bnew.n == batch_size -# assert Bnew.X.shape[1:] == x.shape[1:] -# assert Bnew.y.shape[1:] == y.shape[1:] -# assert (Bnew.X != B.X).all() -# assert (Bnew.y != B.y).all() diff --git a/tests/test_fit.py b/tests/test_fit.py index 6cab1c8d0..6dcd0477e 100644 --- a/tests/test_fit.py +++ b/tests/test_fit.py @@ -20,32 +20,41 @@ import tensorflow_probability.substrates.jax.bijectors as tfb import tensorflow_probability.substrates.jax.distributions as tfd +import pytest + from dataclasses import dataclass from gpjax.base import param_field, Module from gpjax.objectives import AbstractObjective from gpjax.dataset import Dataset -from gpjax.fit import fit +from gpjax.fit import fit, get_batch +from gpjax.gps import Prior, ConjugatePosterior +from gpjax.likelihoods import Gaussian +from gpjax.kernels import RBF +from gpjax.mean_functions import Constant +from gpjax.objectives import ConjugateMLL, ELBO +from gpjax.variational_families import VariationalGaussian -def test_simple_linear_model(): - # (1) Create a dataset: +def test_simple_linear_model() -> None: + + # Create dataset: X = jnp.linspace(0.0, 10.0, 100).reshape(-1, 1) y = 2.0 * X + 1.0 + 10 * jr.normal(jr.PRNGKey(0), X.shape).reshape(-1, 1) D = Dataset(X, y) - # (2) Define your model: + # Define linear model: @dataclass class LinearModel(Module): weight: float = param_field(bijector=tfb.Identity()) - bias: float = param_field(bijector=tfb.Identity()) + bias: float = param_field(bijector=tfb.Identity(), trainable=False) def __call__(self, x): return self.weight * x + self.bias model = LinearModel(weight=1.0, bias=1.0) - # (3) Define your loss function: + # Define loss function: @dataclass class MeanSqaureError(AbstractObjective): def __call__(self, model: LinearModel, train_data: Dataset) -> float: @@ -53,11 +62,136 @@ def __call__(self, model: LinearModel, train_data: Dataset) -> float: loss = MeanSqaureError() - # (4) Train! + # Train! trained_model, hist = fit( model=model, objective=loss, train_data=D, optim=ox.sgd(0.001), num_iters=100 ) + # Ensure we return a history of the correct length assert len(hist) == 100 + + # Ensure we return a model of the same class assert isinstance(trained_model, LinearModel) + + # Test reduction in loss: assert loss(trained_model, D) < loss(model, D) + + # Test stop_gradient on bias: + assert trained_model.bias == 1.0 + + +@pytest.mark.parametrize("num_iters", [1, 5]) +@pytest.mark.parametrize("n_data", [1, 20]) +@pytest.mark.parametrize("verbose", [True, False]) +def test_gaussian_process_regression(num_iters, n_data: int, verbose: bool) -> None: + + # Create dataset: + key = jr.PRNGKey(123) + x = jnp.sort(jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(n_data, 1)), axis=0) + y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 + D = Dataset(X=x, y=y) + + # Define GP model: + prior = Prior(kernel=RBF(), mean_function=Constant()) + likelihood = Gaussian(num_datapoints=n_data) + posterior = prior * likelihood + + # Define loss function: + mll = ConjugateMLL(negative=True) + + # Train! + trained_model, history = fit( + model=posterior, + objective=mll, + train_data=D, + optim=ox.adam(0.1), + num_iters=num_iters, + verbose=verbose, + ) + + # Ensure the trained model is a Gaussian process posterior + assert isinstance(trained_model, ConjugatePosterior) + + # Ensure we return a history of the correct length + assert len(history) == num_iters + + # Ensure we reduce the loss + assert mll(trained_model, D) < mll(posterior, D) + + +@pytest.mark.parametrize("num_iters", [1, 5]) +@pytest.mark.parametrize("batch_size", [1, 20, 50]) +@pytest.mark.parametrize("n_data", [50]) +@pytest.mark.parametrize("verbose", [True, False]) +def test_batch_fitting(num_iters: int, batch_size: int, n_data: int, verbose: bool) -> None: + + # Create dataset: + key = jr.PRNGKey(123) + x = jnp.sort(jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(n_data, 1)), axis=0) + y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 + D = Dataset(X=x, y=y) + + # Define GP model: + prior = Prior(kernel=RBF(), mean_function=Constant()) + likelihood = Gaussian(num_datapoints=n_data) + posterior = prior * likelihood + + # Define variational family: + z = jnp.linspace(-2.0, 2.0, 10).reshape(-1, 1) + q = VariationalGaussian(posterior=posterior, inducing_inputs=z) + + # Define loss function: + elbo = ELBO(negative=True) + + # Train! + trained_model, history = fit( + model=q, + objective=elbo, + train_data=D, + optim=ox.adam(0.1), + num_iters=num_iters, + batch_size=batch_size, + verbose=verbose, + ) + + # Ensure the trained model is a Gaussian process posterior + assert isinstance(trained_model, VariationalGaussian) + + # Ensure we return a history of the correct length + assert len(history) == num_iters + + # Ensure we reduce the loss + assert elbo(trained_model, D) < elbo(q, D) + + + +@pytest.mark.parametrize("n_data", [50]) +@pytest.mark.parametrize("n_dim", [1, 2, 3]) +@pytest.mark.parametrize("batch_size", [1, 2, 50]) +def test_get_batch(n_data: int, n_dim: int, batch_size: int): + + key = jr.PRNGKey(123) + + # Create dataset: + x = jnp.sort( + jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(n_data, n_dim)), axis=0 + ) + y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 + D = Dataset(X=x, y=y) + + # Sample out a batch: + B = get_batch(D, batch_size, key) + + # Check batch is correct size and shape dimensions: + assert B.n == batch_size + assert B.X.shape[1:] == x.shape[1:] + assert B.y.shape[1:] == y.shape[1:] + + # Ensure no caching of batches: + key, subkey = jr.split(key) + New = get_batch(D, batch_size, subkey) + assert New.n == batch_size + assert New.X.shape[1:] == x.shape[1:] + assert New.y.shape[1:] == y.shape[1:] + assert (New.X != B.X).all() + assert (New.y != B.y).all() From 587769824eb383e96f2fae79a11f7d3907b3c186 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Mon, 3 Apr 2023 21:04:10 +0100 Subject: [PATCH 26/44] Refactor docs --- examples/barycentres.pct.py | 115 ++++++++++++++++++++-------- examples/classification.pct.py | 4 +- examples/collapsed_vi.pct.py | 2 +- examples/graph_kernels.pct.py | 2 +- examples/haiku.pct.py | 9 +-- examples/kernels.pct.py | 8 +- examples/natgrads.pct.py | 2 +- examples/regression.pct.py | 132 ++++++++++++++++++-------------- examples/tfp_integration.pct.py | 2 +- examples/uncollapsed_vi.pct.py | 5 +- examples/yacht.pct.py | 5 +- gpjax/__init__.py | 6 ++ gpjax/likelihoods.py | 8 +- 13 files changed, 186 insertions(+), 114 deletions(-) diff --git a/examples/barycentres.pct.py b/examples/barycentres.pct.py index 9a48d8ac0..280fd61aa 100644 --- a/examples/barycentres.pct.py +++ b/examples/barycentres.pct.py @@ -17,22 +17,28 @@ # %% [markdown] # # Gaussian Processes Barycentres # -# In this notebook we'll give an implementation of . In this work, the existence of a Wasserstein barycentre between a collection of Gaussian processes is proven. When faced with trying to _average_ a set of probability distributions, the Wasserstein barycentre is an attractive choice as it enables uncertainty amongst the individual distributions to be incorporated into the averaged distribution. When compared to a naive _mean of means_ and _mean of variances_ approach to computing the average probability distributions, it can be seen that Wasserstein barycentres offer significantly more favourable uncertainty estimation. +# In this notebook we'll give an implementation of +# . In this work, the existence of a +# Wasserstein barycentre between a collection of Gaussian processes is proven. When +# faced with trying to _average_ a set of probability distributions, the Wasserstein +# barycentre is an attractive choice as it enables uncertainty amongst the individual +# distributions to be incorporated into the averaged distribution. When compared to a +# naive _mean of means_ and _mean of variances_ approach to computing the average +# probability distributions, it can be seen that Wasserstein barycentres offer +# significantly more favourable uncertainty estimation. # # %% import typing as tp -import distrax as dx import jax import jax.numpy as jnp import jax.random as jr import jax.scipy.linalg as jsl import matplotlib.pyplot as plt import optax as ox +import tensorflow_probability.substrates.jax.distributions as tfd from jax.config import config -from jaxutils import Dataset -import gpjax.kernels as jk import gpjax as gpx @@ -45,29 +51,53 @@ # # ### Wasserstein distance # -# The 2-Wasserstein distance metric between two probability measures $\mu$ and $\nu$ quantifies the minimal cost required to transport the unit mass from $\mu$ to $\nu$, or vice-versa. Typically, computing this metric requires solving a linear program. However, when $\mu$ and $\nu$ both belong to the family of multivariate Gaussian distributions, the solution is analytically given by +# The 2-Wasserstein distance metric between two probability measures $\mu$ and $\nu$ +# quantifies the minimal cost required to transport the unit mass from $\mu$ to $\nu$, +# or vice-versa. Typically, computing this metric requires solving a linear program. +# However, when $\mu$ and $\nu$ both belong to the family of multivariate Gaussian +# distributions, the solution is analytically given by # $$W_2^2(\mu, \nu) = \lVert m_1- m_2 \rVert^2_2 + \operatorname{Tr}(S_1 + S_2 - 2(S_1^{1/2}S_2S_1^{1/2})^{1/2}),$$ # where $\mu \sim \mathcal{N}(m_1, S_1)$ and $\nu\sim\mathcal{N}(m_2, S_2)$. # # ### Wasserstein barycentre # -# For a collection of $T$ measures $\lbrace\mu_i\rbrace_{t=1}^T \in \mathcal{P}_2(\theta)$, the Wasserstein barycentre $\bar{\mu}$ is the measure that minimises the average Wasserstein distance to all other measures in the set. More formally, the Wasserstein barycentre is the Fréchet mean on a Wasserstein space that we can write as +# For a collection of $T$ measures +# $\lbrace\mu_i\rbrace_{t=1}^T \in \mathcal{P}_2(\theta)$, the Wasserstein barycentre +# $\bar{\mu}$ is the measure that minimises the average Wasserstein distance to all +# other measures in the set. More formally, the Wasserstein barycentre is the Fréchet +# mean on a Wasserstein space that we can write as # $$\bar{\mu} = \operatorname{argmin}_{\mu\in\mathcal{P}_2(\theta)}\sum_{t=1}^T \alpha_t W_2^2(\mu, \mu_t),$$ # where $\alpha\in\bbR^T$ is a weight vector that sums to 1. # -# As with the Wasserstein distance, identifying the Wasserstein barycentre $\bar{\mu}$ is often an computationally demanding optimisation problem. However, when all the measures admit a multivariate Gaussian density, the barycentre $\bar{\mu} = \mathcal{N}(\bar{m}, \bar{S})$ has analytical solutions +# As with the Wasserstein distance, identifying the Wasserstein barycentre $\bar{\mu}$ +# is often an computationally demanding optimisation problem. However, when all the +# measures admit a multivariate Gaussian density, the barycentre +# $\bar{\mu} = \mathcal{N}(\bar{m}, \bar{S})$ has analytical solutions # $$\bar{m} = \sum_{t=1}^T \alpha_t m_t\,, \quad \bar{S}=\sum_{t=1}^T\alpha_t (\bar{S}^{1/2}S_t\bar{S}^{1/2})^{1/2}\,. \qquad (\star)$$ # Identifying $\bar{S}$ is achieved through a fixed-point iterative update. # # ## Barycentre of Gaussian processes # -# It was shown in that the barycentre $\bar{f}$ of a collection of Gaussian processes $\lbrace f_i\rbrace_{i=1}^T$ such that $f_i \sim \mathcal{GP}(m_i, K_i)$ can be found using the same solutions as in $(\star)$. For a full theoretical understanding, we recommend reading the original paper. However, the central argument to this result is that one can first show that the barycentre GP $\bar{f}\sim\mathcal{GP}(\bar{m}, \bar{S})$ is non-degenerate for any finite set of GPs $\lbrace f_t\rbrace_{t=1}^T$ i.e., $T<\infty$. With this established, one can show that for a $n$-dimensional finite Gaussian distribution $f_{i,n}$, the Wasserstein metric between any two Gaussian distributions $f_{i, n}, f_{j, n}$ converges to the Wasserstein metric between GPs as $n\to\infty$. +# It was shown in that the +# barycentre $\bar{f}$ of a collection of Gaussian processes +# $\lbrace f_i\rbrace_{i=1}^T$ such that $f_i \sim \mathcal{GP}(m_i, K_i)$ can be +# found using the same solutions as in $(\star)$. For a full theoretical understanding, +# we recommend reading the original paper. However, the central argument to this result +# is that one can first show that the barycentre GP +# $\bar{f}\sim\mathcal{GP}(\bar{m}, \bar{S})$ is non-degenerate for any finite set of +# GPs $\lbrace f_t\rbrace_{t=1}^T$ i.e., $T<\infty$. With this established, one can +# show that for a $n$-dimensional finite Gaussian distribution $f_{i,n}$, the +# Wasserstein metric between any two Gaussian distributions $f_{i, n}, f_{j, n}$ +# converges to the Wasserstein metric between GPs as $n\to\infty$. # # In this notebook, we will demonstrate how this can be achieved in GPJax. # # ## Dataset # -# We'll simulate five datasets and develop a Gaussian process posterior before identifying the Gaussian process barycentre at a set of test points. Each dataset will be a sine function with a different vertical shift, periodicity, and quantity of noise. +# We'll simulate five datasets and develop a Gaussian process posterior before +# identifying the Gaussian process barycentre at a set of test points. Each dataset +# will be a sine function with a different vertical shift, periodicity, and quantity +# of noise. # %% n = 100 @@ -96,30 +126,32 @@ # %% [markdown] # ## Learning a posterior distribution # -# We'll now independently learn Gaussian process posterior distributions for each dataset. We won't spend any time here discussing how GP hyperparameters are optimised. For advice on achieving this, see the [Regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html) for advice on optimisation and the [Kernels notebook](https://gpjax.readthedocs.io/en/latest/nbs/kernels.html) for advice on selecting an appropriate kernel. +# We'll now independently learn Gaussian process posterior distributions for each +# dataset. We won't spend any time here discussing how GP hyperparameters are +# optimised. For advice on achieving this, see the +# [Regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html) +# for advice on optimisation and the +# [Kernels notebook](https://gpjax.readthedocs.io/en/latest/nbs/kernels.html) for +# advice on selecting an appropriate kernel. # %% -def fit_gp(x: jax.Array, y: jax.Array) -> dx.MultivariateNormalTri: +def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance: if y.ndim == 1: y = y.reshape(-1, 1) - D = Dataset(X=x, y=y) + D = gpx.Dataset(X=x, y=y) likelihood = gpx.Gaussian(num_datapoints=n) - posterior = gpx.Prior(kernel=jk.RBF()) * likelihood - - parameter_state = gpx.initialise(posterior, key) - negative_mll = jax.jit(posterior.marginal_log_likelihood(D, negative=True)) - optimiser = ox.adam(learning_rate=0.01) - - inference_state = gpx.fit( - objective=negative_mll, - parameter_state=parameter_state, - optax_optim=optimiser, - num_iters=1000, + posterior = gpx.Prior(mean_function=gpx.Constant(), kernel=gpx.RBF()) * likelihood + + opt_posterior, _ = gpx.fit( + model=posterior, + objective=jax.jit(gpx.ConjugateMLL(negative=True)), + train_data=D, + optim=ox.adamw(learning_rate=0.01), + num_iters=500, ) - - learned_params, training_history = inference_state.unpack() - return likelihood(learned_params, posterior(learned_params, D)(xtest)) + latent_dist = opt_posterior.predict(xtest, train_data=D) + return opt_posterior.likelihood(latent_dist) posterior_preds = [fit_gp(x, i) for i in ys] @@ -127,7 +159,12 @@ def fit_gp(x: jax.Array, y: jax.Array) -> dx.MultivariateNormalTri: # %% [markdown] # ## Computing the barycentre # -# In GPJax, the predictive distribution of a GP is given by a [Distrax](https://github.com/deepmind/distrax) distribution, making it straightforward to extract the mean vector and covariance matrix of each GP for learning a barycentre. We implement the fixed point scheme given in (3) in the following cell by utilising Jax's `vmap` operator to speed up large matrix operations using broadcasting in `tensordot`. +# In GPJax, the predictive distribution of a GP is given by a +# [Distrax](https://github.com/deepmind/distrax) distribution, making it +# straightforward to extract the mean vector and covariance matrix of each GP for +# learning a barycentre. We implement the fixed point scheme given in (3) in the +# following cell by utilising Jax's `vmap` operator to speed up large matrix operations +# using broadcasting in `tensordot`. # %% def sqrtm(A: jax.Array): @@ -135,7 +172,7 @@ def sqrtm(A: jax.Array): def wasserstein_barycentres( - distributions: tp.List[dx.MultivariateNormalTri], weights: jax.Array + distributions: tp.List[tfd.MultivariateNormalFullCovariance], weights: jax.Array ): covariances = [d.covariance() for d in distributions] cov_stack = jnp.stack(covariances) @@ -152,7 +189,12 @@ def step(covariance_candidate: jax.Array, idx: None): # %% [markdown] -# With a function defined for learning a barycentre, we'll now compute it using the `lax.scan` operator that drastically speeds up for loops in Jax (see the [Jax documentation](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)). The iterative update will be executed 100 times, with convergence measured by the difference between the previous and current iteration that we can confirm by inspecting the `sequence` array in the following cell. +# With a function defined for learning a barycentre, we'll now compute it using the +# `lax.scan` operator that drastically speeds up for loops in Jax (see the +# [Jax documentation](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)). +# The iterative update will be executed 100 times, with convergence measured by the +# difference between the previous and current iteration that we can confirm by +# inspecting the `sequence` array in the following cell. # %% weights = jnp.ones((n_datasets,)) / n_datasets @@ -168,16 +210,18 @@ def step(covariance_candidate: jax.Array, idx: None): ) L = jnp.linalg.cholesky(barycentre_covariance) -barycentre_process = dx.MultivariateNormalTri(barycentre_mean, L) +barycentre_process = tfd.MultivariateNormalTriL(barycentre_mean, L) # %% [markdown] # ## Plotting the result # -# With a barycentre learned, we can visualise the result. We can see that the result looks reasonable as it follows the sinusoidal curve of all the inferred GPs, and the uncertainty bands are sensible. +# With a barycentre learned, we can visualise the result. We can see that the result +# looks reasonable as it follows the sinusoidal curve of all the inferred GPs, and the +# uncertainty bands are sensible. # %% def plot( - dist: dx.MultivariateNormalTri, + dist: tfd.MultivariateNormalTriL, ax, color: str = "tab:blue", label: str = None, @@ -206,7 +250,12 @@ def plot( # %% [markdown] # ## Displacement interpolation # -# In the above example, we assigned uniform weights to each of the posteriors within the barycentre. In practice, we may have prior knowledge of which posterior is most likely to be the correct one. Regardless of the weights chosen, the barycentre remains a Gaussian process. We can interpolate between a pair of posterior distributions $\mu_1$ and $\mu_2$ to visualise the corresponding barycentre $\bar{\mu}$. +# In the above example, we assigned uniform weights to each of the posteriors within +# the barycentre. In practice, we may have prior knowledge of which posterior is most +# likely to be the correct one. Regardless of the weights chosen, the barycentre +# remains a Gaussian process. We can interpolate between a pair of posterior +# distributions $\mu_1$ and $\mu_2$ to visualise the corresponding barycentre +# $\bar{\mu}$. # # ![](figs/barycentre_gp.gif) diff --git a/examples/classification.pct.py b/examples/classification.pct.py index 4d3f9f98b..a656f6a90 100644 --- a/examples/classification.pct.py +++ b/examples/classification.pct.py @@ -22,6 +22,7 @@ # %% import blackjax import distrax as dx +import jax import jax.numpy as jnp import jax.random as jr import jax.scipy as jsp @@ -30,10 +31,9 @@ from jax.config import config from jaxtyping import Array, Float from jaxutils import Dataset -import gpjax.kernels as jk -import jax import gpjax as gpx +import gpjax.kernels as jk # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) diff --git a/examples/collapsed_vi.pct.py b/examples/collapsed_vi.pct.py index 9a22fc513..a6cf8469f 100644 --- a/examples/collapsed_vi.pct.py +++ b/examples/collapsed_vi.pct.py @@ -27,9 +27,9 @@ from jax import jit from jax.config import config from jaxutils import Dataset -import gpjax.kernels as jk import gpjax as gpx +import gpjax.kernels as jk # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) diff --git a/examples/graph_kernels.pct.py b/examples/graph_kernels.pct.py index e0ee1bc63..f87059f65 100644 --- a/examples/graph_kernels.pct.py +++ b/examples/graph_kernels.pct.py @@ -30,9 +30,9 @@ from jax import jit from jax.config import config from jaxutils import Dataset -import gpjax.kernels as jk import gpjax as gpx +import gpjax.kernels as jk # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) diff --git a/examples/haiku.pct.py b/examples/haiku.pct.py index c6192cf25..9355a50a7 100644 --- a/examples/haiku.pct.py +++ b/examples/haiku.pct.py @@ -21,6 +21,7 @@ # %% import typing as tp +from typing import Dict import haiku as hk import jax @@ -29,14 +30,12 @@ import matplotlib.pyplot as plt import optax as ox from jax.config import config -from scipy.signal import sawtooth -from jaxtyping import Float, Array -from typing import Dict +from jaxtyping import Array, Float from jaxutils import Dataset -import gpjax.kernels as jk - +from scipy.signal import sawtooth import gpjax as gpx +import gpjax.kernels as jk from gpjax.kernels import DenseKernelComputation from gpjax.kernels.base import AbstractKernel from gpjax.kernels.computations import AbstractKernelComputation diff --git a/examples/kernels.pct.py b/examples/kernels.pct.py index e42432e62..009e05d06 100644 --- a/examples/kernels.pct.py +++ b/examples/kernels.pct.py @@ -23,6 +23,8 @@ # In this guide, we introduce the kernels available in GPJax and demonstrate how to create custom ones. +from typing import Dict + # %% import distrax as dx import jax.numpy as jnp @@ -31,12 +33,11 @@ from jax import jit from jax.config import config from jaxtyping import Array, Float -from optax import adam -from typing import Dict from jaxutils import Dataset -import gpjax.kernels as jk +from optax import adam import gpjax as gpx +import gpjax.kernels as jk # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -254,6 +255,7 @@ def _initialise_params(self, key: jr.KeyArray) -> Dict: # %% from jax.nn import softplus + from gpjax.config import add_parameter bij_fn = lambda x: softplus(x + jnp.array(4.0)) diff --git a/examples/natgrads.pct.py b/examples/natgrads.pct.py index 54e03451d..7dc2c8a0c 100644 --- a/examples/natgrads.pct.py +++ b/examples/natgrads.pct.py @@ -27,9 +27,9 @@ import optax as ox from jax.config import config from jaxutils import Dataset -import gpjax.kernels as jk import gpjax as gpx +import gpjax.kernels as jk # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) diff --git a/examples/regression.pct.py b/examples/regression.pct.py index 9882bd003..44b2e3c44 100644 --- a/examples/regression.pct.py +++ b/examples/regression.pct.py @@ -28,19 +28,19 @@ from jax import jit from jax.config import config from jaxutils import Dataset -import gpjax.kernels as jk import gpjax as gpx # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) -pp = PrettyPrinter(indent=4) key = jr.PRNGKey(123) # %% [markdown] # ## Dataset # -# With the necessary modules imported, we simulate a dataset $\mathcal{D} = (\boldsymbol{x}, \boldsymbol{y}) = \{(x_i, y_i)\}_{i=1}^{100}$ with inputs $\boldsymbol{x}$ sampled uniformly on $(-3., 3)$ and corresponding independent noisy outputs +# With the necessary modules imported, we simulate a dataset +# $\mathcal{D} = (\boldsymbol{x}, \boldsymbol{y}) = \{(x_i, y_i)\}_{i=1}^{100}$ with inputs $\boldsymbol{x}$ +# sampled uniformly on $(-3., 3)$ and corresponding independent noisy outputs # # $$\boldsymbol{y} \sim \mathcal{N} \left(\sin(4\boldsymbol{x}) + \cos(2 \boldsymbol{x}), \textbf{I} * 0.3^2 \right).$$ # @@ -62,7 +62,8 @@ ytest = f(xtest) # %% [markdown] -# To better understand what we have simulated, we plot both the underlying latent function and the observed data that is subject to Gaussian noise. +# To better understand what we have simulated, we plot both the underlying latent +# function and the observed data that is subject to Gaussian noise. # %% fig, ax = plt.subplots(figsize=(10, 5)) @@ -71,30 +72,46 @@ ax.legend(loc="best") # %% [markdown] -# Our aim in this tutorial will be to reconstruct the latent function from our noisy observations $\mathcal{D}$ via Gaussian process regression. We begin by defining a Gaussian process prior in the next section. +# Our aim in this tutorial will be to reconstruct the latent function from our noisy +# observations $\mathcal{D}$ via Gaussian process regression. We begin by defining a +# Gaussian process prior in the next section. # %% [markdown] # ## Defining the prior # -# A zero-mean Gaussian process (GP) places a prior distribution over real-valued functions $f(\cdot)$ where $f(\boldsymbol{x}) \sim \mathcal{N}(0, \mathbf{K}_{\boldsymbol{x}\boldsymbol{x}})$ for any finite collection of inputs $\boldsymbol{x}$. -# Here $\mathbf{K}_{\boldsymbol{x}\boldsymbol{x}}$ is the Gram matrix generated by a user-specified symmetric, non-negative definite kernel function $k(\cdot, \cdot')$ with $[\mathbf{K}_{\boldsymbol{x}\boldsymbol{x}}]_{i, j} = k(x_i, x_j)$. -# The choice of kernel function is critical as, among other things, it governs the smoothness of the outputs that our GP can generate. +# A zero-mean Gaussian process (GP) places a prior distribution over real-valued +# functions $f(\cdot)$ where +# $f(\boldsymbol{x}) \sim \mathcal{N}(0, \mathbf{K}_{\boldsymbol{x}\boldsymbol{x}})$ +# for any finite collection of inputs $\boldsymbol{x}$. +# +# Here $\mathbf{K}_{\boldsymbol{x}\boldsymbol{x}}$ is the Gram matrix generated by a +# user-specified symmetric, non-negative definite kernel function $k(\cdot, \cdot')$ +# with $[\mathbf{K}_{\boldsymbol{x}\boldsymbol{x}}]_{i, j} = k(x_i, x_j)$. +# The choice of kernel function is critical as, among other things, it governs the +# smoothness of the outputs that our GP can generate. +# # For simplicity, we consider a radial basis function (RBF) kernel: # $$k(x, x') = \sigma^2 \exp\left(-\frac{\lVert x - x' \rVert_2^2}{2 \ell^2}\right).$$ # -# On paper a GP is written as $f(\cdot) \sim \mathcal{GP}(\textbf{0}, k(\cdot, \cdot'))$, we can reciprocate this process in GPJax via defining a `Prior` with our chosen `RBF` kernel. +# On paper a GP is written as $f(\cdot) \sim \mathcal{GP}(\textbf{0}, k(\cdot, \cdot'))$, +# we can reciprocate this process in GPJax via defining a `Prior` with our chosen `RBF` +# kernel. # %% -kernel = jk.RBF() -prior = gpx.Prior(kernel=kernel) +kernel = gpx.kernels.RBF() +meanf = gpx.mean_functions.Constant(constant=0.0) +meanf = meanf.replace_trainable(constant=False) +prior = gpx.Prior(mean_function=meanf, kernel=kernel) # %% [markdown] # -# The above construction forms the foundation for GPJax's models. Moreover, the GP prior we have just defined can be represented by a [Distrax](https://github.com/deepmind/distrax) multivariate Gaussian distribution. Such functionality enables trivial sampling, and mean and covariance evaluation of the GP. +# The above construction forms the foundation for GPJax's models. Moreover, the GP prior +# we have just defined can be represented by a [Distrax](https://github.com/deepmind/distrax) +# multivariate Gaussian distribution. Such functionality enables trivial sampling, and +# mean and covariance evaluation of the GP. # %% -parameter_state = gpx.initialise(prior, key) -prior_dist = prior(parameter_state.params)(xtest) +prior_dist = prior.predict(xtest) prior_mean = prior_dist.mean() prior_std = prior_dist.stddev() @@ -114,7 +131,11 @@ # %% [markdown] # ## Constructing the posterior # -# Having defined our GP, we proceed to define a description of our data $\mathcal{D}$ conditional on our knowledge of $f(\cdot)$ --- this is exactly the notion of a likelihood function $p(\mathcal{D} | f(\cdot))$. While the choice of likelihood is a critical in Bayesian modelling, for simplicity we consider a Gaussian with noise parameter $\alpha$ +# Having defined our GP, we proceed to define a description of our data +# $\mathcal{D}$ conditional on our knowledge of $f(\cdot)$ --- this is exactly the +# notion of a likelihood function $p(\mathcal{D} | f(\cdot))$. While the choice of +# likelihood is a critical in Bayesian modelling, for simplicity we consider a +# Gaussian with noise parameter $\alpha$ # $$p(\mathcal{D} | f(\cdot)) = \mathcal{N}(\boldsymbol{y}; f(\boldsymbol{x}), \textbf{I} \alpha^2).$$ # This is defined in GPJax through calling a `Gaussian` instance. @@ -134,11 +155,15 @@ # %% [markdown] # +# Our kernel is parameterised by a length-scale $\ell^2$ and variance parameter +# $\sigma^2$, while our likelihood controls the observation noise with $\alpha^2$. +# Using Jax's automatic differentiation module, we can take derivatives of --> # # ## Parameter state # -# So far, all of the objects that we've defined have been stateless. To give our model state, we can use the `initialise` function provided in GPJax. Upon calling this, a `ParameterState` class is returned that contains four dictionaries: +# So far, all of the objects that we've defined have been stateless. To give our model +# state, we can use the `initialise` function provided in GPJax. Upon calling this, a +# `ParameterState` class is returned that contains four dictionaries: # # | Dictionary | Description | # |---|---| @@ -146,71 +171,60 @@ # | `trainable` | Boolean dictionary that determines the training status of parameters (`True` for being trained and `False` otherwise). | # | `bijectors` | Bijectors that can map parameters between the _unconstrained space_ and their original _constrained space_. | # -# Further, upon calling `initialise`, we can state specific initial values for some, or all, of the parameters within our model. By default, the kernel lengthscale and variance and the likelihood's variance parameter are all initialised to 1. However, in the following cell, we'll demonstrate how the kernel lengthscale can be initialised to 0.5. - -# %% -parameter_state = gpx.initialise( - posterior, key, kernel={"lengthscale": jnp.array([0.5])} -) -print(type(parameter_state)) - -# %% [markdown] -# Note, for this example a key is not strictly necessary as none of the parameters are stochastic variables. For this reason, it is valid to call `initialise` without supplying a key. For some models, such as the sparse spectrum GP, the parameters are themselves random variables and the key is therefore essential. -# -# We can now unpack the `ParameterState` to receive each of the four components listed above. - -# %% -params, trainable, bijectors = parameter_state.unpack() -pp.pprint(params) - -# %% [markdown] -# To motivate the purpose the `bijectors` more precisely, notice that our model hyperparameters $\{\ell^2, \sigma^2, \alpha^2 \}$ are all strictly positive, bijectors act to unconstrain these during the optimisation proceedure. - -# %% [markdown] -# To train our hyperparameters, we optimising the marginal log-likelihood of the posterior with respect to them. We define the marginal log-likelihood with `marginal_log_likelihood` on the posterior. +# Further, upon calling `initialise`, we can state specific initial values for some, or +# all, of the parameters within our model. By default, the kernel lengthscale and +# variance and the likelihood's variance parameter are all initialised to 1. However, +# in the following cell, we'll demonstrate how the kernel lengthscale can be +# initialised to 0.5. # %% -negative_mll = jit(posterior.marginal_log_likelihood(D, negative=True)) -negative_mll(params) +negative_mll = gpx.objectives.ConjugateMLL(negative=True) +negative_mll(posterior, train_data=D) # %% [markdown] -# Since most optimisers (including here) minimise a given function, we have realised the negative marginal log-likelihood and just-in-time (JIT) compiled this to accelerate training. +# Since most optimisers (including here) minimise a given function, we have realised +# the negative marginal log-likelihood and just-in-time (JIT) compiled this to +# accelerate training. # %% [markdown] -# We can now define an optimiser with `optax`. For this example we'll use the `adam` optimiser. +# We can now define an optimiser with `optax`. For this example we'll use the `adam` +# optimiser. # %% -optimiser = ox.adam(learning_rate=0.01) - -inference_state = gpx.fit( - objective=negative_mll, - parameter_state=parameter_state, - optax_optim=optimiser, +opt_posterior, history = gpx.fit( + model=posterior, + objective=gpx.ConjugateMLL(negative=True), + train_data=D, + optim=ox.adam(learning_rate=0.01), num_iters=500, ) # %% [markdown] -# Similar to the `ParameterState` object above, the returned variable from the `fit` function is a class, namely an `InferenceState` object that contains the parameters' final values and a tracked array of the evaluation of our objective function throughout optimisation. - -# %% -learned_params, training_history = inference_state.unpack() - -pp.pprint(learned_params) +# Similar to the `ParameterState` object above, the returned variable from the `fit` +# function is a class, namely an `InferenceState` object that contains the parameters' +# final values and a tracked array of the evaluation of our objective function +# throughout optimisation. # %% [markdown] # ## Prediction # -# Equipped with the posterior and a set of optimised hyperparameter values, we are now in a position to query our GP's predictive distribution at novel test inputs. To do this, we use our defined `posterior` and `likelihood` at our test inputs to obtain the predictive distribution as a `Distrax` multivariate Gaussian upon which `mean` and `stddev` can be used to extract the predictive mean and standard deviatation. +# Equipped with the posterior and a set of optimised hyperparameter values, we are now +# in a position to query our GP's predictive distribution at novel test inputs. To do +# this, we use our defined `posterior` and `likelihood` at our test inputs to obtain +# the predictive distribution as a `Distrax` multivariate Gaussian upon which `mean` +# and `stddev` can be used to extract the predictive mean and standard deviatation. # %% -latent_dist = posterior(learned_params, D)(xtest) -predictive_dist = likelihood(learned_params, latent_dist) +latent_dist = opt_posterior.predict(xtest, train_data=D) +predictive_dist = opt_posterior.likelihood(latent_dist) predictive_mean = predictive_dist.mean() predictive_std = predictive_dist.stddev() # %% [markdown] -# With the predictions and their uncertainty acquired, we illustrate the GP's performance at explaining the data $\mathcal{D}$ and recovering the underlying latent function of interest. +# With the predictions and their uncertainty acquired, we illustrate the GP's +# performance at explaining the data $\mathcal{D}$ and recovering the underlying +# latent function of interest. # %% fig, ax = plt.subplots(figsize=(12, 5)) diff --git a/examples/tfp_integration.pct.py b/examples/tfp_integration.pct.py index 18444a4d3..7b0f9bd04 100644 --- a/examples/tfp_integration.pct.py +++ b/examples/tfp_integration.pct.py @@ -27,9 +27,9 @@ import matplotlib.pyplot as plt from jax.config import config from jaxutils import Dataset -import gpjax.kernels as jk import gpjax as gpx +import gpjax.kernels as jk from gpjax.utils import dict_array_coercion # Enable Float64 for more stable matrix inversions. diff --git a/examples/uncollapsed_vi.pct.py b/examples/uncollapsed_vi.pct.py index 8b88c1221..d20734e26 100644 --- a/examples/uncollapsed_vi.pct.py +++ b/examples/uncollapsed_vi.pct.py @@ -24,16 +24,17 @@ import jax.random as jr import matplotlib.pyplot as plt import optax as ox +import tensorflow_probability.substrates.jax as tfp from jax import jit from jax.config import config from jaxutils import Dataset -import gpjax.kernels as jk -import tensorflow_probability.substrates.jax as tfp +import gpjax.kernels as jk tfb = tfp.bijectors import distrax as dx + import gpjax as gpx from gpjax.config import get_global_config, reset_global_config diff --git a/examples/yacht.pct.py b/examples/yacht.pct.py index 3f8cb3852..b72472725 100644 --- a/examples/yacht.pct.py +++ b/examples/yacht.pct.py @@ -20,6 +20,7 @@ import optax as ox from jax.config import config from jaxutils import Dataset + import gpjax.kernels as jk # Enable Float64 for more stable matrix inversions. @@ -196,9 +197,7 @@ ax[1].scatter(predictive_mean.squeeze(), residuals) ax[1].plot([0, 1], [0.5, 0.5], color="tab:orange", transform=ax[1].transAxes) ax[1].set_ylim([-1.0, 1.0]) -ax[1].set( - xlabel="Predicted", ylabel="Residuals", title="Predicted vs Residuals" -) +ax[1].set(xlabel="Predicted", ylabel="Residuals", title="Predicted vs Residuals") ax[2].hist(np.asarray(residuals), bins=30) ax[2].set_title("Residuals") diff --git a/gpjax/__init__.py b/gpjax/__init__.py index fb3eb83ad..bd01252da 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -27,6 +27,7 @@ VariationalGaussian, WhitenedVariationalGaussian, ) +from .objectives import ConjugateMLL, NonConjugateMLL, LogPosteriorDensity, CollapsedELBO, ELBO __version__ = _version.get_versions()["version"] __license__ = "MIT" @@ -60,4 +61,9 @@ "WhitenedVariationalGaussian", "CollapsedVI", "StochasticVI", + "ConjugateMLL", + "NonConjugateMLL", + "LogPosteriorDensity", + "CollapsedELBO", + "ELBO" ] diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index 3369105b0..a2742dad1 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -18,7 +18,7 @@ from .linops.utils import to_dense import distrax as dx -import tensorflow_probability.substrates.jax.bijectors as tfb +import tensorflow_probability.substrates.jax as tfp import jax.numpy as jnp import jax.scipy as jsp from jaxtyping import Array, Float @@ -26,6 +26,8 @@ from dataclasses import dataclass from .base import Module, param_field +tfb = tfp.bijectors +tfd = tfp.distributions @dataclass class AbstractLikelihood(Module): @@ -86,7 +88,7 @@ def link_function(self, f: Float[Array, "N 1"]) -> dx.Normal: """ return dx.Normal(loc=f, scale=self.obs_noise) - def predict(self, dist: dx.MultivariateNormalTri) -> dx.Distribution: + def predict(self, dist: tfd.MultivariateNormalTriL) -> tfd.MultivariateNormalFullCovariance: """ Evaluate the Gaussian likelihood function at a given predictive distribution. Computationally, this is equivalent to summing the @@ -105,7 +107,7 @@ def predict(self, dist: dx.MultivariateNormalTri) -> dx.Distribution: cov = to_dense(dist.covariance()) noisy_cov = cov.at[jnp.diag_indices(n_data)].add(self.obs_noise) - return dx.MultivariateNormalFullCovariance(dist.mean(), noisy_cov) + return tfd.MultivariateNormalFullCovariance(dist.mean(), noisy_cov) @dataclass From bacfa64803e0561c35adf60436431c2a19900755 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Mon, 3 Apr 2023 22:09:18 +0100 Subject: [PATCH 27/44] Classification nb --- examples/classification.pct.py | 202 +++++++++++++++++++++++---------- 1 file changed, 140 insertions(+), 62 deletions(-) diff --git a/examples/classification.pct.py b/examples/classification.pct.py index a656f6a90..f9b1f1d8d 100644 --- a/examples/classification.pct.py +++ b/examples/classification.pct.py @@ -17,11 +17,14 @@ # %% [markdown] # # Classification # -# In this notebook we demonstrate how to perform inference for Gaussian process models with non-Gaussian likelihoods via maximum a posteriori (MAP) and Markov chain Monte Carlo (MCMC). We focus on a classification task here and use [BlackJax](https://github.com/blackjax-devs/blackjax/) for sampling. +# In this notebook we demonstrate how to perform inference for Gaussian process models +# with non-Gaussian likelihoods via maximum a posteriori (MAP) and Markov chain Monte +# Carlo (MCMC). We focus on a classification task here and use +# [BlackJax](https://github.com/blackjax-devs/blackjax/) for sampling. # %% import blackjax -import distrax as dx +import tensorflow_probability.substrates.jax as tfp import jax import jax.numpy as jnp import jax.random as jr @@ -30,31 +33,33 @@ import optax as ox from jax.config import config from jaxtyping import Array, Float -from jaxutils import Dataset import gpjax as gpx -import gpjax.kernels as jk # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) +tfd = tfp.distributions I = jnp.eye key = jr.PRNGKey(123) # %% [markdown] # ## Dataset # -# With the necessary modules imported, we simulate a dataset $\mathcal{D} = (, \boldsymbol{y}) = \{(x_i, y_i)\}_{i=1}^{100}$ with inputs $\boldsymbol{x}$ sampled uniformly on $(-1., 1)$ and corresponding binary outputs +# With the necessary modules imported, we simulate a dataset +# $\mathcal{D} = (, \boldsymbol{y}) = \{(x_i, y_i)\}_{i=1}^{100}$ with inputs +# $\boldsymbol{x}$ sampled uniformly on $(-1., 1)$ and corresponding binary outputs # # $$\boldsymbol{y} = 0.5 * \text{sign}(\cos(2 * + \boldsymbol{\epsilon})) + 0.5, \quad \boldsymbol{\epsilon} \sim \mathcal{N} \left(\textbf{0}, \textbf{I} * (0.05)^{2} \right).$$ # -# We store our data $\mathcal{D}$ as a GPJax `Dataset` and create test inputs for later. +# We store our data $\mathcal{D}$ as a GPJax `Dataset` and create test inputs for +# later. # %% key, subkey = jr.split(key) x = jr.uniform(key, shape=(100, 1), minval=-1.0, maxval=1.0) y = 0.5 * jnp.sign(jnp.cos(3 * x + jr.normal(subkey, shape=x.shape) * 0.05)) + 0.5 -D = Dataset(X=x, y=y) +D = gpx.Dataset(X=x, y=y) xtest = jnp.linspace(-1.0, 1.0, 500).reshape(-1, 1) plt.plot(x, y, "o", markersize=8) @@ -62,11 +67,14 @@ # %% [markdown] # ## MAP inference # -# We begin by defining a Gaussian process prior with a radial basis function (RBF) kernel, chosen for the purpose of exposition. Since our observations are binary, we choose a Bernoulli likelihood with a probit link function. +# We begin by defining a Gaussian process prior with a radial basis function (RBF) +# kernel, chosen for the purpose of exposition. Since our observations are binary, we +# choose a Bernoulli likelihood with a probit link function. # %% -kernel = jk.RBF() -prior = gpx.Prior(kernel=kernel) +kernel = gpx.RBF() +meanf = gpx.Constant() +prior = gpx.Prior(mean_function=meanf, kernel=kernel) likelihood = gpx.Bernoulli(num_datapoints=D.n) # %% [markdown] @@ -77,33 +85,40 @@ print(type(posterior)) # %% [markdown] -# Whilst the latent function is Gaussian, the posterior distribution is non-Gaussian since our generative model first samples the latent GP and propagates these samples through the likelihood function's inverse link function. This step prevents us from being able to analytically integrate the latent function's values out of our posterior, and we must instead adopt alternative inference techniques. We begin with maximum a posteriori (MAP) estimation, a fast inference procedure to obtain point estimates for the latent function and the kernel's hyperparameters by maximising the marginal log-likelihood. +# Whilst the latent function is Gaussian, the posterior distribution is non-Gaussian +# since our generative model first samples the latent GP and propagates these samples +# through the likelihood function's inverse link function. This step prevents us from +# being able to analytically integrate the latent function's values out of our +# posterior, and we must instead adopt alternative inference techniques. We begin with +# maximum a posteriori (MAP) estimation, a fast inference procedure to obtain point +# estimates for the latent function and the kernel's hyperparameters by maximising the +# marginal log-likelihood. # %% [markdown] -# To begin we obtain an initial parameter state through the `initialise` callable (see the [regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)). We can obtain a MAP estimate by optimising the marginal log-likelihood with Optax's optimisers. +# To begin we obtain an initial parameter state through the `initialise` callable (see +# the [regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)). +# We can obtain a MAP estimate by optimising the marginal log-likelihood with +# Optax's optimisers. # %% -parameter_state = gpx.initialise(posterior) -negative_mll = jax.jit(posterior.marginal_log_likelihood(D, negative=True)) +negative_lpd = jax.jit(gpx.LogPosteriorDensity(negative=True)) optimiser = ox.adam(learning_rate=0.01) -inference_state = gpx.fit( - objective=negative_mll, - parameter_state=parameter_state, - optax_optim=optimiser, +opt_posterior, history = gpx.fit( + model=posterior, + objective=negative_lpd, + train_data=D, + optim=ox.adamw(learning_rate=0.01), num_iters=1000, ) -map_estimate, training_history = inference_state.unpack() - # %% [markdown] # From which we can make predictions at novel inputs, as illustrated below. # %% -map_latent_dist = posterior(map_estimate, D)(xtest) - -predictive_dist = likelihood(map_estimate, map_latent_dist) +map_latent_dist = opt_posterior.predict(xtest, train_data=D) +predictive_dist = opt_posterior.likelihood(map_latent_dist) predictive_mean = predictive_dist.mean() predictive_std = predictive_dist.stddev() @@ -137,42 +152,56 @@ ax.legend() # %% [markdown] -# Here we projected the map estimates $\hat{\boldsymbol{f}}$ for the function values $\boldsymbol{f}$ at the data points $\boldsymbol{x}$ to get predictions over the whole domain, +# Here we projected the map estimates $\hat{\boldsymbol{f}}$ for the function values +# $\boldsymbol{f}$ at the data points $\boldsymbol{x}$ to get predictions over the +# whole domain, # # \begin{align} # p(f(\cdot)| \mathcal{D}) \approx q_{map}(f(\cdot)) := \int p(f(\cdot)| \boldsymbol{f}) \delta(\boldsymbol{f} - \hat{\boldsymbol{f}}) d \boldsymbol{f} = \mathcal{N}(\mathbf{K}_{\boldsymbol{(\cdot)x}} \mathbf{K}_{\boldsymbol{xx}}^{-1} \hat{\boldsymbol{f}}, \mathbf{K}_{\boldsymbol{(\cdot, \cdot)}} - \mathbf{K}_{\boldsymbol{(\cdot)\boldsymbol{x}}} \mathbf{K}_{\boldsymbol{xx}}^{-1} \mathbf{K}_{\boldsymbol{\boldsymbol{x}(\cdot)}}). # \end{align} # %% [markdown] -# However, as a point estimate, MAP estimation is severely limited for uncertainty quantification, providing only a single piece of information about the posterior. +# However, as a point estimate, MAP estimation is severely limited for uncertainty +# quantification, providing only a single piece of information about the posterior. # %% [markdown] # ## Laplace approximation -# The Laplace approximation improves uncertainty quantification by incorporating curvature induced by the marginal log-likelihood's Hessian to construct an approximate Gaussian distribution centered on the MAP estimate. Writing $\tilde{p}(\boldsymbol{f}|\mathcal{D}) = p(\boldsymbol{y}|\boldsymbol{f}) p(\boldsymbol{f})$ as the unormalised posterior for function values $\boldsymbol{f}$ at the datapoints $\boldsymbol{x}$, we can expand the log of this about the posterior mode $\hat{\boldsymbol{f}}$ via a Taylor expansion. This gives: +# The Laplace approximation improves uncertainty quantification by incorporating +# curvature induced by the marginal log-likelihood's Hessian to construct an +# approximate Gaussian distribution centered on the MAP estimate. Writing +# $\tilde{p}(\boldsymbol{f}|\mathcal{D}) = p(\boldsymbol{y}|\boldsymbol{f}) p(\boldsymbol{f})$ +# as the unormalised posterior for function values $\boldsymbol{f}$ at the datapoints +# $\boldsymbol{x}$, we can expand the log of this about the posterior mode +# $\hat{\boldsymbol{f}}$ via a Taylor expansion. This gives: # # \begin{align} # \log\tilde{p}(\boldsymbol{f}|\mathcal{D}) = \log\tilde{p}(\hat{\boldsymbol{f}}|\mathcal{D}) + \left[\nabla \log\tilde{p}({\boldsymbol{f}}|\mathcal{D})|_{\hat{\boldsymbol{f}}}\right]^{T} (\boldsymbol{f}-\hat{\boldsymbol{f}}) + \frac{1}{2} (\boldsymbol{f}-\hat{\boldsymbol{f}})^{T} \left[\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} \right] (\boldsymbol{f}-\hat{\boldsymbol{f}}) + \mathcal{O}(\lVert \boldsymbol{f} - \hat{\boldsymbol{f}} \rVert^3). # \end{align} # -# Now since $\nabla \log\tilde{p}({\boldsymbol{f}}|\mathcal{D})$ is zero at the mode, this suggests the following approximation +# Now since $\nabla \log\tilde{p}({\boldsymbol{f}}|\mathcal{D})$ is zero at the mode, +# this suggests the following approximation # \begin{align} # \tilde{p}(\boldsymbol{f}|\mathcal{D}) \approx \log\tilde{p}(\hat{\boldsymbol{f}}|\mathcal{D}) \exp\left\{ \frac{1}{2} (\boldsymbol{f}-\hat{\boldsymbol{f}})^{T} \left[-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} \right] (\boldsymbol{f}-\hat{\boldsymbol{f}}) \right\} # \end{align}, # -# that we identify as a Gaussian distribution, $p(\boldsymbol{f}| \mathcal{D}) \approx q(\boldsymbol{f}) := \mathcal{N}(\hat{\boldsymbol{f}}, [-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} ]^{-1} )$. Since the negative Hessian is positive definite, we can use the Cholesky decomposition to obtain the covariance matrix of the Laplace approximation at the datapoints below. +# that we identify as a Gaussian distribution, +# $p(\boldsymbol{f}| \mathcal{D}) \approx q(\boldsymbol{f}) := \mathcal{N}(\hat{\boldsymbol{f}}, [-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} ]^{-1} )$. +# Since the negative Hessian is positive definite, we can use the Cholesky +# decomposition to obtain the covariance matrix of the Laplace approximation at the +# datapoints below. # %% gram, cross_covariance = (kernel.gram, kernel.cross_covariance) jitter = 1e-6 # Compute (latent) function value map estimates at training points: -Kxx = gram(map_estimate["kernel"], x) +Kxx = opt_posterior.prior.kernel.gram(x) Kxx += I(D.n) * jitter Lx = Kxx.to_root() -f_hat = Lx @ map_estimate["latent"] +f_hat = Lx @ opt_posterior.latent # Negative Hessian, H = -∇²p_tilde(y|f): -H = jax.jacfwd(jax.jacrev(negative_mll))(map_estimate)["latent"]["latent"][:, 0, :, 0] +H = jax.jacfwd(jax.jacrev(negative_lpd))(opt_posterior, D).latent.latent[:, 0, :, 0] # LLᵀ = H L = jnp.linalg.cholesky(H + I(D.n) * jitter) @@ -181,25 +210,30 @@ L_inv = jsp.linalg.solve_triangular(L, I(D.n), lower=True) H_inv = jsp.linalg.solve_triangular(L.T, L_inv, lower=False) -laplace_approximation = dx.MultivariateNormalFullCovariance(f_hat.squeeze(), H_inv) +laplace_approximation = tfd.MultivariateNormalFullCovariance(f_hat.squeeze(), H_inv) # %% [markdown] -# For novel inputs, we must project the above approximating distribution through the Gaussian conditional distribution $p(f(\cdot)| \boldsymbol{f})$, +# For novel inputs, we must project the above approximating distribution through the +# Gaussian conditional distribution $p(f(\cdot)| \boldsymbol{f})$, # # \begin{align} # p(f(\cdot)| \mathcal{D}) \approx q_{Laplace}(f(\cdot)) := \int p(f(\cdot)| \boldsymbol{f}) q(\boldsymbol{f}) d \boldsymbol{f} = \mathcal{N}(\mathbf{K}_{\boldsymbol{(\cdot)x}} \mathbf{K}_{\boldsymbol{xx}}^{-1} \hat{\boldsymbol{f}}, \mathbf{K}_{\boldsymbol{(\cdot, \cdot)}} - \mathbf{K}_{\boldsymbol{(\cdot)\boldsymbol{x}}} \mathbf{K}_{\boldsymbol{xx}}^{-1} (\mathbf{K}_{\boldsymbol{xx}} - [-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} ]^{-1}) \mathbf{K}_{\boldsymbol{xx}}^{-1} \mathbf{K}_{\boldsymbol{\boldsymbol{x}(\cdot)}}). # \end{align} # -# This is the same approximate distribution $q_{map}(f(\cdot))$, but we have pertubed the covariance by a curvature term of $\mathbf{K}_{\boldsymbol{(\cdot)\boldsymbol{x}}} \mathbf{K}_{\boldsymbol{xx}}^{-1} [-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} ]^{-1} \mathbf{K}_{\boldsymbol{xx}}^{-1} \mathbf{K}_{\boldsymbol{\boldsymbol{x}(\cdot)}}$. We take the latent distribution computed in the previous section and add this term to the covariance to construct $q_{Laplace}(f(\cdot))$. +# This is the same approximate distribution $q_{map}(f(\cdot))$, but we have perturbed +# the covariance by a curvature term of +# $\mathbf{K}_{\boldsymbol{(\cdot)\boldsymbol{x}}} \mathbf{K}_{\boldsymbol{xx}}^{-1} [-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} ]^{-1} \mathbf{K}_{\boldsymbol{xx}}^{-1} \mathbf{K}_{\boldsymbol{\boldsymbol{x}(\cdot)}}$. +# We take the latent distribution computed in the previous section and add this term +# to the covariance to construct $q_{Laplace}(f(\cdot))$. # %% -def construct_laplace(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormalTri: - - map_latent_dist = posterior(map_estimate, D)(test_inputs) +def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNormalTriL: + map_latent_dist = opt_posterior.predict(xtest, train_data=D) - Kxt = cross_covariance(map_estimate["kernel"], x, test_inputs) - Kxx = gram(map_estimate["kernel"], x) + + Kxt = opt_posterior.prior.kernel.cross_covariance(x, test_inputs) + Kxx = opt_posterior.prior.kernel.gram(x) Kxx += I(D.n) * jitter Lx = Kxx.to_root() @@ -215,14 +249,13 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormal mean = map_latent_dist.mean() covariance = map_latent_dist.covariance() + laplace_cov_term L = jnp.linalg.cholesky(covariance) - return dx.MultivariateNormalTri(jnp.atleast_1d(mean.squeeze()), L) - + return tfd.MultivariateNormalTriL(jnp.atleast_1d(mean.squeeze()), L) # %% [markdown] # From this we can construct the predictive distribution at the test points. # %% laplace_latent_dist = construct_laplace(xtest) -predictive_dist = likelihood(map_estimate, laplace_latent_dist) +predictive_dist = opt_posterior.likelihood(laplace_latent_dist) predictive_mean = predictive_dist.mean() predictive_std = predictive_dist.stddev() @@ -255,37 +288,62 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormal ax.legend() # %% [markdown] -# However, the Laplace approximation is still limited by considering information about the posterior at a single location. On the other hand, through approximate sampling, MCMC methods allow us to learn all information about the posterior distribution. +# However, the Laplace approximation is still limited by considering information about +# the posterior at a single location. On the other hand, through approximate sampling, +# MCMC methods allow us to learn all information about the posterior distribution. # %% [markdown] # ## MCMC inference # -# At the high level, an MCMC sampler works by starting at an initial position and drawing a sample from a cheap-to-simulate distribution known as the _proposal_. The next step is to determine whether this sample could be considered a draw from the posterior. We accomplish this using an _acceptance probability_ determined via the sampler's _transition kernel_ which depends on the current position and the unnormalised target posterior distribution. If the new sample is more _likely_, we accept it; otherwise, we reject it and stay in our current position. Repeating these steps results in a Markov chain (a random sequence that depends only on the last state) whose stationary distribution (the long-run empirical distribution of the states visited) is the posterior. For a gentle introduction, see the first chapter of [A Handbook of Markov Chain Monte Carlo](https://www.mcmchandbook.net/HandbookChapter1.pdf). +# At the high level, an MCMC sampler works by starting at an initial position and +# drawing a sample from a cheap-to-simulate distribution known as the _proposal_. The +# next step is to determine whether this sample could be considered a draw from the +# posterior. We accomplish this using an _acceptance probability_ determined via the +# sampler's _transition kernel_ which depends on the current position and the +# unnormalised target posterior distribution. If the new sample is more _likely_, we +# accept it; otherwise, we reject it and stay in our current position. Repeating these +# steps results in a Markov chain (a random sequence that depends only on the last +# state) whose stationary distribution (the long-run empirical distribution of the +# states visited) is the posterior. For a gentle introduction, see the first chapter +# of [A Handbook of Markov Chain Monte Carlo](https://www.mcmchandbook.net/HandbookChapter1.pdf). # # ### MCMC through BlackJax # -# Rather than implementing a suite of MCMC samplers, GPJax relies on MCMC-specific libraries for sampling functionality. We focus on [BlackJax](https://github.com/blackjax-devs/blackjax/) in this notebook, which we recommend adopting for general applications. However, we also support TensorFlow Probability as demonstrated in the [TensorFlow Probability Integration notebook](https://gpjax.readthedocs.io/en/latest/nbs/tfp_integration.html). +# Rather than implementing a suite of MCMC samplers, GPJax relies on MCMC-specific +# libraries for sampling functionality. We focus on +# [BlackJax](https://github.com/blackjax-devs/blackjax/) in this notebook, which we +# recommend adopting for general applications. However, we also support TensorFlow +# Probability as demonstrated in the +# [TensorFlow Probability Integration notebook](https://gpjax.readthedocs.io/en/latest/nbs/tfp_integration.html). # -# We'll use the No U-Turn Sampler (NUTS) implementation given in BlackJax for sampling. For the interested reader, NUTS is a Hamiltonian Monte Carlo sampling scheme where the number of leapfrog integration steps is computed at each step of the change according to the NUTS algorithm. In general, samplers constructed under this framework are very efficient. +# We'll use the No U-Turn Sampler (NUTS) implementation given in BlackJax for sampling. +# For the interested reader, NUTS is a Hamiltonian Monte Carlo sampling scheme where +# the number of leapfrog integration steps is computed at each step of the change +# according to the NUTS algorithm. In general, samplers constructed under this +# framework are very efficient. # -# We begin by generating _sensible_ initial positions for our sampler before defining an inference loop and sampling 500 values from our Markov chain. In practice, drawing more samples will be necessary. +# We begin by generating _sensible_ initial positions for our sampler before defining +# an inference loop and sampling 500 values from our Markov chain. In practice, +# drawing more samples will be necessary. + +# %% +lpd = gpx.LogPosteriorDensity(negative=True) +lpd(opt_posterior, D) # %% # Adapted from BlackJax's introduction notebook. num_adapt = 500 num_samples = 500 -params, trainables, bijectors = gpx.initialise(posterior, key).unpack() -mll = posterior.marginal_log_likelihood(D, negative=False) -unconstrained_mll = jax.jit(lambda params: mll(gpx.constrain(params, bijectors))) +lpd = gpx.LogPosteriorDensity(negative=False) +unconstrained_lpd = jax.jit(lambda tree: lpd(tree.constrain(), D)) adapt = blackjax.window_adaptation( - blackjax.nuts, unconstrained_mll, num_adapt, target_acceptance_rate=0.65 + blackjax.nuts, unconstrained_lpd, num_adapt, target_acceptance_rate=0.65 ) # Initialise the chain -unconstrained_params = gpx.unconstrain(params, bijectors) -last_state, kernel, _ = adapt.run(key, unconstrained_params) +last_state, kernel, _ = adapt.run(key, posterior.unconstrain()) def inference_loop(rng_key, kernel, initial_state, num_samples): @@ -305,20 +363,31 @@ def one_step(state, rng_key): # %% [markdown] # ### Sampler efficiency # -# BlackJax gives us easy access to our sampler's efficiency through metrics such as the sampler's _acceptance probability_ (the number of times that our chain accepted a proposed sample, divided by the total number of steps run by the chain). For NUTS and Hamiltonian Monte Carlo sampling, we typically seek an acceptance rate of 60-70% to strike the right balance between having a chain which is _stuck_ and rarely moves versus a chain that is too jumpy with frequent small steps. +# BlackJax gives us easy access to our sampler's efficiency through metrics such as the +# sampler's _acceptance probability_ (the number of times that our chain accepted a +# proposed sample, divided by the total number of steps run by the chain). For NUTS and +# Hamiltonian Monte Carlo sampling, we typically seek an acceptance rate of 60-70% to +# strike the right balance between having a chain which is _stuck_ and rarely moves +# versus a chain that is too jumpy with frequent small steps. # %% acceptance_rate = jnp.mean(infos.acceptance_probability) print(f"Acceptance rate: {acceptance_rate:.2f}") # %% [markdown] -# Our acceptance rate is slightly too large, prompting an examination of the chain's trace plots. A well-mixing chain will have very few (if any) flat spots in its trace plot whilst also not having too many steps in the same direction. In addition to the model's hyperparameters, there will be 500 samples for each of the 100 latent function values in the `states.position` dictionary. We depict the chains that correspond to the model hyperparameters and the first value of the latent function for brevity. +# Our acceptance rate is slightly too large, prompting an examination of the chain's +# trace plots. A well-mixing chain will have very few (if any) flat spots in its trace +# plot whilst also not having too many steps in the same direction. In addition to +# the model's hyperparameters, there will be 500 samples for each of the 100 latent +# function values in the `states.position` dictionary. We depict the chains that +# correspond to the model hyperparameters and the first value of the latent function +# for brevity. # %% fig, (ax0, ax1, ax2) = plt.subplots(ncols=3, figsize=(15, 5), tight_layout=True) -ax0.plot(states.position["kernel"]["lengthscale"]) -ax1.plot(states.position["kernel"]["variance"]) -ax2.plot(states.position["latent"][:, 0, :]) +ax0.plot(states.position.prior.kernel.lengthscale) +ax1.plot(states.position.prior.kernel.variance) +ax2.plot(states.position.latent[:, 0, :]) ax0.set_title("Kernel Lengthscale") ax1.set_title("Kernel Variance") ax2.set_title("Latent Function (index = 1)") @@ -326,16 +395,24 @@ def one_step(state, rng_key): # %% [markdown] # ## Prediction # -# Having obtained samples from the posterior, we draw ten instances from our model's predictive distribution per MCMC sample. Using these draws, we will be able to compute credible values and expected values under our posterior distribution. +# Having obtained samples from the posterior, we draw ten instances from our model's +# predictive distribution per MCMC sample. Using these draws, we will be able to +# compute credible values and expected values under our posterior distribution. # -# An ideal Markov chain would have samples completely uncorrelated with their neighbours after a single lag. However, in practice, correlations often exist within our chain's sample set. A commonly used technique to try and reduce this correlation is _thinning_ whereby we select every $n$th sample where $n$ is the minimum lag length at which we believe the samples are uncorrelated. Although further analysis of the chain's autocorrelation is required to find appropriate thinning factors, we employ a thin factor of 10 for demonstration purposes. +# An ideal Markov chain would have samples completely uncorrelated with their +# neighbours after a single lag. However, in practice, correlations often exist +# within our chain's sample set. A commonly used technique to try and reduce this +# correlation is _thinning_ whereby we select every $n$th sample where $n$ is the +# minimum lag length at which we believe the samples are uncorrelated. Although further +# analysis of the chain's autocorrelation is required to find appropriate thinning +# factors, we employ a thin factor of 10 for demonstration purposes. # %% thin_factor = 10 samples = [] for i in range(0, num_samples, thin_factor): - ps = gpx.parameters.copy_dict_structure(params) + posterior.replace(states.position) ps["kernel"]["lengthscale"] = states.position["kernel"]["lengthscale"][i] ps["kernel"]["variance"] = states.position["kernel"]["variance"][i] ps["latent"] = states.position["latent"][i, :, :] @@ -352,7 +429,8 @@ def one_step(state, rng_key): # %% [markdown] # -# Finally, we end this tutorial by plotting the predictions obtained from our model against the observed data. +# Finally, we end this tutorial by plotting the predictions obtained from our model +# against the observed data. # %% fig, ax = plt.subplots(figsize=(16, 5), tight_layout=True) From 18ab3a99d6ea6cc21a8ad2b950cc20248d288270 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Tue, 4 Apr 2023 09:30:54 +0100 Subject: [PATCH 28/44] Collapsed VI --- examples/collapsed_vi.pct.py | 101 +++++++++++++++++++++-------------- 1 file changed, 60 insertions(+), 41 deletions(-) diff --git a/examples/collapsed_vi.pct.py b/examples/collapsed_vi.pct.py index a6cf8469f..f3a2915ad 100644 --- a/examples/collapsed_vi.pct.py +++ b/examples/collapsed_vi.pct.py @@ -16,8 +16,13 @@ # %% [markdown] # # Sparse Gaussian Process Regression # -# In this notebook we consider sparse Gaussian process regression (SGPR) Titsias (2009). This is a solution for medium- to large-scale conjugate regression problems. -# In order to arrive at a computationally tractable method, the approximate posterior is parameterized via a set of $m$ pseudo-points $\boldsymbol{z}$. Critically, the approach leads to $\mathcal{O}(nm^2)$ complexity for approximate maximum likelihood learning and $O(m^2)$ per test point for prediction. +# In this notebook we consider sparse Gaussian process regression (SGPR) +# Titsias (2009). This is a solution for +# medium to large-scale conjugate regression problems. +# In order to arrive at a computationally tractable method, the approximate posterior +# is parameterized via a set of $m$ pseudo-points $\boldsymbol{z}$. Critically, the +# approach leads to $\mathcal{O}(nm^2)$ complexity for approximate maximum likelihood +# learning and $O(m^2)$ per test point for prediction. # %% import jax.numpy as jnp @@ -26,10 +31,8 @@ import optax as ox from jax import jit from jax.config import config -from jaxutils import Dataset import gpjax as gpx -import gpjax.kernels as jk # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -38,11 +41,15 @@ # %% [markdown] # ## Dataset # -# With the necessary modules imported, we simulate a dataset $\mathcal{D} = (\boldsymbol{x}, \boldsymbol{y}) = \{(x_i, y_i)\}_{i=1}^{500}$ with inputs $\boldsymbol{x}$ sampled uniformly on $(-3., 3)$ and corresponding independent noisy outputs +# With the necessary modules imported, we simulate a dataset +# $\mathcal{D} = (\boldsymbol{x}, \boldsymbol{y}) = \{(x_i, y_i)\}_{i=1}^{500}$ +# with inputs $\boldsymbol{x}$ sampled uniformly on $(-3., 3)$ and corresponding +# independent noisy outputs # # $$\boldsymbol{y} \sim \mathcal{N} \left(\sin(7\boldsymbol{x}) + x \cos(2 \boldsymbol{x}), \textbf{I} * 0.5^2 \right).$$ # -# We store our data $\mathcal{D}$ as a GPJax `Dataset` and create test inputs and labels for later. +# We store our data $\mathcal{D}$ as a GPJax `Dataset` and create test inputs and +# labels for later. # %% n = 2500 @@ -54,13 +61,15 @@ signal = f(x) y = signal + jr.normal(subkey, shape=signal.shape) * noise -D = Dataset(X=x, y=y) +D = gpx.Dataset(X=x, y=y) xtest = jnp.linspace(-3.1, 3.1, 500).reshape(-1, 1) ytest = f(xtest) # %% [markdown] -# To better understand what we have simulated, we plot both the underlying latent function and the observed data that is subject to Gaussian noise. We also plot an initial set of inducing points over the space. +# To better understand what we have simulated, we plot both the underlying latent +# function and the observed data that is subject to Gaussian noise. We also plot an +# initial set of inducing points over the space. # %% n_inducing = 50 @@ -77,50 +86,59 @@ # Next we define the posterior model for the data. # %% -kernel = jk.RBF() +meanf = gpx.Constant() +kernel = gpx.RBF() likelihood = gpx.Gaussian(num_datapoints=D.n) -prior = gpx.Prior(kernel=kernel) -p = prior * likelihood +prior = gpx.Prior(mean_function=meanf, kernel=kernel) +posterior = prior * likelihood # %% [markdown] -# We now define the SGPR model through `CollapsedVariationalGaussian`. Since the form of the collapsed optimal posterior depends on the Gaussian likelihood's observation noise, we pass this to the constructer. +# We now define the SGPR model through `CollapsedVariationalGaussian`. Since the form +# of the collapsed optimal posterior depends on the Gaussian likelihood's observation +# noise, we pass this to the constructer. # %% q = gpx.CollapsedVariationalGaussian( - prior=prior, likelihood=likelihood, inducing_inputs=z + posterior=posterior, inducing_inputs=z ) # %% [markdown] -# We define our variational inference algorithm through `CollapsedVI`. This defines the collapsed variational free energy bound considered in Titsias (2009). +# We define our variational inference algorithm through `CollapsedVI`. This defines +# the collapsed variational free energy bound considered in +# Titsias (2009). # %% -sgpr = gpx.CollapsedVI(posterior=p, variational_family=q) +elbo = jit(gpx.CollapsedELBO(negative=True)) # %% [markdown] -# We now train our model akin to a Gaussian process regression model via the `fit` abstraction. Unlike the regression example given in the [conjugate regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html), the inducing locations that induce our variational posterior distribution are now part of the model's parameters. Using a gradient-based optimiser, we can then _optimise_ their location such that the evidence lower bound is maximised. +# We now train our model akin to a Gaussian process regression model via the `fit` +# abstraction. Unlike the regression example given in the +# [conjugate regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html), +# the inducing locations that induce our variational posterior distribution are now +# part of the model's parameters. Using a gradient-based optimiser, we can then +# _optimise_ their location such that the evidence lower bound is maximised. # %% -parameter_state = gpx.initialise(sgpr, key) - -negative_elbo = jit(sgpr.elbo(D, negative=True)) - -optimiser = ox.adam(learning_rate=5e-3) - -inference_state = gpx.fit( - objective=negative_elbo, - parameter_state=parameter_state, - optax_optim=optimiser, +opt_posterior, history = gpx.fit( + model=q, + objective=elbo, + train_data=D, + optim=ox.adamw(learning_rate=5e-3), num_iters=2000, ) -learned_params, training_history = inference_state.unpack() +plt.plot(history) # %% [markdown] # We show predictions of our model with the learned inducing points overlayed in grey. # %% -latent_dist = q(learned_params, D)(xtest) -predictive_dist = likelihood(learned_params, latent_dist) + +# %% +latent_dist = opt_posterior(xtest, train_data=D) +predictive_dist = opt_posterior.posterior.likelihood(latent_dist) + +inducing_points = opt_posterior.inducing_inputs samples = latent_dist.sample(seed=key, sample_shape=(20,)) @@ -167,7 +185,7 @@ ax.plot(xtest, samples.T, color="tab:blue", alpha=0.8, linewidth=0.2) [ ax.axvline(x=z_i, color="tab:gray", alpha=0.3, linewidth=1) - for z_i in learned_params["variational_family"]["inducing_inputs"] + for z_i in inducing_points ] ax.legend() plt.show() @@ -175,23 +193,24 @@ # %% [markdown] # ## Runtime comparison # -# Given the size of the data being considered here, inference in a GP with a full-rank covariance matrix is possible, albeit quite slow. We can therefore compare the speedup that we get from using the above sparse approximation with corresponding bound on the marginal log-likelihood against the marginal log-likelihood in the full model. +# Given the size of the data being considered here, inference in a GP with a full-rank +# covariance matrix is possible, albeit quite slow. We can therefore compare the +# speedup that we get from using the above sparse approximation with corresponding +# bound on the marginal log-likelihood against the marginal log-likelihood in the +# full model. # %% -full_rank_model = gpx.Prior(kernel=gpx.RBF()) * gpx.Gaussian(num_datapoints=D.n) -fr_params, *_ = gpx.initialise(full_rank_model, key).unpack() -negative_mll = jit(full_rank_model.marginal_log_likelihood(D, negative=True)) - -# %timeit negative_mll(fr_params).block_until_ready() +full_rank_model = gpx.Prior(mean_function=gpx.Zero(), kernel=gpx.RBF()) * gpx.Gaussian(num_datapoints=D.n) +negative_mll = jit(gpx.ConjugateMLL(negative=True)) +# %timeit negative_mll(full_rank_model, D).block_until_ready() # %% -params, *_ = gpx.initialise(sgpr, key).unpack() -negative_elbo = jit(sgpr.elbo(D, negative=True)) - -# %timeit negative_elbo(params).block_until_ready() +negative_elbo = jit(gpx.CollapsedELBO(negative=True)) +# %timeit negative_elbo(q, D).block_until_ready() # %% [markdown] -# As we can see, the sparse approximation given here is around 50 times faster when compared against a full-rank model. +# As we can see, the sparse approximation given here is around 50 times faster when +# compared against a full-rank model. # %% [markdown] # ## System configuration From 2557fe82fa7088df10934e6b6877501aca019e0b Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Tue, 4 Apr 2023 11:01:49 +0100 Subject: [PATCH 29/44] Sampling fixed --- examples/classification.pct.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/examples/classification.pct.py b/examples/classification.pct.py index f9b1f1d8d..144b8944c 100644 --- a/examples/classification.pct.py +++ b/examples/classification.pct.py @@ -33,6 +33,8 @@ import optax as ox from jax.config import config from jaxtyping import Array, Float +import jax.tree_util as jtu + import gpjax as gpx @@ -412,14 +414,9 @@ def one_step(state, rng_key): samples = [] for i in range(0, num_samples, thin_factor): - posterior.replace(states.position) - ps["kernel"]["lengthscale"] = states.position["kernel"]["lengthscale"][i] - ps["kernel"]["variance"] = states.position["kernel"]["variance"][i] - ps["latent"] = states.position["latent"][i, :, :] - ps = gpx.constrain(ps, bijectors) - - latent_dist = posterior(ps, D)(xtest) - predictive_dist = likelihood(ps, latent_dist) + sample = jtu.tree_map(lambda samples: samples[0], states.position) + latent_dist = sample.predict(xtest, train_data=D) + predictive_dist = sample.likelihood(latent_dist) samples.append(predictive_dist.sample(seed=key, sample_shape=(10,))) samples = jnp.vstack(samples) From 395f29536748d6366fe70070d5ba5238faf6cdde Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Tue, 4 Apr 2023 11:02:39 +0100 Subject: [PATCH 30/44] Sampling fixed --- examples/classification.pct.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/classification.pct.py b/examples/classification.pct.py index 144b8944c..530f58b56 100644 --- a/examples/classification.pct.py +++ b/examples/classification.pct.py @@ -233,7 +233,7 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNormalTriL: map_latent_dist = opt_posterior.predict(xtest, train_data=D) - + Kxt = opt_posterior.prior.kernel.cross_covariance(x, test_inputs) Kxx = opt_posterior.prior.kernel.gram(x) Kxx += I(D.n) * jitter @@ -414,7 +414,7 @@ def one_step(state, rng_key): samples = [] for i in range(0, num_samples, thin_factor): - sample = jtu.tree_map(lambda samples: samples[0], states.position) + sample = jtu.tree_map(lambda samples: samples[i], states.position) latent_dist = sample.predict(xtest, train_data=D) predictive_dist = sample.likelihood(latent_dist) samples.append(predictive_dist.sample(seed=key, sample_shape=(10,))) From 0700628865d13258408a4398c360493db1f9097a Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Tue, 4 Apr 2023 11:12:12 +0100 Subject: [PATCH 31/44] Graph kernel --- examples/graph_kernels.pct.py | 61 ++++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/examples/graph_kernels.pct.py b/examples/graph_kernels.pct.py index f87059f65..ea8f5c211 100644 --- a/examples/graph_kernels.pct.py +++ b/examples/graph_kernels.pct.py @@ -17,7 +17,11 @@ # %% [markdown] # # Graph Kernels # -# This notebook demonstrates how regression models can be constructed on the vertices of a graph using a Gaussian process with a Matérn kernel presented in . For a general discussion of the kernels supported within GPJax, see the [kernels notebook](https://gpjax.readthedocs.io/en/latest/nbs/kernels.html). +# This notebook demonstrates how regression models can be constructed on the vertices +# of a graph using a Gaussian process with a Matérn kernel presented in +# . For a general discussion of the +# kernels supported within GPJax, see the +# [kernels notebook](https://gpjax.readthedocs.io/en/latest/nbs/kernels.html). # %% import random @@ -29,10 +33,8 @@ import optax as ox from jax import jit from jax.config import config -from jaxutils import Dataset import gpjax as gpx -import gpjax.kernels as jk # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -41,9 +43,16 @@ # %% [markdown] # ## Graph construction # -# Our graph $\mathcal{G}=\lbrace V, E \rbrace$ comprises a set of vertices $V = \lbrace v_1, v_2, \ldots, v_n\rbrace$ and edges $E=\lbrace (v_i, v_j)\in V \ : \ i \neq j\rbrace$. In particular, we will consider a [barbell graph](https://en.wikipedia.org/wiki/Barbell_graph) that is an undirected graph containing two clusters of vertices with a single shared edge between the two clusters. +# Our graph $\mathcal{G}=\lbrace V, E \rbrace$ comprises a set of vertices +# $V = \lbrace v_1, v_2, \ldots, v_n\rbrace$ and edges +# $E=\lbrace (v_i, v_j)\in V \ : \ i \neq j\rbrace$. In particular, we will consider +# a [barbell graph](https://en.wikipedia.org/wiki/Barbell_graph) that is an undirected +# graph containing two clusters of vertices with a single shared edge between the +# two clusters. # -# Contrary to the typical barbell graph, we'll randomly remove a subset of 30 edges within each of the two clusters. Given the 40 vertices within the graph, this results in 351 edges as shown below. +# Contrary to the typical barbell graph, we'll randomly remove a subset of 30 edges +# within each of the two clusters. Given the 40 vertices within the graph, this results +# in 351 edges as shown below. # %% vertex_per_side = 20 @@ -63,9 +72,13 @@ # # ### Computing the graph Laplacian # -# Graph kernels use the _Laplacian matrix_ $L$ to quantify the smoothness of a signal (or function) on a graph +# Graph kernels use the _Laplacian matrix_ $L$ to quantify the smoothness of a signal +# (or function) on a graph # $$L=D-A,$$ -# where $D$ is the diagonal _degree matrix_ containing each vertices' degree and $A$ is the _adjacency matrix_ that has an $(i,j)^{\text{th}}$ entry of 1 if $v_i, v_j$ are connected and 0 otherwise. [Networkx](https://networkx.org) gives us an easy way to compute this. +# where $D$ is the diagonal _degree matrix_ containing each vertices' degree and $A$ +# is the _adjacency matrix_ that has an $(i,j)^{\text{th}}$ entry of 1 if $v_i, v_j$ +# are connected and 0 otherwise. [Networkx](https://networkx.org) gives us an easy +# way to compute this. # %% L = nx.laplacian_matrix(G).toarray() @@ -74,16 +87,20 @@ # # ## Simulating a signal on the graph # -# Our task is to construct a Gaussian process $f(\cdot)$ that maps from the graph's vertex set $V$ onto the real line. -# To that end, we begin by simulating a signal on the graph's vertices that we will go on to try and predict. -# We use a single draw from a Gaussian process prior to draw our response values $\boldsymbol{y}$ where we hardcode parameter values. -# The corresponding input value set for this model, denoted $\boldsymbol{x}$, is the index set of the graph's vertices. +# Our task is to construct a Gaussian process $f(\cdot)$ that maps from the graph's +# vertex set $V$ onto the real line. +# To that end, we begin by simulating a signal on the graph's vertices that we will go +# on to try and predict. +# We use a single draw from a Gaussian process prior to draw our response values +# $\boldsymbol{y}$ where we hardcode parameter values. +# The corresponding input value set for this model, denoted $\boldsymbol{x}$, is the +# index set of the graph's vertices. # %% x = jnp.arange(G.number_of_nodes()).reshape(-1, 1) -kernel = jk.GraphKernel(laplacian=L) -prior = gpx.Prior(kernel=kernel) +kernel = gpx.GraphKernel(laplacian=L) +prior = gpx.Prior(mean_function = gpx.Zero(), kernel=kernel) true_params = prior.init_params(key) true_params["kernel"] = { @@ -121,9 +138,14 @@ # # ## Constructing a graph Gaussian process # -# With our dataset created, we proceed to define our posterior Gaussian process and optimise the model's hyperparameters. -# Whilst our underlying space is the graph's vertex set and is therefore non-Euclidean, our likelihood is still Gaussian and the model is still conjugate. -# For this reason, we simply perform gradient descent on the GP's marginal log-likelihood term as in the [regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html). We do this using the Adam optimiser provided in `optax`. +# With our dataset created, we proceed to define our posterior Gaussian process and +# optimise the model's hyperparameters. +# Whilst our underlying space is the graph's vertex set and is therefore +# non-Euclidean, our likelihood is still Gaussian and the model is still conjugate. +# For this reason, we simply perform gradient descent on the GP's marginal +# log-likelihood term as in the +# [regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html). +# We do this using the Adam optimiser provided in `optax`. # %% likelihood = gpx.Gaussian(num_datapoints=y.shape[0]) @@ -147,7 +169,9 @@ # ## Making predictions # # Having optimised our hyperparameters, we can now make predictions on the graph. -# Though we haven't defined a training and testing dataset here, we'll simply query the predictive posterior for the full graph to compare the root-mean-squared error (RMSE) of the model for the initialised parameters vs the optimised set. +# Though we haven't defined a training and testing dataset here, we'll simply query +# the predictive posterior for the full graph to compare the root-mean-squared error +# (RMSE) of the model for the initialised parameters vs the optimised set. # %% initial_params = parameter_state.params @@ -168,7 +192,8 @@ # %% [markdown] # -# We can also plot the source of error in our model's predictions on the graph by the following. +# We can also plot the source of error in our model's predictions on the graph by the +# following. # %% error = jnp.abs(learned_mean - y.squeeze()) From 5d29d7f473b170d19654506a5c753bceac208574 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Mon, 3 Apr 2023 09:01:18 +0100 Subject: [PATCH 32/44] RFF refactored --- gpjax/kernels/approximations/rff.py | 56 ++-- gpjax/kernels/base.py | 6 + gpjax/kernels/computations/basis_functions.py | 12 +- gpjax/kernels/non_euclidean/graph.py | 56 ++-- gpjax/kernels/stationary/matern12.py | 4 +- gpjax/kernels/stationary/matern32.py | 4 +- gpjax/kernels/stationary/matern52.py | 4 +- gpjax/kernels/stationary/rbf.py | 4 + gpjax/kernels/stationary/utils.py | 9 +- tests/test_kernels/test_approximations.py | 244 +++++++++--------- 10 files changed, 189 insertions(+), 210 deletions(-) diff --git a/gpjax/kernels/approximations/rff.py b/gpjax/kernels/approximations/rff.py index a1e3b97f7..6e2610054 100644 --- a/gpjax/kernels/approximations/rff.py +++ b/gpjax/kernels/approximations/rff.py @@ -1,10 +1,25 @@ from ..base import AbstractKernel from ..computations import BasisFunctionComputation -from jax.random import KeyArray +from jax.random import KeyArray, PRNGKey from typing import Dict, Any +from jaxtyping import Float, Array +from dataclasses import dataclass +from ...parameters import param_field +from ..computations import DenseKernelComputation, AbstractKernelComputation +from simple_pytree import static_field +import tensorflow_probability.substrates.jax as tfp +tfb = tfp.bijectors -class RFF(AbstractKernel): +@dataclass +class AbstractFourierKernel: + base_kernel: AbstractKernel + num_basis_fns: int + frequencies: Float[Array, "M 1"] = param_field(None, bijector=tfb.Identity) + key: KeyArray = static_field(PRNGKey(123)) + +@dataclass +class RFF(AbstractKernel, AbstractFourierKernel): """Computes an approximation of the kernel using Random Fourier Features. All stationary kernels are equivalent to the Fourier transform of a probability @@ -23,39 +38,22 @@ class RFF(AbstractKernel): AbstractKernel (_type_): _description_ """ - def __init__(self, base_kernel: AbstractKernel, num_basis_fns: int) -> None: - """Initialise the Random Fourier Features approximation. + def __post_init__(self) -> None: + """Post-initialisation function. - Args: - base_kernel (AbstractKernel): The kernel that is to be approximated. This kernel must be stationary. - num_basis_fns (int): The number of basis functions that should be used to approximate the kernel. + This function is called after the initialisation of the kernel. It is used to + set the computation engine to be the basis function computation engine. """ - self._check_valid_base_kernel(base_kernel) - self.base_kernel = base_kernel - self.num_basis_fns = num_basis_fns - # Set the computation engine to be basis function computation engine + self._check_valid_base_kernel(self.base_kernel) self.compute_engine = BasisFunctionComputation - # Inform the compute engine of the number of basis functions - self.compute_engine.num_basis_fns = num_basis_fns - - def init_params(self, key: KeyArray) -> Dict: - """Initialise the parameters of the RFF approximation. - Args: - key (KeyArray): A pseudo-random number generator key. - - Returns: - Dict: A dictionary containing the original kernel's parameters and the initial frequencies used in RFF approximation. - """ - base_params = self.base_kernel.init_params(key) - n_dims = self.base_kernel.ndims - frequencies = self.base_kernel.spectral_density.sample( - seed=key, sample_shape=(self.num_basis_fns, n_dims) + if self.frequencies is None: + n_dims = self.base_kernel.ndims + self.frequencies = self.base_kernel.spectral_density.sample( + seed=self.key, sample_shape=(self.num_basis_fns, n_dims) ) - base_params["frequencies"] = frequencies - return base_params - def __call__(self, *args: Any, **kwds: Any) -> Any: + def __call__(self, x: Array, y: Array) -> Array: pass def _check_valid_base_kernel(self, kernel: AbstractKernel): diff --git a/gpjax/kernels/base.py b/gpjax/kernels/base.py index 0a1aee259..f901c1f95 100644 --- a/gpjax/kernels/base.py +++ b/gpjax/kernels/base.py @@ -22,6 +22,8 @@ from functools import partial from simple_pytree import static_field from dataclasses import dataclass +from functools import partial +import tensorflow_probability.substrates.jax.distributions as tfd from ..base import Module, param_field from .computations import AbstractKernelComputation, DenseKernelComputation @@ -115,6 +117,10 @@ def __mul__( return ProductKernel(kernels=[self, Constant(other)]) + @property + def spectral_density(self) -> tfd.Distribution: + return None + @dataclass class Constant(AbstractKernel): diff --git a/gpjax/kernels/computations/basis_functions.py b/gpjax/kernels/computations/basis_functions.py index 97e12fe1c..c19807cad 100644 --- a/gpjax/kernels/computations/basis_functions.py +++ b/gpjax/kernels/computations/basis_functions.py @@ -10,7 +10,7 @@ class BasisFunctionComputation(AbstractKernelComputation): """Compute engine class for finite basis function approximations to a kernel.""" - num_basis_fns = None + num_basis_fns: int = None def cross_covariance( self, x: Float[Array, "N D"], y: Float[Array, "M D"] @@ -26,8 +26,8 @@ def cross_covariance( """ z1 = self.compute_features(x) z2 = self.compute_features(y) - z1 /= self.num_basis_fns - return self.kernel.variance * jnp.matmul(z1, z2.T) + z1 /= self.kernel.num_basis_fns + return self.kernel.base_kernel.variance * jnp.matmul(z1, z2.T) def gram(self, inputs: Float[Array, "N D"]) -> DenseLinearOperator: """For the Gram matrix, we can save computations by computing only one matrix multiplication between the inputs and the scaled frequencies. @@ -41,8 +41,8 @@ def gram(self, inputs: Float[Array, "N D"]) -> DenseLinearOperator: """ z1 = self.compute_features(inputs) matrix = jnp.matmul(z1, z1.T) # shape: (n_samples, n_samples) - matrix /= self.num_basis_fns - return DenseLinearOperator(self.kernel.variance * matrix) + matrix /= self.kernel.num_basis_fns + return DenseLinearOperator(self.kernel.base_kernel.variance * matrix) def compute_features(self, x: Float[Array, "N D"]) -> Float[Array, "N L"]: """Compute the features for the inputs. @@ -55,7 +55,7 @@ def compute_features(self, x: Float[Array, "N D"]) -> Float[Array, "N L"]: Float[Array, "N L"]: A N x L array of features where L = 2M. """ frequencies = self.kernel.frequencies - scaling_factor = self.kernel.lengthscale + scaling_factor = self.kernel.base_kernel.lengthscale z = jnp.matmul(x, (frequencies / scaling_factor).T) z = jnp.concatenate([jnp.cos(z), jnp.sin(z)], axis=-1) return z diff --git a/gpjax/kernels/non_euclidean/graph.py b/gpjax/kernels/non_euclidean/graph.py index a11c62ee1..2c31b4c78 100644 --- a/gpjax/kernels/non_euclidean/graph.py +++ b/gpjax/kernels/non_euclidean/graph.py @@ -18,49 +18,42 @@ import jax.numpy as jnp from jax.random import KeyArray from jaxtyping import Array, Float - +from dataclasses import dataclass from ..computations import EigenKernelComputation from ..base import AbstractKernel +from ...parameters import param_field from .utils import jax_gather_nd - +import tensorflow_probability.substrates.jax as tfp +tfb = tfp.bijectors ########################################## # Graph kernels ########################################## -class GraphKernel(AbstractKernel): - """A Matérn graph kernel defined on the vertices of a graph. The key reference for this object is borovitskiy et. al., (2020).""" +@dataclass +class AbstractGraphKernel: + laplacian: Float[Array, "N N"] - def __init__( - self, - laplacian: Float[Array, "N N"], - active_dims: Optional[List[int]] = None, - name: Optional[str] = "Matérn Graph kernel", - ) -> None: - """Initialize a Matérn graph kernel. - Args: - laplacian (Float[Array]): An N x N matrix representing the Laplacian matrix of a graph. - compute_engine (EigenKernelComputation, optional): The compute engine that should be used in the kernel to compute covariance matrices. Defaults to EigenKernelComputation. - active_dims (Optional[List[int]], optional): The dimensions of the input data for which the kernel should be evaluated on. Defaults to None. - stationary (Optional[bool], optional): _description_. Defaults to False. - name (Optional[str], optional): _description_. Defaults to "Graph kernel". - """ - super().__init__( - EigenKernelComputation, - active_dims, - spectral_density=None, - name=name, - ) - self.laplacian = laplacian +@dataclass +class GraphKernel(AbstractKernel, AbstractGraphKernel): + """A Matérn graph kernel defined on the vertices of a graph. The key reference for this object is borovitskiy et. al., (2020). + + Args: + laplacian (Float[Array]): An N x N matrix representing the Laplacian matrix of a graph. + compute_engine + """ + lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus) + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus) + smoothness: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus) + + def __post_init__(self): evals, self.evecs = jnp.linalg.eigh(self.laplacian) self.evals = evals.reshape(-1, 1) self.compute_engine.eigensystem = self.evals, self.evecs self.compute_engine.num_vertex = self.laplacian.shape[0] - self._stationary = True def __call__( self, - params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"], **kwargs, @@ -68,7 +61,6 @@ def __call__( """Evaluate the graph kernel on a pair of vertices :math:`v_i, v_j`. Args: - params (Dict): Parameter set for which the kernel should be evaluated on. x (Float[Array, "1 D"]): Index of the ith vertex. y (Float[Array, "1 D"]): Index of the jth vertex. @@ -81,14 +73,6 @@ def __call__( ) # shape (n,n) return Kxx.squeeze() - def init_params(self, key: KeyArray) -> Dict: - """Initialise the lengthscale, variance and smoothness parameters of the kernel""" - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - "smoothness": jnp.array([1.0]), - } - @property def num_vertex(self) -> int: """The number of vertices within the graph. diff --git a/gpjax/kernels/stationary/matern12.py b/gpjax/kernels/stationary/matern12.py index 23ecb881b..6e4d3938f 100644 --- a/gpjax/kernels/stationary/matern12.py +++ b/gpjax/kernels/stationary/matern12.py @@ -18,12 +18,12 @@ import tensorflow_probability.substrates.jax.distributions as tfd from jaxtyping import Array, Float -from dataclasses import dataclass +import tensorflow_probability.substrates.jax as tfp from ...base import param_field from ..base import AbstractKernel from .utils import build_student_t_distribution, euclidean_distance - +tfd = tfp.distributions @dataclass class Matern12(AbstractKernel): diff --git a/gpjax/kernels/stationary/matern32.py b/gpjax/kernels/stationary/matern32.py index ca2068e81..f6efc685d 100644 --- a/gpjax/kernels/stationary/matern32.py +++ b/gpjax/kernels/stationary/matern32.py @@ -18,12 +18,12 @@ import tensorflow_probability.substrates.jax.distributions as tfd from jaxtyping import Array, Float -from dataclasses import dataclass +import tensorflow_probability.substrates.jax as tfp from ...base import param_field from ..base import AbstractKernel from .utils import build_student_t_distribution, euclidean_distance - +tfd = tfp.distributions @dataclass class Matern32(AbstractKernel): diff --git a/gpjax/kernels/stationary/matern52.py b/gpjax/kernels/stationary/matern52.py index 4bae5ca99..d96636c44 100644 --- a/gpjax/kernels/stationary/matern52.py +++ b/gpjax/kernels/stationary/matern52.py @@ -18,12 +18,12 @@ import tensorflow_probability.substrates.jax.distributions as tfd from jaxtyping import Array, Float -from dataclasses import dataclass +import tensorflow_probability.substrates.jax as tfp from ...base import param_field from ..base import AbstractKernel from .utils import build_student_t_distribution, euclidean_distance - +tfd = tfp.distributions @dataclass class Matern52(AbstractKernel): diff --git a/gpjax/kernels/stationary/rbf.py b/gpjax/kernels/stationary/rbf.py index ae78cb158..0d6a74baa 100644 --- a/gpjax/kernels/stationary/rbf.py +++ b/gpjax/kernels/stationary/rbf.py @@ -23,7 +23,11 @@ from ...base import param_field from ..base import AbstractKernel from .utils import squared_distance +import tensorflow_probability.substrates.jax as tfp +from dataclasses import dataclass +from ...parameters import param_field, Softplus +tfd = tfp.distributions @dataclass class RBF(AbstractKernel): diff --git a/gpjax/kernels/stationary/utils.py b/gpjax/kernels/stationary/utils.py index fa04f2310..e7aa67d13 100644 --- a/gpjax/kernels/stationary/utils.py +++ b/gpjax/kernels/stationary/utils.py @@ -15,23 +15,22 @@ import jax.numpy as jnp from jaxtyping import Array, Float -import distrax as dx import tensorflow_probability.substrates.jax as tfp tfd = tfp.distributions -def build_student_t_distribution(nu: int) -> dx.Distribution: +def build_student_t_distribution(nu: int) -> tfd.Distribution: """For a fixed half-integer smoothness parameter, compute the spectral density of a Matérn kernel; a Student's t distribution. Args: nu (int): The smoothness parameter of the Matérn kernel. Returns: - dx.Distribution: A Student's t distribution with the same smoothness parameter. + tfp.Distribution: A Student's t distribution with the same smoothness parameter. """ - tfp_dist = tfd.StudentT(df=nu, loc=0.0, scale=1.0) - return dx._src.distributions.distribution_from_tfp.distribution_from_tfp(tfp_dist) + dist = tfd.StudentT(df=nu, loc=0.0, scale=1.0) + return dist def squared_distance( diff --git a/tests/test_kernels/test_approximations.py b/tests/test_kernels/test_approximations.py index 1631145a9..c5674561c 100644 --- a/tests/test_kernels/test_approximations.py +++ b/tests/test_kernels/test_approximations.py @@ -1,119 +1,109 @@ -# import pytest -# from jaxkern.approximations import RFF -# from jaxkern.stationary import ( -# Matern12, -# Matern32, -# Matern52, -# RBF, -# RationalQuadratic, -# PoweredExponential, -# Periodic, -# ) -# from jaxkern.nonstationary import Polynomial, Linear -# from jaxkern.base import AbstractKernel -# import jax.random as jr -# from jax.config import config -# import jax.numpy as jnp -# from gpjax.linops import DenseLinearOperator -# from typing import Tuple -# import jax - -# config.update("jax_enable_x64", True) -# _jitter = 1e-5 - - -# @pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) -# @pytest.mark.parametrize("num_basis_fns", [2, 10, 20]) -# @pytest.mark.parametrize("n_dims", [1, 2, 5]) -# def test_frequency_sampler(kernel: AbstractKernel, num_basis_fns: int, n_dims: int): -# key = jr.PRNGKey(123) -# base_kernel = kernel(active_dims=list(range(n_dims))) -# approximate = RFF(base_kernel, num_basis_fns) - -# params = approximate.init_params(key) -# assert params["frequencies"].shape == (num_basis_fns, n_dims) - - -# @pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) -# @pytest.mark.parametrize("num_basis_fns", [2, 10, 20]) -# @pytest.mark.parametrize("n_dims", [1, 2, 5]) -# @pytest.mark.parametrize("n_data", [50, 100]) -# def test_gram(kernel: AbstractKernel, num_basis_fns: int, n_dims: int, n_data: int): -# key = jr.PRNGKey(123) -# x = jr.uniform(key, shape=(n_data, 1), minval=-3.0, maxval=3.0).reshape(-1, 1) -# if n_dims > 1: -# x = jnp.hstack([x] * n_dims) -# base_kernel = kernel(active_dims=list(range(n_dims))) -# approximate = RFF(base_kernel, num_basis_fns) - -# params = approximate.init_params(key) - -# linop = approximate.gram(params, x) - -# # Check the return type -# assert isinstance(linop, DenseLinearOperator) - -# Kxx = linop.to_dense() + jnp.eye(n_data) * _jitter - -# # Check that the shape is correct -# assert Kxx.shape == (n_data, n_data) - -# # Check that the Gram matrix is PSD -# evals, _ = jnp.linalg.eigh(Kxx) -# assert jnp.all(evals > 0) - - -# @pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) -# @pytest.mark.parametrize("num_basis_fns", [2, 10, 20]) -# @pytest.mark.parametrize("n_dims", [1, 2, 5]) -# @pytest.mark.parametrize("n_datas", [(50, 100), (100, 50)]) -# def test_cross_covariance( -# kernel: AbstractKernel, -# num_basis_fns: int, -# n_dims: int, -# n_datas: Tuple[int, int], -# ): -# nd1, nd2 = n_datas -# key = jr.PRNGKey(123) -# x1 = jr.uniform(key, shape=(nd1, 1), minval=-3.0, maxval=3.0) -# if n_dims > 1: -# x1 = jnp.hstack([x1] * n_dims) -# x2 = jr.uniform(key, shape=(nd2, 1), minval=-3.0, maxval=3.0) -# if n_dims > 1: -# x2 = jnp.hstack([x2] * n_dims) - -# base_kernel = kernel(active_dims=list(range(n_dims))) -# approximate = RFF(base_kernel, num_basis_fns) - -# params = approximate.init_params(key) - -# Kxx = approximate.cross_covariance(params, x1, x2) - -# # Check the return type -# assert isinstance(Kxx, jax.Array) - -# # Check that the shape is correct -# assert Kxx.shape == (nd1, nd2) - - -# @pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) -# @pytest.mark.parametrize("n_dim", [1, 2, 5]) -# def test_improvement(kernel, n_dim): -# n_data = 100 -# key = jr.PRNGKey(123) - -# x = jr.uniform(key, minval=-3.0, maxval=3.0, shape=(n_data, n_dim)) -# base_kernel = kernel(active_dims=list(range(n_dim))) -# exact_params = base_kernel.init_params(key) -# exact_linop = base_kernel.gram(exact_params, x).to_dense() - -# crude_approximation = RFF(base_kernel, num_basis_fns=10) -# c_params = crude_approximation.init_params(key) -# c_linop = crude_approximation.gram(c_params, x).to_dense() - -# better_approximation = RFF(base_kernel, num_basis_fns=50) -# b_params = better_approximation.init_params(key) -# b_linop = better_approximation.gram(b_params, x).to_dense() +import pytest +from gpjax.kernels.approximations import RFF +from gpjax.kernels.stationary import ( + Matern12, + Matern32, + Matern52, + RBF, + RationalQuadratic, + PoweredExponential, + Periodic, +) +from gpjax.kernels.nonstationary import Polynomial, Linear +from gpjax.kernels.base import AbstractKernel +import jax.random as jr +from jax.config import config +import jax.numpy as jnp +from gpjax.linops import DenseLinearOperator +from typing import Tuple +import jax + +config.update("jax_enable_x64", True) +_jitter = 1e-6 + + +@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) +@pytest.mark.parametrize("num_basis_fns", [2, 10, 20]) +@pytest.mark.parametrize("n_dims", [1, 2, 5]) +def test_frequency_sampler(kernel: AbstractKernel, num_basis_fns: int, n_dims: int): + key = jr.PRNGKey(123) + base_kernel = kernel(active_dims=list(range(n_dims))) + approximate = RFF(base_kernel, num_basis_fns) + assert approximate.frequencies.shape == (num_basis_fns, n_dims) + + +@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) +@pytest.mark.parametrize("num_basis_fns", [2, 10, 20]) +@pytest.mark.parametrize("n_dims", [1, 2, 5]) +@pytest.mark.parametrize("n_data", [50, 100]) +def test_gram(kernel: AbstractKernel, num_basis_fns: int, n_dims: int, n_data: int): + key = jr.PRNGKey(123) + x = jr.uniform(key, shape=(n_data, 1), minval=-3.0, maxval=3.0).reshape(-1, 1) + if n_dims > 1: + x = jnp.hstack([x] * n_dims) + base_kernel = kernel(active_dims=list(range(n_dims))) + approximate = RFF(base_kernel, num_basis_fns) + + linop = approximate.gram(x) + + # Check the return type + assert isinstance(linop, DenseLinearOperator) + + Kxx = linop.to_dense() + jnp.eye(n_data) * _jitter + + # Check that the shape is correct + assert Kxx.shape == (n_data, n_data) + + # Check that the Gram matrix is PSD + evals, _ = jnp.linalg.eigh(Kxx) + assert jnp.all(evals > 0) + + +@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) +@pytest.mark.parametrize("num_basis_fns", [2, 10, 20]) +@pytest.mark.parametrize("n_dims", [1, 2, 5]) +@pytest.mark.parametrize("n_datas", [(50, 100), (100, 50)]) +def test_cross_covariance( + kernel: AbstractKernel, + num_basis_fns: int, + n_dims: int, + n_datas: Tuple[int, int], +): + nd1, nd2 = n_datas + key = jr.PRNGKey(123) + x1 = jr.uniform(key, shape=(nd1, 1), minval=-3.0, maxval=3.0) + if n_dims > 1: + x1 = jnp.hstack([x1] * n_dims) + x2 = jr.uniform(key, shape=(nd2, 1), minval=-3.0, maxval=3.0) + if n_dims > 1: + x2 = jnp.hstack([x2] * n_dims) + + base_kernel = kernel(active_dims=list(range(n_dims))) + approximate = RFF(base_kernel, num_basis_fns) + Kxx = approximate.cross_covariance(x1, x2) + + # Check the return type + assert isinstance(Kxx, jax.Array) + + # Check that the shape is correct + assert Kxx.shape == (nd1, nd2) + + +@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) +@pytest.mark.parametrize("n_dim", [1, 2, 5]) +def test_improvement(kernel, n_dim): + n_data = 100 + key = jr.PRNGKey(123) + + x = jr.uniform(key, minval=-3.0, maxval=3.0, shape=(n_data, n_dim)) + base_kernel = kernel(active_dims=list(range(n_dim))) + exact_linop = base_kernel.gram(x).to_dense() + + crude_approximation = RFF(base_kernel, num_basis_fns=10) + c_linop = crude_approximation.gram(x).to_dense() + + better_approximation = RFF(base_kernel, num_basis_fns=50) + b_linop = better_approximation.gram(x).to_dense() # c_delta = jnp.linalg.norm(exact_linop - c_linop, ord="fro") # b_delta = jnp.linalg.norm(exact_linop - b_linop, ord="fro") @@ -123,21 +113,19 @@ # assert c_delta > b_delta -# @pytest.mark.parametrize("kernel", [RBF(), Matern12(), Matern32(), Matern52()]) -# def test_exactness(kernel): -# n_data = 100 -# key = jr.PRNGKey(123) +@pytest.mark.parametrize("kernel", [RBF(), Matern12(), Matern32(), Matern52()]) +def test_exactness(kernel): + n_data = 100 + key = jr.PRNGKey(123) -# x = jr.uniform(key, minval=-3.0, maxval=3.0, shape=(n_data, 1)) -# exact_params = kernel.init_params(key) -# exact_linop = kernel.gram(exact_params, x).to_dense() + x = jr.uniform(key, minval=-3.0, maxval=3.0, shape=(n_data, 1)) + exact_linop = kernel.gram(x).to_dense() -# better_approximation = RFF(kernel, num_basis_fns=500) -# b_params = better_approximation.init_params(key) -# b_linop = better_approximation.gram(b_params, x).to_dense() + better_approximation = RFF(kernel, num_basis_fns=500) + b_linop = better_approximation.gram(x).to_dense() -# max_delta = jnp.max(exact_linop - b_linop) -# assert max_delta < 0.1 + max_delta = jnp.max(exact_linop - b_linop) + assert max_delta < 0.1 # @pytest.mark.parametrize( From d955c98280ddac387a73996355a1bb7d78d78815 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Mon, 3 Apr 2023 14:06:07 +0100 Subject: [PATCH 33/44] Graph kernel refactored --- gpjax/base/module.py | 4 +- gpjax/kernels/__init__.py | 6 +- gpjax/kernels/approximations/rff.py | 27 ++++--- gpjax/kernels/base.py | 2 +- gpjax/kernels/computations/__init__.py | 2 +- gpjax/kernels/computations/base.py | 10 ++- gpjax/kernels/computations/basis_functions.py | 6 +- .../kernels/computations/constant_diagonal.py | 8 +-- gpjax/kernels/computations/dense.py | 1 + gpjax/kernels/computations/diagonal.py | 6 +- gpjax/kernels/computations/eigen.py | 17 ++--- gpjax/kernels/non_euclidean/graph.py | 40 +++++------ gpjax/kernels/non_euclidean/utils.py | 2 +- gpjax/kernels/nonstationary/linear.py | 2 + gpjax/kernels/nonstationary/polynomial.py | 2 + gpjax/kernels/stationary/matern12.py | 3 +- gpjax/kernels/stationary/matern32.py | 3 +- gpjax/kernels/stationary/matern52.py | 7 +- gpjax/kernels/stationary/rbf.py | 6 +- gpjax/kernels/stationary/utils.py | 10 +-- gpjax/linops/__init__.py | 11 ++- .../constant_diagonal_linear_operator.py | 4 +- gpjax/linops/dense_linear_operator.py | 2 +- gpjax/linops/diagonal_linear_operator.py | 5 +- gpjax/linops/identity_linear_operator.py | 2 +- gpjax/linops/linear_operator.py | 6 +- gpjax/linops/triangular_linear_operator.py | 2 +- gpjax/linops/utils.py | 7 +- gpjax/linops/zero_linear_operator.py | 6 +- tests/test_base/test_module.py | 12 ++-- tests/test_kernels/__init__.py | 2 +- tests/test_kernels/test_approximations.py | 44 +++++++----- tests/test_kernels/test_base.py | 7 +- tests/test_kernels/test_computation.py | 10 +-- tests/test_kernels/test_non_euclidean.py | 70 ++++++++----------- tests/test_kernels/test_nonstationary.py | 2 +- tests/test_kernels/test_stationary.py | 13 ++-- tests/test_kernels/test_utils.py | 1 + .../test_constant_linear_operator.py | 7 +- .../test_linops/test_dense_linear_operator.py | 5 +- .../test_diagonal_linear_operator.py | 5 +- .../test_identity_linear_operator.py | 8 ++- tests/test_linops/test_linear_operator.py | 19 ++--- .../test_triangular_linear_operator.py | 4 +- tests/test_linops/test_utils.py | 3 +- .../test_linops/test_zero_linear_operator.py | 5 +- 46 files changed, 219 insertions(+), 207 deletions(-) diff --git a/gpjax/base/module.py b/gpjax/base/module.py index dd8b97a18..34ba57e63 100644 --- a/gpjax/base/module.py +++ b/gpjax/base/module.py @@ -19,13 +19,13 @@ import dataclasses from copy import copy, deepcopy -from typing import Any, Callable, Dict, Iterable, Tuple, List -from typing_extensions import Self +from typing import Any, Callable, Dict, Iterable, List, Tuple import jax import jax.tree_util as jtu from jax._src.tree_util import _registry from simple_pytree import Pytree, static_field +from typing_extensions import Self import tensorflow_probability.substrates.jax.bijectors as tfb diff --git a/gpjax/kernels/__init__.py b/gpjax/kernels/__init__.py index 8e835064e..14f454bb8 100644 --- a/gpjax/kernels/__init__.py +++ b/gpjax/kernels/__init__.py @@ -15,7 +15,7 @@ """JaxKern.""" from .approximations import RFF -from .base import ProductKernel, SumKernel, AbstractKernel +from .base import AbstractKernel, ProductKernel, SumKernel from .computations import ( BasisFunctionComputation, ConstantDiagonalKernelComputation, @@ -23,18 +23,18 @@ DiagonalKernelComputation, EigenKernelComputation, ) +from .non_euclidean import GraphKernel from .nonstationary import Linear, Polynomial from .stationary import ( RBF, Matern12, Matern32, Matern52, - RationalQuadratic, Periodic, PoweredExponential, + RationalQuadratic, White, ) -from .non_euclidean import GraphKernel __all__ = [ "AbstractKernel", diff --git a/gpjax/kernels/approximations/rff.py b/gpjax/kernels/approximations/rff.py index 6e2610054..7f6a6b2c8 100644 --- a/gpjax/kernels/approximations/rff.py +++ b/gpjax/kernels/approximations/rff.py @@ -1,13 +1,19 @@ -from ..base import AbstractKernel -from ..computations import BasisFunctionComputation -from jax.random import KeyArray, PRNGKey -from typing import Dict, Any -from jaxtyping import Float, Array from dataclasses import dataclass -from ...parameters import param_field -from ..computations import DenseKernelComputation, AbstractKernelComputation -from simple_pytree import static_field +from typing import Any, Dict + import tensorflow_probability.substrates.jax as tfp +from jax.random import KeyArray, PRNGKey +from jaxtyping import Array, Float +from simple_pytree import static_field + +from ...parameters import param_field +from ..base import AbstractKernel +from ..computations import ( + AbstractKernelComputation, + BasisFunctionComputation, + DenseKernelComputation, +) + tfb = tfp.bijectors @@ -18,6 +24,7 @@ class AbstractFourierKernel: frequencies: Float[Array, "M 1"] = param_field(None, bijector=tfb.Identity) key: KeyArray = static_field(PRNGKey(123)) + @dataclass class RFF(AbstractKernel, AbstractFourierKernel): """Computes an approximation of the kernel using Random Fourier Features. @@ -50,8 +57,8 @@ def __post_init__(self) -> None: if self.frequencies is None: n_dims = self.base_kernel.ndims self.frequencies = self.base_kernel.spectral_density.sample( - seed=self.key, sample_shape=(self.num_basis_fns, n_dims) - ) + seed=self.key, sample_shape=(self.num_basis_fns, n_dims) + ) def __call__(self, x: Array, y: Array) -> Array: pass diff --git a/gpjax/kernels/base.py b/gpjax/kernels/base.py index f901c1f95..277d126b1 100644 --- a/gpjax/kernels/base.py +++ b/gpjax/kernels/base.py @@ -23,7 +23,7 @@ from simple_pytree import static_field from dataclasses import dataclass from functools import partial -import tensorflow_probability.substrates.jax.distributions as tfd +from typing import Callable, List, Union from ..base import Module, param_field from .computations import AbstractKernelComputation, DenseKernelComputation diff --git a/gpjax/kernels/computations/__init__.py b/gpjax/kernels/computations/__init__.py index 7937838f5..57aeeb244 100644 --- a/gpjax/kernels/computations/__init__.py +++ b/gpjax/kernels/computations/__init__.py @@ -14,11 +14,11 @@ # ============================================================================== from .base import AbstractKernelComputation +from .basis_functions import BasisFunctionComputation from .constant_diagonal import ConstantDiagonalKernelComputation from .dense import DenseKernelComputation from .diagonal import DiagonalKernelComputation from .eigen import EigenKernelComputation -from .basis_functions import BasisFunctionComputation __all__ = [ "AbstractKernelComputation", diff --git a/gpjax/kernels/computations/base.py b/gpjax/kernels/computations/base.py index af45826c2..ccd550aa4 100644 --- a/gpjax/kernels/computations/base.py +++ b/gpjax/kernels/computations/base.py @@ -14,15 +14,13 @@ # ============================================================================== import abc +from dataclasses import dataclass from typing import Any + from jax import vmap -from gpjax.linops import ( - DenseLinearOperator, - DiagonalLinearOperator, - LinearOperator, -) from jaxtyping import Array, Float -from dataclasses import dataclass + +from gpjax.linops import DenseLinearOperator, DiagonalLinearOperator, LinearOperator Kernel = Any diff --git a/gpjax/kernels/computations/basis_functions.py b/gpjax/kernels/computations/basis_functions.py index c19807cad..39aa44fc9 100644 --- a/gpjax/kernels/computations/basis_functions.py +++ b/gpjax/kernels/computations/basis_functions.py @@ -1,9 +1,11 @@ +from dataclasses import dataclass + import jax.numpy as jnp from jaxtyping import Array, Float -from .base import AbstractKernelComputation + from gpjax.linops import DenseLinearOperator -from dataclasses import dataclass +from .base import AbstractKernelComputation @dataclass diff --git a/gpjax/kernels/computations/constant_diagonal.py b/gpjax/kernels/computations/constant_diagonal.py index 58b8087a1..3e4baed26 100644 --- a/gpjax/kernels/computations/constant_diagonal.py +++ b/gpjax/kernels/computations/constant_diagonal.py @@ -13,13 +13,11 @@ # limitations under the License. # ============================================================================== import jax.numpy as jnp - from jax import vmap -from gpjax.linops import ( - ConstantDiagonalLinearOperator, - DiagonalLinearOperator, -) from jaxtyping import Array, Float + +from gpjax.linops import ConstantDiagonalLinearOperator, DiagonalLinearOperator + from .base import AbstractKernelComputation diff --git a/gpjax/kernels/computations/dense.py b/gpjax/kernels/computations/dense.py index c64981feb..7be0de06a 100644 --- a/gpjax/kernels/computations/dense.py +++ b/gpjax/kernels/computations/dense.py @@ -15,6 +15,7 @@ from jax import vmap from jaxtyping import Array, Float + from .base import AbstractKernelComputation class DenseKernelComputation(AbstractKernelComputation): diff --git a/gpjax/kernels/computations/diagonal.py b/gpjax/kernels/computations/diagonal.py index 999bac468..fe51be4a1 100644 --- a/gpjax/kernels/computations/diagonal.py +++ b/gpjax/kernels/computations/diagonal.py @@ -14,10 +14,10 @@ # ============================================================================== from jax import vmap -from gpjax.linops import ( - DiagonalLinearOperator, -) from jaxtyping import Array, Float + +from gpjax.linops import DiagonalLinearOperator + from .base import AbstractKernelComputation diff --git a/gpjax/kernels/computations/eigen.py b/gpjax/kernels/computations/eigen.py index ab3271336..c90f53c60 100644 --- a/gpjax/kernels/computations/eigen.py +++ b/gpjax/kernels/computations/eigen.py @@ -13,36 +13,31 @@ # limitations under the License. # ============================================================================== +from dataclasses import dataclass from typing import Dict import jax.numpy as jnp from jaxtyping import Array, Float + from .base import AbstractKernelComputation -from dataclasses import dataclass @dataclass class EigenKernelComputation(AbstractKernelComputation): - eigenvalues: Float[Array, "N"] = None - eigenvectors: Float[Array, "N N"] = None - num_verticies: int = None - def cross_covariance( - self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] + self, x: Float[Array, "N D"], y: Float[Array, "M D"] ) -> Float[Array, "N M"]: - # Extract the graph Laplacian's eigenvalues - evals = self.eigenvalues # Transform the eigenvalues of the graph Laplacian according to the # RBF kernel's SPDE form. S = jnp.power( - evals + self.kernel.eigenvalues + 2 * self.kernel.smoothness / self.kernel.lengthscale / self.kernel.lengthscale, -self.kernel.smoothness, ) - S = jnp.multiply(S, self.num_vertex / jnp.sum(S)) + S = jnp.multiply(S, self.kernel.num_vertex / jnp.sum(S)) # Scale the transform eigenvalues by the kernel variance - S = jnp.multiply(S, params["variance"]) + S = jnp.multiply(S, self.kernel.variance) return self.kernel(x, y, S=S) diff --git a/gpjax/kernels/non_euclidean/graph.py b/gpjax/kernels/non_euclidean/graph.py index 2c31b4c78..e39010013 100644 --- a/gpjax/kernels/non_euclidean/graph.py +++ b/gpjax/kernels/non_euclidean/graph.py @@ -13,17 +13,20 @@ # limitations under the License. # ============================================================================== +from dataclasses import dataclass, replace from typing import Dict, List, Optional import jax.numpy as jnp +import tensorflow_probability.substrates.jax as tfp from jax.random import KeyArray from jaxtyping import Array, Float -from dataclasses import dataclass -from ..computations import EigenKernelComputation -from ..base import AbstractKernel +from simple_pytree import static_field + from ...parameters import param_field +from ..base import AbstractKernel +from ..computations import AbstractKernelComputation, EigenKernelComputation from .utils import jax_gather_nd -import tensorflow_probability.substrates.jax as tfp + tfb = tfp.bijectors ########################################## @@ -42,15 +45,21 @@ class GraphKernel(AbstractKernel, AbstractGraphKernel): laplacian (Float[Array]): An N x N matrix representing the Laplacian matrix of a graph. compute_engine """ - lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus) + + lengthscale: Float[Array, "D"] = param_field( + jnp.array([1.0]), bijector=tfb.Softplus + ) variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus) smoothness: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus) + eigenvalues: Float[Array, "N"] = static_field(None) + eigenvectors: Float[Array, "N N"] = static_field(None) + num_vertex: Float[Array, "1"] = static_field(None) + compute_engine: AbstractKernelComputation = static_field(EigenKernelComputation) def __post_init__(self): - evals, self.evecs = jnp.linalg.eigh(self.laplacian) - self.evals = evals.reshape(-1, 1) - self.compute_engine.eigensystem = self.evals, self.evecs - self.compute_engine.num_vertex = self.laplacian.shape[0] + evals, self.eigenvectors = jnp.linalg.eigh(self.laplacian) + self.eigenvalues = evals.reshape(-1, 1) + self.num_vertex = self.eigenvalues.shape[0] def __call__( self, @@ -68,16 +77,7 @@ def __call__( Float[Array, "1"]: The value of :math:`k(v_i, v_j)`. """ S = kwargs["S"] - Kxx = (jax_gather_nd(self.evecs, x) * S[None, :]) @ jnp.transpose( - jax_gather_nd(self.evecs, y) + Kxx = (jax_gather_nd(self.eigenvectors, x) * S.squeeze()) @ jnp.transpose( + jax_gather_nd(self.eigenvectors, y) ) # shape (n,n) return Kxx.squeeze() - - @property - def num_vertex(self) -> int: - """The number of vertices within the graph. - - Returns: - int: An integer representing the number of vertices within the graph. - """ - return self.compute_engine.num_vertex diff --git a/gpjax/kernels/non_euclidean/utils.py b/gpjax/kernels/non_euclidean/utils.py index b7c28e991..25db92d38 100644 --- a/gpjax/kernels/non_euclidean/utils.py +++ b/gpjax/kernels/non_euclidean/utils.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -from jaxtyping import Num, Array, Int +from jaxtyping import Array, Int, Num def jax_gather_nd( diff --git a/gpjax/kernels/nonstationary/linear.py b/gpjax/kernels/nonstationary/linear.py index 66ada13fb..fafee3bfd 100644 --- a/gpjax/kernels/nonstationary/linear.py +++ b/gpjax/kernels/nonstationary/linear.py @@ -18,6 +18,8 @@ from jaxtyping import Array from dataclasses import dataclass + +import jax.numpy as jnp from jaxtyping import Array, Float from ..base import AbstractKernel diff --git a/gpjax/kernels/nonstationary/polynomial.py b/gpjax/kernels/nonstationary/polynomial.py index 4a25db188..3b6219b9b 100644 --- a/gpjax/kernels/nonstationary/polynomial.py +++ b/gpjax/kernels/nonstationary/polynomial.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== +from dataclasses import dataclass + import jax.numpy as jnp import tensorflow_probability.substrates.jax.bijectors as tfb diff --git a/gpjax/kernels/stationary/matern12.py b/gpjax/kernels/stationary/matern12.py index 6e4d3938f..e4e0b6730 100644 --- a/gpjax/kernels/stationary/matern12.py +++ b/gpjax/kernels/stationary/matern12.py @@ -18,13 +18,14 @@ import tensorflow_probability.substrates.jax.distributions as tfd from jaxtyping import Array, Float -import tensorflow_probability.substrates.jax as tfp from ...base import param_field from ..base import AbstractKernel from .utils import build_student_t_distribution, euclidean_distance + tfd = tfp.distributions + @dataclass class Matern12(AbstractKernel): """The Matérn kernel with smoothness parameter fixed at 0.5.""" diff --git a/gpjax/kernels/stationary/matern32.py b/gpjax/kernels/stationary/matern32.py index f6efc685d..561195db0 100644 --- a/gpjax/kernels/stationary/matern32.py +++ b/gpjax/kernels/stationary/matern32.py @@ -18,13 +18,14 @@ import tensorflow_probability.substrates.jax.distributions as tfd from jaxtyping import Array, Float -import tensorflow_probability.substrates.jax as tfp from ...base import param_field from ..base import AbstractKernel from .utils import build_student_t_distribution, euclidean_distance + tfd = tfp.distributions + @dataclass class Matern32(AbstractKernel): """The Matérn kernel with smoothness parameter fixed at 1.5.""" diff --git a/gpjax/kernels/stationary/matern52.py b/gpjax/kernels/stationary/matern52.py index d96636c44..2cd5817d5 100644 --- a/gpjax/kernels/stationary/matern52.py +++ b/gpjax/kernels/stationary/matern52.py @@ -18,13 +18,14 @@ import tensorflow_probability.substrates.jax.distributions as tfd from jaxtyping import Array, Float -import tensorflow_probability.substrates.jax as tfp from ...base import param_field from ..base import AbstractKernel from .utils import build_student_t_distribution, euclidean_distance + tfd = tfp.distributions + @dataclass class Matern52(AbstractKernel): """The Matérn kernel with smoothness parameter fixed at 2.5.""" @@ -32,9 +33,7 @@ class Matern52(AbstractKernel): lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) - def __call__( - self, x: Float[Array, "D"], y: Float[Array, "D"] - ) -> Float[Array, "1"]: + def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` diff --git a/gpjax/kernels/stationary/rbf.py b/gpjax/kernels/stationary/rbf.py index 0d6a74baa..7393c5e96 100644 --- a/gpjax/kernels/stationary/rbf.py +++ b/gpjax/kernels/stationary/rbf.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== +from dataclasses import dataclass + import jax.numpy as jnp import tensorflow_probability.substrates.jax.bijectors as tfb import tensorflow_probability.substrates.jax.distributions as tfd @@ -23,12 +25,10 @@ from ...base import param_field from ..base import AbstractKernel from .utils import squared_distance -import tensorflow_probability.substrates.jax as tfp -from dataclasses import dataclass -from ...parameters import param_field, Softplus tfd = tfp.distributions + @dataclass class RBF(AbstractKernel): """The Radial Basis Function (RBF) kernel.""" diff --git a/gpjax/kernels/stationary/utils.py b/gpjax/kernels/stationary/utils.py index e7aa67d13..9ff9248ad 100644 --- a/gpjax/kernels/stationary/utils.py +++ b/gpjax/kernels/stationary/utils.py @@ -14,8 +14,8 @@ # ============================================================================== import jax.numpy as jnp -from jaxtyping import Array, Float import tensorflow_probability.substrates.jax as tfp +from jaxtyping import Array, Float tfd = tfp.distributions @@ -33,9 +33,7 @@ def build_student_t_distribution(nu: int) -> tfd.Distribution: return dist -def squared_distance( - x: Float[Array, "D"], y: Float[Array, "D"] -) -> Float[Array, "1"]: +def squared_distance(x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Compute the squared distance between a pair of inputs. Args: @@ -49,9 +47,7 @@ def squared_distance( return jnp.sum((x - y) ** 2) -def euclidean_distance( - x: Float[Array, "D"], y: Float[Array, "D"] -) -> Float[Array, "1"]: +def euclidean_distance(x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Compute the euclidean distance between a pair of inputs. Args: diff --git a/gpjax/linops/__init__.py b/gpjax/linops/__init__.py index d41888f99..5e7d2212a 100644 --- a/gpjax/linops/__init__.py +++ b/gpjax/linops/__init__.py @@ -13,20 +13,17 @@ # limitations under the License. # ============================================================================== -from .linear_operator import LinearOperator +from .constant_diagonal_linear_operator import ConstantDiagonalLinearOperator from .dense_linear_operator import DenseLinearOperator from .diagonal_linear_operator import DiagonalLinearOperator -from .constant_diagonal_linear_operator import ConstantDiagonalLinearOperator from .identity_linear_operator import IdentityLinearOperator -from .zero_linear_operator import ZeroLinearOperator +from .linear_operator import LinearOperator from .triangular_linear_operator import ( LowerTriangularLinearOperator, UpperTriangularLinearOperator, ) -from .utils import ( - identity, - to_dense, -) +from .utils import identity, to_dense +from .zero_linear_operator import ZeroLinearOperator __all__ = [ "LinearOperator", diff --git a/gpjax/linops/constant_diagonal_linear_operator.py b/gpjax/linops/constant_diagonal_linear_operator.py index d3a53d19f..43ae533f2 100644 --- a/gpjax/linops/constant_diagonal_linear_operator.py +++ b/gpjax/linops/constant_diagonal_linear_operator.py @@ -15,15 +15,15 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Any, Union import jax.numpy as jnp from jaxtyping import Array, Float from simple_pytree import static_field -from dataclasses import dataclass -from .linear_operator import LinearOperator from .diagonal_linear_operator import DiagonalLinearOperator +from .linear_operator import LinearOperator def _check_args(value: Any, size: Any) -> None: diff --git a/gpjax/linops/dense_linear_operator.py b/gpjax/linops/dense_linear_operator.py index 91e4d9e56..ff5d249e0 100644 --- a/gpjax/linops/dense_linear_operator.py +++ b/gpjax/linops/dense_linear_operator.py @@ -20,10 +20,10 @@ if TYPE_CHECKING: from .diagonal_linear_operator import DiagonalLinearOperator +from dataclasses import dataclass from typing import Union import jax.numpy as jnp -from dataclasses import dataclass from jaxtyping import Array, Float from .linear_operator import LinearOperator diff --git a/gpjax/linops/diagonal_linear_operator.py b/gpjax/linops/diagonal_linear_operator.py index f2be325df..5d1780396 100644 --- a/gpjax/linops/diagonal_linear_operator.py +++ b/gpjax/linops/diagonal_linear_operator.py @@ -15,13 +15,14 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Any, Union import jax.numpy as jnp from jaxtyping import Array, Float -from dataclasses import dataclass -from .linear_operator import LinearOperator + from .dense_linear_operator import DenseLinearOperator +from .linear_operator import LinearOperator from .utils import to_linear_operator diff --git a/gpjax/linops/identity_linear_operator.py b/gpjax/linops/identity_linear_operator.py index af595d18f..4cd7db1bd 100644 --- a/gpjax/linops/identity_linear_operator.py +++ b/gpjax/linops/identity_linear_operator.py @@ -15,10 +15,10 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Any import jax.numpy as jnp -from dataclasses import dataclass from jaxtyping import Array, Float from .constant_diagonal_linear_operator import ConstantDiagonalLinearOperator diff --git a/gpjax/linops/linear_operator.py b/gpjax/linops/linear_operator.py index c8e5dd508..f39e7dcc4 100644 --- a/gpjax/linops/linear_operator.py +++ b/gpjax/linops/linear_operator.py @@ -14,16 +14,18 @@ # ============================================================================== from __future__ import annotations + from typing import TYPE_CHECKING if TYPE_CHECKING: from .diagonal_linear_operator import DiagonalLinearOperator import abc -import jax.numpy as jnp from dataclasses import dataclass +from typing import Any, Generic, Iterable, Mapping, Tuple, TypeVar, Union + +import jax.numpy as jnp from jaxtyping import Array, Float -from typing import Any, TypeVar, Iterable, Mapping, Generic, Tuple, Union from simple_pytree import Pytree, static_field # Generic type. diff --git a/gpjax/linops/triangular_linear_operator.py b/gpjax/linops/triangular_linear_operator.py index 9a7d5b548..64a227e17 100644 --- a/gpjax/linops/triangular_linear_operator.py +++ b/gpjax/linops/triangular_linear_operator.py @@ -19,8 +19,8 @@ import jax.scipy as jsp from jaxtyping import Array, Float -from .linear_operator import LinearOperator from .dense_linear_operator import DenseLinearOperator +from .linear_operator import LinearOperator class LowerTriangularLinearOperator(DenseLinearOperator): diff --git a/gpjax/linops/utils.py b/gpjax/linops/utils.py index 292ed013f..efd381cbd 100644 --- a/gpjax/linops/utils.py +++ b/gpjax/linops/utils.py @@ -15,15 +15,14 @@ from __future__ import annotations -from typing import Union, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Tuple, Union if TYPE_CHECKING: from .identity_linear_operator import IdentityLinearOperator -from jaxtyping import Float, Array - -import jax.numpy as jnp import jax +import jax.numpy as jnp +from jaxtyping import Array, Float from .linear_operator import LinearOperator diff --git a/gpjax/linops/zero_linear_operator.py b/gpjax/linops/zero_linear_operator.py index af309ffd7..a58ddeb6c 100644 --- a/gpjax/linops/zero_linear_operator.py +++ b/gpjax/linops/zero_linear_operator.py @@ -15,15 +15,15 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Any, Tuple, Union import jax.numpy as jnp from jaxtyping import Array, Float -from dataclasses import dataclass -from .linear_operator import LinearOperator from .diagonal_linear_operator import DiagonalLinearOperator -from .utils import check_shapes_match, to_linear_operator, default_dtype +from .linear_operator import LinearOperator +from .utils import check_shapes_match, default_dtype, to_linear_operator def _check_size(shape: Any) -> None: diff --git a/tests/test_base/test_module.py b/tests/test_base/test_module.py index 1e33d6747..105f9b2c2 100644 --- a/tests/test_base/test_module.py +++ b/tests/test_base/test_module.py @@ -368,11 +368,11 @@ def __init__(self, a, sub_tree, b): def loss(tree): t = tree.stop_gradient() return jnp.sum( - t.a**2 - + t.sub_tree.c**2 - + t.sub_tree.d**2 - + t.sub_tree.e**2 - + t.b**2 + t.a ** 2 + + t.sub_tree.c ** 2 + + t.sub_tree.d ** 2 + + t.sub_tree.e ** 2 + + t.b ** 2 ) g = jax.grad(loss)(new_tree) @@ -866,4 +866,4 @@ class Foo(Module, mutable=True): # test mutation pytree.x = 4 - assert pytree.x == 4 \ No newline at end of file + assert pytree.x == 4 diff --git a/tests/test_kernels/__init__.py b/tests/test_kernels/__init__.py index 837f63bb0..2589bbd13 100644 --- a/tests/test_kernels/__init__.py +++ b/tests/test_kernels/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. # ============================================================================== -"""Test suite for the gpjax.kernels package.""" \ No newline at end of file +"""Test suite for the gpjax.kernels package.""" diff --git a/tests/test_kernels/test_approximations.py b/tests/test_kernels/test_approximations.py index c5674561c..a7e196ec5 100644 --- a/tests/test_kernels/test_approximations.py +++ b/tests/test_kernels/test_approximations.py @@ -1,22 +1,24 @@ +from typing import Tuple + +import jax +import jax.numpy as jnp +import jax.random as jr import pytest +from jax.config import config + from gpjax.kernels.approximations import RFF +from gpjax.kernels.base import AbstractKernel +from gpjax.kernels.nonstationary import Linear, Polynomial from gpjax.kernels.stationary import ( + RBF, Matern12, Matern32, Matern52, - RBF, - RationalQuadratic, - PoweredExponential, Periodic, + PoweredExponential, + RationalQuadratic, ) -from gpjax.kernels.nonstationary import Polynomial, Linear -from gpjax.kernels.base import AbstractKernel -import jax.random as jr -from jax.config import config -import jax.numpy as jnp from gpjax.linops import DenseLinearOperator -from typing import Tuple -import jax config.update("jax_enable_x64", True) _jitter = 1e-6 @@ -128,10 +130,18 @@ def test_exactness(kernel): assert max_delta < 0.1 -# @pytest.mark.parametrize( -# "kernel", -# [RationalQuadratic, PoweredExponential, Polynomial, Linear, Periodic], -# ) -# def test_value_error(kernel): -# with pytest.raises(ValueError): -# RFF(kernel(), num_basis_fns=10) +@pytest.mark.parametrize( + "kernel", + [RationalQuadratic, PoweredExponential, Polynomial, Linear, Periodic], +) +def test_value_error(kernel): + with pytest.raises(ValueError): + RFF(kernel(), num_basis_fns=10) + + +@pytest.mark.parametrize("kernel", [RBF(), Matern12(), Matern32(), Matern52()]) +def stochastic_init(kernel: AbstractKernel): + k1 = RFF(kernel, num_basis_fns=10, key=123) + k2 = RFF(kernel, num_basis_fns=10, key=42) + + assert (k1.frequencies != k2.frequencies).any() diff --git a/tests/test_kernels/test_base.py b/tests/test_kernels/test_base.py index 40b348931..f6c7ca1d1 100644 --- a/tests/test_kernels/test_base.py +++ b/tests/test_kernels/test_base.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== +from dataclasses import dataclass + import jax.numpy as jnp import pytest from jax.config import config @@ -23,6 +25,7 @@ ProductKernel, SumKernel, ) +from gpjax.kernels.nonstationary import Linear, Polynomial from gpjax.kernels.stationary import ( RBF, Matern12, @@ -54,7 +57,9 @@ class DummyKernel(AbstractKernel): test_a: Float[Array, "1"] = jnp.array([1.0]) test_b: Float[Array, "1"] = param_field(jnp.array([2.0]), bijector=tfb.Softplus()) - def __call__(self, x: Float[Array, "1 D"], y: Float[Array, "1 D"]) -> Float[Array, "1"]: + def __call__( + self, x: Float[Array, "1 D"], y: Float[Array, "1 D"] + ) -> Float[Array, "1"]: return x * self.test_b * y # Initialise dummy kernel class and test __call__ method: diff --git a/tests/test_kernels/test_computation.py b/tests/test_kernels/test_computation.py index 38ba5202a..5ce35018b 100644 --- a/tests/test_kernels/test_computation.py +++ b/tests/test_kernels/test_computation.py @@ -2,19 +2,19 @@ import pytest from gpjax.kernels.computations import ( - DiagonalKernelComputation, ConstantDiagonalKernelComputation, + DiagonalKernelComputation, ) +from gpjax.kernels.nonstationary import Linear, Polynomial from gpjax.kernels.stationary import ( RBF, Matern12, Matern32, Matern52, + Periodic, PoweredExponential, RationalQuadratic, - Periodic, ) -from gpjax.kernels.nonstationary import Linear, Polynomial @pytest.mark.parametrize( @@ -39,7 +39,7 @@ def test_change_computation(kernel): dense_diagonals = jnp.diag(dense_matrix) # Let's now change the computation to DiagonalKernelComputation - kernel = kernel.replace(compute_engine = DiagonalKernelComputation) + kernel = kernel.replace(compute_engine=DiagonalKernelComputation) diagonal_matrix = kernel.gram(x).to_dense() diag_entries = jnp.diag(diagonal_matrix) @@ -50,7 +50,7 @@ def test_change_computation(kernel): assert jnp.allclose(diagonal_matrix - jnp.diag(diag_entries), 0.0) # Let's now change the computation to ConstantDiagonalKernelComputation - kernel = kernel.replace(compute_engine = ConstantDiagonalKernelComputation) + kernel = kernel.replace(compute_engine=ConstantDiagonalKernelComputation) constant_diagonal_matrix = kernel.gram(x).to_dense() constant_entries = jnp.diag(constant_diagonal_matrix) diff --git a/tests/test_kernels/test_non_euclidean.py b/tests/test_kernels/test_non_euclidean.py index 266115689..95fd52f26 100644 --- a/tests/test_kernels/test_non_euclidean.py +++ b/tests/test_kernels/test_non_euclidean.py @@ -13,54 +13,44 @@ # # limitations under the License. # # ============================================================================== -# import jax.numpy as jnp -# import jax.random as jr -# import networkx as nx -# from jax.config import config -# from jaxlinop import identity +import jax.numpy as jnp +import jax.random as jr +import networkx as nx +from jax.config import config -# from gpjax.kernels.non_euclidean import GraphKernel +from gpjax.kernels.non_euclidean import GraphKernel +from gpjax.linops import identity # # Enable Float64 for more stable matrix inversions. -# config.update("jax_enable_x64", True) -# _initialise_key = jr.PRNGKey(123) -# _jitter = 1e-6 +config.update("jax_enable_x64", True) -# def test_graph_kernel(): -# # Create a random graph, G, and verice labels, x, -# n_verticies = 20 -# n_edges = 40 -# G = nx.gnm_random_graph(n_verticies, n_edges, seed=123) -# x = jnp.arange(n_verticies).reshape(-1, 1) -# # Compute graph laplacian -# L = nx.laplacian_matrix(G).toarray() + jnp.eye(n_verticies) * 1e-12 +def test_graph_kernel(): + # Create a random graph, G, and verice labels, x, + n_verticies = 20 + n_edges = 40 + G = nx.gnm_random_graph(n_verticies, n_edges, seed=123) + x = jnp.arange(n_verticies).reshape(-1, 1) -# # Create graph kernel -# kern = GraphKernel(laplacian=L) -# assert isinstance(kern, GraphKernel) -# assert kern.num_vertex == n_verticies -# assert kern.evals.shape == (n_verticies, 1) -# assert kern.evecs.shape == (n_verticies, n_verticies) + # Compute graph laplacian + L = nx.laplacian_matrix(G).toarray() + jnp.eye(n_verticies) * 1e-12 -# # Unpack kernel computation -# kern.gram + # Create graph kernel + kern = GraphKernel(laplacian=L) + assert isinstance(kern, GraphKernel) + assert kern.num_vertex == n_verticies + assert kern.eigenvalues.shape == (n_verticies, 1) + assert kern.eigenvectors.shape == (n_verticies, n_verticies) -# # Initialise default parameters -# params = kern.init_params(_initialise_key) -# assert isinstance(params, dict) -# assert list(sorted(list(params.keys()))) == [ -# "lengthscale", -# "smoothness", -# "variance", -# ] + # Unpack kernel computation + kern.gram -# # Compute gram matrix -# Kxx = kern.gram(params, x) -# assert Kxx.shape == (n_verticies, n_verticies) + # Compute gram matrix + Kxx = kern.gram(x) + assert Kxx.shape == (n_verticies, n_verticies) -# # Check positive definiteness -# Kxx += identity(n_verticies) * _jitter -# eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) -# assert all(eigen_values > 0) + # Check positive definiteness + Kxx += identity(n_verticies) * 1e-6 + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + assert all(eigen_values > 0) diff --git a/tests/test_kernels/test_nonstationary.py b/tests/test_kernels/test_nonstationary.py index b3932855e..66db76212 100644 --- a/tests/test_kernels/test_nonstationary.py +++ b/tests/test_kernels/test_nonstationary.py @@ -20,10 +20,10 @@ import jax.tree_util as jtu import pytest from jax.config import config -from gpjax.linops import LinearOperator, identity from gpjax.kernels.base import AbstractKernel from gpjax.kernels.nonstationary import Linear, Polynomial +from gpjax.linops import LinearOperator, identity # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) diff --git a/tests/test_kernels/test_stationary.py b/tests/test_kernels/test_stationary.py index 20dc8f010..7088c49ad 100644 --- a/tests/test_kernels/test_stationary.py +++ b/tests/test_kernels/test_stationary.py @@ -22,17 +22,21 @@ import tensorflow_probability.substrates.jax.distributions as tfd import pytest -import distrax as dx +import tensorflow_probability.substrates.jax.distributions as tfd from jax.config import config from gpjax.linops import LinearOperator from gpjax.kernels.base import AbstractKernel +from gpjax.kernels.computations import ( + ConstantDiagonalKernelComputation, + DenseKernelComputation, + DiagonalKernelComputation, +) from gpjax.kernels.stationary import ( RBF, Matern12, Matern32, Matern52, - White, Periodic, PoweredExponential, RationalQuadratic, @@ -139,8 +143,7 @@ def test_spectral_density(self): kernel: AbstractKernel = self.kernel() if self.kernel not in [RBF, Matern12, Matern32, Matern52]: - with pytest.raises(AttributeError): - kernel.spectral_density + assert not kernel.spectral_density else: sdensity = kernel.spectral_density assert sdensity.name == self.spectral_density_name @@ -222,4 +225,4 @@ class TestRationalQuadratic(BaseTestKernel): @pytest.mark.parametrize("smoothness", [1, 2, 3]) def test_build_studentt_dist(smoothness: int) -> None: dist = build_student_t_distribution(smoothness) - assert isinstance(dist, dx.Distribution) + assert isinstance(dist, tfd.Distribution) diff --git a/tests/test_kernels/test_utils.py b/tests/test_kernels/test_utils.py index b707fb4a2..9b167df39 100644 --- a/tests/test_kernels/test_utils.py +++ b/tests/test_kernels/test_utils.py @@ -18,6 +18,7 @@ import jax.numpy as jnp import pytest from jaxtyping import Array, Float + from gpjax.kernels.stationary.utils import euclidean_distance diff --git a/tests/test_linops/test_constant_linear_operator.py b/tests/test_linops/test_constant_linear_operator.py index 0ec6707d2..48e2efa74 100644 --- a/tests/test_linops/test_constant_linear_operator.py +++ b/tests/test_linops/test_constant_linear_operator.py @@ -19,14 +19,15 @@ import pytest from jax.config import config - # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) _PRNGKey = jr.PRNGKey(42) -from gpjax.linops.diagonal_linear_operator import DiagonalLinearOperator +from gpjax.linops.constant_diagonal_linear_operator import ( + ConstantDiagonalLinearOperator, +) from gpjax.linops.dense_linear_operator import DenseLinearOperator -from gpjax.linops.constant_diagonal_linear_operator import ConstantDiagonalLinearOperator +from gpjax.linops.diagonal_linear_operator import DiagonalLinearOperator def approx_equal(res: jnp.ndarray, actual: jnp.ndarray) -> bool: diff --git a/tests/test_linops/test_dense_linear_operator.py b/tests/test_linops/test_dense_linear_operator.py index f46f688f6..37110c2f7 100644 --- a/tests/test_linops/test_dense_linear_operator.py +++ b/tests/test_linops/test_dense_linear_operator.py @@ -14,19 +14,18 @@ # ============================================================================== +import jax import jax.numpy as jnp import jax.random as jr -import jax import pytest from jax.config import config - # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) _PRNGKey = jr.PRNGKey(42) -from gpjax.linops.diagonal_linear_operator import DiagonalLinearOperator from gpjax.linops.dense_linear_operator import DenseLinearOperator +from gpjax.linops.diagonal_linear_operator import DiagonalLinearOperator from gpjax.linops.triangular_linear_operator import LowerTriangularLinearOperator diff --git a/tests/test_linops/test_diagonal_linear_operator.py b/tests/test_linops/test_diagonal_linear_operator.py index 5187af33a..981ec6064 100644 --- a/tests/test_linops/test_diagonal_linear_operator.py +++ b/tests/test_linops/test_diagonal_linear_operator.py @@ -14,19 +14,18 @@ # ============================================================================== +import jax import jax.numpy as jnp import jax.random as jr import pytest -import jax from jax.config import config - # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) _PRNGKey = jr.PRNGKey(42) -from gpjax.linops.diagonal_linear_operator import DiagonalLinearOperator from gpjax.linops.dense_linear_operator import DenseLinearOperator +from gpjax.linops.diagonal_linear_operator import DiagonalLinearOperator def approx_equal(res: jax.Array, actual: jax.Array) -> bool: diff --git a/tests/test_linops/test_identity_linear_operator.py b/tests/test_linops/test_identity_linear_operator.py index 5e10daaec..be1d8eba7 100644 --- a/tests/test_linops/test_identity_linear_operator.py +++ b/tests/test_linops/test_identity_linear_operator.py @@ -14,9 +14,9 @@ # ============================================================================== +import jax import jax.numpy as jnp import jax.random as jr -import jax import pytest from jax.config import config @@ -24,10 +24,12 @@ config.update("jax_enable_x64", True) _PRNGKey = jr.PRNGKey(42) +from gpjax.linops.constant_diagonal_linear_operator import ( + ConstantDiagonalLinearOperator, +) +from gpjax.linops.dense_linear_operator import DenseLinearOperator from gpjax.linops.diagonal_linear_operator import DiagonalLinearOperator from gpjax.linops.identity_linear_operator import IdentityLinearOperator -from gpjax.linops.constant_diagonal_linear_operator import ConstantDiagonalLinearOperator -from gpjax.linops.dense_linear_operator import DenseLinearOperator def approx_equal(res: jax.Array, actual: jax.Array) -> bool: diff --git a/tests/test_linops/test_linear_operator.py b/tests/test_linops/test_linear_operator.py index 0729e8ecb..b0b2cacb1 100644 --- a/tests/test_linops/test_linear_operator.py +++ b/tests/test_linops/test_linear_operator.py @@ -13,18 +13,21 @@ # limitations under the License. # ============================================================================== -import pytest -import jax.tree_util as jtu -import jax.numpy as jnp -from gpjax.linops.linear_operator import LinearOperator from dataclasses import dataclass + +import jax.numpy as jnp +import jax.tree_util as jtu +import pytest from simple_pytree import static_field +from gpjax.linops.linear_operator import LinearOperator + def test_covariance_operator() -> None: with pytest.raises(TypeError): LinearOperator() + @pytest.mark.parametrize("is_dataclass", [True, False]) @pytest.mark.parametrize("shape", [(1, 1), (2, 3), (4, 5, 6), [7, 8]]) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.float64]) @@ -66,7 +69,7 @@ def from_dense(cls, *args, **kwargs): assert linop.shape == shape assert linop.dtype == dtype assert linop.ndim == len(shape) - assert jtu.tree_leaves(linop) == [] # shape and dtype are static! + assert jtu.tree_leaves(linop) == [] # shape and dtype are static! # if not is_dataclass: # assert linop.__repr__() == f"DummyLinearOperator(shape={shape}, dtype={dtype})" @@ -80,7 +83,7 @@ def test_instantiate_with_attributes(is_dataclass, shape, dtype) -> None: class DummyLinearOperator(LinearOperator): a: int - b: int = static_field() # Lets have a static attribute here. + b: int = static_field() # Lets have a static attribute here. c: int def __init__(self, shape, dtype, a=1, b=2, c=3): @@ -124,7 +127,7 @@ def from_dense(cls, *args, **kwargs): assert linop.shape == shape assert linop.dtype == dtype assert linop.ndim == len(shape) - assert jtu.tree_leaves(linop) == [1, 3] # b, shape, dtype are static! + assert jtu.tree_leaves(linop) == [1, 3] # b, shape, dtype are static! # if not is_dataclass: - # assert linop.__repr__() == f"DummyLinearOperator(shape={shape}, dtype={dtype})" \ No newline at end of file + # assert linop.__repr__() == f"DummyLinearOperator(shape={shape}, dtype={dtype})" diff --git a/tests/test_linops/test_triangular_linear_operator.py b/tests/test_linops/test_triangular_linear_operator.py index d03d1b991..5c563cb73 100644 --- a/tests/test_linops/test_triangular_linear_operator.py +++ b/tests/test_linops/test_triangular_linear_operator.py @@ -25,9 +25,9 @@ atol: float = 1e-6 config.update("jax_enable_x64", True) +from gpjax.linops.dense_linear_operator import DenseLinearOperator +from gpjax.linops.diagonal_linear_operator import DiagonalLinearOperator from gpjax.linops.triangular_linear_operator import ( LowerTriangularLinearOperator, UpperTriangularLinearOperator, ) -from gpjax.linops.dense_linear_operator import DenseLinearOperator -from gpjax.linops.diagonal_linear_operator import DiagonalLinearOperator diff --git a/tests/test_linops/test_utils.py b/tests/test_linops/test_utils.py index d65e9df47..f0fea6455 100644 --- a/tests/test_linops/test_utils.py +++ b/tests/test_linops/test_utils.py @@ -23,9 +23,8 @@ config.update("jax_enable_x64", True) _PRNGKey = jr.PRNGKey(42) -from gpjax.linops.identity_linear_operator import IdentityLinearOperator from gpjax.linops.dense_linear_operator import DenseLinearOperator - +from gpjax.linops.identity_linear_operator import IdentityLinearOperator from gpjax.linops.utils import identity, to_dense diff --git a/tests/test_linops/test_zero_linear_operator.py b/tests/test_linops/test_zero_linear_operator.py index 149f9a81d..769078b4f 100644 --- a/tests/test_linops/test_zero_linear_operator.py +++ b/tests/test_linops/test_zero_linear_operator.py @@ -14,19 +14,18 @@ # ============================================================================== +import jax import jax.numpy as jnp import jax.random as jr -import jax import pytest from jax.config import config - # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) _PRNGKey = jr.PRNGKey(42) -from gpjax.linops.diagonal_linear_operator import DiagonalLinearOperator from gpjax.linops.dense_linear_operator import DenseLinearOperator +from gpjax.linops.diagonal_linear_operator import DiagonalLinearOperator from gpjax.linops.zero_linear_operator import ZeroLinearOperator From d6a045bd7e97343cb14eff154c7720692feac329 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Tue, 4 Apr 2023 12:23:00 +0100 Subject: [PATCH 34/44] Fix imports and switch to FillTriangular for now, to avoid dtype error. --- gpjax/kernels/approximations/rff.py | 13 ++++--------- gpjax/kernels/non_euclidean/graph.py | 6 ++---- gpjax/kernels/stationary/matern12.py | 2 +- gpjax/kernels/stationary/matern32.py | 2 +- gpjax/kernels/stationary/matern52.py | 2 +- gpjax/kernels/stationary/periodic.py | 2 ++ gpjax/kernels/stationary/powered_exponential.py | 2 ++ gpjax/kernels/stationary/rbf.py | 2 -- gpjax/variational_families.py | 2 +- tests/test_fit.py | 6 +++++- tests/test_kernels/test_approximations.py | 10 +++++----- tests/test_kernels/test_stationary.py | 2 +- 12 files changed, 25 insertions(+), 26 deletions(-) diff --git a/gpjax/kernels/approximations/rff.py b/gpjax/kernels/approximations/rff.py index 7f6a6b2c8..249fc5c31 100644 --- a/gpjax/kernels/approximations/rff.py +++ b/gpjax/kernels/approximations/rff.py @@ -1,20 +1,15 @@ +import tensorflow_probability.substrates.jax.bijectors as tfb + from dataclasses import dataclass -from typing import Any, Dict -import tensorflow_probability.substrates.jax as tfp from jax.random import KeyArray, PRNGKey from jaxtyping import Array, Float from simple_pytree import static_field -from ...parameters import param_field +from ...base import param_field from ..base import AbstractKernel -from ..computations import ( - AbstractKernelComputation, - BasisFunctionComputation, - DenseKernelComputation, -) +from ..computations import BasisFunctionComputation -tfb = tfp.bijectors @dataclass diff --git a/gpjax/kernels/non_euclidean/graph.py b/gpjax/kernels/non_euclidean/graph.py index e39010013..40d09c867 100644 --- a/gpjax/kernels/non_euclidean/graph.py +++ b/gpjax/kernels/non_euclidean/graph.py @@ -13,16 +13,14 @@ # limitations under the License. # ============================================================================== -from dataclasses import dataclass, replace -from typing import Dict, List, Optional +from dataclasses import dataclass import jax.numpy as jnp import tensorflow_probability.substrates.jax as tfp -from jax.random import KeyArray from jaxtyping import Array, Float from simple_pytree import static_field -from ...parameters import param_field +from ...base import param_field from ..base import AbstractKernel from ..computations import AbstractKernelComputation, EigenKernelComputation from .utils import jax_gather_nd diff --git a/gpjax/kernels/stationary/matern12.py b/gpjax/kernels/stationary/matern12.py index e4e0b6730..9d0c4ce39 100644 --- a/gpjax/kernels/stationary/matern12.py +++ b/gpjax/kernels/stationary/matern12.py @@ -23,7 +23,7 @@ from ..base import AbstractKernel from .utils import build_student_t_distribution, euclidean_distance -tfd = tfp.distributions +from dataclasses import dataclass @dataclass diff --git a/gpjax/kernels/stationary/matern32.py b/gpjax/kernels/stationary/matern32.py index 561195db0..3484f2b70 100644 --- a/gpjax/kernels/stationary/matern32.py +++ b/gpjax/kernels/stationary/matern32.py @@ -23,7 +23,7 @@ from ..base import AbstractKernel from .utils import build_student_t_distribution, euclidean_distance -tfd = tfp.distributions +from dataclasses import dataclass @dataclass diff --git a/gpjax/kernels/stationary/matern52.py b/gpjax/kernels/stationary/matern52.py index 2cd5817d5..1db973047 100644 --- a/gpjax/kernels/stationary/matern52.py +++ b/gpjax/kernels/stationary/matern52.py @@ -23,7 +23,7 @@ from ..base import AbstractKernel from .utils import build_student_t_distribution, euclidean_distance -tfd = tfp.distributions +from dataclasses import dataclass @dataclass diff --git a/gpjax/kernels/stationary/periodic.py b/gpjax/kernels/stationary/periodic.py index b9fccf4ec..c5e35693a 100644 --- a/gpjax/kernels/stationary/periodic.py +++ b/gpjax/kernels/stationary/periodic.py @@ -23,6 +23,8 @@ from ...base import param_field from ..base import AbstractKernel +from dataclasses import dataclass + @dataclass class Periodic(AbstractKernel): diff --git a/gpjax/kernels/stationary/powered_exponential.py b/gpjax/kernels/stationary/powered_exponential.py index 725f17f3e..c2a704e8a 100644 --- a/gpjax/kernels/stationary/powered_exponential.py +++ b/gpjax/kernels/stationary/powered_exponential.py @@ -24,6 +24,8 @@ from ..base import AbstractKernel from .utils import euclidean_distance +from dataclasses import dataclass + @dataclass class PoweredExponential(AbstractKernel): diff --git a/gpjax/kernels/stationary/rbf.py b/gpjax/kernels/stationary/rbf.py index 7393c5e96..c552b60ed 100644 --- a/gpjax/kernels/stationary/rbf.py +++ b/gpjax/kernels/stationary/rbf.py @@ -26,8 +26,6 @@ from ..base import AbstractKernel from .utils import squared_distance -tfd = tfp.distributions - @dataclass class RBF(AbstractKernel): diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index 0f431391e..43810cd8f 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -95,7 +95,7 @@ class VariationalGaussian(AbstractVariationalGaussian): :math:`\\mu` and sqrt with S = sqrt sqrtᵀ. """ variational_mean: Float[Array, "N 1"] = param_field(None) - variational_root_covariance: Float[Array, "N N"] = param_field(None, bijector=tfb.FillScaleTriL(diag_shift=jnp.array(1e-6))) + variational_root_covariance: Float[Array, "N N"] = param_field(None, bijector=tfb.FillTriangular()) def __post_init__(self) -> None: diff --git a/tests/test_fit.py b/tests/test_fit.py index 6dcd0477e..a6746a678 100644 --- a/tests/test_fit.py +++ b/tests/test_fit.py @@ -18,7 +18,6 @@ import jax.random as jr import optax as ox import tensorflow_probability.substrates.jax.bijectors as tfb -import tensorflow_probability.substrates.jax.distributions as tfd import pytest @@ -35,6 +34,11 @@ from gpjax.objectives import ConjugateMLL, ELBO from gpjax.variational_families import VariationalGaussian +from jax.config import config + +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) + def test_simple_linear_model() -> None: diff --git a/tests/test_kernels/test_approximations.py b/tests/test_kernels/test_approximations.py index a7e196ec5..96faba5dd 100644 --- a/tests/test_kernels/test_approximations.py +++ b/tests/test_kernels/test_approximations.py @@ -107,12 +107,12 @@ def test_improvement(kernel, n_dim): better_approximation = RFF(base_kernel, num_basis_fns=50) b_linop = better_approximation.gram(x).to_dense() -# c_delta = jnp.linalg.norm(exact_linop - c_linop, ord="fro") -# b_delta = jnp.linalg.norm(exact_linop - b_linop, ord="fro") + c_delta = jnp.linalg.norm(exact_linop - c_linop, ord="fro") + b_delta = jnp.linalg.norm(exact_linop - b_linop, ord="fro") -# # The frobenius norm of the difference between the exact and approximate -# # should improve as we increase the number of basis functions -# assert c_delta > b_delta + # The frobenius norm of the difference between the exact and approximate + # should improve as we increase the number of basis functions + assert c_delta > b_delta @pytest.mark.parametrize("kernel", [RBF(), Matern12(), Matern32(), Matern52()]) diff --git a/tests/test_kernels/test_stationary.py b/tests/test_kernels/test_stationary.py index 7088c49ad..8e2f3fc00 100644 --- a/tests/test_kernels/test_stationary.py +++ b/tests/test_kernels/test_stationary.py @@ -30,7 +30,6 @@ from gpjax.kernels.computations import ( ConstantDiagonalKernelComputation, DenseKernelComputation, - DiagonalKernelComputation, ) from gpjax.kernels.stationary import ( RBF, @@ -40,6 +39,7 @@ Periodic, PoweredExponential, RationalQuadratic, + White, ) from gpjax.kernels.computations import ( DenseKernelComputation, From 80abfd288e83b9d772a06cddd9f2cbbd14501f23 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Wed, 5 Apr 2023 22:00:47 +0100 Subject: [PATCH 35/44] Docs complete --- examples/barycentres.pct.py | 3 + examples/classification.pct.py | 8 +- examples/collapsed_vi.pct.py | 13 +- examples/graph_kernels.pct.py | 41 ++-- examples/haiku.pct.py | 7 +- examples/kernels.pct.py | 148 +++++------- examples/natgrads.pct.py | 15 +- examples/regression.pct.py | 3 +- examples/tfp_integration.pct.py | 1 - examples/uncollapsed_vi.pct.py | 220 +++++++++++------- examples/yacht.pct.py | 156 +++++++++---- gpjax/fit.py | 2 +- gpjax/kernels/approximations/rff.py | 1 + gpjax/kernels/base.py | 1 + gpjax/kernels/non_euclidean/graph.py | 19 +- gpjax/kernels/nonstationary/linear.py | 1 + gpjax/kernels/nonstationary/polynomial.py | 3 + gpjax/kernels/stationary/matern12.py | 1 + gpjax/kernels/stationary/matern32.py | 1 + gpjax/kernels/stationary/matern52.py | 1 + gpjax/kernels/stationary/periodic.py | 1 + .../kernels/stationary/powered_exponential.py | 1 + .../kernels/stationary/rational_quadratic.py | 1 + gpjax/kernels/stationary/rbf.py | 1 + gpjax/kernels/stationary/white.py | 1 + 25 files changed, 368 insertions(+), 282 deletions(-) diff --git a/examples/barycentres.pct.py b/examples/barycentres.pct.py index 280fd61aa..9762cbfca 100644 --- a/examples/barycentres.pct.py +++ b/examples/barycentres.pct.py @@ -134,6 +134,7 @@ # [Kernels notebook](https://gpjax.readthedocs.io/en/latest/nbs/kernels.html) for # advice on selecting an appropriate kernel. + # %% def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance: if y.ndim == 1: @@ -166,6 +167,7 @@ def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance: # following cell by utilising Jax's `vmap` operator to speed up large matrix operations # using broadcasting in `tensordot`. + # %% def sqrtm(A: jax.Array): return jnp.real(jsl.sqrtm(A)) @@ -219,6 +221,7 @@ def step(covariance_candidate: jax.Array, idx: None): # looks reasonable as it follows the sinusoidal curve of all the inferred GPs, and the # uncertainty bands are sensible. + # %% def plot( dist: tfd.MultivariateNormalTriL, diff --git a/examples/classification.pct.py b/examples/classification.pct.py index 530f58b56..7c1bc9c78 100644 --- a/examples/classification.pct.py +++ b/examples/classification.pct.py @@ -24,17 +24,16 @@ # %% import blackjax -import tensorflow_probability.substrates.jax as tfp import jax import jax.numpy as jnp import jax.random as jr import jax.scipy as jsp +import jax.tree_util as jtu import matplotlib.pyplot as plt import optax as ox +import tensorflow_probability.substrates.jax as tfp from jax.config import config from jaxtyping import Array, Float -import jax.tree_util as jtu - import gpjax as gpx @@ -229,11 +228,11 @@ # We take the latent distribution computed in the previous section and add this term # to the covariance to construct $q_{Laplace}(f(\cdot))$. + # %% def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNormalTriL: map_latent_dist = opt_posterior.predict(xtest, train_data=D) - Kxt = opt_posterior.prior.kernel.cross_covariance(x, test_inputs) Kxx = opt_posterior.prior.kernel.gram(x) Kxx += I(D.n) * jitter @@ -253,6 +252,7 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma L = jnp.linalg.cholesky(covariance) return tfd.MultivariateNormalTriL(jnp.atleast_1d(mean.squeeze()), L) + # %% [markdown] # From this we can construct the predictive distribution at the test points. # %% diff --git a/examples/collapsed_vi.pct.py b/examples/collapsed_vi.pct.py index f3a2915ad..dea8e9a22 100644 --- a/examples/collapsed_vi.pct.py +++ b/examples/collapsed_vi.pct.py @@ -98,9 +98,7 @@ # noise, we pass this to the constructer. # %% -q = gpx.CollapsedVariationalGaussian( - posterior=posterior, inducing_inputs=z -) +q = gpx.CollapsedVariationalGaussian(posterior=posterior, inducing_inputs=z) # %% [markdown] # We define our variational inference algorithm through `CollapsedVI`. This defines @@ -183,10 +181,7 @@ ax.plot(xtest, samples.T, color="tab:blue", alpha=0.8, linewidth=0.2) -[ - ax.axvline(x=z_i, color="tab:gray", alpha=0.3, linewidth=1) - for z_i in inducing_points -] +[ax.axvline(x=z_i, color="tab:gray", alpha=0.3, linewidth=1) for z_i in inducing_points] ax.legend() plt.show() @@ -200,7 +195,9 @@ # full model. # %% -full_rank_model = gpx.Prior(mean_function=gpx.Zero(), kernel=gpx.RBF()) * gpx.Gaussian(num_datapoints=D.n) +full_rank_model = gpx.Prior(mean_function=gpx.Zero(), kernel=gpx.RBF()) * gpx.Gaussian( + num_datapoints=D.n +) negative_mll = jit(gpx.ConjugateMLL(negative=True)) # %timeit negative_mll(full_rank_model, D).block_until_ready() diff --git a/examples/graph_kernels.pct.py b/examples/graph_kernels.pct.py index ea8f5c211..e89ee2ddd 100644 --- a/examples/graph_kernels.pct.py +++ b/examples/graph_kernels.pct.py @@ -100,25 +100,12 @@ x = jnp.arange(G.number_of_nodes()).reshape(-1, 1) kernel = gpx.GraphKernel(laplacian=L) -prior = gpx.Prior(mean_function = gpx.Zero(), kernel=kernel) +prior = gpx.Prior(mean_function=gpx.Zero(), kernel=kernel) -true_params = prior.init_params(key) -true_params["kernel"] = { - "lengthscale": jnp.array(2.3), - "variance": jnp.array(3.2), - "smoothness": jnp.array(6.1), -} - -fx = prior(true_params)(x) +fx = prior(x) y = fx.sample(seed=key).reshape(-1, 1) -D = Dataset(X=x, y=y) - -# %% -kernel.compute_engine.gram - -# %% -kernel.gram(params=kernel.init_params(key), inputs=x) +D = gpx.Dataset(X=x, y=y) # %% [markdown] # @@ -148,22 +135,22 @@ # We do this using the Adam optimiser provided in `optax`. # %% -likelihood = gpx.Gaussian(num_datapoints=y.shape[0]) -posterior = prior * likelihood +from gpjax.base.module import meta_leaves -parameter_state = gpx.initialise(posterior, key) -negative_mll = jit(posterior.marginal_log_likelihood(train_data=D, negative=True)) -optimiser = ox.adam(learning_rate=0.01) +meta_leaves(posterior)[1] -inference_state = gpx.fit( - objective=negative_mll, - parameter_state=parameter_state, - optax_optim=optimiser, +# %% +likelihood = gpx.Gaussian(num_datapoints=D.n) +posterior = prior * likelihood + +opt_posterior, training_history = gpx.fit( + model=posterior, + objective=gpx.ConjugateMLL(negative=True), + train_data=D, + optim=ox.adamw(learning_rate=0.01), num_iters=1000, ) -learned_params, training_history = inference_state.unpack() - # %% [markdown] # # ## Making predictions diff --git a/examples/haiku.pct.py b/examples/haiku.pct.py index 9355a50a7..d2f13ab2e 100644 --- a/examples/haiku.pct.py +++ b/examples/haiku.pct.py @@ -21,6 +21,7 @@ # %% import typing as tp +from dataclasses import dataclass from typing import Dict import haiku as hk @@ -31,7 +32,6 @@ import optax as ox from jax.config import config from jaxtyping import Array, Float -from jaxutils import Dataset from scipy.signal import sawtooth import gpjax as gpx @@ -59,7 +59,7 @@ signal = f(x) y = signal + jr.normal(subkey, shape=signal.shape) * noise -D = Dataset(X=x, y=y) +D = gpx.Dataset(X=x, y=y) xtest = jnp.linspace(-2.0, 2.0, 500).reshape(-1, 1) ytest = f(xtest) @@ -80,7 +80,9 @@ # # Although deep kernels are not currently supported natively in GPJax, defining one is straightforward as we now demonstrate. Using the base `AbstractKernel` object given in GPJax, we provide a mixin class named `_DeepKernelFunction` to facilitate the user supplying the neural network and base kernel of their choice. Kernel matrices are then computed using the regular `gram` and `cross_covariance` functions. + # %% +@dataclass class DeepKernelFunction(AbstractKernel): def __init__( self, @@ -122,6 +124,7 @@ def _initialise_params(self, key: jr.KeyArray) -> Dict: # With a deep kernel object created, we proceed to define a neural network. Here we consider a small multi-layer perceptron with two linear hidden layers and ReLU activation functions between the layers. The first hidden layer contains 32 units, while the second layer contains 64 units. Finally, we'll make the output of our network a single unit. However, it would be possible to project our data into a $d-$dimensional space for $d>1$. In these instances, making the [base kernel ARD](https://gpjax.readthedocs.io/en/latest/nbs/kernels.html#Active-dimensions) would be sensible. # Users may wish to design more intricate network structures for more complex tasks, which functionality is supported well in Haiku. + # %% def forward(x): mlp = hk.Sequential( diff --git a/examples/kernels.pct.py b/examples/kernels.pct.py index 009e05d06..4d9130fb3 100644 --- a/examples/kernels.pct.py +++ b/examples/kernels.pct.py @@ -15,33 +15,34 @@ # --- # %% [markdown] -# ```{note} -# This notebook is a duplicate of the one found in the [JaxKern documentation](https://jaxkern.readthedocs.io/en/latest/nbs/kernels.html). It is included here for completeness. -# ``` # # Kernel Guide # # In this guide, we introduce the kernels available in GPJax and demonstrate how to create custom ones. +# +# +# from typing import Dict - -from typing import Dict +from dataclasses import dataclass # %% import distrax as dx import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt +import optax as ox +import tensorflow_probability.substrates.jax as tfp from jax import jit from jax.config import config from jaxtyping import Array, Float -from jaxutils import Dataset -from optax import adam +from simple_pytree import static_field import gpjax as gpx -import gpjax.kernels as jk +from gpjax.base.param import param_field # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) key = jr.PRNGKey(123) +tfb = tfp.bijectors # %% [markdown] # ## Supported Kernels @@ -59,21 +60,22 @@ # %% kernels = [ - jk.Matern12(), - jk.Matern32(), - jk.Matern52(), - jk.RBF(), - jk.Polynomial(), - jk.Polynomial(degree=2), + gpx.kernels.Matern12(), + gpx.kernels.Matern32(), + gpx.kernels.Matern52(), + gpx.kernels.RBF(), + gpx.kernels.Polynomial(), + gpx.kernels.Polynomial(degree=2), ] -fig, axes = plt.subplots(ncols=3, nrows=2, figsize=(20, 10)) +fig, axes = plt.subplots(ncols=3, nrows=2, figsize=(10, 6), tight_layout=True) x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1) +meanf = gpx.mean_functions.Zero() + for k, ax in zip(kernels, axes.ravel()): - prior = gpx.Prior(kernel=k) - params, *_ = gpx.initialise(prior, key).unpack() - rv = prior(params)(x) + prior = gpx.Prior(mean_function=meanf, kernel=k) + rv = prior(x) y = rv.sample(seed=key, sample_shape=(10,)) ax.plot(x, y.T, alpha=0.7) ax.set_title(k.name) @@ -90,15 +92,14 @@ # like our RBF kernel to act on the first, second and fourth dimensions. # %% -slice_kernel = jk.RBF(active_dims=[0, 1, 3]) +slice_kernel = gpx.kernels.RBF(active_dims=[0, 1, 3]) # %% [markdown] # # The resulting kernel has one length-scale parameter per input dimension --- an ARD kernel. # %% -print(f"ARD: {slice_kernel.ard}") -print(f"Lengthscales: {slice_kernel.init_params(key)['lengthscale']}") +print(f"Lengthscales: {slice_kernel.lengthscale}") # %% [markdown] # We'll now simulate some data and evaluate the kernel on the previously selected input dimensions. @@ -107,11 +108,8 @@ # Inputs x_matrix = jr.normal(key, shape=(50, 5)) -# Default parameter dictionary -params = slice_kernel.init_params(key) - # Compute the Gram matrix -K = slice_kernel.gram(params, x_matrix) +K = slice_kernel.gram(x_matrix) print(K.shape) # %% [markdown] @@ -123,14 +121,14 @@ # can be created by applying the `+` operator as follows. # %% -k1 = jk.RBF() -k2 = jk.Polynomial() -sum_k = k1 + k2 +k1 = gpx.kernels.RBF() +k2 = gpx.kernels.Polynomial() +sum_k = gpx.kernels.ProductKernel(kernels=[k1, k2]) fig, ax = plt.subplots(ncols=3, figsize=(20, 5)) -im0 = ax[0].matshow(k1.gram(k1.init_params(key), x).to_dense()) -im1 = ax[1].matshow(k2.gram(k2.init_params(key), x).to_dense()) -im2 = ax[2].matshow(sum_k.gram(sum_k.init_params(key), x).to_dense()) +im0 = ax[0].matshow(k1.gram(x).to_dense()) +im1 = ax[1].matshow(k2.gram(x).to_dense()) +im2 = ax[2].matshow(sum_k.gram(x).to_dense()) fig.colorbar(im0, ax=ax[0]) fig.colorbar(im1, ax=ax[1]) @@ -140,28 +138,21 @@ # Similarily, products of kernels can be created through the `*` operator. # %% -k3 = jk.Matern32() +k3 = gpx.kernels.Matern32() -prod_k = k1 * k2 * k3 +prod_k = gpx.kernels.ProductKernel(kernels=[k1, k2, k3]) fig, ax = plt.subplots(ncols=4, figsize=(20, 5)) -im0 = ax[0].matshow(k1.gram(k1.init_params(key), x).to_dense()) -im1 = ax[1].matshow(k2.gram(k2.init_params(key), x).to_dense()) -im2 = ax[2].matshow(k3.gram(k3.init_params(key), x).to_dense()) -im3 = ax[3].matshow(prod_k.gram(prod_k.init_params(key), x).to_dense()) +im0 = ax[0].matshow(k1.gram(x).to_dense()) +im1 = ax[1].matshow(k2.gram(x).to_dense()) +im2 = ax[2].matshow(k3.gram(x).to_dense()) +im3 = ax[3].matshow(prod_k.gram(x).to_dense()) fig.colorbar(im0, ax=ax[0]) fig.colorbar(im1, ax=ax[1]) fig.colorbar(im2, ax=ax[2]) fig.colorbar(im3, ax=ax[3]) -# %% [markdown] -# Alternatively kernel sums and multiplications can be created by passing a list of kernels into the `SumKernel` `ProductKernel` objects respectively. - -# %% -sum_k = jk.SumKernel(kernel_set=[k1, k2]) -prod_k = jk.ProductKernel(kernel_set=[k1, k2, k3]) - # %% [markdown] # ## Custom kernel @@ -200,32 +191,32 @@ # # To implement this, one must write the following class. + # %% def angular_distance(x, y, c): return jnp.abs((x - y + c) % (c * 2) - c) -class Polar(jk.base.AbstractKernel): - def __init__(self) -> None: - super().__init__() - self.period: float = 2 * jnp.pi - self.c = self.period / 2.0 # in [0, \pi] +@dataclass +class _Polar: + period: float = static_field(2 * jnp.pi) + tau: float = param_field(jnp.array([4.0]), bijector=tfb.Softplus(low=4.0)) + + +@dataclass +class Polar(gpx.kernels.AbstractKernel, _Polar): + def __post_init__(self): + self.c = self.period / 2.0 def __call__( - self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] + self, x: Float[Array, "1 D"], y: Float[Array, "1 D"] ) -> Float[Array, "1"]: - tau = params["tau"] t = angular_distance(x, y, self.c) - K = (1 + tau * t / self.c) * jnp.clip(1 - t / self.c, 0, jnp.inf) ** tau + K = (1 + self.tau * t / self.c) * jnp.clip( + 1 - t / self.c, 0, jnp.inf + ) ** self.tau return K.squeeze() - def init_params(self, key: jr.KeyArray) -> dict: - return {"tau": jnp.array([4.0])} - - # This is depreciated. Can be removed once JaxKern is updated. - def _initialise_params(self, key: jr.KeyArray) -> Dict: - return self.init_params(key) - # %% [markdown] # We unpack this now to make better sense of it. In the kernel's `__init__` @@ -253,18 +244,6 @@ def _initialise_params(self, key: jr.KeyArray) -> Dict: # transformation](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.softplus.html) # where the lower bound is shifted by $4$. -# %% -from jax.nn import softplus - -from gpjax.config import add_parameter - -bij_fn = lambda x: softplus(x + jnp.array(4.0)) -bij = dx.Lambda( - forward=bij_fn, inverse=lambda y: -jnp.log(-jnp.expm1(-y - 4.0)) + y - 4.0 -) - -add_parameter("tau", bij) - # %% [markdown] # ### Using our polar kernel # @@ -279,29 +258,24 @@ def _initialise_params(self, key: jr.KeyArray) -> Dict: X = jnp.sort(jr.uniform(key, minval=0.0, maxval=jnp.pi * 2, shape=(n, 1)), axis=0) y = 4 + jnp.cos(2 * X) + jr.normal(key, shape=X.shape) * noise -D = Dataset(X=X, y=y) +D = gpx.Dataset(X=X, y=y) # Define polar Gaussian process PKern = Polar() +meanf = gpx.mean_functions.Zero() likelihood = gpx.Gaussian(num_datapoints=n) -circlular_posterior = gpx.Prior(kernel=PKern) * likelihood +circlular_posterior = gpx.Prior(mean_function=meanf, kernel=PKern) * likelihood -# Initialise parameter state: -parameter_state = gpx.initialise(circlular_posterior, key) # Optimise GP's marginal log-likelihood using Adam -negative_mll = jit(circlular_posterior.marginal_log_likelihood(D, negative=True)) -optimiser = adam(learning_rate=0.05) - -inference_state = gpx.fit( - objective=negative_mll, - parameter_state=parameter_state, - optax_optim=optimiser, - num_iters=1000, +opt_posterior, history = gpx.fit( + model=circlular_posterior, + objective=gpx.ConjugateMLL(negative=True), + train_data=D, + optim=ox.adamw(learning_rate=0.05), + num_iters=500, ) -learned_params, training_history = inference_state.unpack() - # %% [markdown] # ### Prediction # @@ -309,9 +283,7 @@ def _initialise_params(self, key: jr.KeyArray) -> Dict: # and illustrate the results. # %% -posterior_rv = likelihood( - learned_params, circlular_posterior(learned_params, D)(angles) -) +posterior_rv = opt_posterior.likelihood(opt_posterior.predict(angles, train_data=D)) mu = posterior_rv.mean() one_sigma = posterior_rv.stddev() diff --git a/examples/natgrads.pct.py b/examples/natgrads.pct.py index 7dc2c8a0c..d5df6474e 100644 --- a/examples/natgrads.pct.py +++ b/examples/natgrads.pct.py @@ -26,10 +26,8 @@ import matplotlib.pyplot as plt import optax as ox from jax.config import config -from jaxutils import Dataset import gpjax as gpx -import gpjax.kernels as jk # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -55,7 +53,7 @@ signal = f(x) y = signal + jr.normal(subkey, shape=signal.shape) * noise -D = Dataset(X=x, y=y) +D = gpx.Dataset(X=x, y=y) xtest = jnp.linspace(-5.5, 5.5, 500).reshape(-1, 1) # %% [markdown] @@ -78,15 +76,14 @@ # %% likelihood = gpx.Gaussian(num_datapoints=n) -kernel = jk.RBF() -prior = gpx.Prior(kernel=kernel) +meanf = gpx.mean_functions.Zero() +kernel = gpx.RBF() +prior = gpx.Prior(mean_function=meanf, kernel=kernel) p = prior * likelihood -natural_q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z) -natural_svgp = gpx.StochasticVI(posterior=p, variational_family=natural_q) - -parameter_state = gpx.initialise(natural_svgp) +natural_q = gpx.NaturalVariationalGaussian(posterior=p, inducing_inputs=z) +natural_svgp = gpx.ELBO(negative=True) # %% [markdown] # Next, we can conduct natural gradients as follows: diff --git a/examples/regression.pct.py b/examples/regression.pct.py index 44b2e3c44..16c34bc60 100644 --- a/examples/regression.pct.py +++ b/examples/regression.pct.py @@ -27,7 +27,6 @@ import optax as ox from jax import jit from jax.config import config -from jaxutils import Dataset import gpjax as gpx @@ -56,7 +55,7 @@ signal = f(x) y = signal + jr.normal(subkey, shape=signal.shape) * noise -D = Dataset(X=x, y=y) +D = gpx.Dataset(X=x, y=y) xtest = jnp.linspace(-3.5, 3.5, 500).reshape(-1, 1) ytest = f(xtest) diff --git a/examples/tfp_integration.pct.py b/examples/tfp_integration.pct.py index 7b0f9bd04..1982a673c 100644 --- a/examples/tfp_integration.pct.py +++ b/examples/tfp_integration.pct.py @@ -135,7 +135,6 @@ def build_log_pi(log_mll, unconstrained_priors, mapper_fn): def array_mll(parameter_array): - # Convert parameter array to a dictionary: params_dict = mapper_fn([jnp.array(i) for i in parameter_array]) diff --git a/examples/uncollapsed_vi.pct.py b/examples/uncollapsed_vi.pct.py index d20734e26..63445fdff 100644 --- a/examples/uncollapsed_vi.pct.py +++ b/examples/uncollapsed_vi.pct.py @@ -7,7 +7,7 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.14.4 +# jupytext_version: 1.11.2 # kernelspec: # display_name: gpjax # language: python @@ -17,7 +17,17 @@ # %% [markdown] # # Sparse Stochastic Variational Inference # -# In this notebook we demonstrate how to implement sparse variational Gaussian processes (SVGPs) of Hensman et al. (2013); Hensman et al. (2015). In particular, this approximation framework provides a tractable option for working with non-conjugate Gaussian processes with more than ~5000 data points. However, for conjugate models of less than 5000 data points, we recommend using the marginal log-likelihood approach presented in the [regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html). Though we illustrate SVGPs here with a conjugate regression example, the same GPJax code works for general likelihoods, such as a Bernoulli for classification. +# In this notebook we demonstrate how to implement sparse variational Gaussian +# processes (SVGPs) of +# Hensman et al. (2013); +# Hensman et al. (2015). In +# particular, this approximation framework provides a tractable option for working with +# non-conjugate Gaussian processes with more than ~5000 data points. However, for +# conjugate models of less than 5000 data points, we recommend using the marginal +# log-likelihood approach presented in the +# [regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html). +# Though we illustrate SVGPs here with a conjugate regression example, the same GPJax +# code works for general likelihoods, such as a Bernoulli for classification. # %% import jax.numpy as jnp @@ -27,32 +37,28 @@ import tensorflow_probability.substrates.jax as tfp from jax import jit from jax.config import config -from jaxutils import Dataset - -import gpjax.kernels as jk - -tfb = tfp.bijectors - -import distrax as dx import gpjax as gpx -from gpjax.config import get_global_config, reset_global_config +import gpjax.kernels as jk # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) key = jr.PRNGKey(123) +tfb = tfp.bijectors # %% [markdown] # ## Dataset # -# With the necessary modules imported, we simulate a dataset $\mathcal{D} = (\boldsymbol{x}, \boldsymbol{y}) = \{(x_i, y_i)\}_{i=1}^{5000}$ with inputs $\boldsymbol{x}$ sampled uniformly on $(-5, 5)$ and corresponding binary outputs +# With the necessary modules imported, we simulate a dataset +# $\mathcal{D} = (\boldsymbol{x}, \boldsymbol{y}) = \{(x_i, y_i)\}_{i=1}^{5000}$ +# with inputs $\boldsymbol{x}$ sampled uniformly on $(-5, 5)$ and corresponding binary outputs # # $$\boldsymbol{y} \sim \mathcal{N} \left(\sin(4 * \boldsymbol{x}) + \sin(2 * \boldsymbol{x}), \textbf{I} * (0.2)^{2} \right).$$ # # We store our data $\mathcal{D}$ as a GPJax `Dataset` and create test inputs for later. # %% -n = 5000 +n = 50000 noise = 0.2 key, subkey = jr.split(key) @@ -61,28 +67,49 @@ signal = f(x) y = signal + jr.normal(subkey, shape=signal.shape) * noise -D = Dataset(X=x, y=y) +D = gpx.Dataset(X=x, y=y) xtest = jnp.linspace(-5.5, 5.5, 500).reshape(-1, 1) # %% [markdown] # ## Sparse GPs via inducing inputs # -# Despite their endowment with elegant theoretical properties, GPs are burdened with prohibitive $\mathcal{O}(n^3)$ inference and $\mathcal{O}(n^2)$ memory costs in the number of data points $n$ due to the necessity of computing inverses and determinants of the kernel Gram matrix $\mathbf{K}_{\boldsymbol{x}\boldsymbol{x}}$ during inference and hyperparameter learning. +# Despite their endowment with elegant theoretical properties, GPs are burdened with +# prohibitive $\mathcal{O}(n^3)$ inference and $\mathcal{O}(n^2)$ memory costs in the +# number of data points $n$ due to the necessity of computing inverses and determinants +# of the kernel Gram matrix $\mathbf{K}_{\boldsymbol{x}\boldsymbol{x}}$ during inference +# and hyperparameter learning. # Sparse GPs seek to resolve tractability through low-rank approximations. # -# Their name originates with the idea of using subsets of the data to approximate the kernel matrix, with _sparseness_ occurring through the selection of the data points. -# Given inputs $\boldsymbol{x}$ and outputs $\boldsymbol{y}$ the task was to select an $m review many popular approximation schemes in this vein. However, because the model and the approximation are intertwined, assigning performance and faults to one or the other becomes tricky. +# Their name originates with the idea of using subsets of the data to approximate the +# kernel matrix, with _sparseness_ occurring through the selection of the data points. +# Given inputs $\boldsymbol{x}$ and outputs $\boldsymbol{y}$ the task was to select an +# $m +# review many popular approximation schemes in this vein. However, because the model +# and the approximation are intertwined, assigning performance and faults to one or the +# other becomes tricky. # -# On the other hand, sparse variational Gaussian processes (SVGPs) [approximate the posterior, not the model](https://www.secondmind.ai/labs/sparse-gps-approximate-the-posterior-not-the-model/). -# These provide a low-rank approximation scheme via variational inference. Here we posit a family of densities parameterised by “variational parameters”. -# We then seek to find the closest family member to the posterior by minimising the Kullback-Leibler divergence over the variational parameters. +# On the other hand, sparse variational Gaussian processes (SVGPs) +# [approximate the posterior, not the model](https://www.secondmind.ai/labs/sparse-gps-approximate-the-posterior-not-the-model/). +# These provide a low-rank approximation scheme via variational inference. Here we +# posit a family of densities parameterised by “variational parameters”. +# We then seek to find the closest family member to the posterior by minimising the +# Kullback-Leibler divergence over the variational parameters. # The fitted variational density then serves as a proxy for the exact posterior. -# This procedure makes variational methods efficiently solvable via off-the-shelf optimisation techniques whilst retaining the true-underlying model. -# Furthermore, SVGPs offer further cost reductions with mini-batch stochastic gradient descent and address non-conjugacy . -# We show a cost comparison between the approaches below, where $b$ is the mini-batch size. +# This procedure makes variational methods efficiently solvable via off-the-shelf +# optimisation techniques whilst retaining the true-underlying model. +# Furthermore, SVGPs offer further cost reductions with mini-batch stochastic gradient +# descent and address non-conjugacy +# . +# We show a cost comparison between the approaches below, where $b$ is the mini-batch +# size. # # # @@ -92,7 +119,9 @@ # | Memory cost | $\mathcal{O}(n^2)$ | $\mathcal{O}(n m)$ | $\mathcal{O}(b m + m^2)$ | # # -# To apply SVGP inference to our dataset, we begin by initialising $m = 50$ equally spaced inducing inputs $\boldsymbol{z}$ across our observed data's support. These are depicted below via horizontal black lines. +# To apply SVGP inference to our dataset, we begin by initialising $m = 50$ equally +# spaced inducing inputs $\boldsymbol{z}$ across our observed data's support. These +# are depicted below via horizontal black lines. # %% z = jnp.linspace(-5.0, 5.0, 50).reshape(-1, 1) @@ -104,90 +133,123 @@ plt.show() # %% [markdown] -# The inducing inputs will summarise our dataset, and since they are treated as variational parameters, their locations will be optimised. The next step to SVGP is to define a variational family. +# The inducing inputs will summarise our dataset, and since they are treated as +# variational parameters, their locations will be optimised. The next step to SVGP is +# to define a variational family. # %% [markdown] # ## Defining the variational process # -# We begin by considering the form of the posterior distribution for all function values $f(\cdot)$ +# We begin by considering the form of the posterior distribution for all function +# values $f(\cdot)$ # # \begin{align} # p(f(\cdot) | \mathcal{D}) = \int p(f(\cdot)|f(\boldsymbol{x})) p(f(\boldsymbol{x})|\mathcal{D}) \text{d}f(\boldsymbol{x}). \qquad (\dagger) # \end{align} # -# To arrive at an approximation framework, we assume some redundancy in the data. Instead of predicting $f(\cdot)$ with function values at the datapoints $f(\boldsymbol{x})$, we assume this can be achieved with only function values at $m$ inducing inputs $\boldsymbol{z}$ +# To arrive at an approximation framework, we assume some redundancy in the data. +# Instead of predicting $f(\cdot)$ with function values at the datapoints +# $f(\boldsymbol{x})$, we assume this can be achieved with only function values at +# $m$ inducing inputs $\boldsymbol{z}$ # # $$ p(f(\cdot) | \mathcal{D}) \approx \int p(f(\cdot)|f(\boldsymbol{z})) p(f(\boldsymbol{z})|\mathcal{D}) \text{d}f(\boldsymbol{z}). \qquad (\star) $$ # -# This lower dimensional integral results in computational savings in the model's predictive component from $p(f(\cdot)|f(\boldsymbol{x}))$ to $p(f(\cdot)|f(\boldsymbol{z}))$ where inverting $\mathbf{K}_{\boldsymbol{x}\boldsymbol{x}}$ is replaced with inverting $\mathbf{K}_{\boldsymbol{z}\boldsymbol{z}}$. -# However, since we did not observe our data $\mathcal{D}$ at $\boldsymbol{z}$ we ask, what exactly is the posterior $p(f(\boldsymbol{z})|\mathcal{D})$? +# This lower dimensional integral results in computational savings in the model's +# predictive component from $p(f(\cdot)|f(\boldsymbol{x}))$ to +# $p(f(\cdot)|f(\boldsymbol{z}))$ where inverting +# $\mathbf{K}_{\boldsymbol{x}\boldsymbol{x}}$ is replaced with inverting +# $\mathbf{K}_{\boldsymbol{z}\boldsymbol{z}}$. +# However, since we did not observe our data $\mathcal{D}$ at $\boldsymbol{z}$ we ask, +# what exactly is the posterior $p(f(\boldsymbol{z})|\mathcal{D})$? # -# Notice this is simply obtained by substituting $\boldsymbol{z}$ into $(\dagger)$, but we arrive back at square one with computing the expensive integral. To side-step this, we consider replacing $p(f(\boldsymbol{z})|\mathcal{D})$ in $(\star)$ with a cheap-to-compute approximate distribution $q(f(\boldsymbol{z}))$ +# Notice this is simply obtained by substituting $\boldsymbol{z}$ into $(\dagger)$, +# but we arrive back at square one with computing the expensive integral. To side-step +# this, we consider replacing $p(f(\boldsymbol{z})|\mathcal{D})$ in $(\star)$ with a +# cheap-to-compute approximate distribution $q(f(\boldsymbol{z}))$ # # $$ q(f(\cdot)) = \int p(f(\cdot)|f(\boldsymbol{z})) q(f(\boldsymbol{z})) \text{d}f(\boldsymbol{z}). \qquad (\times) $$ # -# To measure the quality of the approximation, we consider the Kullback-Leibler divergence $\operatorname{KL}(\cdot || \cdot)$ from our approximate process $q(f(\cdot))$ to the true process $p(f(\cdot)|\mathcal{D})$. By parametrising $q(f(\boldsymbol{z}))$ over a variational family of distributions, we can optimise Kullback-Leibler divergence with respect to the variational parameters. Moreover, since inducing input locations $\boldsymbol{z}$ augment the model, they themselves can be treated as variational parameters without altering the true underlying model $p(f(\boldsymbol{z})|\mathcal{D})$. This is exactly what gives SVGPs great flexibility whilst retaining robustness to overfitting. +# To measure the quality of the approximation, we consider the Kullback-Leibler +# divergence $\operatorname{KL}(\cdot || \cdot)$ from our approximate process +# $q(f(\cdot))$ to the true process $p(f(\cdot)|\mathcal{D})$. By parametrising +# $q(f(\boldsymbol{z}))$ over a variational family of distributions, we can optimise +# Kullback-Leibler divergence with respect to the variational parameters. Moreover, +# since inducing input locations $\boldsymbol{z}$ augment the model, they themselves +# can be treated as variational parameters without altering the true underlying model +# $p(f(\boldsymbol{z})|\mathcal{D})$. This is exactly what gives SVGPs great +# flexibility whilst retaining robustness to overfitting. # -# It is popular to elect a Gaussian variational distribution $q(f(\boldsymbol{z})) = \mathcal{N}(f(\boldsymbol{z}); \mathbf{m}, \mathbf{S})$ with parameters $\{\boldsymbol{z}, \mathbf{m}, \mathbf{S}\}$, since conjugacy is provided between $q(f(\boldsymbol{z}))$ and $p(f(\cdot)|f(\boldsymbol{z}))$ so that the resulting variational process $q(f(\cdot))$ is a GP. We can implement this in GPJax by the following. +# It is popular to elect a Gaussian variational distribution +# $q(f(\boldsymbol{z})) = \mathcal{N}(f(\boldsymbol{z}); \mathbf{m}, \mathbf{S})$ +# with parameters $\{\boldsymbol{z}, \mathbf{m}, \mathbf{S}\}$, since conjugacy is +# provided between $q(f(\boldsymbol{z}))$ and $p(f(\cdot)|f(\boldsymbol{z}))$ so that +# the resulting variational process $q(f(\cdot))$ is a GP. We can implement this in +# GPJax by the following. # %% +meanf = gpx.mean_functions.Zero() likelihood = gpx.Gaussian(num_datapoints=n) -prior = gpx.Prior(kernel=jk.RBF()) +prior = gpx.Prior(mean_function=meanf, kernel=jk.RBF()) p = prior * likelihood -q = gpx.VariationalGaussian(prior=prior, inducing_inputs=z) - -# %% [markdown] -# Here, the variational process $q(\cdot)$ depends on the prior through $p(f(\cdot)|f(\boldsymbol{z}))$ in $(\times)$. +q = gpx.VariationalGaussian(posterior=p, inducing_inputs=z) # %% [markdown] -# -# We combine our true and approximate posterior Gaussian processes into an `StochasticVI` object to define the variational strategy that we will adopt in the forthcoming inference. - -# %% -svgp = gpx.StochasticVI(posterior=p, variational_family=q) +# Here, the variational process $q(\cdot)$ depends on the prior through +# $p(f(\cdot)|f(\boldsymbol{z}))$ in $(\times)$. # %% [markdown] # ## Inference # # ### Evidence lower bound # -# With our model defined, we seek to infer the optimal inducing inputs $\boldsymbol{z}$, variational mean $\mathbf{m}$ and covariance $\mathbf{S}$ that define our approximate posterior. To achieve this, we maximise the evidence lower bound (ELBO) with respect to $\{\boldsymbol{z}, \mathbf{m}, \mathbf{S} \}$, a proxy for minimising the Kullback-Leibler divergence. Moreover, as hinted by its name, the ELBO is a lower bound to the marginal log-likelihood, providing a tractable objective to optimise the model's hyperparameters akin to the conjugate setting. For further details on this, see Sections 3.1 and 4.1 of the excellent review paper . +# With our model defined, we seek to infer the optimal inducing inputs +# $\boldsymbol{z}$, variational mean $\mathbf{m}$ and covariance +# $\mathbf{S}$ that define our approximate posterior. To achieve this, we maximise the +# evidence lower bound (ELBO) with respect to +# $\{\boldsymbol{z}, \mathbf{m}, \mathbf{S} \}$, a proxy for minimising the +# Kullback-Leibler divergence. Moreover, as hinted by its name, the ELBO is a lower +# bound to the marginal log-likelihood, providing a tractable objective to optimise the +# model's hyperparameters akin to the conjugate setting. For further details on this, +# see Sections 3.1 and 4.1 of the excellent review paper +# . # -# Since Optax's optimisers work to minimise functions, to maximise the ELBO we return its negative. +# Since Optax's optimisers work to minimise functions, to maximise the ELBO we return +# its negative. # %% -negative_elbo = jit(svgp.elbo(D, negative=True)) +negative_elbo = gpx.ELBO(negative=True) # %% [markdown] # ### Mini-batching # -# Despite introducing inducing inputs into our model, inference can still be intractable with large datasets. To circumvent this, optimisation can be done using stochastic mini-batches. +# Despite introducing inducing inputs into our model, inference can still be +# intractable with large datasets. To circumvent this, optimisation can be done using +# stochastic mini-batches. # %% -reset_global_config() -parameter_state = gpx.initialise(svgp, key) -optimiser = ox.adam(learning_rate=0.01) - -inference_state = gpx.fit_batches( +opt_posterior, history = gpx.fit( + model=q, objective=negative_elbo, - parameter_state=parameter_state, train_data=D, - optax_optim=optimiser, + optim=ox.adam(learning_rate=0.01), num_iters=3000, key=jr.PRNGKey(42), batch_size=128, ) -learned_params, training_history = inference_state.unpack() +plt.plot(history) # %% [markdown] # ## Predictions # -# With optimisation complete, we can use our inferred parameter set to make predictions at novel inputs akin -# to all other models within GPJax on our variational process object $q(\cdot)$ (for example, see the [regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)). +# With optimisation complete, we can use our inferred parameter set to make +# predictions at novel inputs akin +# to all other models within GPJax on our variational process object $q(\cdot)$ (for +# example, see the +# [regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)). # %% -latent_dist = q(learned_params)(xtest) -predictive_dist = likelihood(learned_params, latent_dist) +latent_dist = opt_posterior(xtest) +predictive_dist = opt_posterior.posterior.likelihood(latent_dist) meanf = predictive_dist.mean() sigma = predictive_dist.stddev() @@ -198,45 +260,41 @@ ax.fill_between(xtest.flatten(), meanf - sigma, meanf + sigma, alpha=0.3) [ ax.axvline(x=z_i, color="black", alpha=0.3, linewidth=1) - for z_i in learned_params["variational_family"]["inducing_inputs"] + for z_i in opt_posterior.inducing_inputs ] plt.show() # %% [markdown] # ## Custom transformations # -# To train a covariance matrix, GPJax uses `tfb.FillScaleTriL` transformation by default. `tfb.FillScaleTriL` fills a 1d vector into a lower triangular matrix and then applies `Softplus` transformation on the diagonal to satisfy the necessary conditions for a valid Cholesky matrix. Users can change this default transformation with another valid transformation of their choice. For example, `Square` transformation on the diagonal can also serve the purpose. +# To train a covariance matrix, GPJax uses `tfb.FillScaleTriL` transformation by +# default. `tfb.FillScaleTriL` fills a 1d vector into a lower triangular matrix and +# then applies `Softplus` transformation on the diagonal to satisfy the necessary +# conditions for a valid Cholesky matrix. Users can change this default transformation +# with another valid transformation of their choice. For example, `Square` +# transformation on the diagonal can also serve the purpose. # %% -gpx_config = get_global_config() -transformations = gpx_config.transformations -jitter = gpx_config.jitter - -triangular_transform = dx.Chain( - [tfb.FillScaleTriL(diag_bijector=tfb.Square(), diag_shift=jnp.array(jitter))] +triangular_transform = tfb.FillScaleTriL( + diag_bijector=tfb.Square(), diag_shift=jnp.array(q.jitter) ) - -transformations.update({"triangular_transform": triangular_transform}) +reparameterised_q = q.replace_bijector(variational_root_covariance=triangular_transform) # %% -parameter_state = gpx.initialise(svgp, key) -optimiser = ox.adam(learning_rate=0.01) - -inference_state = gpx.fit_batches( +opt_rep, history = gpx.fit( + model=reparameterised_q, objective=negative_elbo, - parameter_state=parameter_state, train_data=D, - optax_optim=optimiser, + optim=ox.adam(learning_rate=0.01), num_iters=3000, key=jr.PRNGKey(42), batch_size=128, ) -learned_params, training_history = inference_state.unpack() # %% -latent_dist = q(learned_params)(xtest) -predictive_dist = likelihood(learned_params, latent_dist) +latent_dist = opt_rep(xtest) +predictive_dist = opt_rep.posterior.likelihood(latent_dist) meanf = predictive_dist.mean() sigma = predictive_dist.stddev() @@ -247,12 +305,14 @@ ax.fill_between(xtest.flatten(), meanf - sigma, meanf + sigma, alpha=0.3) [ ax.axvline(x=z_i, color="black", alpha=0.3, linewidth=1) - for z_i in learned_params["variational_family"]["inducing_inputs"] + for z_i in opt_rep.inducing_inputs ] plt.show() # %% [markdown] -# We can see that `Square` transformation is able to get relatively better fit compared to `Softplus` with the same number of iterations, but `Softplus` is recommended over `Square` for stability of optimisation. +# We can see that `Square` transformation is able to get relatively better fit +# compared to `Softplus` with the same number of iterations, but `Softplus` is +# recommended over `Square` for stability of optimisation. # %% [markdown] # ## System configuration diff --git a/examples/yacht.pct.py b/examples/yacht.pct.py index b72472725..8d69723dc 100644 --- a/examples/yacht.pct.py +++ b/examples/yacht.pct.py @@ -18,32 +18,37 @@ import matplotlib.pyplot as plt import numpy as np import optax as ox -from jax.config import config -from jaxutils import Dataset - -import gpjax.kernels as jk - -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) -# %% [markdown] -# # UCI Data Benchmarking -# -# In this notebook, we will show how to apply GPJax on a benchmark UCI regression problem. These kind of tasks are often used in the research community to benchmark and assess new techniques against those already in the literature. Much of the code contained in this notebook can be adapted to applied problems concerning datasets other than the one presented here. -# %% import pandas as pd from jax import jit +from jax.config import config from sklearn.metrics import mean_squared_error, r2_score from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler import gpjax as gpx +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) key = jr.PRNGKey(123) - +# %% [markdown] +# # UCI Data Benchmarking +# +# In this notebook, we will show how to apply GPJax on a benchmark UCI regression +# problem. These kind of tasks are often used in the research community to benchmark +# and assess new techniques against those already in the literature. Much of the code +# contained in this notebook can be adapted to applied problems concerning datasets +# other than the one presented here. # %% [markdown] # ## Data Loading # -# We'll be using the [Yacht](https://archive.ics.uci.edu/ml/datasets/yacht+hydrodynamics) dataset from the UCI machine learning data repository. Each observation describes the hydrodynamic performance of a yacht through its resistance. The dataset contains 6 covariates and a single positive, real valued response variable. There are 308 observations in the dataset, so we can comfortably use a conjugate regression Gaussian process here (for more more details, checkout the [Regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)). +# We'll be using the +# [Yacht](https://archive.ics.uci.edu/ml/datasets/yacht+hydrodynamics) dataset from +# the UCI machine learning data repository. Each observation describes the +# hydrodynamic performance of a yacht through its resistance. The dataset contains 6 +# covariates and a single positive, real valued response variable. There are 308 +# observations in the dataset, so we can comfortably use a conjugate regression +# Gaussian process here (for more more details, checkout the +# [Regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)). # %% yacht = pd.read_fwf("data/yacht_hydrodynamics.data", header=None).values[:-1, :] @@ -53,11 +58,16 @@ # %% [markdown] # ## Preprocessing # -# With a dataset loaded, we'll now preprocess it such that it is more amenable to modelling with a Gaussian process. +# With a dataset loaded, we'll now preprocess it such that it is more amenable to +# modelling with a Gaussian process. # # ### Data Partitioning # -# We'll first partition our data into a _training_ and _testing_ split. We'll fit our Gaussian process to the training data and evaluate its performance on the test data. This allows us to investigate how effectively our Gaussian process generalises to out-of-sample datapoints and ensure that we are not overfitting. We'll hold 30% of our data back for testing purposes. +# We'll first partition our data into a _training_ and _testing_ split. We'll fit our +# Gaussian process to the training data and evaluate its performance on the test data. +# This allows us to investigate how effectively our Gaussian process generalises to +# out-of-sample datapoints and ensure that we are not overfitting. We'll hold 30% of +# our data back for testing purposes. # %% Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.3, random_state=42) @@ -65,9 +75,19 @@ # %% [markdown] # ### Response Variable # -# We'll now process our response variable $\mathbf{y}$. As the below plots show, the data has a very long tail and is certainly not Gaussian. However, we would like to model a Gaussian response variable so that we can adopt a Gaussian likelihood function and leverage the model's conjugacy. To achieve this, we'll first log-scale the data, to bring the long right tail in closer to the data's mean. We'll then standardise the data such that is distributed according to a unit normal distribution. Both of these transformations are invertible through the log-normal expectation and variance formulae and the the inverse standardisation identity, should we ever need our model's predictions to be back on the scale of the original dataset. +# We'll now process our response variable $\mathbf{y}$. As the below plots show, the +# data has a very long tail and is certainly not Gaussian. However, we would like to +# model a Gaussian response variable so that we can adopt a Gaussian likelihood +# function and leverage the model's conjugacy. To achieve this, we'll first log-scale +# the data, to bring the long right tail in closer to the data's mean. We'll then +# standardise the data such that is distributed according to a unit normal +# distribution. Both of these transformations are invertible through the log-normal +# expectation and variance formulae and the the inverse standardisation identity, +# should we ever need our model's predictions to be back on the scale of the +# original dataset. # -# For transforming both the input and response variable, all transformations will be done with respect to the training data where relevant. +# For transforming both the input and response variable, all transformations will be +# done with respect to the training data where relevant. # %% log_ytr = np.log(ytr) @@ -92,7 +112,8 @@ # %% [markdown] # ### Input Variable # -# We'll now transform our input variable $\mathbf{X}$ to be distributed according to a unit Gaussian. +# We'll now transform our input variable $\mathbf{X}$ to be distributed according to a +# unit Gaussian. # %% x_scaler = StandardScaler().fit(Xtr) @@ -102,18 +123,32 @@ # %% [markdown] # ## Model fitting # -# With data now loaded and preprocessed, we'll proceed to defining a Gaussian process model and optimising its parameters. This notebook purposefully does not go into great detail on this process, so please see notebooks such as the [Regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html) and [Classification notebook](https://gpjax.readthedocs.io/en/latest/nbs/classification.html) for further information. +# With data now loaded and preprocessed, we'll proceed to defining a Gaussian process +# model and optimising its parameters. This notebook purposefully does not go into +# great detail on this process, so please see notebooks such as the +# [Regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html) +# and +# [Classification notebook](https://gpjax.readthedocs.io/en/latest/nbs/classification.html) +# for further information. # # ### Model specification # -# We'll use a radial basis function kernel to parameterise the Gaussian process in this notebook. As we have 5 covariates, we'll assign each covariate its own lengthscale parameter. This form of kernel is commonly known as an automatic relevance determination (ARD) kernel. +# We'll use a radial basis function kernel to parameterise the Gaussian process in this +# notebook. As we have 5 covariates, we'll assign each covariate its own lengthscale +# parameter. This form of kernel is commonly known as an automatic relevance +# determination (ARD) kernel. # -# In practice, the exact form of kernel used should be selected such that it represents your understanding of the data. For example, if you were to model temperature; a process that we know to be periodic, then you would likely wish to select a periodic kernel. Having _Gaussian-ised_ our data somewhat, we'll also adopt a Gaussian likelihood function. +# In practice, the exact form of kernel used should be selected such that it +# represents your understanding of the data. For example, if you were to model +# temperature; a process that we know to be periodic, then you would likely wish to +# select a periodic kernel. Having _Gaussian-ised_ our data somewhat, we'll also adopt +# a Gaussian likelihood function. # %% n_train, n_covariates = scaled_Xtr.shape -kernel = jk.RBF(active_dims=list(range(n_covariates))) -prior = gpx.Prior(kernel=kernel) +kernel = gpx.RBF(active_dims=list(range(n_covariates))) +meanf = gpx.mean_functions.Zero() +prior = gpx.Prior(mean_function=meanf, kernel=kernel) likelihood = gpx.Gaussian(num_datapoints=n_train) @@ -122,38 +157,34 @@ # %% [markdown] # ### Model Optimisation # -# With a model now defined, we can proceed to optimise the hyperparameters of our model using Optax. +# With a model now defined, we can proceed to optimise the hyperparameters of our +# model using Optax. # %% -training_data = Dataset(X=scaled_Xtr, y=scaled_ytr) +training_data = gpx.Dataset(X=scaled_Xtr, y=scaled_ytr) -parameter_state = gpx.initialise(posterior, key) -negative_mll = jit( - posterior.marginal_log_likelihood(train_data=training_data, negative=True) -) -optimiser = ox.adam(0.05) +negative_mll = gpx.ConjugateMLL(negative=True) +optimiser = ox.adamw(0.05) -inference_state = gpx.fit( +opt_posterior, history = gpx.fit( + model=posterior, objective=negative_mll, - parameter_state=parameter_state, - optax_optim=optimiser, - num_iters=1000, - log_rate=50, + train_data=training_data, + optim=ox.adamw(learning_rate=0.05), + num_iters=500, ) -learned_params, training_history = inference_state.unpack() - # %% [markdown] # ## Prediction # -# With an optimal set of parameters learned, we can make predictions on the set of data that we held back right at the start. We'll do this in the usual way by first computing the latent function's distribution before computing the predictive posterior distribution. +# With an optimal set of parameters learned, we can make predictions on the set of +# data that we held back right at the start. We'll do this in the usual way by first +# computing the latent function's distribution before computing the predictive +# posterior distribution. # %% -latent_dist = posterior( - learned_params, - training_data, -)(scaled_Xte) -predictive_dist = likelihood(learned_params, latent_dist) +latent_dist = opt_posterior(scaled_Xte, training_data) +predictive_dist = likelihood(latent_dist) predictive_mean = predictive_dist.mean() predictive_stddev = predictive_dist.stddev() @@ -161,11 +192,21 @@ # %% [markdown] # ## Evaluation # -# We'll now show how the performance of our Gaussian process can be evaluated by numerically and visually. +# We'll now show how the performance of our Gaussian process can be evaluated by +# numerically and visually. # # ### Metrics # -# To numerically assess the performance of our model, two commonly used metrics are root mean squared error (RMSE) and the R2 coefficient. RMSE is simply the square root of the squared difference between predictions and actuals. A value of 0 for this metric implies that our model has 0 generalisation error on the test set. R2 measures the amount of variation within the data that is explained by the model. This can be useful when designing variance reduction methods such as control variates as it allows you to understand what proportion of the data's variance will be soaked up. A perfect model here would score 1 for R2 score, whereas predicting the data's mean would score 0 and models doing worse than simple mean predictions can score less than 0. +# To numerically assess the performance of our model, two commonly used metrics are +# root mean squared error (RMSE) and the R2 coefficient. RMSE is simply the square +# root of the squared difference between predictions and actuals. A value of 0 for +# this metric implies that our model has 0 generalisation error on the test set. R2 +# measures the amount of variation within the data that is explained by the model. +# This can be useful when designing variance reduction methods such as control +# variates as it allows you to understand what proportion of the data's variance will +# be soaked up. A perfect model here would score 1 for R2 score, whereas predicting +# the data's mean would score 0 and models doing worse than simple mean predictions +# can score less than 0. # %% rmse = mean_squared_error(y_true=scaled_yte.squeeze(), y_pred=predictive_mean) @@ -173,17 +214,27 @@ print(f"Results:\n\tRMSE: {rmse: .4f}\n\tR2: {r2: .2f}") # %% [markdown] -# Both of these metrics seem very promising, so, based off these, we can be quite happy that our first attempt at modelling the Yacht data is promising. +# Both of these metrics seem very promising, so, based off these, we can be quite +# happy that our first attempt at modelling the Yacht data is promising. # # ### Diagnostic plots # -# To accompany the above metrics, we can also produce residual plots to explore exactly where our model's shortcomings lie. If we define a residual as the true value minus the prediction, then we can produce three plots: +# To accompany the above metrics, we can also produce residual plots to explore +# exactly where our model's shortcomings lie. If we define a residual as the true +# value minus the prediction, then we can produce three plots: # # 1. Predictions vs. actuals. # 2. Predictions vs. residuals. # 3. Residual density. # -# The first plot allows us to explore if our model struggles to predict well for larger or smaller values by observing where the model deviates more from the line $y=x$. In the second plot we can inspect whether or not there were outliers or structure within the errors of our model. A well-performing model would have predictions close to and symmetrically distributed either side of $y=0$. Such a plot can be useful for diagnosing heteroscedasticity. Finally, by plotting a histogram of our residuals we can observe whether or not there is any skew to our residuals. +# The first plot allows us to explore if our model struggles to predict well for +# larger or smaller values by observing where the model deviates more from the line +# $y=x$. In the second plot we can inspect whether or not there were outliers or +# structure within the errors of our model. A well-performing model would have +# predictions close to and symmetrically distributed either side of $y=0$. Such a +# plot can be useful for diagnosing heteroscedasticity. Finally, by plotting a +# histogram of our residuals we can observe whether or not there is any skew to +# our residuals. # %% residuals = scaled_yte.squeeze() - predictive_mean @@ -203,7 +254,12 @@ ax[2].set_title("Residuals") # %% [markdown] -# From this, we can see that our model is struggling to predict the smallest values of the Yacht's hydrodynamic and performs increasingly well as the Yacht's hydrodynamic performance increases. This is likely due to the original data's heavy right-skew, and successive modelling attempts may wish to introduce a heteroscedastic likelihood function that would enable more flexible modelling of the smaller response values. +# From this, we can see that our model is struggling to predict the smallest values +# of the Yacht's hydrodynamic and performs increasingly well as the Yacht's +# hydrodynamic performance increases. This is likely due to the original data's heavy +# right-skew, and successive modelling attempts may wish to introduce a +# heteroscedastic likelihood function that would enable more flexible modelling of +# the smaller response values. # # ## System configuration diff --git a/gpjax/fit.py b/gpjax/fit.py index b4dddda2f..4f23ff2ee 100644 --- a/gpjax/fit.py +++ b/gpjax/fit.py @@ -175,7 +175,7 @@ def _check_model(model: Any) -> None: def _check_objective(objective: Any) -> None: """Check that the objective is of type Objective.""" if not isinstance(objective, AbstractObjective): - raise TypeError("objective must be of type jaxutils.Objective") + raise TypeError(f"objective of type {type(objective)} must be of type jaxutils.Objective.") def _check_train_data(train_data: Any) -> None: diff --git a/gpjax/kernels/approximations/rff.py b/gpjax/kernels/approximations/rff.py index 249fc5c31..0399d2464 100644 --- a/gpjax/kernels/approximations/rff.py +++ b/gpjax/kernels/approximations/rff.py @@ -54,6 +54,7 @@ def __post_init__(self) -> None: self.frequencies = self.base_kernel.spectral_density.sample( seed=self.key, sample_shape=(self.num_basis_fns, n_dims) ) + self.name = f"{self.base_kernel.name} (RFF)" def __call__(self, x: Array, y: Array) -> Array: pass diff --git a/gpjax/kernels/base.py b/gpjax/kernels/base.py index 277d126b1..37a92d73b 100644 --- a/gpjax/kernels/base.py +++ b/gpjax/kernels/base.py @@ -35,6 +35,7 @@ class AbstractKernel(Module): compute_engine: AbstractKernelComputation = static_field(DenseKernelComputation) active_dims: List[int] = static_field(None) + name: str = static_field("AbstractKernel") @property def ndims(self): diff --git a/gpjax/kernels/non_euclidean/graph.py b/gpjax/kernels/non_euclidean/graph.py index 40d09c867..2188ca6d8 100644 --- a/gpjax/kernels/non_euclidean/graph.py +++ b/gpjax/kernels/non_euclidean/graph.py @@ -17,7 +17,7 @@ import jax.numpy as jnp import tensorflow_probability.substrates.jax as tfp -from jaxtyping import Array, Float +from jaxtyping import Array, Float, Int from simple_pytree import static_field from ...base import param_field @@ -32,7 +32,7 @@ ########################################## @dataclass class AbstractGraphKernel: - laplacian: Float[Array, "N N"] + laplacian: Float[Array, "N N"] = static_field() @dataclass @@ -44,20 +44,19 @@ class GraphKernel(AbstractKernel, AbstractGraphKernel): compute_engine """ - lengthscale: Float[Array, "D"] = param_field( - jnp.array([1.0]), bijector=tfb.Softplus - ) - variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus) - smoothness: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus) + lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + smoothness: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) eigenvalues: Float[Array, "N"] = static_field(None) eigenvectors: Float[Array, "N N"] = static_field(None) - num_vertex: Float[Array, "1"] = static_field(None) + num_vertex: Int[Array, "1"] = static_field(None) compute_engine: AbstractKernelComputation = static_field(EigenKernelComputation) - + name: str = "Graph Matérn" def __post_init__(self): evals, self.eigenvectors = jnp.linalg.eigh(self.laplacian) self.eigenvalues = evals.reshape(-1, 1) - self.num_vertex = self.eigenvalues.shape[0] + if self.num_vertex is None: + self.num_vertex = self.eigenvalues.shape[0] def __call__( self, diff --git a/gpjax/kernels/nonstationary/linear.py b/gpjax/kernels/nonstationary/linear.py index fafee3bfd..f83ffc7b8 100644 --- a/gpjax/kernels/nonstationary/linear.py +++ b/gpjax/kernels/nonstationary/linear.py @@ -31,6 +31,7 @@ class Linear(AbstractKernel): """The linear kernel.""" variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + name: str = "Linear" def __call__( self, diff --git a/gpjax/kernels/nonstationary/polynomial.py b/gpjax/kernels/nonstationary/polynomial.py index 3b6219b9b..c5a5027e0 100644 --- a/gpjax/kernels/nonstationary/polynomial.py +++ b/gpjax/kernels/nonstationary/polynomial.py @@ -35,6 +35,9 @@ class Polynomial(AbstractKernel): shift: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + def __post_init__(self): + self.name = f"Polynomial (degree {self.degree})" + def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with shift parameter :math:`\\alpha` and variance :math:`\\sigma^2` through diff --git a/gpjax/kernels/stationary/matern12.py b/gpjax/kernels/stationary/matern12.py index 9d0c4ce39..6a06dfaf7 100644 --- a/gpjax/kernels/stationary/matern12.py +++ b/gpjax/kernels/stationary/matern12.py @@ -32,6 +32,7 @@ class Matern12(AbstractKernel): lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + name: str = "Matérn12" def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with diff --git a/gpjax/kernels/stationary/matern32.py b/gpjax/kernels/stationary/matern32.py index 3484f2b70..30bc3e55a 100644 --- a/gpjax/kernels/stationary/matern32.py +++ b/gpjax/kernels/stationary/matern32.py @@ -32,6 +32,7 @@ class Matern32(AbstractKernel): lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + name: str = "Matérn32" def __call__( self, diff --git a/gpjax/kernels/stationary/matern52.py b/gpjax/kernels/stationary/matern52.py index 1db973047..b5b4e4ef0 100644 --- a/gpjax/kernels/stationary/matern52.py +++ b/gpjax/kernels/stationary/matern52.py @@ -32,6 +32,7 @@ class Matern52(AbstractKernel): lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + name: str = "Matérn52" def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with diff --git a/gpjax/kernels/stationary/periodic.py b/gpjax/kernels/stationary/periodic.py index c5e35693a..b01db6ae2 100644 --- a/gpjax/kernels/stationary/periodic.py +++ b/gpjax/kernels/stationary/periodic.py @@ -36,6 +36,7 @@ class Periodic(AbstractKernel): lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) period: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + name: str = "Periodic" def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell` and variance :math:`\\sigma` diff --git a/gpjax/kernels/stationary/powered_exponential.py b/gpjax/kernels/stationary/powered_exponential.py index c2a704e8a..0505d4ef5 100644 --- a/gpjax/kernels/stationary/powered_exponential.py +++ b/gpjax/kernels/stationary/powered_exponential.py @@ -38,6 +38,7 @@ class PoweredExponential(AbstractKernel): lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) power: Float[Array, "1"] = param_field(jnp.array([1.0])) + name: str = "Powered Exponential" def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell`, :math:`\\sigma` and power :math:`\\kappa`. diff --git a/gpjax/kernels/stationary/rational_quadratic.py b/gpjax/kernels/stationary/rational_quadratic.py index f3887c71b..4ff52851d 100644 --- a/gpjax/kernels/stationary/rational_quadratic.py +++ b/gpjax/kernels/stationary/rational_quadratic.py @@ -31,6 +31,7 @@ class RationalQuadratic(AbstractKernel): lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) alpha: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + name: str = "Rational Quadratic" def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell` and variance :math:`\\sigma` diff --git a/gpjax/kernels/stationary/rbf.py b/gpjax/kernels/stationary/rbf.py index c552b60ed..8ffc1efe7 100644 --- a/gpjax/kernels/stationary/rbf.py +++ b/gpjax/kernels/stationary/rbf.py @@ -33,6 +33,7 @@ class RBF(AbstractKernel): lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + name: str = "RBF" def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with diff --git a/gpjax/kernels/stationary/white.py b/gpjax/kernels/stationary/white.py index 83144513c..54076aecc 100644 --- a/gpjax/kernels/stationary/white.py +++ b/gpjax/kernels/stationary/white.py @@ -33,6 +33,7 @@ class White(AbstractKernel): compute_engine: AbstractKernelComputation = static_field( ConstantDiagonalKernelComputation ) + name: str = "White" def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with variance :math:`\\sigma` From e7472fa211254c0dcd32984f96c2df581a55f26a Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Thu, 6 Apr 2023 12:14:57 +0100 Subject: [PATCH 36/44] Docs outline --- docs/sharp_bits.md | 17 ++++ examples/baselines.pct.py | 62 ++++++++++++++ examples/haiku.pct.py | 168 ++++++++++++++++++++----------------- examples/pytree.pct.py | 37 ++++++++ examples/regression.pct.py | 26 +++--- 5 files changed, 221 insertions(+), 89 deletions(-) create mode 100644 docs/sharp_bits.md create mode 100644 examples/baselines.pct.py create mode 100644 examples/pytree.pct.py diff --git a/docs/sharp_bits.md b/docs/sharp_bits.md new file mode 100644 index 000000000..bf1573a7c --- /dev/null +++ b/docs/sharp_bits.md @@ -0,0 +1,17 @@ +# 🔪 The sharp bits + +## Pseudo-randomness + +Can briefly acknowledge and then point to the Jax docs for more information. + +## Float64 + +The need for Float64 when inverting the Gram matrix + +## Positive-definiteness + +The need for jitter in the kernel Gram matrix + +## Slow-to-evaluate + +More than several thousand data points will require the use of inducing points - don't try and use the ConjugateMLL objective on a million data points. \ No newline at end of file diff --git a/examples/baselines.pct.py b/examples/baselines.pct.py new file mode 100644 index 000000000..61bc1be6b --- /dev/null +++ b/examples/baselines.pct.py @@ -0,0 +1,62 @@ +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.2 +# kernelspec: +# display_name: gpjax +# language: python +# name: python3 +# --- + +# %% +import jax.numpy as jnp +import jax.random as jr +import matplotlib.pyplot as plt +import optax as ox +from jax import jit, grad +from jax.config import config + +import gpjax as gpx + +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) +key = jr.PRNGKey(123) + +# %% +n = 1000 +noise = 0.3 + +key, subkey = jr.split(key) +x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,)).reshape(-1, 1) +f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x) +signal = f(x) +y = signal + jr.normal(subkey, shape=signal.shape) * noise + +D = gpx.Dataset(X=x, y=y) + +# %% +kernel = gpx.kernels.RBF() +meanf = gpx.mean_functions.Constant(constant=0.0) +meanf = meanf.replace_trainable(constant=False) +prior = gpx.Prior(mean_function=meanf, kernel=kernel) +likelihood = gpx.Gaussian(num_datapoints=D.n) + +posterior = prior * likelihood + +negative_mll = gpx.objectives.ConjugateMLL(negative=True) + +# %timeit negative_mll(posterior, train_data=D).block_until_ready() + +# %% +# %timeit jit(negative_mll)(posterior, train_data=D).block_until_ready() + +# %% +# %timeit grad(negative_mll)(posterior, train_data=D) + +# %% diff --git a/examples/haiku.pct.py b/examples/haiku.pct.py index d2f13ab2e..ab72f034a 100644 --- a/examples/haiku.pct.py +++ b/examples/haiku.pct.py @@ -17,7 +17,12 @@ # %% [markdown] # # Deep Kernel Learning # -# In this notebook we demonstrate how GPJax can be used in conjunction with [Haiku](https://github.com/deepmind/dm-haiku) to build deep kernel Gaussian processes. Modelling data with discontinuities is a challenging task for regular Gaussian process models. However, as shown in , transforming the inputs to our Gaussian process model's kernel through a neural network can offer a solution to this. +# In this notebook we demonstrate how GPJax can be used in conjunction with +# [Haiku](https://github.com/deepmind/dm-haiku) to build deep kernel Gaussian +# processes. Modelling data with discontinuities is a challenging task for regular +# Gaussian process models. However, as shown in +# , transforming the inputs to our +# Gaussian process model's kernel through a neural network can offer a solution to this. # %% import typing as tp @@ -33,12 +38,15 @@ from jax.config import config from jaxtyping import Array, Float from scipy.signal import sawtooth +from flax import linen as nn +from simple_pytree import static_field import gpjax as gpx import gpjax.kernels as jk from gpjax.kernels import DenseKernelComputation from gpjax.kernels.base import AbstractKernel from gpjax.kernels.computations import AbstractKernelComputation +from gpjax.base import param_field # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -47,7 +55,8 @@ # %% [markdown] # ## Dataset # -# As previously mentioned, deep kernels are particularly useful when the data has discontinuities. To highlight this, we will use a sawtooth function as our data. +# As previously mentioned, deep kernels are particularly useful when the data has +# discontinuities. To highlight this, we will use a sawtooth function as our data. # %% n = 500 @@ -74,105 +83,112 @@ # # ### Details # -# Instead of applying a kernel $k(\cdot, \cdot')$ directly on some data, we seek to apply a _feature map_ $\phi(\cdot)$ that projects the data to learn more meaningful representations beforehand. In deep kernel learning, $\phi$ is a neural network whose parameters are learned jointly with the GP model's hyperparameters. The corresponding kernel is then computed by $k(\phi(\cdot), \phi(\cdot'))$. Here $k(\cdot,\cdot')$ is referred to as the _base kernel_. +# Instead of applying a kernel $k(\cdot, \cdot')$ directly on some data, we seek to +# apply a _feature map_ $\phi(\cdot)$ that projects the data to learn more meaningful +# representations beforehand. In deep kernel learning, $\phi$ is a neural network +# whose parameters are learned jointly with the GP model's hyperparameters. The +# corresponding kernel is then computed by $k(\phi(\cdot), \phi(\cdot'))$. Here +# $k(\cdot,\cdot')$ is referred to as the _base kernel_. # # ### Implementation # -# Although deep kernels are not currently supported natively in GPJax, defining one is straightforward as we now demonstrate. Using the base `AbstractKernel` object given in GPJax, we provide a mixin class named `_DeepKernelFunction` to facilitate the user supplying the neural network and base kernel of their choice. Kernel matrices are then computed using the regular `gram` and `cross_covariance` functions. +# Although deep kernels are not currently supported natively in GPJax, defining one is +# straightforward as we now demonstrate. Using the base `AbstractKernel` object given +# in GPJax, we provide a mixin class named `_DeepKernelFunction` to facilitate the +# user supplying the neural network and base kernel of their choice. Kernel matrices +# are then computed using the regular `gram` and `cross_covariance` functions. # %% +import flax + +@dataclass +class _DeepKernelFunction: + network: static_field(hk.Module) + base_kernel: AbstractKernel + dummy_x: jax.Array = static_field(None) + key: jr.PRNGKeyArray = static_field(jr.PRNGKey(123)) + + def __post_init__(self): + self.nn_params = param_field(flax.core.unfreeze(self.network.init(key, self.dummy_x)), bijector=None) + @dataclass -class DeepKernelFunction(AbstractKernel): - def __init__( - self, - network: hk.Module, - base_kernel: AbstractKernel, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: tp.Optional[tp.List[int]] = None, - ) -> None: - super().__init__(compute_engine, active_dims, True, False, "Deep Kernel") - self.network = network - self.base_kernel = base_kernel - - def __call__( - self, - params: Dict, - x: Float[Array, "1 D"], - y: Float[Array, "1 D"], - ) -> Float[Array, "1"]: - xt = self.network.apply(params=params, x=x) - yt = self.network.apply(params=params, x=y) - return self.base_kernel(params, xt, yt) - - def initialise(self, dummy_x: Float[Array, "1 D"], key: jr.KeyArray) -> None: - nn_params = self.network.init(rng=key, x=dummy_x) - base_kernel_params = self.base_kernel.init_params(key) - self._params = {**nn_params, **base_kernel_params} - - def init_params(self, key: jr.KeyArray) -> Dict: - return self._params - - # This is depreciated. Can be removed once JaxKern is updated. - def _initialise_params(self, key: jr.KeyArray) -> Dict: - return self.init_params(key) +class DeepKernelFunction(AbstractKernel, _DeepKernelFunction): + def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: + state = self.network.init(self.key, x) + xt = self.network.apply(state, x) + yt = self.network.apply(state, y) + return self.base_kernel(xt, yt) # %% [markdown] # ### Defining a network # -# With a deep kernel object created, we proceed to define a neural network. Here we consider a small multi-layer perceptron with two linear hidden layers and ReLU activation functions between the layers. The first hidden layer contains 32 units, while the second layer contains 64 units. Finally, we'll make the output of our network a single unit. However, it would be possible to project our data into a $d-$dimensional space for $d>1$. In these instances, making the [base kernel ARD](https://gpjax.readthedocs.io/en/latest/nbs/kernels.html#Active-dimensions) would be sensible. -# Users may wish to design more intricate network structures for more complex tasks, which functionality is supported well in Haiku. +# With a deep kernel object created, we proceed to define a neural network. Here we +# consider a small multi-layer perceptron with two linear hidden layers and ReLU +# activation functions between the layers. The first hidden layer contains 32 units, +# while the second layer contains 64 units. Finally, we'll make the output of our +# network a single unit. However, it would be possible to project our data into a +# $d-$dimensional space for $d>1$. In these instances, making the +# [base kernel ARD](https://gpjax.readthedocs.io/en/latest/nbs/kernels.html#Active-dimensions) +# would be sensible. +# Users may wish to design more intricate network structures for more complex tasks, +# which functionality is supported well in Haiku. # %% -def forward(x): - mlp = hk.Sequential( - [ - hk.Linear(32), - jax.nn.relu, - hk.Linear(64), - jax.nn.relu, - hk.Linear(1), - ] - ) - return mlp(x) - - -forward_linear1 = hk.transform(forward) -forward_linear1 = hk.without_apply_rng(forward_linear1) +class Network(nn.Module): + """A simple MLP.""" + @nn.compact + def __call__(self, x): + x = nn.Dense(features=128)(x) + x = nn.relu(x) + x = nn.Dense(features=64)(x) + x = nn.relu(x) + x = nn.Dense(features=1)(x) + return x + + +forward_linear = Network() +state = jax.jit(forward_linear.init)(key, jnp.ones(x.shape[-1])) # %% [markdown] # ## Defining a model # -# Having characterised the feature extraction network, we move to define a Gaussian process parameterised by this deep kernel. We consider a third-order Matérn base kernel and assume a Gaussian likelihood. Parameters, trainability status and transformations are initialised in the usual manner. +# Having characterised the feature extraction network, we move to define a Gaussian +# process parameterised by this deep kernel. We consider a third-order Matérn base +# kernel and assume a Gaussian likelihood. Parameters, trainability status and +# transformations are initialised in the usual manner. # %% -base_kernel = jk.RBF() -kernel = DeepKernelFunction(network=forward_linear1, base_kernel=base_kernel) -kernel.initialise(x, key) -prior = gpx.Prior(kernel=kernel) +base_kernel = gpx.RBF() +kernel = DeepKernelFunction(network=forward_linear, base_kernel=base_kernel, key=key, dummy_x=x) +meanf = gpx.Zero() +prior = gpx.Prior(mean_function=meanf, kernel=kernel) likelihood = gpx.Gaussian(num_datapoints=D.n) posterior = prior * likelihood # %% [markdown] # ### Optimisation # -# We train our model via maximum likelihood estimation of the marginal log-likelihood. The parameters of our neural network are learned jointly with the model's hyperparameter set. +# We train our model via maximum likelihood estimation of the marginal log-likelihood. +# The parameters of our neural network are learned jointly with the model's +# hyperparameter set. # -# With the inclusion of a neural network, we take this opportunity to highlight the additional benefits gleaned from using [Optax](https://optax.readthedocs.io/en/latest/) for optimisation. In particular, we showcase the ability to use a learning rate scheduler that decays the optimiser's learning rate throughout the inference. We decrease the learning rate according to a half-cosine curve over 1000 iterations, providing us with large step sizes early in the optimisation procedure before approaching more conservative values, ensuring we do not step too far. We also consider a linear warmup, where the learning rate is increased from 0 to 1 over 50 steps to get a reasonable initial learning rate value. +# With the inclusion of a neural network, we take this opportunity to highlight the +# additional benefits gleaned from using +# [Optax](https://optax.readthedocs.io/en/latest/) for optimisation. In particular, we +# showcase the ability to use a learning rate scheduler that decays the optimiser's +# learning rate throughout the inference. We decrease the learning rate according to a +# half-cosine curve over 1000 iterations, providing us with large step sizes early in +# the optimisation procedure before approaching more conservative values, ensuring we +# do not step too far. We also consider a linear warmup, where the learning rate is +# increased from 0 to 1 over 50 steps to get a reasonable initial learning rate value. # %% -parameter_state = gpx.initialise(posterior, key) - -negative_mll = jax.jit(posterior.marginal_log_likelihood(D, negative=True)) +negative_mll = gpx.ConjugateMLL(negative=True) # %% -parameter_state = gpx.initialise(posterior, key) - -negative_mll = jax.jit(posterior.marginal_log_likelihood(D, negative=True)) -negative_mll(parameter_state.params) - schedule = ox.warmup_cosine_decay_schedule( init_value=0.0, peak_value=0.01, @@ -186,19 +202,19 @@ def forward(x): ox.adamw(learning_rate=schedule), ) -inference_state = gpx.fit( - objective=negative_mll, - parameter_state=parameter_state, - optax_optim=optimiser, +opt_posterior, history = gpx.fit( + model=posterior, + objective=gpx.ConjugateMLL(negative=True), + train_data=D, + optim=optimiser, num_iters=2500, ) -learned_params, training_history = inference_state.unpack() - # %% [markdown] # ## Prediction # -# With a set of learned parameters, the only remaining task is to predict the output of the model. We can do this by simply applying the model to a test data set. +# With a set of learned parameters, the only remaining task is to predict the output +# of the model. We can do this by simply applying the model to a test data set. # %% latent_dist = posterior(learned_params, D)(xtest) diff --git a/examples/pytree.pct.py b/examples/pytree.pct.py new file mode 100644 index 000000000..d8d87ea0b --- /dev/null +++ b/examples/pytree.pct.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +# --- +# jupyter: +# jupytext: +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.2 +# kernelspec: +# display_name: base +# language: python +# name: python3 +# --- + +# %% [markdown] +# # Guide to PyTrees +# ## Immutable strucutres (Dan) +# +# - Intro to PyTrees +# - Why PyTrees +# - Diagrams labelling the components of a PyTree +# +# ## Operations on PyTrees (Tom) +# +# - Import the RBF kernel +# - Computing gradients +# - Squaring the leaves +# - Adding two PyTrees +# +# ## Writing your own PyTree +# +# - Give an example of creating the `Constant` mean function +# - Demonstrate fixing the mean vs. learning it +# - Transforming the parameter's value +# - Changing the bijection diff --git a/examples/regression.pct.py b/examples/regression.pct.py index 16c34bc60..e5204c7e7 100644 --- a/examples/regression.pct.py +++ b/examples/regression.pct.py @@ -18,7 +18,7 @@ # # In this notebook we demonstate how to fit a Gaussian process regression model. -# %% +# %% vscode={"languageId": "python"} from pprint import PrettyPrinter import jax.numpy as jnp @@ -45,7 +45,7 @@ # # We store our data $\mathcal{D}$ as a GPJax `Dataset` and create test inputs and labels for later. -# %% +# %% vscode={"languageId": "python"} n = 100 noise = 0.3 @@ -64,7 +64,7 @@ # To better understand what we have simulated, we plot both the underlying latent # function and the observed data that is subject to Gaussian noise. -# %% +# %% vscode={"languageId": "python"} fig, ax = plt.subplots(figsize=(10, 5)) ax.plot(xtest, ytest, label="Latent function") ax.plot(x, y, "o", label="Observations") @@ -96,7 +96,7 @@ # we can reciprocate this process in GPJax via defining a `Prior` with our chosen `RBF` # kernel. -# %% +# %% vscode={"languageId": "python"} kernel = gpx.kernels.RBF() meanf = gpx.mean_functions.Constant(constant=0.0) meanf = meanf.replace_trainable(constant=False) @@ -109,7 +109,7 @@ # multivariate Gaussian distribution. Such functionality enables trivial sampling, and # mean and covariance evaluation of the GP. -# %% +# %% vscode={"languageId": "python"} prior_dist = prior.predict(xtest) prior_mean = prior_dist.mean() @@ -138,7 +138,7 @@ # $$p(\mathcal{D} | f(\cdot)) = \mathcal{N}(\boldsymbol{y}; f(\boldsymbol{x}), \textbf{I} \alpha^2).$$ # This is defined in GPJax through calling a `Gaussian` instance. -# %% +# %% vscode={"languageId": "python"} likelihood = gpx.Gaussian(num_datapoints=D.n) # %% [markdown] @@ -148,7 +148,7 @@ # # Mimicking this construct, the posterior is established in GPJax through the `*` operator. -# %% +# %% vscode={"languageId": "python"} posterior = prior * likelihood # %% [markdown] @@ -176,8 +176,8 @@ # in the following cell, we'll demonstrate how the kernel lengthscale can be # initialised to 0.5. -# %% -negative_mll = gpx.objectives.ConjugateMLL(negative=True) +# %% vscode={"languageId": "python"} +negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True)) negative_mll(posterior, train_data=D) # %% [markdown] @@ -189,7 +189,7 @@ # We can now define an optimiser with `optax`. For this example we'll use the `adam` # optimiser. -# %% +# %% vscode={"languageId": "python"} opt_posterior, history = gpx.fit( model=posterior, objective=gpx.ConjugateMLL(negative=True), @@ -213,7 +213,7 @@ # the predictive distribution as a `Distrax` multivariate Gaussian upon which `mean` # and `stddev` can be used to extract the predictive mean and standard deviatation. -# %% +# %% vscode={"languageId": "python"} latent_dist = opt_posterior.predict(xtest, train_data=D) predictive_dist = opt_posterior.likelihood(latent_dist) @@ -225,7 +225,7 @@ # performance at explaining the data $\mathcal{D}$ and recovering the underlying # latent function of interest. -# %% +# %% vscode={"languageId": "python"} fig, ax = plt.subplots(figsize=(12, 5)) ax.plot(x, y, "o", label="Observations", color="tab:red") ax.plot(xtest, predictive_mean, label="Predictive mean", color="tab:blue") @@ -261,6 +261,6 @@ # %% [markdown] # ## System configuration -# %% +# %% vscode={"languageId": "python"} # %reload_ext watermark # %watermark -n -u -v -iv -w -a 'Thomas Pinder & Daniel Dodd' From 1616e1c77124d62258842a9c40833e7e17c1e07b Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Thu, 6 Apr 2023 14:47:10 +0100 Subject: [PATCH 37/44] Push fix. --- examples/haiku.pct.py | 30 ++++++++++++++++++++---------- gpjax/base/module.py | 13 +++++++++++++ 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/examples/haiku.pct.py b/examples/haiku.pct.py index ab72f034a..5f28fa125 100644 --- a/examples/haiku.pct.py +++ b/examples/haiku.pct.py @@ -29,7 +29,6 @@ from dataclasses import dataclass from typing import Dict -import haiku as hk import jax import jax.numpy as jnp import jax.random as jr @@ -101,19 +100,31 @@ # %% import flax +from dataclasses import field +from typing import Any +from simple_pytree import static_field + @dataclass -class _DeepKernelFunction: - network: static_field(hk.Module) - base_kernel: AbstractKernel +class DeepKernelFunction(AbstractKernel): + base_kernel: AbstractKernel = None + network: nn.Module = static_field(None) dummy_x: jax.Array = static_field(None) key: jr.PRNGKeyArray = static_field(jr.PRNGKey(123)) + nn_params: Any = field(init=False, repr=False) def __post_init__(self): - self.nn_params = param_field(flax.core.unfreeze(self.network.init(key, self.dummy_x)), bijector=None) -@dataclass -class DeepKernelFunction(AbstractKernel, _DeepKernelFunction): + if self.base_kernel is None: + raise ValueError("base_kernel must be specified") + + if self.network is None: + raise ValueError("network must be specified") + + + self.nn_params = flax.core.unfreeze(self.network.init(key, self.dummy_x)) + + def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: state = self.network.init(self.key, x) xt = self.network.apply(state, x) @@ -141,9 +152,9 @@ class Network(nn.Module): """A simple MLP.""" @nn.compact def __call__(self, x): - x = nn.Dense(features=128)(x) + x = nn.Dense(features=4)(x) x = nn.relu(x) - x = nn.Dense(features=64)(x) + x = nn.Dense(features=2)(x) x = nn.relu(x) x = nn.Dense(features=1)(x) return x @@ -167,7 +178,6 @@ def __call__(self, x): prior = gpx.Prior(mean_function=meanf, kernel=kernel) likelihood = gpx.Gaussian(num_datapoints=D.n) posterior = prior * likelihood - # %% [markdown] # ### Optimisation # diff --git a/gpjax/base/module.py b/gpjax/base/module.py index 34ba57e63..3eba2f5b6 100644 --- a/gpjax/base/module.py +++ b/gpjax/base/module.py @@ -22,6 +22,7 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple import jax +from jax import lax import jax.tree_util as jtu from jax._src.tree_util import _registry from simple_pytree import Pytree, static_field @@ -124,6 +125,10 @@ def constrain(self) -> Self: def _apply_constrain(meta_leaf): meta, leaf = meta_leaf + + if meta is None: + return leaf + return meta.get("bijector", tfb.Identity()).forward(leaf) return meta_map(_apply_constrain, self) @@ -137,6 +142,10 @@ def unconstrain(self) -> Self: def _apply_unconstrain(meta_leaf): meta, leaf = meta_leaf + + if meta is None: + return leaf + return meta.get("bijector", tfb.Identity()).inverse(leaf) return meta_map(_apply_unconstrain, self) @@ -154,6 +163,10 @@ def _stop_grad(leaf: jax.Array, trainable: bool) -> jax.Array: def _apply_stop_grad(meta_leaf): meta, leaf = meta_leaf + + if meta is None: + return leaf + return _stop_grad(leaf, meta.get("trainable", True)) return meta_map(_apply_stop_grad, self) From 2496ec6f5f8154eac2c071148d193ad6b8ab3a02 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Fri, 7 Apr 2023 08:53:11 +0100 Subject: [PATCH 38/44] DKL fixed --- examples/haiku.pct.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/examples/haiku.pct.py b/examples/haiku.pct.py index 5f28fa125..6efc17a23 100644 --- a/examples/haiku.pct.py +++ b/examples/haiku.pct.py @@ -9,7 +9,7 @@ # format_version: '1.3' # jupytext_version: 1.11.2 # kernelspec: -# display_name: Python 3.9.7 ('gpjax') +# display_name: gpjax # language: python # name: python3 # --- @@ -148,20 +148,21 @@ def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, " # %% +feature_space_dim = 3 + class Network(nn.Module): """A simple MLP.""" @nn.compact def __call__(self, x): - x = nn.Dense(features=4)(x) + x = nn.Dense(features=32)(x) x = nn.relu(x) - x = nn.Dense(features=2)(x) + x = nn.Dense(features=64)(x) x = nn.relu(x) - x = nn.Dense(features=1)(x) + x = nn.Dense(features=feature_space_dim)(x) return x forward_linear = Network() -state = jax.jit(forward_linear.init)(key, jnp.ones(x.shape[-1])) # %% [markdown] # ## Defining a model @@ -172,7 +173,7 @@ def __call__(self, x): # transformations are initialised in the usual manner. # %% -base_kernel = gpx.RBF() +base_kernel = gpx.Matern52(active_dims=list(range(feature_space_dim))) kernel = DeepKernelFunction(network=forward_linear, base_kernel=base_kernel, key=key, dummy_x=x) meanf = gpx.Zero() prior = gpx.Prior(mean_function=meanf, kernel=kernel) @@ -202,8 +203,8 @@ def __call__(self, x): schedule = ox.warmup_cosine_decay_schedule( init_value=0.0, peak_value=0.01, - warmup_steps=50, - decay_steps=1_000, + warmup_steps=75, + decay_steps=700, end_value=0.0, ) @@ -217,7 +218,7 @@ def __call__(self, x): objective=gpx.ConjugateMLL(negative=True), train_data=D, optim=optimiser, - num_iters=2500, + num_iters=1000, ) # %% [markdown] @@ -227,8 +228,8 @@ def __call__(self, x): # of the model. We can do this by simply applying the model to a test data set. # %% -latent_dist = posterior(learned_params, D)(xtest) -predictive_dist = likelihood(learned_params, latent_dist) +latent_dist = opt_posterior(xtest, train_data=D) +predictive_dist = opt_posterior.likelihood(latent_dist) predictive_mean = predictive_dist.mean() predictive_std = predictive_dist.stddev() From 8e33bc0ae5f7de9882e3821a1f6efb908d0162c0 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Fri, 7 Apr 2023 09:04:05 +0100 Subject: [PATCH 39/44] Docs up-to-date --- docs/index.md | 13 +- examples/baselines.pct.py | 62 ----- examples/{haiku.pct.py => deep_kernels.py} | 13 +- examples/graph_kernels.pct.py | 19 +- examples/natgrads.pct.py | 213 ----------------- examples/{pytree.pct.py => pytrees.pct.py} | 0 examples/tfp_integration.pct.py | 258 --------------------- examples/uncollapsed_vi.pct.py | 1 - gpjax/kernels/non_euclidean/graph.py | 13 +- 9 files changed, 20 insertions(+), 572 deletions(-) delete mode 100644 examples/baselines.pct.py rename examples/{haiku.pct.py => deep_kernels.py} (99%) delete mode 100644 examples/natgrads.pct.py rename examples/{pytree.pct.py => pytrees.pct.py} (100%) delete mode 100644 examples/tfp_integration.pct.py diff --git a/docs/index.md b/docs/index.md index 01ec15d78..bcab6d302 100644 --- a/docs/index.md +++ b/docs/index.md @@ -60,6 +60,7 @@ installation design contributing examples/intro_to_gps +examples/pytree ``` ```{toctree} @@ -74,8 +75,7 @@ examples/uncollapsed_vi examples/collapsed_vi examples/graph_kernels examples/barycentres -examples/haiku -examples/tfp_integration +examples/deep_kernels ``` ```{toctree} @@ -88,15 +88,6 @@ examples/kernels examples/yacht ``` -```{toctree} ---- -maxdepth: 1 -caption: Experimental -hidden: ---- -examples/natgrads -``` - ```{toctree} --- maxdepth: 1 diff --git a/examples/baselines.pct.py b/examples/baselines.pct.py deleted file mode 100644 index 61bc1be6b..000000000 --- a/examples/baselines.pct.py +++ /dev/null @@ -1,62 +0,0 @@ -# --- -# jupyter: -# jupytext: -# cell_metadata_filter: -all -# custom_cell_magics: kql -# text_representation: -# extension: .py -# format_name: percent -# format_version: '1.3' -# jupytext_version: 1.11.2 -# kernelspec: -# display_name: gpjax -# language: python -# name: python3 -# --- - -# %% -import jax.numpy as jnp -import jax.random as jr -import matplotlib.pyplot as plt -import optax as ox -from jax import jit, grad -from jax.config import config - -import gpjax as gpx - -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) -key = jr.PRNGKey(123) - -# %% -n = 1000 -noise = 0.3 - -key, subkey = jr.split(key) -x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,)).reshape(-1, 1) -f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x) -signal = f(x) -y = signal + jr.normal(subkey, shape=signal.shape) * noise - -D = gpx.Dataset(X=x, y=y) - -# %% -kernel = gpx.kernels.RBF() -meanf = gpx.mean_functions.Constant(constant=0.0) -meanf = meanf.replace_trainable(constant=False) -prior = gpx.Prior(mean_function=meanf, kernel=kernel) -likelihood = gpx.Gaussian(num_datapoints=D.n) - -posterior = prior * likelihood - -negative_mll = gpx.objectives.ConjugateMLL(negative=True) - -# %timeit negative_mll(posterior, train_data=D).block_until_ready() - -# %% -# %timeit jit(negative_mll)(posterior, train_data=D).block_until_ready() - -# %% -# %timeit grad(negative_mll)(posterior, train_data=D) - -# %% diff --git a/examples/haiku.pct.py b/examples/deep_kernels.py similarity index 99% rename from examples/haiku.pct.py rename to examples/deep_kernels.py index 6efc17a23..dac62226f 100644 --- a/examples/haiku.pct.py +++ b/examples/deep_kernels.py @@ -37,7 +37,7 @@ from jax.config import config from jaxtyping import Array, Float from scipy.signal import sawtooth -from flax import linen as nn +from flax import linen as nn from simple_pytree import static_field import gpjax as gpx @@ -103,7 +103,7 @@ from dataclasses import field from typing import Any from simple_pytree import static_field - + @dataclass class DeepKernelFunction(AbstractKernel): @@ -112,19 +112,14 @@ class DeepKernelFunction(AbstractKernel): dummy_x: jax.Array = static_field(None) key: jr.PRNGKeyArray = static_field(jr.PRNGKey(123)) nn_params: Any = field(init=False, repr=False) - - def __post_init__(self): + def __post_init__(self): if self.base_kernel is None: raise ValueError("base_kernel must be specified") - if self.network is None: raise ValueError("network must be specified") - - self.nn_params = flax.core.unfreeze(self.network.init(key, self.dummy_x)) - def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: state = self.network.init(self.key, x) xt = self.network.apply(state, x) @@ -160,7 +155,7 @@ def __call__(self, x): x = nn.relu(x) x = nn.Dense(features=feature_space_dim)(x) return x - + forward_linear = Network() diff --git a/examples/graph_kernels.pct.py b/examples/graph_kernels.pct.py index e89ee2ddd..913f2dbc8 100644 --- a/examples/graph_kernels.pct.py +++ b/examples/graph_kernels.pct.py @@ -9,7 +9,7 @@ # format_version: '1.3' # jupytext_version: 1.11.2 # kernelspec: -# display_name: Python 3.9.7 ('gpjax') +# display_name: gpjax # language: python # name: python3 # --- @@ -99,8 +99,8 @@ # %% x = jnp.arange(G.number_of_nodes()).reshape(-1, 1) -kernel = gpx.GraphKernel(laplacian=L) -prior = gpx.Prior(mean_function=gpx.Zero(), kernel=kernel) +true_kernel = gpx.GraphKernel(laplacian=L, lengthscale=jnp.array([2.3]), variance=jnp.array([3.2]), smoothness=jnp.array([6.1])) +prior = gpx.Prior(mean_function=gpx.Zero(), kernel=true_kernel) fx = prior(x) y = fx.sample(seed=key).reshape(-1, 1) @@ -134,13 +134,9 @@ # [regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html). # We do this using the Adam optimiser provided in `optax`. -# %% -from gpjax.base.module import meta_leaves - -meta_leaves(posterior)[1] - # %% likelihood = gpx.Gaussian(num_datapoints=D.n) +prior = gpx.Prior(mean_function=gpx.Zero(), kernel=gpx.GraphKernel(laplacian=L)) posterior = prior * likelihood opt_posterior, training_history = gpx.fit( @@ -161,9 +157,8 @@ # (RMSE) of the model for the initialised parameters vs the optimised set. # %% -initial_params = parameter_state.params -initial_dist = likelihood(initial_params, posterior(initial_params, D)(x)) -predictive_dist = likelihood(learned_params, posterior(learned_params, D)(x)) +initial_dist = likelihood(posterior(x, D)) +predictive_dist = opt_posterior.likelihood(opt_posterior(x, D)) initial_mean = initial_dist.mean() learned_mean = predictive_dist.mean() @@ -204,3 +199,5 @@ # %% # %reload_ext watermark # %watermark -n -u -v -iv -w -a 'Thomas Pinder (edited by Daniel Dodd)' + +# %% diff --git a/examples/natgrads.pct.py b/examples/natgrads.pct.py deleted file mode 100644 index d5df6474e..000000000 --- a/examples/natgrads.pct.py +++ /dev/null @@ -1,213 +0,0 @@ -# -*- coding: utf-8 -*- -# --- -# jupyter: -# jupytext: -# custom_cell_magics: kql -# text_representation: -# extension: .py -# format_name: percent -# format_version: '1.3' -# jupytext_version: 1.11.2 -# kernelspec: -# display_name: base -# language: python -# name: python3 -# --- - -# %% [markdown] -# # Natural Gradients - -# %% [markdown] -# In this notebook, we show how to create natural gradients. Ordinary gradient descent algorithms are an undesirable for variational inference because we are minimising the KL divergence between distributions rather than a set of parameters directly. Natural gradients, on the other hand, accounts for the curvature induced by the KL divergence that has the capacity to considerably improve performance (see e.g., Salimbeni et al. (2018) for further details). - -# %% -import jax.numpy as jnp -import jax.random as jr -import matplotlib.pyplot as plt -import optax as ox -from jax.config import config - -import gpjax as gpx - -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) -key = jr.PRNGKey(123) - -# %% [markdown] -# # Dataset: - -# %% [markdown] -# We simulate a dataset $\mathcal{D} = (\boldsymbol{x}, \boldsymbol{y}) = \{(x_i, y_i)\}_{i=1}^{5000}$ with inputs $\boldsymbol{x}$ sampled uniformly on $(-5, 5)$ and corresponding binary outputs -# -# $$\boldsymbol{y} \sim \mathcal{N} \left(\sin(4 * \boldsymbol{x}) + \sin(2 * \boldsymbol{x}), \textbf{I} * (0.2)^{2} \right).$$ -# -# We store our data $\mathcal{D}$ as a GPJax `Dataset` and create test inputs for later. - -# %% -n = 5000 -noise = 0.2 - -key, subkey = jr.split(key) -x = jr.uniform(key=key, minval=-5.0, maxval=5.0, shape=(n,)).reshape(-1, 1) -f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x) -signal = f(x) -y = signal + jr.normal(subkey, shape=signal.shape) * noise - -D = gpx.Dataset(X=x, y=y) -xtest = jnp.linspace(-5.5, 5.5, 500).reshape(-1, 1) - -# %% [markdown] -# Intialise inducing points: - -# %% -z = jnp.linspace(-5.0, 5.0, 20).reshape(-1, 1) - -fig, ax = plt.subplots(figsize=(12, 5)) -ax.plot(x, y, "o", alpha=0.3) -ax.plot(xtest, f(xtest)) -[ax.axvline(x=z_i, color="black", alpha=0.3, linewidth=1) for z_i in z] -plt.show() - -# %% [markdown] -# # Natural gradients: - -# %% [markdown] -# We begin by defining our model, variational family and variational inference strategy: - -# %% -likelihood = gpx.Gaussian(num_datapoints=n) -meanf = gpx.mean_functions.Zero() -kernel = gpx.RBF() -prior = gpx.Prior(mean_function=meanf, kernel=kernel) -p = prior * likelihood - - -natural_q = gpx.NaturalVariationalGaussian(posterior=p, inducing_inputs=z) -natural_svgp = gpx.ELBO(negative=True) - -# %% [markdown] -# Next, we can conduct natural gradients as follows: - -# %% -inference_state = gpx.fit_natgrads( - natural_svgp, - parameter_state=parameter_state, - train_data=D, - num_iters=5000, - batch_size=256, - key=jr.PRNGKey(42), - moment_optim=ox.sgd(0.01), - hyper_optim=ox.adam(1e-3), -) - -learned_params, training_history = inference_state.unpack() - -# %% [markdown] -# Here is the fitted model: - -# %% -latent_dist = natural_q(learned_params)(xtest) -predictive_dist = likelihood(learned_params, latent_dist) - -meanf = predictive_dist.mean() -sigma = predictive_dist.stddev() - -fig, ax = plt.subplots(figsize=(12, 5)) -ax.plot(x, y, "o", alpha=0.15, label="Training Data", color="tab:gray") -ax.plot(xtest, meanf, label="Posterior mean", color="tab:blue") -ax.fill_between(xtest.flatten(), meanf - sigma, meanf + sigma, alpha=0.3) -[ - ax.axvline(x=z_i, color="black", alpha=0.3, linewidth=1) - for z_i in learned_params["variational_family"]["inducing_inputs"] -] -plt.show() - -# %% [markdown] -# # Natural gradients and sparse varational Gaussian process regression: - -# %% [markdown] -# As mentioned in Hensman et al. (2013), in the case of a Gaussian likelihood, taking a step of unit length for natural gradients on a full batch of data recovers the same solution as Titsias (2009). We now illustrate this. - -# %% -n = 1000 -noise = 0.2 - -x = jr.uniform(key=key, minval=-5.0, maxval=5.0, shape=(n,)).sort().reshape(-1, 1) -f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x) -signal = f(x) -y = signal + jr.normal(key, shape=signal.shape) * noise - -D = Dataset(X=x, y=y) - -xtest = jnp.linspace(-5.5, 5.5, 500).reshape(-1, 1) - -# %% -z = jnp.linspace(-5.0, 5.0, 20).reshape(-1, 1) - -fig, ax = plt.subplots(figsize=(12, 5)) -ax.plot(x, y, "o", alpha=0.3) -ax.plot(xtest, f(xtest)) -[ax.axvline(x=z_i, color="black", alpha=0.3, linewidth=1) for z_i in z] -plt.show() - -# %% -likelihood = gpx.Gaussian(num_datapoints=n) -kernel = jk.RBF() -prior = gpx.Prior(kernel=kernel) -p = prior * likelihood - -# %% [markdown] -# We begin with natgrads: - -# %% -from gpjax.natural_gradients import natural_gradients - -q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z) -svgp = gpx.StochasticVI(posterior=p, variational_family=q) -params, trainables, bijectors = gpx.initialise(svgp).unpack() - -params = gpx.unconstrain(params, bijectors) - -nat_grads_fn, hyper_grads_fn = natural_gradients(svgp, D, bijectors, trainables) - -moment_optim = ox.sgd(1.0) - -moment_state = moment_optim.init(params) - -# Natural gradients update: -loss_val, loss_gradient = nat_grads_fn(params, D) -print(loss_val) - -updates, moment_state = moment_optim.update(loss_gradient, moment_state, params) -params = ox.apply_updates(params, updates) - -loss_val, _ = nat_grads_fn(params, D) - -print(loss_val) - -# %% [markdown] -# Let us now run it for SGPR: - -# %% -q = gpx.CollapsedVariationalGaussian( - prior=prior, likelihood=likelihood, inducing_inputs=z -) -sgpr = gpx.CollapsedVI(posterior=p, variational_family=q) - -params, _, _ = gpx.initialise(svgp).unpack() - -loss_fn = sgpr.elbo(D, negative=True) - -loss_val = loss_fn(params) - -print(loss_val) - -# %% [markdown] -# The discrepancy is due to the quadrature approximation. - -# %% [markdown] -# ## System configuration - -# %% -# %reload_ext watermark -# %watermark -n -u -v -iv -w -a 'Daniel Dodd' diff --git a/examples/pytree.pct.py b/examples/pytrees.pct.py similarity index 100% rename from examples/pytree.pct.py rename to examples/pytrees.pct.py diff --git a/examples/tfp_integration.pct.py b/examples/tfp_integration.pct.py deleted file mode 100644 index 1982a673c..000000000 --- a/examples/tfp_integration.pct.py +++ /dev/null @@ -1,258 +0,0 @@ -# -*- coding: utf-8 -*- -# --- -# jupyter: -# jupytext: -# custom_cell_magics: kql -# text_representation: -# extension: .py -# format_name: percent -# format_version: '1.3' -# jupytext_version: 1.11.2 -# kernelspec: -# display_name: Python 3.10.0 ('base') -# language: python -# name: python3 -# --- - -# %% [markdown] -# # TensorFlow Probability Integration -# This notebook demonstrates how to perform Markov chain Monte Carlo (MCMC) inference for Gaussian process models using TensorFlow Probability Lao et al. (2020). - -# %% -from pprint import PrettyPrinter - -import jax -import jax.numpy as jnp -import jax.random as jr -import matplotlib.pyplot as plt -from jax.config import config -from jaxutils import Dataset - -import gpjax as gpx -import gpjax.kernels as jk -from gpjax.utils import dict_array_coercion - -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) -pp = PrettyPrinter(indent=4) -key = jr.PRNGKey(123) - -# %% [markdown] -# ## Dataset -# -# In this tutorial we'll be trying to model a normalised sinc function -# $$f(x) = \frac{\sin(\pi x)}{\pi x}, \qquad x\in\mathbb{R}\setminus\{0\}, $$ -# -# through observations perturbed by Gaussian noise. We begin by simulating some data below. - -# %% -n = 100 -noise = 0.1 - -key, subkey = jr.split(key) -x = jr.uniform(key, minval=-5.0, maxval=5.0, shape=(n, 1)) -f = lambda x: jnp.sin(jnp.pi * x) / (jnp.pi * x) -y = f(x) + jr.normal(subkey, shape=x.shape) * noise - -fig, ax = plt.subplots(figsize=(12, 6)) -ax.plot(x, f(x), label="Latent fn") -ax.plot(x, y, "o", label="Observations", alpha=0.6) -ax.legend(loc="best") - -# %% [markdown] -# ## Define GPJax objects -# -# We'll wrap our pair of observed data arrays up into a `Dataset` object $\mathcal{D}$ and define a GP posterior. - -# %% -D = Dataset(X=x, y=y) -likelihood = gpx.Gaussian(num_datapoints=D.n) -posterior = gpx.Prior(kernel=jk.RBF()) * likelihood - -# %% [markdown] -# ## Initialise parameters -# -# Since our model hyperparameters are positive, our MCMC sampler will sample on the parameters' unconstrained space and the samples will then be back-transformed onto the original positive real line. GPJax's `initialise` function makes this straightforward. - -# %% -params, _, bijectors = gpx.initialise(posterior, key).unpack() - -# %% [markdown] -# #### Parameter type -# -# MCMC samplers supplied with TensorFlow probability require us to supply our parameters as an array. -# This is at odds with GPJax where our parameters are stored as dictionaries. -# To resolve this, we use the `dict_array_coercion` callable that returns two functions; one that maps from an array to a dictionary and a second that maps back to an array given a dictionary. -# These functions are order preserving. - -# %% -dict_to_array, array_to_dict = dict_array_coercion(params) - -# %% -parray = dict_to_array(params) -print(parray) - -# %% -array_to_dict(parray) == params - -# %% [markdown] -# ### Specifying priors -# -# We can define Gamma priors on our hyperparameters through TensorFlow Probability's `Distributions` module. We transform these to the unconstained space via `tfd.TransformedDistribution`. - -# %% -import tensorflow_probability.substrates.jax as tfp -import tensorflow_probability.substrates.jax.bijectors as tfb - -tfd = tfp.distributions - -priors = gpx.parameters.copy_dict_structure(params) -priors["kernel"]["lengthscale"] = tfd.TransformedDistribution( - tfd.Gamma(concentration=jnp.array(1.0), rate=jnp.array(1.0)), tfb.Softplus() -) -priors["kernel"]["variance"] = tfd.TransformedDistribution( - tfd.Gamma(concentration=jnp.array(1.0), rate=jnp.array(1.0)), tfb.Softplus() -) -priors["likelihood"]["obs_noise"] = tfd.TransformedDistribution( - tfd.Gamma(concentration=jnp.array(1.0), rate=jnp.array(1.0)), tfb.Softplus() -) - -# %% [markdown] -# ### Defining our target function -# -# We now define the marginal likelihood for the target distribution that our MCMC sampler will sample from. For our GP, this is the marginal log-likelihood that we specify below. - -# %% -log_mll = posterior.marginal_log_likelihood(D, negative=False) -log_mll(params) - -# %% [markdown] -# Since our model parameters are now an array, not a dictionary, we must define a function that maps the array back to a dictionary and then evaluates the marginal log-likelihood. Using the second return of `dict_array_coercion` this is straightforward as follows. - -# %% -from gpjax.parameters import evaluate_priors - - -def build_log_pi(log_mll, unconstrained_priors, mapper_fn): - def array_mll(parameter_array): - # Convert parameter array to a dictionary: - params_dict = mapper_fn([jnp.array(i) for i in parameter_array]) - - # Evaluate the log prior, log p(θ): - log_hyper_prior_eval = evaluate_priors(params_dict, unconstrained_priors) - - # Evaluate the log-likelihood probability kernel, log [p(y|f, θ) p(f| θ)]: - log_mll_eval = log_mll(gpx.constrain(params_dict, bijectors)) - - return log_mll_eval + log_hyper_prior_eval - - return array_mll - - -mll_array_form = build_log_pi(log_mll, priors, array_to_dict) - -# %% [markdown] -# ## Sample -# -# We now have all the necessary machinery in place. To sample from our target distribution, we'll use TensorFlow's Hamiltonian Monte-Carlo sampler equipped with the No U-Turn Sampler kernel to draw 500 samples for illustrative purposes (you will likely need more in practice). - -# %% -n_samples = 500 - - -def run_chain(key, state): - kernel = tfp.mcmc.NoUTurnSampler(mll_array_form, 1e-1) - return tfp.mcmc.sample_chain( - n_samples, - current_state=state, - kernel=kernel, - trace_fn=lambda _, results: results.target_log_prob, - seed=key, - ) - - -# %% [markdown] -# Since everything is pure Jax, we are free to JIT compile our sampling function and go. - -# %% -unconstrained_params = gpx.unconstrain(params, bijectors) -states, log_probs = jax.jit(run_chain)( - key, jnp.array(dict_to_array(unconstrained_params)) -) -states, log_probs = jax.jit(run_chain)(key, jnp.array(dict_to_array(params))) - -# %% [markdown] -# ## Inspecting samples -# -# We now assess the quality of our chains. To illustrate the acts of burn-in and thinning, we discard the first 50 samples as burn-in and thin the remaining samples by a factor of 2. - -# %% -burn_in = 50 -thin_factor = 2 -n_params = states.shape[1] - -samples = [states[burn_in:, i, :][::thin_factor] for i in range(n_params)] -sample_dict = array_to_dict(samples) -constrained_samples = gpx.constrain(sample_dict, bijectors) -constrained_sample_list = dict_to_array(constrained_samples) - -# %% [markdown] -# We observe reasonable performance for our chains as shown in the traceplots below. - -# %% -fig, axes = plt.subplots(figsize=(20, 10), ncols=n_params, nrows=2) -titles = ["Lengthscale", "Kernel Variance", "Obs. Noise"] - -for i in range(n_params): - axes[0, i].plot(samples[i], alpha=0.5, color="tab:orange") - axes[1, i].plot(constrained_sample_list[i], alpha=0.5, color="tab:blue") - axes[0, i].axhline(y=jnp.mean(samples[i]), color="tab:orange") - axes[1, i].axhline(y=jnp.mean(constrained_sample_list[i]), color="tab:blue") - axes[0, i].set_title(titles[i]) - axes[1, i].set_title(titles[i]) - -plt.tight_layout() - -# %% [markdown] -# ## Making predictions -# -# We’ll now use our MCMC samples to make predictions. For simplicity, we’ll take the average of the samples to give point estimate parameter values for prediction. However, you may wish to draw from the GP posterior for each sample collected during the MCMC phase. - -# %% -xtest = jnp.linspace(-5.2, 5.2, 500).reshape(-1, 1) -learned_params = array_to_dict([jnp.mean(i) for i in constrained_sample_list]) - -predictive_dist = likelihood(learned_params, posterior(learned_params, D)(xtest)) - -mu = predictive_dist.mean() -sigma = predictive_dist.stddev() - -# %% [markdown] -# Finally, we plot the learned posterior predictive distribution evaluated at the test points defined above. - -# %% -fig, ax = plt.subplots(figsize=(12, 5)) -ax.plot(x, y, "o", label="Obs", color="tab:red") -ax.plot(xtest, mu, label="pred", color="tab:blue") -ax.fill_between( - xtest.squeeze(), - mu.squeeze() - sigma, - mu.squeeze() + sigma, - alpha=0.2, - color="tab:blue", -) -ax.plot(xtest, mu.squeeze() - sigma, color="tab:blue", linestyle="--", linewidth=1) -ax.plot(xtest, mu.squeeze() + sigma, color="tab:blue", linestyle="--", linewidth=1) - -ax.legend() - -# %% [markdown] -# This concludes our tutorial on interfacing TensorFlow Probability with GPJax. -# The workflow demonstrated here only scratches the surface regarding the inference possible with a large number of samplers available in TensorFlow probability. - -# %% [markdown] -# ## System configuration - -# %% -# %load_ext watermark -# %watermark -n -u -v -iv -w -a "Thomas Pinder (edited by Daniel Dodd)" diff --git a/examples/uncollapsed_vi.pct.py b/examples/uncollapsed_vi.pct.py index 63445fdff..ed20e8d9e 100644 --- a/examples/uncollapsed_vi.pct.py +++ b/examples/uncollapsed_vi.pct.py @@ -19,7 +19,6 @@ # # In this notebook we demonstrate how to implement sparse variational Gaussian # processes (SVGPs) of -# Hensman et al. (2013); # Hensman et al. (2015). In # particular, this approximation framework provides a tractable option for working with # non-conjugate Gaussian processes with more than ~5000 data points. However, for diff --git a/gpjax/kernels/non_euclidean/graph.py b/gpjax/kernels/non_euclidean/graph.py index 2188ca6d8..bf89391f8 100644 --- a/gpjax/kernels/non_euclidean/graph.py +++ b/gpjax/kernels/non_euclidean/graph.py @@ -31,19 +31,14 @@ # Graph kernels ########################################## @dataclass -class AbstractGraphKernel: - laplacian: Float[Array, "N N"] = static_field() - - -@dataclass -class GraphKernel(AbstractKernel, AbstractGraphKernel): +class GraphKernel(AbstractKernel): """A Matérn graph kernel defined on the vertices of a graph. The key reference for this object is borovitskiy et. al., (2020). Args: laplacian (Float[Array]): An N x N matrix representing the Laplacian matrix of a graph. compute_engine """ - + laplacian: Float[Array, "N N"] = static_field(None) lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) smoothness: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) @@ -52,7 +47,11 @@ class GraphKernel(AbstractKernel, AbstractGraphKernel): num_vertex: Int[Array, "1"] = static_field(None) compute_engine: AbstractKernelComputation = static_field(EigenKernelComputation) name: str = "Graph Matérn" + def __post_init__(self): + if self.laplacian is None: + raise ValueError("Graph laplacian must be specified") + evals, self.eigenvectors = jnp.linalg.eigh(self.laplacian) self.eigenvalues = evals.reshape(-1, 1) if self.num_vertex is None: From aa571566596a9192bee0f00a94f2dad7a7edfbd0 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Fri, 7 Apr 2023 20:37:48 +0100 Subject: [PATCH 40/44] Add flax to reqs --- docs/requirements.txt | 2 +- requirements/dev.txt | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 8cc65120f..ce70529ee 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -18,9 +18,9 @@ watermark sphinxext-opengraph blackjax>=0.8.2 jaxopt -dm-haiku ipywidgets pandas scikit-learn +flax # Install GPJax istself . \ No newline at end of file diff --git a/requirements/dev.txt b/requirements/dev.txt index b6081cc7e..495166edb 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -4,4 +4,5 @@ pylint flake8 pytest networkx -pytest-cov \ No newline at end of file +pytest-cov +flax \ No newline at end of file From 3f21cb9081249f2eba056140bfa2a70504747a87 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Fri, 7 Apr 2023 20:39:58 +0100 Subject: [PATCH 41/44] Drop beartype refs --- gpjax/mean_functions.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/gpjax/mean_functions.py b/gpjax/mean_functions.py index dde0302c0..671fc696a 100644 --- a/gpjax/mean_functions.py +++ b/gpjax/mean_functions.py @@ -18,7 +18,7 @@ import abc import dataclasses import jax.numpy as jnp -from beartype.typing import List, Callable, Union +from typing import List, Callable, Union from jaxtyping import Array, Float @@ -42,7 +42,7 @@ def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]: Float[Array, "1]: The evaluated mean function. """ raise NotImplementedError - + def __add__(self, other: Union[AbstractMeanFunction, Float[Array, "1"]]) -> AbstractMeanFunction: """Add two mean functions. @@ -55,9 +55,9 @@ def __add__(self, other: Union[AbstractMeanFunction, Float[Array, "1"]]) -> Abst if isinstance(other, AbstractMeanFunction): return SumMeanFunction([self, other]) - + return SumMeanFunction([self, Constant(other)]) - + def __radd__(self, other: Union[AbstractMeanFunction, Float[Array, "1"]]) -> AbstractMeanFunction: """Add two mean functions. @@ -68,27 +68,27 @@ def __radd__(self, other: Union[AbstractMeanFunction, Float[Array, "1"]]) -> Abs AbstractMeanFunction: The sum of the two mean functions. """ return self.__add__(other) - + def __mul__(self, other: Union[AbstractMeanFunction, Float[Array, "1"]]) -> AbstractMeanFunction: """Multiply two mean functions. Args: other (AbstractMeanFunction): The other mean function to multiply. - + Returns: AbstractMeanFunction: The product of the two mean functions. """ if isinstance(other, AbstractMeanFunction): return ProductMeanFunction([self, other]) - + return ProductMeanFunction([self, Constant(other)]) - + def __rmul__(self, other: Union[AbstractMeanFunction, Float[Array, "1"]]) -> AbstractMeanFunction: """Multiply two mean functions. Args: other (AbstractMeanFunction): The other mean function to multiply. - + Returns: AbstractMeanFunction: The product of the two mean functions. """ @@ -113,7 +113,7 @@ def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]: Float[Array, "1"]: The evaluated mean function. """ return jnp.ones((x.shape[0], 1)) * self.constant - + @dataclasses.dataclass class CombinationMeanFunction(AbstractMeanFunction): @@ -131,7 +131,7 @@ def __init__( #Add means to a list, flattening out instances of this class therein, as in GPFlow kernels. items_list: List[AbstractMeanFunction] = [] - + for item in means: if not isinstance(item, AbstractMeanFunction): raise TypeError("can only combine AbstractMeanFunction instances") # pragma: no cover @@ -154,7 +154,7 @@ def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]: Float[Array, "Q"]: The evaluated mean function. """ return self.operator(jnp.stack([m(x) for m in self.means])) - + SumMeanFunction = partial(CombinationMeanFunction, operator=partial(jnp.sum, axis=0)) ProductMeanFunction = partial(CombinationMeanFunction, operator=partial(jnp.sum, axis=0)) From 070dda3e75300a93d2b235a11903133b50ce25e2 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Fri, 7 Apr 2023 20:59:24 +0100 Subject: [PATCH 42/44] Fix link fn. tests --- gpjax/likelihoods.py | 35 +++++++++++++++++------------------ tests/test_likelihoods.py | 23 ++++++++++++----------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index a2742dad1..44cab1fec 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -17,7 +17,6 @@ from typing import Any from .linops.utils import to_dense -import distrax as dx import tensorflow_probability.substrates.jax as tfp import jax.numpy as jnp import jax.scipy as jsp @@ -35,7 +34,7 @@ class AbstractLikelihood(Module): num_datapoints: int = static_field() - def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: + def __call__(self, *args: Any, **kwargs: Any) -> tfd.Distribution: """Evaluate the likelihood function at a given predictive distribution. Args: @@ -43,12 +42,12 @@ def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: **kwargs (Any): Keyword arguments to be passed to the likelihood's `predict` method. Returns: - dx.Distribution: The predictive distribution. + tfd.Distribution: The predictive distribution. """ return self.predict(*args, **kwargs) @abc.abstractmethod - def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: + def predict(self, *args: Any, **kwargs: Any) -> tfd.Distribution: """Evaluate the likelihood function at a given predictive distribution. Args: @@ -56,17 +55,17 @@ def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: **kwargs (Any): Keyword arguments to be passed to the likelihood's `predict` method. Returns: - dx.Distribution: The predictive distribution. + tfd.Distribution: The predictive distribution. """ raise NotImplementedError @property @abc.abstractmethod - def link_function(self) -> dx.Distribution: + def link_function(self) -> tfd.Distribution: """Return the link function of the likelihood function. Returns: - dx.Distribution: The distribution of observations, y, given values of the Gaussian process, f. + tfd.Distribution: The distribution of observations, y, given values of the Gaussian process, f. """ raise NotImplementedError @@ -76,7 +75,7 @@ class Gaussian(AbstractLikelihood): """Gaussian likelihood object.""" obs_noise: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) - def link_function(self, f: Float[Array, "N 1"]) -> dx.Normal: + def link_function(self, f: Float[Array, "N 1"]) -> tfd.Normal: """The link function of the Gaussian likelihood. Args: @@ -84,9 +83,9 @@ def link_function(self, f: Float[Array, "N 1"]) -> dx.Normal: f (Float[Array, "N 1"]): Function values. Returns: - dx.Normal: The likelihood function. + tfd.Normal: The likelihood function. """ - return dx.Normal(loc=f, scale=self.obs_noise) + return tfd.Normal(loc=f, scale=self.obs_noise.astype(f.dtype)) def predict(self, dist: tfd.MultivariateNormalTriL) -> tfd.MultivariateNormalFullCovariance: """ @@ -97,11 +96,11 @@ def predict(self, dist: tfd.MultivariateNormalTriL) -> tfd.MultivariateNormalFul Args: params (Dict): The parameters of the likelihood function. - dist (dx.Distribution): The Gaussian process posterior, + dist (tfd.Distribution): The Gaussian process posterior, evaluated at a finite set of test points. Returns: - dx.Distribution: The predictive distribution. + tfd.Distribution: The predictive distribution. """ n_data = dist.event_shape[0] cov = to_dense(dist.covariance()) @@ -113,28 +112,28 @@ def predict(self, dist: tfd.MultivariateNormalTriL) -> tfd.MultivariateNormalFul @dataclass class Bernoulli(AbstractLikelihood): - def link_function(self, f: Float[Array, "N 1"]) -> dx.Distribution: + def link_function(self, f: Float[Array, "N 1"]) -> tfd.Distribution: """The probit link function of the Bernoulli likelihood. Args: f (Float[Array, "N 1"]): Function values. Returns: - dx.Distribution: The likelihood function. + tfd.Distribution: The likelihood function. """ - return dx.Bernoulli(probs=inv_probit(f)) + return tfd.Bernoulli(probs=inv_probit(f)) - def predict(self, dist: dx.Distribution) -> dx.Distribution: + def predict(self, dist: tfd.Distribution) -> tfd.Distribution: """Evaluate the pointwise predictive distribution, given a Gaussian process posterior and likelihood parameters. Args: params (Dict): The parameters of the likelihood function. - dist (dx.Distribution): The Gaussian process posterior, evaluated + dist (tfd.Distribution): The Gaussian process posterior, evaluated at a finite set of test points. Returns: - dx.Distribution: The pointwise predictive distribution. + tfd.Distribution: The pointwise predictive distribution. """ variance = jnp.diag(dist.covariance()) mean = dist.mean().ravel() diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py index 7486f0b15..e4da2154b 100644 --- a/tests/test_likelihoods.py +++ b/tests/test_likelihoods.py @@ -16,7 +16,7 @@ from typing import Callable import jax.tree_util as jtu -import distrax as dx +import tensorflow_probability.substrates.jax as tfp import jax.numpy as jnp import jax.random as jr import numpy as np @@ -32,6 +32,7 @@ inv_probit, ) +tfd = tfp.distributions # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -44,11 +45,11 @@ def test_abstract_likelihood(): # Create a dummy likelihood class with abstract methods implemented. class DummyLikelihood(AbstractLikelihood): - def predict(self, dist: dx.Distribution) -> dx.Distribution: - return dx.Normal(0.0, 1.0) + def predict(self, dist: tfd.Distribution) -> tfd.Distribution: + return tfd.Normal(0.0, 1.0) def link_function(self, f: Float[Array, "N 1"]) -> Float[Array, "N 1"]: - return dx.MultivariateNormalDiag(loc=f) + return tfd.MultivariateNormalDiag(loc=f) # Test that the dummy likelihood can be instantiated. dummy_likelihood = DummyLikelihood(num_datapoints=123) @@ -60,7 +61,7 @@ def link_function(self, f: Float[Array, "N 1"]) -> Float[Array, "N 1"]: def test_gaussian_init(n: int, noise: float) -> None: likelihood = Gaussian(num_datapoints=n, obs_noise=jnp.array([noise])) - + assert likelihood.obs_noise == jnp.array([noise]) assert likelihood.num_datapoints == n assert jtu.tree_leaves(likelihood) == [jnp.array([noise])] @@ -86,7 +87,7 @@ def test_link_fns(lik: AbstractLikelihood, n: int) -> None: # Test likelihood link function. assert isinstance(likelihood.link_function, Callable) - assert isinstance(likelihood.link_function(f), dx.Distribution) + assert isinstance(likelihood.link_function(f), tfd.Distribution) @pytest.mark.parametrize("noise", [0.1, 0.5, 1.0]) @@ -101,20 +102,20 @@ def test_call_gaussian(noise: float, n: int) -> None: latent_mean = jr.uniform(key, shape=(n,)) latent_sqrt = jr.uniform(key, shape=(n, n)) latent_cov = jnp.matmul(latent_sqrt, latent_sqrt.T) - latent_dist = dx.MultivariateNormalFullCovariance(latent_mean, latent_cov) + latent_dist = tfd.MultivariateNormalFullCovariance(latent_mean, latent_cov) # Test call method. pred_dist = likelihood(latent_dist) # Check that the distribution is a MultivariateNormalFullCovariance. - assert isinstance(pred_dist, dx.MultivariateNormalFullCovariance) + assert isinstance(pred_dist, tfd.MultivariateNormalFullCovariance) # Check predictive mean and variance. assert (pred_dist.mean() == latent_mean).all() noise_matrix = jnp.eye(n) * noise assert np.allclose( - pred_dist.scale_tri, jnp.linalg.cholesky(latent_cov + noise_matrix) + pred_dist.scale_tril, jnp.linalg.cholesky(latent_cov + noise_matrix) ) @@ -129,13 +130,13 @@ def test_call_bernoulli(n: int) -> None: latent_mean = jr.uniform(key, shape=(n,)) latent_sqrt = jr.uniform(key, shape=(n, n)) latent_cov = jnp.matmul(latent_sqrt, latent_sqrt.T) - latent_dist = dx.MultivariateNormalFullCovariance(latent_mean, latent_cov) + latent_dist = tfd.MultivariateNormalFullCovariance(latent_mean, latent_cov) # Test call method. pred_dist = likelihood(latent_dist) # Check that the distribution is a Bernoulli. - assert isinstance(pred_dist, dx.Bernoulli) + assert isinstance(pred_dist, tfd.Bernoulli) # Check predictive mean and variance. From fdff3184772df848eefec8f530b8e681c00e946d Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Fri, 7 Apr 2023 21:16:58 +0100 Subject: [PATCH 43/44] Add flax deps --- .github/workflows/tests.yml | 1 + setup.py | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c14667b14..0722c5563 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -22,6 +22,7 @@ jobs: pip install -e . pip install -e .[dev] pytest --cov=./ --cov-report=xml + - name: Upload coverage to Codecov uses: codecov/codecov-action@v1 with: diff --git a/setup.py b/setup.py index 01808385b..22de32708 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ def get_versions(): "pytest", "networkx", "pytest-cov", + "flax" ], "cuda": ["jax[cuda]"], } From 38b5bf486bcc74d6ac2576369ef631fe5896ac9f Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Fri, 7 Apr 2023 22:51:12 +0100 Subject: [PATCH 44/44] Documentation text updates --- examples/barycentres.pct.py | 17 ++++---- examples/classification.pct.py | 12 ++---- examples/collapsed_vi.pct.py | 12 +++--- examples/deep_kernels.py | 37 ++++++----------- examples/kernels.pct.py | 75 +++++++++++++++------------------- examples/regression.pct.py | 44 +++++++++----------- examples/uncollapsed_vi.pct.py | 3 -- 7 files changed, 84 insertions(+), 116 deletions(-) diff --git a/examples/barycentres.pct.py b/examples/barycentres.pct.py index 9762cbfca..ade690825 100644 --- a/examples/barycentres.pct.py +++ b/examples/barycentres.pct.py @@ -28,7 +28,7 @@ # significantly more favourable uncertainty estimation. # -# %% +# %% vscode={"languageId": "python"} import typing as tp import jax @@ -99,7 +99,7 @@ # will be a sine function with a different vertical shift, periodicity, and quantity # of noise. -# %% +# %% vscode={"languageId": "python"} n = 100 n_test = 200 n_datasets = 5 @@ -135,7 +135,7 @@ # advice on selecting an appropriate kernel. -# %% +# %% vscode={"languageId": "python"} def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance: if y.ndim == 1: y = y.reshape(-1, 1) @@ -161,14 +161,15 @@ def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance: # ## Computing the barycentre # # In GPJax, the predictive distribution of a GP is given by a -# [Distrax](https://github.com/deepmind/distrax) distribution, making it +# [TensorFlow Probability](https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax) +# distribution, making it # straightforward to extract the mean vector and covariance matrix of each GP for # learning a barycentre. We implement the fixed point scheme given in (3) in the # following cell by utilising Jax's `vmap` operator to speed up large matrix operations # using broadcasting in `tensordot`. -# %% +# %% vscode={"languageId": "python"} def sqrtm(A: jax.Array): return jnp.real(jsl.sqrtm(A)) @@ -198,7 +199,7 @@ def step(covariance_candidate: jax.Array, idx: None): # difference between the previous and current iteration that we can confirm by # inspecting the `sequence` array in the following cell. -# %% +# %% vscode={"languageId": "python"} weights = jnp.ones((n_datasets,)) / n_datasets means = jnp.stack([d.mean() for d in posterior_preds]) @@ -222,7 +223,7 @@ def step(covariance_candidate: jax.Array, idx: None): # uncertainty bands are sensible. -# %% +# %% vscode={"languageId": "python"} def plot( dist: tfd.MultivariateNormalTriL, ax, @@ -265,6 +266,6 @@ def plot( # %% [markdown] # ## System configuration -# %% +# %% vscode={"languageId": "python"} # %reload_ext watermark # %watermark -n -u -v -iv -w -a 'Thomas Pinder (edited by Daniel Dodd)' diff --git a/examples/classification.pct.py b/examples/classification.pct.py index 7c1bc9c78..d1b2cb1fa 100644 --- a/examples/classification.pct.py +++ b/examples/classification.pct.py @@ -96,9 +96,7 @@ # marginal log-likelihood. # %% [markdown] -# To begin we obtain an initial parameter state through the `initialise` callable (see -# the [regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)). -# We can obtain a MAP estimate by optimising the marginal log-likelihood with +# We can obtain a MAP estimate by optimising the log-posterior density with # Optax's optimisers. # %% @@ -179,7 +177,7 @@ # \log\tilde{p}(\boldsymbol{f}|\mathcal{D}) = \log\tilde{p}(\hat{\boldsymbol{f}}|\mathcal{D}) + \left[\nabla \log\tilde{p}({\boldsymbol{f}}|\mathcal{D})|_{\hat{\boldsymbol{f}}}\right]^{T} (\boldsymbol{f}-\hat{\boldsymbol{f}}) + \frac{1}{2} (\boldsymbol{f}-\hat{\boldsymbol{f}})^{T} \left[\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} \right] (\boldsymbol{f}-\hat{\boldsymbol{f}}) + \mathcal{O}(\lVert \boldsymbol{f} - \hat{\boldsymbol{f}} \rVert^3). # \end{align} # -# Now since $\nabla \log\tilde{p}({\boldsymbol{f}}|\mathcal{D})$ is zero at the mode, +# Since $\nabla \log\tilde{p}({\boldsymbol{f}}|\mathcal{D})$ is zero at the mode, # this suggests the following approximation # \begin{align} # \tilde{p}(\boldsymbol{f}|\mathcal{D}) \approx \log\tilde{p}(\hat{\boldsymbol{f}}|\mathcal{D}) \exp\left\{ \frac{1}{2} (\boldsymbol{f}-\hat{\boldsymbol{f}})^{T} \left[-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} \right] (\boldsymbol{f}-\hat{\boldsymbol{f}}) \right\} @@ -297,7 +295,7 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma # %% [markdown] # ## MCMC inference # -# At the high level, an MCMC sampler works by starting at an initial position and +# An MCMC sampler works by starting at an initial position and # drawing a sample from a cheap-to-simulate distribution known as the _proposal_. The # next step is to determine whether this sample could be considered a draw from the # posterior. We accomplish this using an _acceptance probability_ determined via the @@ -314,9 +312,7 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma # Rather than implementing a suite of MCMC samplers, GPJax relies on MCMC-specific # libraries for sampling functionality. We focus on # [BlackJax](https://github.com/blackjax-devs/blackjax/) in this notebook, which we -# recommend adopting for general applications. However, we also support TensorFlow -# Probability as demonstrated in the -# [TensorFlow Probability Integration notebook](https://gpjax.readthedocs.io/en/latest/nbs/tfp_integration.html). +# recommend adopting for general applications. # # We'll use the No U-Turn Sampler (NUTS) implementation given in BlackJax for sampling. # For the interested reader, NUTS is a Hamiltonian Monte Carlo sampling scheme where diff --git a/examples/collapsed_vi.pct.py b/examples/collapsed_vi.pct.py index dea8e9a22..a7e0525ff 100644 --- a/examples/collapsed_vi.pct.py +++ b/examples/collapsed_vi.pct.py @@ -83,7 +83,8 @@ plt.show() # %% [markdown] -# Next we define the posterior model for the data. +# Next we define the true posterior model for the data - note that whilst we can define +# this, it is intractable to evaluate. # %% meanf = gpx.Constant() @@ -93,9 +94,10 @@ posterior = prior * likelihood # %% [markdown] -# We now define the SGPR model through `CollapsedVariationalGaussian`. Since the form -# of the collapsed optimal posterior depends on the Gaussian likelihood's observation -# noise, we pass this to the constructer. +# We now define the SGPR model through `CollapsedVariationalGaussian`. Through a +# set of inducing points $\boldsymbol{z}$ this object builds an approximation to the +# true posterior distribution. Consequently, we pass the true posterior and initial +# inducing points into the constructor as arguments. # %% q = gpx.CollapsedVariationalGaussian(posterior=posterior, inducing_inputs=z) @@ -130,8 +132,6 @@ # %% [markdown] # We show predictions of our model with the learned inducing points overlayed in grey. -# %% - # %% latent_dist = opt_posterior(xtest, train_data=D) predictive_dist = opt_posterior.posterior.likelihood(latent_dist) diff --git a/examples/deep_kernels.py b/examples/deep_kernels.py index dac62226f..7beb9fc38 100644 --- a/examples/deep_kernels.py +++ b/examples/deep_kernels.py @@ -26,8 +26,8 @@ # %% import typing as tp -from dataclasses import dataclass -from typing import Dict +from dataclasses import dataclass, field +from typing import Dict, Any import jax import jax.numpy as jnp @@ -39,6 +39,7 @@ from scipy.signal import sawtooth from flax import linen as nn from simple_pytree import static_field +import flax import gpjax as gpx import gpjax.kernels as jk @@ -92,19 +93,12 @@ # ### Implementation # # Although deep kernels are not currently supported natively in GPJax, defining one is -# straightforward as we now demonstrate. Using the base `AbstractKernel` object given -# in GPJax, we provide a mixin class named `_DeepKernelFunction` to facilitate the -# user supplying the neural network and base kernel of their choice. Kernel matrices +# straightforward as we now demonstrate. Inheriting from the base `AbstractKernel` +# in GPJax, we create the `DeepKernelFunction` object that allows the +# user to supply the neural network and base kernel of their choice. Kernel matrices # are then computed using the regular `gram` and `cross_covariance` functions. - # %% -import flax -from dataclasses import field -from typing import Any -from simple_pytree import static_field - - @dataclass class DeepKernelFunction(AbstractKernel): base_kernel: AbstractKernel = None @@ -132,12 +126,11 @@ def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, " # # With a deep kernel object created, we proceed to define a neural network. Here we # consider a small multi-layer perceptron with two linear hidden layers and ReLU -# activation functions between the layers. The first hidden layer contains 32 units, -# while the second layer contains 64 units. Finally, we'll make the output of our -# network a single unit. However, it would be possible to project our data into a -# $d-$dimensional space for $d>1$. In these instances, making the -# [base kernel ARD](https://gpjax.readthedocs.io/en/latest/nbs/kernels.html#Active-dimensions) -# would be sensible. +# activation functions between the layers. The first hidden layer contains 64 units, +# while the second layer contains 32 units. Finally, we'll make the output of our +# network a three units wide. The corresponding kernel that we define will then be of +# [ARD form](https://gpjax.readthedocs.io/en/latest/nbs/kernels.html#Active-dimensions) +# to allow for different lengthscales in each dimension of the feature space. # Users may wish to design more intricate network structures for more complex tasks, # which functionality is supported well in Haiku. @@ -164,8 +157,7 @@ def __call__(self, x): # # Having characterised the feature extraction network, we move to define a Gaussian # process parameterised by this deep kernel. We consider a third-order Matérn base -# kernel and assume a Gaussian likelihood. Parameters, trainability status and -# transformations are initialised in the usual manner. +# kernel and assume a Gaussian likelihood. # %% base_kernel = gpx.Matern52(active_dims=list(range(feature_space_dim))) @@ -186,14 +178,11 @@ def __call__(self, x): # [Optax](https://optax.readthedocs.io/en/latest/) for optimisation. In particular, we # showcase the ability to use a learning rate scheduler that decays the optimiser's # learning rate throughout the inference. We decrease the learning rate according to a -# half-cosine curve over 1000 iterations, providing us with large step sizes early in +# half-cosine curve over 700 iterations, providing us with large step sizes early in # the optimisation procedure before approaching more conservative values, ensuring we # do not step too far. We also consider a linear warmup, where the learning rate is # increased from 0 to 1 over 50 steps to get a reasonable initial learning rate value. -# %% -negative_mll = gpx.ConjugateMLL(negative=True) - # %% schedule = ox.warmup_cosine_decay_schedule( init_value=0.0, diff --git a/examples/kernels.pct.py b/examples/kernels.pct.py index 4d9130fb3..7f883a36a 100644 --- a/examples/kernels.pct.py +++ b/examples/kernels.pct.py @@ -17,15 +17,13 @@ # %% [markdown] # # Kernel Guide # -# In this guide, we introduce the kernels available in GPJax and demonstrate how to create custom ones. -# -# -# from typing import Dict - -from dataclasses import dataclass +# In this guide, we introduce the kernels available in GPJax and demonstrate how to +# create custom kernels. # %% -import distrax as dx +from typing import Dict + +from dataclasses import dataclass import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt @@ -51,6 +49,11 @@ # # * Matérn 1/2, 3/2 and 5/2. # * RBF (or squared exponential). +# * Rational quadratic. +# * Powered exponential. +# * Polynomial. +# * White noise +# * Linear. # * Polynomial. # * [Graph kernels](https://gpjax.readthedocs.io/en/latest/nbs/graph_kernels.html). # @@ -102,7 +105,8 @@ print(f"Lengthscales: {slice_kernel.lengthscale}") # %% [markdown] -# We'll now simulate some data and evaluate the kernel on the previously selected input dimensions. +# We'll now simulate some data and evaluate the kernel on the previously selected +# input dimensions. # %% # Inputs @@ -117,13 +121,13 @@ # # The product or sum of two positive definite matrices yields a positive # definite matrix. Consequently, summing or multiplying sets of kernels is a -# valid operation that can give rich kernel functions. In GPJax, sums of kernels -# can be created by applying the `+` operator as follows. +# valid operation that can give rich kernel functions. In GPJax, functionality for +# a sum kernel is provided by the `SumKernel` class. # %% k1 = gpx.kernels.RBF() k2 = gpx.kernels.Polynomial() -sum_k = gpx.kernels.ProductKernel(kernels=[k1, k2]) +sum_k = gpx.kernels.SumKernel(kernels=[k1, k2]) fig, ax = plt.subplots(ncols=3, figsize=(20, 5)) im0 = ax[0].matshow(k1.gram(x).to_dense()) @@ -135,7 +139,7 @@ fig.colorbar(im2, ax=ax[2]) # %% [markdown] -# Similarily, products of kernels can be created through the `*` operator. +# Similarily, products of kernels can be created through the `ProductKernel` class. # %% k3 = gpx.kernels.Matern32() @@ -153,7 +157,6 @@ fig.colorbar(im2, ax=ax[2]) fig.colorbar(im3, ax=ax[3]) - # %% [markdown] # ## Custom kernel # @@ -171,7 +174,8 @@ # ### Circular kernel # # When the underlying space is polar, typical Euclidean kernels such as Matérn -# kernels are insufficient at the boundary as discontinuities will be present. +# kernels are insufficient at the boundary where discontinuities will present +# themselves. # This is due to the fact that for a polar space $\lvert 0, 2\pi\rvert=0$ i.e., # the space wraps. Euclidean kernels have no mechanism in them to represent this # logic and will instead treat $0$ and $2\pi$ and elements far apart. Circular @@ -198,13 +202,10 @@ def angular_distance(x, y, c): @dataclass -class _Polar: +class Polar(gpx.kernels.AbstractKernel): period: float = static_field(2 * jnp.pi) tau: float = param_field(jnp.array([4.0]), bijector=tfb.Softplus(low=4.0)) - -@dataclass -class Polar(gpx.kernels.AbstractKernel, _Polar): def __post_init__(self): self.c = self.period / 2.0 @@ -219,35 +220,25 @@ def __call__( # %% [markdown] -# We unpack this now to make better sense of it. In the kernel's `__init__` -# function we simply specify the length of a single period. As the underlying -# domain is a circle, this is $2\pi$. Next we define the kernel's `__call__` -# function which is a direct implementation of Equation (1). Finally, we define -# the Kernel's parameter property which contains just one value $\tau$ that we -# initialise to 4 in the kernel's `__init__`. -# -# -# ### Custom Parameter Bijection -# -# The constraint on $\tau$ makes optimisation challenging with gradient descent. -# It would be much easier if we could instead parameterise $\tau$ to be on the -# real line. Fortunately, this can be taken care of with GPJax's `add parameter` -# function, only requiring us to define the parameter's name and matching -# bijection (either a Distrax of TensorFlow probability bijector). Under the -# hood, calling this function updates a configuration object to register this -# parameter and its corresponding transform. +# We unpack this now to make better sense of it. In the kernel's initialiser +# we specify the length of a single period. As the underlying +# domain is a circle, this is $2\pi$. Next, we define +# the Kernel's half-period parameter. As the kernel is a `dataclass` and `c` is +# function of `period`, we must define it in the `__post_init__` method. +# Finally, we define the kernel's `__call__` +# function which is a direct implementation of Equation (1). # -# To define a bijector here we'll make use of the `Lambda` operator given in -# Distrax. This lets us convert any regular Jax function into a bijection. Given -# that we require $\tau$ to be strictly greater than $4.$, we'll apply a -# [softplus -# transformation](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.softplus.html) -# where the lower bound is shifted by $4$. +# To constrain $\tau$ to be greater than 4, we use a `Softplus` bijector with a +# clipped lower bound of 4.0. This is done by specifying the `bijector` argument +# when we define the parameter field. # %% [markdown] # ### Using our polar kernel # -# We proceed to fit a GP with our custom circular kernel to a random sequence of points on a circle (see the [Regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html) for further details on this process). +# We proceed to fit a GP with our custom circular kernel to a random sequence of +# points on a circle (see the +# [Regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html) +# for further details on this process). # %% # Simulate data diff --git a/examples/regression.pct.py b/examples/regression.pct.py index e5204c7e7..3682173c9 100644 --- a/examples/regression.pct.py +++ b/examples/regression.pct.py @@ -43,7 +43,8 @@ # # $$\boldsymbol{y} \sim \mathcal{N} \left(\sin(4\boldsymbol{x}) + \cos(2 \boldsymbol{x}), \textbf{I} * 0.3^2 \right).$$ # -# We store our data $\mathcal{D}$ as a GPJax `Dataset` and create test inputs and labels for later. +# We store our data $\mathcal{D}$ as a GPJax `Dataset` and create test inputs and labels +# for later. # %% vscode={"languageId": "python"} n = 100 @@ -75,7 +76,6 @@ # observations $\mathcal{D}$ via Gaussian process regression. We begin by defining a # Gaussian process prior in the next section. -# %% [markdown] # ## Defining the prior # # A zero-mean Gaussian process (GP) places a prior distribution over real-valued @@ -98,16 +98,16 @@ # %% vscode={"languageId": "python"} kernel = gpx.kernels.RBF() -meanf = gpx.mean_functions.Constant(constant=0.0) -meanf = meanf.replace_trainable(constant=False) +meanf = gpx.mean_functions.Zero() prior = gpx.Prior(mean_function=meanf, kernel=kernel) # %% [markdown] # # The above construction forms the foundation for GPJax's models. Moreover, the GP prior -# we have just defined can be represented by a [Distrax](https://github.com/deepmind/distrax) +# we have just defined can be represented by a +# [TensorFlow Probability](https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax) # multivariate Gaussian distribution. Such functionality enables trivial sampling, and -# mean and covariance evaluation of the GP. +# the evaluation of the GP's mean and covariance . # %% vscode={"languageId": "python"} prior_dist = prior.predict(xtest) @@ -160,23 +160,15 @@ # # ## Parameter state # -# So far, all of the objects that we've defined have been stateless. To give our model -# state, we can use the `initialise` function provided in GPJax. Upon calling this, a -# `ParameterState` class is returned that contains four dictionaries: -# -# | Dictionary | Description | -# |---|---| -# | `params` | Initial parameter values. | -# | `trainable` | Boolean dictionary that determines the training status of parameters (`True` for being trained and `False` otherwise). | -# | `bijectors` | Bijectors that can map parameters between the _unconstrained space_ and their original _constrained space_. | -# -# Further, upon calling `initialise`, we can state specific initial values for some, or -# all, of the parameters within our model. By default, the kernel lengthscale and -# variance and the likelihood's variance parameter are all initialised to 1. However, -# in the following cell, we'll demonstrate how the kernel lengthscale can be -# initialised to 0.5. +# As outlined in the [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html) +# documentation, parameters are contained within the model and for the leaves of the +# PyTree. Consequently, in this particular model, we have three parameters: the +# kernel lengthscale, kernel variance and the observation noise variance. Whilst +# we have initialised each of these to 1, we can learn Type 2 MLEs for each of +# these parameters by optimising the marginal log-likelihood (MLL). # %% vscode={"languageId": "python"} +# TODO: drop this once `step` is implemented into `Objectives` negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True)) negative_mll(posterior, train_data=D) @@ -199,10 +191,12 @@ ) # %% [markdown] -# Similar to the `ParameterState` object above, the returned variable from the `fit` -# function is a class, namely an `InferenceState` object that contains the parameters' -# final values and a tracked array of the evaluation of our objective function -# throughout optimisation. +# The calling of `fit` returns two objects: the optimised posterior and a history of +# training losses. We can plot the training loss to see how the optimisation has +# progressed. + +# %% vscode={"languageId": "python"} +plt.plot(history) # %% [markdown] # ## Prediction diff --git a/examples/uncollapsed_vi.pct.py b/examples/uncollapsed_vi.pct.py index ed20e8d9e..32f48e91e 100644 --- a/examples/uncollapsed_vi.pct.py +++ b/examples/uncollapsed_vi.pct.py @@ -110,14 +110,11 @@ # We show a cost comparison between the approaches below, where $b$ is the mini-batch # size. # -# -# # | | GPs | sparse GPs | SVGP | # | -- | -- | -- | -- | # | Inference cost | $\mathcal{O}(n^3)$ | $\mathcal{O}(n m^2)$ | $\mathcal{O}(b m^2 + m^3)$ | # | Memory cost | $\mathcal{O}(n^2)$ | $\mathcal{O}(n m)$ | $\mathcal{O}(b m + m^2)$ | # -# # To apply SVGP inference to our dataset, we begin by initialising $m = 50$ equally # spaced inducing inputs $\boldsymbol{z}$ across our observed data's support. These # are depicted below via horizontal black lines.