Skip to content

Commit

Permalink
Verify preconditioner pseudoinverses explicitly
Browse files Browse the repository at this point in the history
  • Loading branch information
kburns committed Nov 17, 2023
1 parent 786bf11 commit f4803a6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
16 changes: 10 additions & 6 deletions dedalus/core/subsystems.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import uuid

from .domain import Domain
from ..tools.array import zeros_with_pattern, expand_pattern, sparse_block_diag, copyto, perm_matrix, drop_empty_rows, csr_matvecs
from ..tools.array import zeros_with_pattern, expand_pattern, sparse_block_diag, copyto, perm_matrix, drop_empty_rows, csr_matvecs, assert_sparse_pinv
from ..tools.cache import CachedAttribute, CachedMethod
from ..tools.general import replace, OrderedSet
from ..tools.progress import log_progress
Expand Down Expand Up @@ -339,7 +339,7 @@ def gather_inputs(self, fields, out=None):
if out is None:
out = self._compressed_buffer
out.fill(0)
csr_matvecs(self.pre_right_T, self._input_buffer, out)
csr_matvecs(self.pre_right_pinv, self._input_buffer, out)
return out

def gather_outputs(self, fields, out=None):
Expand Down Expand Up @@ -369,7 +369,7 @@ def scatter_outputs(self, data, fields):
"""Precondition and scatter subproblem data out to output-like field list."""
# Undo left preconditioner to expand outputs
self._output_buffer.fill(0)
csr_matvecs(self.pre_left_T, data, self._output_buffer)
csr_matvecs(self.pre_left_pinv, data, self._output_buffer)
# Scatter to fields
views = self._output_field_views(tuple(fields))
for buffer_view, field_view in views:
Expand Down Expand Up @@ -554,9 +554,13 @@ def build_matrices(self, names):

# Preconditioners
self.pre_left = drop_empty_rows(left_perm @ valid_eqn).tocsr()
self.pre_left_T = self.pre_left.T.tocsr()
self.pre_right_T = drop_empty_rows(right_perm @ valid_var).tocsr()
self.pre_right = self.pre_right_T.T.tocsr()
self.pre_left_pinv = self.pre_left.T.tocsr()
self.pre_right_pinv = drop_empty_rows(right_perm @ valid_var).tocsr()
self.pre_right = self.pre_right_pinv.T.tocsr()

# Check preconditioner pseudoinverses
assert_sparse_pinv(self.pre_left, self.pre_left_pinv)
assert_sparse_pinv(self.pre_right, self.pre_right_pinv)

# Precondition matrices
for name in matrices:
Expand Down
18 changes: 18 additions & 0 deletions dedalus/tools/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,21 @@ def interleave_matrices(matrices):
P[i, i] = 0
return sum


def sparse_allclose(A, B):
A = A.tocsr()
B = B.tocsr()
return (np.allclose(A.data, B.data) and
np.allclose(A.indices, B.indices) and
np.allclose(A.indptr, B.indptr))

def assert_sparse_pinv(A, B):
if not sparse_allclose(A @ B @ A, A):
raise AssertionError("Not a pseudoinverse")
if not sparse_allclose(B @ A @ B, B):
raise AssertionError("Not a pseudoinverse")
if not sparse_allclose((A @ B).conj().T, A @ B):
raise AssertionError("Not a pseudoinverse")
if not sparse_allclose((B @ A).conj().T, B @ A):
raise AssertionError("Not a pseudoinverse")

0 comments on commit f4803a6

Please sign in to comment.