Skip to content

Commit

Permalink
Jaxlinop merge (#196)
Browse files Browse the repository at this point in the history
Merge jaxlinop into main
  • Loading branch information
thomaspinder committed Mar 23, 2023
1 parent 3388f16 commit aa3f5d2
Show file tree
Hide file tree
Showing 38 changed files with 2,802 additions and 37 deletions.
2 changes: 1 addition & 1 deletion gpjax/gaussian_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================

import jax.numpy as jnp
from jaxlinop import LinearOperator, IdentityLinearOperator
from .linops import LinearOperator, IdentityLinearOperator

from jaxtyping import Array, Float
from jax import vmap
Expand Down
4 changes: 2 additions & 2 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from jaxtyping import Array, Float
from jax.random import KeyArray

from jaxlinop import identity
from jaxkern.base import AbstractKernel
from .linops import identity
from .kernels.base import AbstractKernel
from jaxutils import PyTree

from .config import get_global_config
Expand Down
7 changes: 4 additions & 3 deletions gpjax/kernels/computations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
from typing import Callable, Dict

from jax import vmap
from jaxlinop import (
from jaxtyping import Array, Float
from jaxutils import PyTree

from ...linops import (
DenseLinearOperator,
DiagonalLinearOperator,
LinearOperator,
)
from jaxtyping import Array, Float
from jaxutils import PyTree


class AbstractKernelComputation(PyTree):
Expand Down
2 changes: 1 addition & 1 deletion gpjax/kernels/computations/basis_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import jax.numpy as jnp
from jaxtyping import Array, Float
from .base import AbstractKernelComputation
from jaxlinop import DenseLinearOperator
from ...linops import DenseLinearOperator


class BasisFunctionComputation(AbstractKernelComputation):
Expand Down
7 changes: 4 additions & 3 deletions gpjax/kernels/computations/constant_diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
import jax.numpy as jnp

from jax import vmap
from jaxlinop import (
from jaxtyping import Array, Float
from .base import AbstractKernelComputation

from ...linops import (
ConstantDiagonalLinearOperator,
DiagonalLinearOperator,
)
from jaxtyping import Array, Float
from .base import AbstractKernelComputation


class ConstantDiagonalKernelComputation(AbstractKernelComputation):
Expand Down
9 changes: 3 additions & 6 deletions gpjax/kernels/computations/diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,12 @@
# limitations under the License.
# ==============================================================================

from typing import Callable, Dict

from jax import vmap
from jaxlinop import (
DiagonalLinearOperator,
)
from typing import Callable, Dict
from jaxtyping import Array, Float
from .base import AbstractKernelComputation

from .base import AbstractKernelComputation
from ...linops import DiagonalLinearOperator

class DiagonalKernelComputation(AbstractKernelComputation):
def __init__(
Expand Down
2 changes: 1 addition & 1 deletion gpjax/likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import abc
from typing import Any, Callable, Dict, Optional
from jaxlinop.utils import to_dense
from .linops.utils import to_dense
from jaxutils import PyTree

import distrax as dx
Expand Down
53 changes: 53 additions & 0 deletions gpjax/linops/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# LinOps

The `linops` submodule is a lightweight linear operator library written in [`jax`](https://github.com/google/jax).

# Overview
Consider solving a diagonal matrix $A$ against a vector $b$.

```python
import jax.numpy as jnp

n = 1000
diag = jnp.linspace(1.0, 2.0, n)

A = jnp.diag(diag)
b = jnp.linspace(3.0, 4.0, n)

# A⁻¹ b
jnp.solve(A, b)
```
Doing so is costly in large problems. Storing the matrix gives rise to memory costs of $O(n^2)$, and inverting the matrix costs $O(n^3)$ in the number of data points $n$.

But hold on a second. Notice:

- We only have to store the diagonal entries to determine the matrix $A$. Doing so, would reduce memory costs from $O(n^2)$ to $O(n)$.
- To invert $A$, we only need to take the reciprocal of the diagonal, reducing inversion costs from $O(n^3)$, to $O(n)$.

`JaxLinOp` is designed to exploit stucture of this kind.
```python
from gpjax import linops

A = linops.DiagonalLinearOperator(diag = diag)

# A⁻¹ b
A.solve(b)
```
`linops` is designed to automatically reduce cost savings in matrix addition, multiplication, computing log-determinants and more, for other matrix stuctures too!

# Custom Linear Operator (details to come soon)

The flexible design of `linops` will allow users to impliment their own custom linear operators.

```python
from gpjax.linops import LinearOperator

class MyLinearOperator(LinearOperator):

def __init__(self, ...)
...

# There will be a minimal number methods that users need to impliment for their custom operator.
# For optimal efficiency, we'll make it easy for the user to add optional methods to their operator,
# if they give better performance than the defaults.
```
42 changes: 42 additions & 0 deletions gpjax/linops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2022 The JaxLinOp Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from .linear_operator import LinearOperator
from .dense_linear_operator import DenseLinearOperator
from .diagonal_linear_operator import DiagonalLinearOperator
from .constant_diagonal_linear_operator import ConstantDiagonalLinearOperator
from .identity_linear_operator import IdentityLinearOperator
from .zero_linear_operator import ZeroLinearOperator
from .triangular_linear_operator import (
LowerTriangularLinearOperator,
UpperTriangularLinearOperator,
)
from .utils import (
identity,
to_dense,
)

__all__ = [
"LinearOperator",
"DenseLinearOperator",
"DiagonalLinearOperator",
"ConstantDiagonalLinearOperator",
"IdentityLinearOperator",
"ZeroLinearOperator",
"LowerTriangularLinearOperator",
"UpperTriangularLinearOperator",
"identity",
"to_dense",
]
190 changes: 190 additions & 0 deletions gpjax/linops/constant_diagonal_linear_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright 2022 The JaxLinOp Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import annotations

from typing import Any, Union

import jax.numpy as jnp
from jaxtyping import Array, Float
from simple_pytree import static_field
from dataclasses import dataclass

from .linear_operator import LinearOperator
from .diagonal_linear_operator import DiagonalLinearOperator


def _check_args(value: Any, size: Any) -> None:

if not isinstance(size, int):
raise ValueError(f"`length` must be an integer, but `length = {size}`.")

if value.ndim != 1:
raise ValueError(
f"`value` must be one dimensional scalar, but `value.shape = {value.shape}`."
)


@dataclass
class ConstantDiagonalLinearOperator(DiagonalLinearOperator):
value: Float[Array, "1"]
size: int = static_field()

def __init__(
self, value: Float[Array, "1"], size: int, dtype: jnp.dtype = None
) -> None:
"""Initialize the constant diagonal linear operator.
Args:
value (Float[Array, "1"]): Constant value of the diagonal.
size (int): Size of the diagonal.
"""

_check_args(value, size)

if dtype is not None:
value = value.astype(dtype)

self.value = value
self.size = size
self.shape = (size, size)
self.dtype = value.dtype

def __add__(
self, other: Union[Float[Array, "N N"], LinearOperator]
) -> DiagonalLinearOperator:
if isinstance(other, ConstantDiagonalLinearOperator):
if other.size == self.size:
return ConstantDiagonalLinearOperator(
value=self.value + other.value, size=self.size
)

raise ValueError(
f"`length` must be the same, but `length = {self.size}` and `length = {other.size}`."
)

else:
return super().__add__(other)

def __mul__(self, other: float) -> LinearOperator:
"""Multiply covariance operator by scalar.
Args:
other (LinearOperator): Scalar.
Returns:
LinearOperator: Covariance operator multiplied by a scalar.
"""

return ConstantDiagonalLinearOperator(value=self.value * other, size=self.size)

def _add_diagonal(self, other: DiagonalLinearOperator) -> LinearOperator:
"""Add diagonal to the covariance operator, useful for computing, Kxx + Iσ².
Args:
other (DiagonalLinearOperator): Diagonal covariance operator to add to the covariance operator.
Returns:
LinearOperator: Covariance operator with the diagonal added.
"""

if isinstance(other, ConstantDiagonalLinearOperator):
if other.size == self.size:
return ConstantDiagonalLinearOperator(
value=self.value + other.value, size=self.size
)

raise ValueError(
f"`length` must be the same, but `length = {self.size}` and `length = {other.size}`."
)

else:
return super()._add_diagonal(other)

def diagonal(self) -> Float[Array, "N"]:
"""Diagonal of the covariance operator."""
return self.value * jnp.ones(self.size)

def to_root(self) -> ConstantDiagonalLinearOperator:
"""
Lower triangular.
Returns:
Float[Array, "N N"]: Lower triangular matrix.
"""
return ConstantDiagonalLinearOperator(
value=jnp.sqrt(self.value), size=self.size
)

def log_det(self) -> Float[Array, "1"]:
"""Log determinant.
Returns:
Float[Array, "1"]: Log determinant of the covariance matrix.
"""
return 2.0 * self.size * jnp.log(self.value)

def inverse(self) -> ConstantDiagonalLinearOperator:
"""Inverse of the covariance operator.
Returns:
DiagonalLinearOperator: Inverse of the covariance operator.
"""
return ConstantDiagonalLinearOperator(value=1.0 / self.value, size=self.size)

def solve(self, rhs: Float[Array, "N M"]) -> Float[Array, "N M"]:
"""Solve linear system.
Args:
rhs (Float[Array, "N M"]): Right hand side of the linear system.
Returns:
Float[Array, "N M"]: Solution of the linear system.
"""

return rhs / self.value

@classmethod
def from_dense(cls, dense: Float[Array, "N N"]) -> ConstantDiagonalLinearOperator:
"""Construct covariance operator from dense matrix.
Args:
dense (Float[Array, "N N"]): Dense matrix.
Returns:
DiagonalLinearOperator: Covariance operator.
"""
return ConstantDiagonalLinearOperator(
value=jnp.atleast_1d(dense[0, 0]), size=dense.shape[0]
)

@classmethod
def from_root(
cls, root: ConstantDiagonalLinearOperator
) -> ConstantDiagonalLinearOperator:
"""Construct covariance operator from root.
Args:
root (ConstantDiagonalLinearOperator): Root of the covariance operator.
Returns:
ConstantDiagonalLinearOperator: Covariance operator.
"""
return ConstantDiagonalLinearOperator(value=root.value**2, size=root.size)


__all__ = [
"ConstantDiagonalLinearOperator",
]
Loading

0 comments on commit aa3f5d2

Please sign in to comment.