Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring kernels #206

Merged
merged 8 commits into from
Mar 29, 2023
Merged

Refactoring kernels #206

merged 8 commits into from
Mar 29, 2023

Conversation

frazane
Copy link
Contributor

@frazane frazane commented Mar 28, 2023

Just contributing to #199.

Pull request type

Please check the type of change your PR introduces:

  • Bugfix
  • Feature
  • Code style update (formatting, renaming)
  • Refactoring (no functional changes, no api changes)
  • Build related changes
  • Documentation content changes
  • Other (please describe):

@codecov-commenter
Copy link

Codecov Report

Merging #206 (b609c0b) into refactor_kernels (623be42) will increase coverage by 0.33%.
The diff coverage is 34.65%.

📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more

@@                 Coverage Diff                  @@
##           refactor_kernels     #206      +/-   ##
====================================================
+ Coverage             28.29%   28.63%   +0.33%     
====================================================
  Files                    70       70              
  Lines                  3262     3332      +70     
====================================================
+ Hits                    923      954      +31     
- Misses                 2339     2378      +39     
Impacted Files Coverage Δ
tests/test_kernels/test_stationary.py 0.00% <0.00%> (ø)
gpjax/kernels/stationary/matern32.py 72.72% <69.23%> (+17.72%) ⬆️
gpjax/kernels/stationary/matern12.py 77.27% <71.42%> (+19.37%) ⬆️
gpjax/kernels/stationary/matern52.py 72.72% <75.00%> (+14.83%) ⬆️
gpjax/kernels/stationary/powered_exponential.py 80.00% <75.00%> (+16.84%) ⬆️
gpjax/kernels/stationary/rational_quadratic.py 80.00% <75.00%> (+16.84%) ⬆️
gpjax/kernels/stationary/periodic.py 75.00% <80.00%> (+17.10%) ⬆️
gpjax/kernels/stationary/white.py 84.61% <85.71%> (+20.32%) ⬆️
gpjax/parameters/module.py 33.33% <100.00%> (+0.68%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@daniel-dodd daniel-dodd marked this pull request as ready for review March 29, 2023 19:17
Comment on lines 41 to 37
self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"]
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to replace all kernel inputs of Float[Array, "1 D"] with Float[Array, "D"] for the x's and y's!

Comment on lines 28 to 30
class White(AbstractKernel):
def __init__(
self,
compute_engine: AbstractKernelComputation = ConstantDiagonalKernelComputation,
active_dims: Optional[List[int]] = None,
name: Optional[str] = "White Noise Kernel",
) -> None:
super().__init__(compute_engine, active_dims, spectral_density=None, name=name)
self._stationary = True

variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the White kernel, the default computation should be ConstantDiagonalKernelComputation , so that

from simple_pytree import static_field

@dataclass
class White(AbstractKernel):

    variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus)
    compute_engine: AbstractKernelComputation = static_field(ConstantDiagonalKernelComputation) # <- set the default

Comment on lines 114 to 121
def test_gram(self, dim: int, n: int) -> None:
kernel: AbstractKernel = self.kernel()
kernel.gram
x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim)
Kxx = kernel.gram(x)
assert isinstance(Kxx, LinearOperator)
assert Kxx.shape == (n, n)
assert jnp.all(jnp.linalg.eigvalsh(Kxx.to_dense()) > 0.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you test positive definiteness, I would recommend adding a small jitter, to the diagonal of the gram matrix Kxx, I found that one of the tests failed for me locally on my machine.

I would also give a quick comment to say what the line:

assert jnp.all(jnp.linalg.eigvalsh(Kxx.to_dense()) > 0.0)

is doing to make it clear for others!

@daniel-dodd daniel-dodd merged commit 48d7fdb into refactor_kernels Mar 29, 2023
@thomaspinder thomaspinder deleted the refactor_kernels_zaf branch April 7, 2023 18:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants