Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static field to base, update tests to run on 3.11, fix 3.11 compatibility #246

Merged
merged 6 commits into from
Apr 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
matrix:
# Select the Python versions to test against
os: ["ubuntu-latest", "macos-latest"]
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
fail-fast: true
steps:
- name: Check out the code
Expand Down
2 changes: 1 addition & 1 deletion examples/deep_kernels.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from jax.config import config
from jaxtyping import Array, Float
from scipy.signal import sawtooth
from simple_pytree import static_field
from gpjax.base import static_field

from jaxtyping import install_import_hook

Expand Down
2 changes: 1 addition & 1 deletion examples/kernels.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from jax import jit
from jax.config import config
from jaxtyping import Array, Float
from simple_pytree import static_field
from gpjax.base import static_field

from jaxtyping import install_import_hook

Expand Down
2 changes: 2 additions & 0 deletions gpjax/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
meta_flatten,
meta_leaves,
meta_map,
static_field,
save_tree,
)
from gpjax.base.param import param_field
Expand All @@ -31,6 +32,7 @@
"meta_map",
"meta",
"param_field",
"static_field",
"save_tree",
"load_tree",
]
38 changes: 36 additions & 2 deletions gpjax/base/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================


__all__ = ["Module", "meta_leaves", "meta_flatten", "meta_map", "meta"]
__all__ = ["Module", "meta_leaves", "meta_flatten", "meta_map", "meta", "static_field"]

from copy import (
copy,
Expand All @@ -33,6 +33,7 @@
Tuple,
TypeVar,
Union,
Mapping,
)
import jax
from jax import lax
Expand All @@ -48,13 +49,46 @@
)
from simple_pytree import (
Pytree,
static_field,
)
import tensorflow_probability.substrates.jax.bijectors as tfb

Self = TypeVar("Self")


def static_field(
default: Any = dataclasses.MISSING,
*,
default_factory: Any = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
metadata: Optional[Mapping[str, Any]] = None,
):

metadata = {} if metadata is None else dict(metadata)

if "pytree_node" in metadata:
raise ValueError("Cannot use metadata with `pytree_node` already set.")

metadata["pytree_node"] = False

if default is not dataclasses.MISSING and default_factory is not dataclasses.MISSING:
raise ValueError("Cannot specify both default and default_factory.")

if default is not dataclasses.MISSING:
default_factory = lambda: default

return dataclasses.field(
default_factory=default_factory,
init=init,
repr=repr,
hash=hash,
compare=compare,
metadata=metadata,
)


class Module(Pytree):
_pytree__meta: Dict[str, Any] = static_field()

Expand Down
19 changes: 12 additions & 7 deletions gpjax/base/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,17 @@ def param_field(
metadata["trainable"] = trainable
metadata["pytree_node"] = True

if default is not dataclasses.MISSING and default_factory is not dataclasses.MISSING:
raise ValueError("Cannot specify both default and default_factory.")

if default is not dataclasses.MISSING:
default_factory = lambda: default

return dataclasses.field(
default=default,
default_factory=default_factory,
init=init,
repr=repr,
hash=hash,
compare=compare,
metadata=metadata,
default_factory=default_factory,
init=init,
repr=repr,
hash=hash,
compare=compare,
metadata=metadata,
)
2 changes: 1 addition & 1 deletion gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
Float,
Num,
)
from simple_pytree import static_field
from gpjax.base import static_field

from gpjax.base import (
Module,
Expand Down
2 changes: 1 addition & 1 deletion gpjax/kernels/approximations/rff.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from jax.random import PRNGKey
from jaxtyping import Float
from simple_pytree import static_field
from gpjax.base import static_field
import tensorflow_probability.substrates.jax.bijectors as tfb

from gpjax.base import param_field
Expand Down
2 changes: 1 addition & 1 deletion gpjax/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
Float,
Num,
)
from simple_pytree import static_field
from gpjax.base import static_field
import tensorflow_probability.substrates.jax.distributions as tfd

from gpjax.base import (
Expand Down
2 changes: 1 addition & 1 deletion gpjax/kernels/non_euclidean/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
Int,
Num,
)
from simple_pytree import static_field
from gpjax.base import static_field
import tensorflow_probability.substrates.jax as tfp

from gpjax.base import param_field
Expand Down
2 changes: 1 addition & 1 deletion gpjax/kernels/nonstationary/arccosine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from beartype.typing import Union
import jax.numpy as jnp
from jaxtyping import Float
from simple_pytree import static_field
from gpjax.base import static_field
import tensorflow_probability.substrates.jax.bijectors as tfb

from gpjax.base import param_field
Expand Down
2 changes: 1 addition & 1 deletion gpjax/kernels/nonstationary/polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import jax.numpy as jnp
from jaxtyping import Float
from simple_pytree import static_field
from gpjax.base import static_field
import tensorflow_probability.substrates.jax.bijectors as tfb

from gpjax.base import param_field
Expand Down
2 changes: 1 addition & 1 deletion gpjax/kernels/stationary/white.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import jax.numpy as jnp
from jaxtyping import Float
from simple_pytree import static_field
from gpjax.base import static_field
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd

Expand Down
2 changes: 1 addition & 1 deletion gpjax/likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import jax.numpy as jnp
import jax.scipy as jsp
from jaxtyping import Float
from simple_pytree import static_field
from gpjax.base import static_field
import tensorflow_probability.substrates.jax as tfp

from gpjax.base import (
Expand Down
2 changes: 1 addition & 1 deletion gpjax/linops/constant_diagonal_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
import jax.numpy as jnp
from jaxtyping import Float
from simple_pytree import static_field
from gpjax.base import static_field

from gpjax.linops.diagonal_linear_operator import DiagonalLinearOperator
from gpjax.linops.linear_operator import LinearOperator
Expand Down
2 changes: 1 addition & 1 deletion gpjax/mean_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
Float,
Num,
)
from simple_pytree import static_field
from gpjax.base import static_field

from gpjax.base import (
Module,
Expand Down
2 changes: 1 addition & 1 deletion gpjax/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax.scipy as jsp
import jax.tree_util as jtu
from jaxtyping import Float
from simple_pytree import static_field
from gpjax.base import static_field
import tensorflow_probability.substrates.jax as tfp

from gpjax.base import Module
Expand Down
2 changes: 1 addition & 1 deletion gpjax/variational_families.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import jax.numpy as jnp
import jax.scipy as jsp
from jaxtyping import Float
from simple_pytree import static_field
from gpjax.base import static_field
import tensorflow_probability.substrates.jax.bijectors as tfb

from gpjax.base import (
Expand Down
2 changes: 1 addition & 1 deletion tests/test_base/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@
import pytest
from simple_pytree import (
Pytree,
static_field,
)
import tensorflow_probability.substrates.jax.bijectors as tfb

from gpjax.base.module import (
Module,
meta,
static_field,
)
from gpjax.base.param import param_field

Expand Down
4 changes: 2 additions & 2 deletions tests/test_kernels/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================

from dataclasses import dataclass
from dataclasses import dataclass, field

from jax.config import config
import jax.numpy as jnp
Expand Down Expand Up @@ -55,7 +55,7 @@ def test_abstract_kernel():
# Create a dummy kernel class with __call__ implemented:
@dataclass
class DummyKernel(AbstractKernel):
test_a: Float[Array, "1"] = jnp.array([1.0])
test_a: Float[Array, "1"] = field(default_factory = lambda: jnp.array([1.0]))
test_b: Float[Array, "1"] = param_field(
jnp.array([2.0]), bijector=tfb.Softplus()
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_linops/test_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import jax.numpy as jnp
import jax.tree_util as jtu
import pytest
from simple_pytree import static_field
from gpjax.base import static_field

from gpjax.linops.linear_operator import LinearOperator

Expand Down