From 6d8ace97c0819d5d127f7491a0291b6c1c743648 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Sun, 30 Apr 2023 20:11:18 +0100 Subject: [PATCH 1/6] Add static_field to gpjax.base --- gpjax/base/__init__.py | 2 ++ gpjax/base/module.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/gpjax/base/__init__.py b/gpjax/base/__init__.py index b967bfc5..96551311 100644 --- a/gpjax/base/__init__.py +++ b/gpjax/base/__init__.py @@ -20,6 +20,7 @@ meta_flatten, meta_leaves, meta_map, + static_field, save_tree, ) from gpjax.base.param import param_field @@ -31,6 +32,7 @@ "meta_map", "meta", "param_field", + "static_field", "save_tree", "load_tree", ] diff --git a/gpjax/base/module.py b/gpjax/base/module.py index d6763a8d..c61e46b9 100644 --- a/gpjax/base/module.py +++ b/gpjax/base/module.py @@ -14,7 +14,7 @@ # ============================================================================== -__all__ = ["Module", "meta_leaves", "meta_flatten", "meta_map", "meta"] +__all__ = ["Module", "meta_leaves", "meta_flatten", "meta_map", "meta", "static_field"] from copy import ( copy, From acfc637613d32803fcde06b8fa28490a6c565728 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Sun, 30 Apr 2023 20:18:38 +0100 Subject: [PATCH 2/6] Update imports. --- examples/deep_kernels.pct.py | 2 +- examples/kernels.pct.py | 2 +- gpjax/gps.py | 2 +- gpjax/kernels/approximations/rff.py | 2 +- gpjax/kernels/base.py | 2 +- gpjax/kernels/non_euclidean/graph.py | 2 +- gpjax/kernels/nonstationary/arccosine.py | 2 +- gpjax/kernels/nonstationary/polynomial.py | 2 +- gpjax/kernels/stationary/white.py | 2 +- gpjax/likelihoods.py | 2 +- gpjax/linops/constant_diagonal_linear_operator.py | 2 +- gpjax/mean_functions.py | 2 +- gpjax/objectives.py | 2 +- gpjax/variational_families.py | 2 +- tests/test_base/test_module.py | 2 +- tests/test_linops/test_linear_operator.py | 2 +- 16 files changed, 16 insertions(+), 16 deletions(-) diff --git a/examples/deep_kernels.pct.py b/examples/deep_kernels.pct.py index 7d20a427..d3b739c1 100644 --- a/examples/deep_kernels.pct.py +++ b/examples/deep_kernels.pct.py @@ -22,7 +22,7 @@ from jax.config import config from jaxtyping import Array, Float from scipy.signal import sawtooth -from simple_pytree import static_field +from gpjax.base import static_field from jaxtyping import install_import_hook diff --git a/examples/kernels.pct.py b/examples/kernels.pct.py index c8b3fab1..851ec12a 100644 --- a/examples/kernels.pct.py +++ b/examples/kernels.pct.py @@ -17,7 +17,7 @@ from jax import jit from jax.config import config from jaxtyping import Array, Float -from simple_pytree import static_field +from gpjax.base import static_field from jaxtyping import install_import_hook diff --git a/gpjax/gps.py b/gpjax/gps.py index b0cc3a05..8b96cce2 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -30,7 +30,7 @@ Float, Num, ) -from simple_pytree import static_field +from gpjax.base import static_field from gpjax.base import ( Module, diff --git a/gpjax/kernels/approximations/rff.py b/gpjax/kernels/approximations/rff.py index f0a0edfb..00e7551b 100644 --- a/gpjax/kernels/approximations/rff.py +++ b/gpjax/kernels/approximations/rff.py @@ -2,7 +2,7 @@ from jax.random import PRNGKey from jaxtyping import Float -from simple_pytree import static_field +from gpjax.base import static_field import tensorflow_probability.substrates.jax.bijectors as tfb from gpjax.base import param_field diff --git a/gpjax/kernels/base.py b/gpjax/kernels/base.py index f874dc65..56b7095d 100644 --- a/gpjax/kernels/base.py +++ b/gpjax/kernels/base.py @@ -34,7 +34,7 @@ Float, Num, ) -from simple_pytree import static_field +from gpjax.base import static_field import tensorflow_probability.substrates.jax.distributions as tfd from gpjax.base import ( diff --git a/gpjax/kernels/non_euclidean/graph.py b/gpjax/kernels/non_euclidean/graph.py index 355d7879..02cfd8f2 100644 --- a/gpjax/kernels/non_euclidean/graph.py +++ b/gpjax/kernels/non_euclidean/graph.py @@ -22,7 +22,7 @@ Int, Num, ) -from simple_pytree import static_field +from gpjax.base import static_field import tensorflow_probability.substrates.jax as tfp from gpjax.base import param_field diff --git a/gpjax/kernels/nonstationary/arccosine.py b/gpjax/kernels/nonstationary/arccosine.py index fefbc996..f05b0bbf 100644 --- a/gpjax/kernels/nonstationary/arccosine.py +++ b/gpjax/kernels/nonstationary/arccosine.py @@ -18,7 +18,7 @@ from beartype.typing import Union import jax.numpy as jnp from jaxtyping import Float -from simple_pytree import static_field +from gpjax.base import static_field import tensorflow_probability.substrates.jax.bijectors as tfb from gpjax.base import param_field diff --git a/gpjax/kernels/nonstationary/polynomial.py b/gpjax/kernels/nonstationary/polynomial.py index ed2269af..9cd42eef 100644 --- a/gpjax/kernels/nonstationary/polynomial.py +++ b/gpjax/kernels/nonstationary/polynomial.py @@ -17,7 +17,7 @@ import jax.numpy as jnp from jaxtyping import Float -from simple_pytree import static_field +from gpjax.base import static_field import tensorflow_probability.substrates.jax.bijectors as tfb from gpjax.base import param_field diff --git a/gpjax/kernels/stationary/white.py b/gpjax/kernels/stationary/white.py index 37e70305..336b63b9 100644 --- a/gpjax/kernels/stationary/white.py +++ b/gpjax/kernels/stationary/white.py @@ -17,7 +17,7 @@ import jax.numpy as jnp from jaxtyping import Float -from simple_pytree import static_field +from gpjax.base import static_field import tensorflow_probability.substrates.jax.bijectors as tfb import tensorflow_probability.substrates.jax.distributions as tfd diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index 51e88308..b80c3b2f 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -23,7 +23,7 @@ import jax.numpy as jnp import jax.scipy as jsp from jaxtyping import Float -from simple_pytree import static_field +from gpjax.base import static_field import tensorflow_probability.substrates.jax as tfp from gpjax.base import ( diff --git a/gpjax/linops/constant_diagonal_linear_operator.py b/gpjax/linops/constant_diagonal_linear_operator.py index c388a730..4d77d962 100644 --- a/gpjax/linops/constant_diagonal_linear_operator.py +++ b/gpjax/linops/constant_diagonal_linear_operator.py @@ -22,7 +22,7 @@ ) import jax.numpy as jnp from jaxtyping import Float -from simple_pytree import static_field +from gpjax.base import static_field from gpjax.linops.diagonal_linear_operator import DiagonalLinearOperator from gpjax.linops.linear_operator import LinearOperator diff --git a/gpjax/mean_functions.py b/gpjax/mean_functions.py index a86efdb8..6200e7f0 100644 --- a/gpjax/mean_functions.py +++ b/gpjax/mean_functions.py @@ -28,7 +28,7 @@ Float, Num, ) -from simple_pytree import static_field +from gpjax.base import static_field from gpjax.base import ( Module, diff --git a/gpjax/objectives.py b/gpjax/objectives.py index a787efaf..6817c06e 100644 --- a/gpjax/objectives.py +++ b/gpjax/objectives.py @@ -6,7 +6,7 @@ import jax.scipy as jsp import jax.tree_util as jtu from jaxtyping import Float -from simple_pytree import static_field +from gpjax.base import static_field import tensorflow_probability.substrates.jax as tfp from gpjax.base import Module diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index 3d3f72b0..671349a8 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -20,7 +20,7 @@ import jax.numpy as jnp import jax.scipy as jsp from jaxtyping import Float -from simple_pytree import static_field +from gpjax.base import static_field import tensorflow_probability.substrates.jax.bijectors as tfb from gpjax.base import ( diff --git a/tests/test_base/test_module.py b/tests/test_base/test_module.py index fb25e470..71dbdd74 100644 --- a/tests/test_base/test_module.py +++ b/tests/test_base/test_module.py @@ -32,13 +32,13 @@ import pytest from simple_pytree import ( Pytree, - static_field, ) import tensorflow_probability.substrates.jax.bijectors as tfb from gpjax.base.module import ( Module, meta, + static_field, ) from gpjax.base.param import param_field diff --git a/tests/test_linops/test_linear_operator.py b/tests/test_linops/test_linear_operator.py index 8b7987b4..2548ce5a 100644 --- a/tests/test_linops/test_linear_operator.py +++ b/tests/test_linops/test_linear_operator.py @@ -21,7 +21,7 @@ import jax.numpy as jnp import jax.tree_util as jtu import pytest -from simple_pytree import static_field +from gpjax.base import static_field from gpjax.linops.linear_operator import LinearOperator From 583fd005b25a4c439c09b3b7c49d774300951215 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Sun, 30 Apr 2023 22:05:14 +0100 Subject: [PATCH 3/6] Add static_field to base --- gpjax/base/module.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/gpjax/base/module.py b/gpjax/base/module.py index c61e46b9..a0102373 100644 --- a/gpjax/base/module.py +++ b/gpjax/base/module.py @@ -33,6 +33,7 @@ Tuple, TypeVar, Union, + Mapping, ) import jax from jax import lax @@ -48,13 +49,41 @@ ) from simple_pytree import ( Pytree, - static_field, ) import tensorflow_probability.substrates.jax.bijectors as tfb Self = TypeVar("Self") +def static_field( + default: Any = dataclasses.MISSING, + *, + default_factory: Any = dataclasses.MISSING, + init: bool = True, + repr: bool = True, + hash: Optional[bool] = None, + compare: bool = True, + metadata: Optional[Mapping[str, Any]] = None, +): + + metadata = {} if metadata is None else dict(metadata) + + if "pytree_node" in metadata: + raise ValueError("Cannot use metadata with `pytree_node` already set.") + + metadata["pytree_node"] = False + + return dataclasses.field( + default=default, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + ) + + class Module(Pytree): _pytree__meta: Dict[str, Any] = static_field() From c114035613212f8716f4170fe83547afcfb67e51 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Sun, 30 Apr 2023 22:13:10 +0100 Subject: [PATCH 4/6] Add fix for python 3.11 --- gpjax/base/module.py | 19 ++++++++++++------- gpjax/base/param.py | 19 ++++++++++++------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/gpjax/base/module.py b/gpjax/base/module.py index a0102373..b3cc66d8 100644 --- a/gpjax/base/module.py +++ b/gpjax/base/module.py @@ -73,14 +73,19 @@ def static_field( metadata["pytree_node"] = False + if default is not dataclasses.MISSING and default_factory is not dataclasses.MISSING: + raise ValueError("Cannot specify both default and default_factory.") + + if default is not dataclasses.MISSING: + default_factory = lambda: default + return dataclasses.field( - default=default, - default_factory=default_factory, - init=init, - repr=repr, - hash=hash, - compare=compare, - metadata=metadata, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, ) diff --git a/gpjax/base/param.py b/gpjax/base/param.py index dead54be..4ced1262 100644 --- a/gpjax/base/param.py +++ b/gpjax/base/param.py @@ -53,12 +53,17 @@ def param_field( metadata["trainable"] = trainable metadata["pytree_node"] = True + if default is not dataclasses.MISSING and default_factory is not dataclasses.MISSING: + raise ValueError("Cannot specify both default and default_factory.") + + if default is not dataclasses.MISSING: + default_factory = lambda: default + return dataclasses.field( - default=default, - default_factory=default_factory, - init=init, - repr=repr, - hash=hash, - compare=compare, - metadata=metadata, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, ) From 3e8321606ca10bedcc5374148964a319b111eb25 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Sun, 30 Apr 2023 22:14:07 +0100 Subject: [PATCH 5/6] Update tests to run on 3.11 --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 307e1041..ea237433 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -12,7 +12,7 @@ jobs: matrix: # Select the Python versions to test against os: ["ubuntu-latest", "macos-latest"] - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] fail-fast: true steps: - name: Check out the code From 0e7be970a533bff3e0342afbe63a443748069530 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Sun, 30 Apr 2023 22:25:04 +0100 Subject: [PATCH 6/6] Update test_base.py --- tests/test_kernels/test_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_kernels/test_base.py b/tests/test_kernels/test_base.py index 91ddc9ad..06d1228e 100644 --- a/tests/test_kernels/test_base.py +++ b/tests/test_kernels/test_base.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -from dataclasses import dataclass +from dataclasses import dataclass, field from jax.config import config import jax.numpy as jnp @@ -55,7 +55,7 @@ def test_abstract_kernel(): # Create a dummy kernel class with __call__ implemented: @dataclass class DummyKernel(AbstractKernel): - test_a: Float[Array, "1"] = jnp.array([1.0]) + test_a: Float[Array, "1"] = field(default_factory = lambda: jnp.array([1.0])) test_b: Float[Array, "1"] = param_field( jnp.array([2.0]), bijector=tfb.Softplus() )