Skip to content

Commit

Permalink
stationary kernels refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
frazane committed Mar 28, 2023
1 parent 6859304 commit ace9e5f
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 149 deletions.
39 changes: 13 additions & 26 deletions gpjax/kernels/stationary/matern32.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,28 @@
# limitations under the License.
# ==============================================================================

from dataclasses import dataclass
from typing import Dict, List, Optional

import jax.numpy as jnp
from jax.random import KeyArray
from jaxtyping import Array, Float

from ...parameters import Softplus, param_field
from ..base import AbstractKernel
from ..computations import (
DenseKernelComputation,
)
from .utils import euclidean_distance, build_student_t_distribution
from ..computations import DenseKernelComputation
from .utils import build_student_t_distribution, euclidean_distance


@dataclass
class Matern32(AbstractKernel):
"""The Matérn kernel with smoothness parameter fixed at 1.5."""

def __init__(
self,
active_dims: Optional[List[int]] = None,
name: Optional[str] = "Matern 3/2",
) -> None:
spectral_density = build_student_t_distribution(nu=3)
super().__init__(DenseKernelComputation, active_dims, spectral_density, name)
self._stationary = True
lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus)
variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus)

def __call__(
self,
params: Dict,
x: Float[Array, "1 D"],
y: Float[Array, "1 D"],

This comment has been minimized.

Copy link
@daniel-dodd

daniel-dodd Mar 29, 2023

Member

The typing needs to be updated from Float[Array, "1 D"] to Float[Array, "D"] on all kernel inputs!

) -> Float[Array, "1"]:
Expand All @@ -51,25 +45,18 @@ def __call__(
k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{3}\\lvert x-y \\rvert}{\\ell^2} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{3}\\lvert x-y\\rvert}{\\ell^2} \\Bigg)
Args:
params (Dict): Parameter set for which the kernel should be evaluated on.
x (Float[Array, "1 D"]): The left hand argument of the kernel function's call.
y (Float[Array, "1 D"]): The right hand argument of the kernel function's call.
Returns:
Float[Array, "1"]: The value of :math:`k(x, y)`.
"""
x = self.slice_input(x) / params["lengthscale"]
y = self.slice_input(y) / params["lengthscale"]
x = self.slice_input(x) / self.lengthscale
y = self.slice_input(y) / self.lengthscale
tau = euclidean_distance(x, y)
K = (
params["variance"]
* (1.0 + jnp.sqrt(3.0) * tau)
* jnp.exp(-jnp.sqrt(3.0) * tau)
)
K = self.variance * (1.0 + jnp.sqrt(3.0) * tau) * jnp.exp(-jnp.sqrt(3.0) * tau)
return K.squeeze()

def init_params(self, key: KeyArray) -> Dict:
return {
"lengthscale": jnp.array([1.0] * self.ndims),
"variance": jnp.array([1.0]),
}
@property
def spectral_density(self):
return build_student_t_distribution(nu=3)
35 changes: 14 additions & 21 deletions gpjax/kernels/stationary/matern52.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,28 @@
# limitations under the License.
# ==============================================================================

from dataclasses import dataclass
from typing import Dict, List, Optional

import jax.numpy as jnp
from jax.random import KeyArray
from jaxtyping import Array, Float

from ...parameters import Softplus, param_field
from ..base import AbstractKernel
from ..computations import (
DenseKernelComputation,
)
from .utils import euclidean_distance, build_student_t_distribution
from ..computations import DenseKernelComputation
from .utils import build_student_t_distribution, euclidean_distance


@dataclass
class Matern52(AbstractKernel):
"""The Matérn kernel with smoothness parameter fixed at 2.5."""

def __init__(
self,
active_dims: Optional[List[int]] = None,
name: Optional[str] = "Matern 5/2",
) -> None:
spectral_density = build_student_t_distribution(nu=5)
super().__init__(DenseKernelComputation, active_dims, spectral_density, name)
lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus)
variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus)

def __call__(
self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"]
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"]
) -> Float[Array, "1"]:
"""Evaluate the kernel on a pair of inputs :math:`(x, y)` with
lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2`
Expand All @@ -47,25 +43,22 @@ def __call__(
k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{5}\\lvert x-y \\rvert}{\\ell^2} + \\frac{5\\lvert x - y \\rvert^2}{3\\ell^2} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{5}\\lvert x-y\\rvert}{\\ell^2} \\Bigg)
Args:
params (Dict): Parameter set for which the kernel should be evaluated on.
x (Float[Array, "1 D"]): The left hand argument of the kernel function's call.
y (Float[Array, "1 D"]): The right hand argument of the kernel function's call.
Returns:
Float[Array, "1"]: The value of :math:`k(x, y)`.
"""
x = self.slice_input(x) / params["lengthscale"]
y = self.slice_input(y) / params["lengthscale"]
x = self.slice_input(x) / self.lengthscale
y = self.slice_input(y) / self.lengthscale
tau = euclidean_distance(x, y)
K = (
params["variance"]
self.variance
* (1.0 + jnp.sqrt(5.0) * tau + 5.0 / 3.0 * jnp.square(tau))
* jnp.exp(-jnp.sqrt(5.0) * tau)
)
return K.squeeze()

def init_params(self, key: KeyArray) -> Dict:
return {
"lengthscale": jnp.array([1.0] * self.ndims),
"variance": jnp.array([1.0]),
}
@property
def spectral_density(self):
return build_student_t_distribution(nu=5)
37 changes: 14 additions & 23 deletions gpjax/kernels/stationary/periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,54 +18,45 @@
import jax
import jax.numpy as jnp
from jax.random import KeyArray
from jaxtyping import Array
from jaxtyping import Array, Float

from ..base import AbstractKernel
from ..computations import (
DenseKernelComputation,
)

from dataclasses import dataclass
from ...parameters import param_field, Softplus

@dataclass
class Periodic(AbstractKernel):
"""The periodic kernel.
Key reference is MacKay 1998 - "Introduction to Gaussian processes".
"""

def __init__(
self,
active_dims: Optional[List[int]] = None,
name: Optional[str] = "Periodic",
) -> None:
super().__init__(
DenseKernelComputation, active_dims, spectral_density=None, name=name
)
self._stationary = True
lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus)
variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus)
period: Float[Array, "1"] = param_field(
jnp.array([1.0]), bijector=Softplus
) # NOTE: is bijector needed?

def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array:
def __call__(self, x: jax.Array, y: jax.Array) -> Array:
"""Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell` and variance :math:`\\sigma`
TODO: write docstring
.. math::
k(x, y) = \\sigma^2 \\exp \\Bigg( -0.5 \\sum_{i=1}^{d} \\Bigg)
Args:
x (jax.Array): The left hand argument of the kernel function's call.
y (jax.Array): The right hand argument of the kernel function's call
params (dict): Parameter set for which the kernel should be evaluated on.
Returns:
Array: The value of :math:`k(x, y)`
"""
x = self.slice_input(x)
y = self.slice_input(y)
sine_squared = (
jnp.sin(jnp.pi * (x - y) / params["period"]) / params["lengthscale"]
) ** 2
K = params["variance"] * jnp.exp(-0.5 * jnp.sum(sine_squared, axis=0))
sine_squared = (jnp.sin(jnp.pi * (x - y) / self.period) / self.lengthscale) ** 2
K = self.variance * jnp.exp(-0.5 * jnp.sum(sine_squared, axis=0))
return K.squeeze()

def init_params(self, key: KeyArray) -> Dict:
return {
"lengthscale": jnp.array([1.0] * self.ndims),
"variance": jnp.array([1.0]),
"period": jnp.array([1.0] * self.ndims),
}
37 changes: 12 additions & 25 deletions gpjax/kernels/stationary/powered_exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,33 @@
# limitations under the License.
# ==============================================================================

from dataclasses import dataclass
from typing import Dict, List, Optional

import jax
import jax.numpy as jnp
from jax.random import KeyArray
from jaxtyping import Array
from jaxtyping import Array, Float

from ...parameters import Softplus, param_field
from ..base import AbstractKernel
from ..computations import (
DenseKernelComputation,
)
from ..computations import DenseKernelComputation
from .utils import euclidean_distance


@dataclass
class PoweredExponential(AbstractKernel):
"""The powered exponential family of kernels.
Key reference is Diggle and Ribeiro (2007) - "Model-based Geostatistics".
"""

def __init__(
self,
active_dims: Optional[List[int]] = None,
name: Optional[str] = "Powered exponential",
) -> None:
super().__init__(
DenseKernelComputation, active_dims, spectral_density=None, name=name
)
self._stationary = True
lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus)
variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus)
power: Float[Array, "1"] = param_field(jnp.array([1.0]))

def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array:
def __call__(self, x: jax.Array, y: jax.Array) -> Array:
"""Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell`, :math:`\\sigma` and power :math:`\\kappa`.
.. math::
Expand All @@ -53,19 +48,11 @@ def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array:
Args:
x (jax.Array): The left hand argument of the kernel function's call.
y (jax.Array): The right hand argument of the kernel function's call
params (dict): Parameter set for which the kernel should be evaluated on.
Returns:
Array: The value of :math:`k(x, y)`
"""
x = self.slice_input(x) / params["lengthscale"]
y = self.slice_input(y) / params["lengthscale"]
K = params["variance"] * jnp.exp(-euclidean_distance(x, y) ** params["power"])
x = self.slice_input(x) / self.lengthscale
y = self.slice_input(y) / self.lengthscale
K = self.variance * jnp.exp(-euclidean_distance(x, y) ** self.power)
return K.squeeze()

def init_params(self, key: KeyArray) -> Dict:
return {
"lengthscale": jnp.array([1.0] * self.ndims),
"variance": jnp.array([1.0]),
"power": jnp.array([1.0]),
}
42 changes: 15 additions & 27 deletions gpjax/kernels/stationary/rational_quadratic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,28 @@
# limitations under the License.
# ==============================================================================

from dataclasses import dataclass
from typing import List, Optional

import jax
import jax.numpy as jnp
from jax.random import KeyArray
from jaxtyping import Array
from jaxtyping import Array, Float

from ...parameters import Softplus, param_field
from ..base import AbstractKernel
from ..computations import (
DenseKernelComputation,
)
from ..computations import DenseKernelComputation
from .utils import squared_distance


@dataclass
class RationalQuadratic(AbstractKernel):
def __init__(
self,
active_dims: Optional[List[int]] = None,
name: Optional[str] = "Rational Quadratic",
) -> None:
super().__init__(
DenseKernelComputation, active_dims, spectral_density=None, name=name
)
self._stationary = True

def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array:
lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus)
variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus)
alpha: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus)

def __call__(self, x: jax.Array, y: jax.Array) -> Array:
"""Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell` and variance :math:`\\sigma`
.. math::
Expand All @@ -47,20 +43,12 @@ def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array:
Args:
x (jax.Array): The left hand argument of the kernel function's call.
y (jax.Array): The right hand argument of the kernel function's call
params (dict): Parameter set for which the kernel should be evaluated on.
Returns:
Array: The value of :math:`k(x, y)`
"""
x = self.slice_input(x) / params["lengthscale"]
y = self.slice_input(y) / params["lengthscale"]
K = params["variance"] * (
1 + 0.5 * squared_distance(x, y) / params["alpha"]
) ** (-params["alpha"])
x = self.slice_input(x) / self.lengthscale
y = self.slice_input(y) / self.lengthscale
K = self.variance * (1 + 0.5 * squared_distance(x, y) / self.alpha) ** (
-self.alpha
)
return K.squeeze()

def init_params(self, key: KeyArray) -> dict:
return {
"lengthscale": jnp.array([1.0] * self.ndims),
"variance": jnp.array([1.0]),
"alpha": jnp.array([1.0]),
}
Loading

0 comments on commit ace9e5f

Please sign in to comment.