-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
38 changed files
with
2,802 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# LinOps | ||
|
||
The `linops` submodule is a lightweight linear operator library written in [`jax`](https://github.com/google/jax). | ||
|
||
# Overview | ||
Consider solving a diagonal matrix $A$ against a vector $b$. | ||
|
||
```python | ||
import jax.numpy as jnp | ||
|
||
n = 1000 | ||
diag = jnp.linspace(1.0, 2.0, n) | ||
|
||
A = jnp.diag(diag) | ||
b = jnp.linspace(3.0, 4.0, n) | ||
|
||
# A⁻¹ b | ||
jnp.solve(A, b) | ||
``` | ||
Doing so is costly in large problems. Storing the matrix gives rise to memory costs of $O(n^2)$, and inverting the matrix costs $O(n^3)$ in the number of data points $n$. | ||
|
||
But hold on a second. Notice: | ||
|
||
- We only have to store the diagonal entries to determine the matrix $A$. Doing so, would reduce memory costs from $O(n^2)$ to $O(n)$. | ||
- To invert $A$, we only need to take the reciprocal of the diagonal, reducing inversion costs from $O(n^3)$, to $O(n)$. | ||
|
||
`JaxLinOp` is designed to exploit stucture of this kind. | ||
```python | ||
from gpjax import linops | ||
|
||
A = linops.DiagonalLinearOperator(diag = diag) | ||
|
||
# A⁻¹ b | ||
A.solve(b) | ||
``` | ||
`linops` is designed to automatically reduce cost savings in matrix addition, multiplication, computing log-determinants and more, for other matrix stuctures too! | ||
|
||
# Custom Linear Operator (details to come soon) | ||
|
||
The flexible design of `linops` will allow users to impliment their own custom linear operators. | ||
|
||
```python | ||
from gpjax.linops import LinearOperator | ||
|
||
class MyLinearOperator(LinearOperator): | ||
|
||
def __init__(self, ...) | ||
... | ||
|
||
# There will be a minimal number methods that users need to impliment for their custom operator. | ||
# For optimal efficiency, we'll make it easy for the user to add optional methods to their operator, | ||
# if they give better performance than the defaults. | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright 2022 The JaxLinOp Contributors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
from .linear_operator import LinearOperator | ||
from .dense_linear_operator import DenseLinearOperator | ||
from .diagonal_linear_operator import DiagonalLinearOperator | ||
from .constant_diagonal_linear_operator import ConstantDiagonalLinearOperator | ||
from .identity_linear_operator import IdentityLinearOperator | ||
from .zero_linear_operator import ZeroLinearOperator | ||
from .triangular_linear_operator import ( | ||
LowerTriangularLinearOperator, | ||
UpperTriangularLinearOperator, | ||
) | ||
from .utils import ( | ||
identity, | ||
to_dense, | ||
) | ||
|
||
__all__ = [ | ||
"LinearOperator", | ||
"DenseLinearOperator", | ||
"DiagonalLinearOperator", | ||
"ConstantDiagonalLinearOperator", | ||
"IdentityLinearOperator", | ||
"ZeroLinearOperator", | ||
"LowerTriangularLinearOperator", | ||
"UpperTriangularLinearOperator", | ||
"identity", | ||
"to_dense", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
# Copyright 2022 The JaxLinOp Contributors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Any, Union | ||
|
||
import jax.numpy as jnp | ||
from jaxtyping import Array, Float | ||
from simple_pytree import static_field | ||
from dataclasses import dataclass | ||
|
||
from .linear_operator import LinearOperator | ||
from .diagonal_linear_operator import DiagonalLinearOperator | ||
|
||
|
||
def _check_args(value: Any, size: Any) -> None: | ||
|
||
if not isinstance(size, int): | ||
raise ValueError(f"`length` must be an integer, but `length = {size}`.") | ||
|
||
if value.ndim != 1: | ||
raise ValueError( | ||
f"`value` must be one dimensional scalar, but `value.shape = {value.shape}`." | ||
) | ||
|
||
|
||
@dataclass | ||
class ConstantDiagonalLinearOperator(DiagonalLinearOperator): | ||
value: Float[Array, "1"] | ||
size: int = static_field() | ||
|
||
def __init__( | ||
self, value: Float[Array, "1"], size: int, dtype: jnp.dtype = None | ||
) -> None: | ||
"""Initialize the constant diagonal linear operator. | ||
Args: | ||
value (Float[Array, "1"]): Constant value of the diagonal. | ||
size (int): Size of the diagonal. | ||
""" | ||
|
||
_check_args(value, size) | ||
|
||
if dtype is not None: | ||
value = value.astype(dtype) | ||
|
||
self.value = value | ||
self.size = size | ||
self.shape = (size, size) | ||
self.dtype = value.dtype | ||
|
||
def __add__( | ||
self, other: Union[Float[Array, "N N"], LinearOperator] | ||
) -> DiagonalLinearOperator: | ||
if isinstance(other, ConstantDiagonalLinearOperator): | ||
if other.size == self.size: | ||
return ConstantDiagonalLinearOperator( | ||
value=self.value + other.value, size=self.size | ||
) | ||
|
||
raise ValueError( | ||
f"`length` must be the same, but `length = {self.size}` and `length = {other.size}`." | ||
) | ||
|
||
else: | ||
return super().__add__(other) | ||
|
||
def __mul__(self, other: float) -> LinearOperator: | ||
"""Multiply covariance operator by scalar. | ||
Args: | ||
other (LinearOperator): Scalar. | ||
Returns: | ||
LinearOperator: Covariance operator multiplied by a scalar. | ||
""" | ||
|
||
return ConstantDiagonalLinearOperator(value=self.value * other, size=self.size) | ||
|
||
def _add_diagonal(self, other: DiagonalLinearOperator) -> LinearOperator: | ||
"""Add diagonal to the covariance operator, useful for computing, Kxx + Iσ². | ||
Args: | ||
other (DiagonalLinearOperator): Diagonal covariance operator to add to the covariance operator. | ||
Returns: | ||
LinearOperator: Covariance operator with the diagonal added. | ||
""" | ||
|
||
if isinstance(other, ConstantDiagonalLinearOperator): | ||
if other.size == self.size: | ||
return ConstantDiagonalLinearOperator( | ||
value=self.value + other.value, size=self.size | ||
) | ||
|
||
raise ValueError( | ||
f"`length` must be the same, but `length = {self.size}` and `length = {other.size}`." | ||
) | ||
|
||
else: | ||
return super()._add_diagonal(other) | ||
|
||
def diagonal(self) -> Float[Array, "N"]: | ||
"""Diagonal of the covariance operator.""" | ||
return self.value * jnp.ones(self.size) | ||
|
||
def to_root(self) -> ConstantDiagonalLinearOperator: | ||
""" | ||
Lower triangular. | ||
Returns: | ||
Float[Array, "N N"]: Lower triangular matrix. | ||
""" | ||
return ConstantDiagonalLinearOperator( | ||
value=jnp.sqrt(self.value), size=self.size | ||
) | ||
|
||
def log_det(self) -> Float[Array, "1"]: | ||
"""Log determinant. | ||
Returns: | ||
Float[Array, "1"]: Log determinant of the covariance matrix. | ||
""" | ||
return 2.0 * self.size * jnp.log(self.value) | ||
|
||
def inverse(self) -> ConstantDiagonalLinearOperator: | ||
"""Inverse of the covariance operator. | ||
Returns: | ||
DiagonalLinearOperator: Inverse of the covariance operator. | ||
""" | ||
return ConstantDiagonalLinearOperator(value=1.0 / self.value, size=self.size) | ||
|
||
def solve(self, rhs: Float[Array, "N M"]) -> Float[Array, "N M"]: | ||
"""Solve linear system. | ||
Args: | ||
rhs (Float[Array, "N M"]): Right hand side of the linear system. | ||
Returns: | ||
Float[Array, "N M"]: Solution of the linear system. | ||
""" | ||
|
||
return rhs / self.value | ||
|
||
@classmethod | ||
def from_dense(cls, dense: Float[Array, "N N"]) -> ConstantDiagonalLinearOperator: | ||
"""Construct covariance operator from dense matrix. | ||
Args: | ||
dense (Float[Array, "N N"]): Dense matrix. | ||
Returns: | ||
DiagonalLinearOperator: Covariance operator. | ||
""" | ||
return ConstantDiagonalLinearOperator( | ||
value=jnp.atleast_1d(dense[0, 0]), size=dense.shape[0] | ||
) | ||
|
||
@classmethod | ||
def from_root( | ||
cls, root: ConstantDiagonalLinearOperator | ||
) -> ConstantDiagonalLinearOperator: | ||
"""Construct covariance operator from root. | ||
Args: | ||
root (ConstantDiagonalLinearOperator): Root of the covariance operator. | ||
Returns: | ||
ConstantDiagonalLinearOperator: Covariance operator. | ||
""" | ||
return ConstantDiagonalLinearOperator(value=root.value**2, size=root.size) | ||
|
||
|
||
__all__ = [ | ||
"ConstantDiagonalLinearOperator", | ||
] |
Oops, something went wrong.