Skip to content

Commit

Permalink
Matsolver updates (#251)
Browse files Browse the repository at this point in the history
Add colamd factorized transpose, and other matsolver updates
  • Loading branch information
kburns committed Jun 8, 2023
1 parent fab25b9 commit 5ca7577
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 32 deletions.
90 changes: 60 additions & 30 deletions dedalus/libraries/matsolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,29 +73,39 @@ def __init__(self, matrix, solver=None):
self.matrix = matrix.copy()

def solve(self, vector):
return spla.spsolve(self.matrix, vector, use_umfpack=True)
out = spla.spsolve(self.matrix, vector, use_umfpack=True)
# Fix return shape for matrices
if vector.ndim == 2 and out.ndim == 1:
out = out[:, None]
return out


@add_solver
class SuperluNaturalSpsolve(SparseSolver):
"""SuperLU+NATURAL spsolve."""
class _SuperluSpsolveBase(SparseSolver):
""""SuperLU spsolve base class."""

permc_spec = None

def __init__(self, matrix, solver=None):
self.matrix = matrix.copy()

def solve(self, vector):
return spla.spsolve(self.matrix, vector, permc_spec='NATURAL', use_umfpack=False)
out = spla.spsolve(self.matrix, vector, permc_spec=self.permc_spec, use_umfpack=False)
# Fix return shape for matrices
if vector.ndim == 2 and out.ndim == 1:
out = out[:, None]
return out


@add_solver
class SuperluColamdSpsolve(SparseSolver):
"""SuperLU+COLAMD spsolve."""
class SuperluNaturalSpsolve(_SuperluSpsolveBase):
"""SuperLU spsolve with 'NATURAL' column permutation."""
permc_spec = "NATURAL"

def __init__(self, matrix, solver=None):
self.matrix = matrix.copy()

def solve(self, vector):
return spla.spsolve(self.matrix, vector, permc_spec='COLAMD', use_umfpack=False)
@add_solver
class SuperluColamdSpsolve(_SuperluSpsolveBase):
"""SuperLU spsolve with 'COLAMD' column permutation."""
permc_spec = "COLAMD"


@add_solver
Expand All @@ -104,43 +114,62 @@ class UmfpackFactorized(SparseSolver):

def __init__(self, matrix, solver=None):
from scikits import umfpack
self.LU = spla.factorized(matrix.tocsc())
self.LU = umfpack.splu(matrix.tocsc())

def solve(self, vector):
return self.LU(vector)
return self.LU.solve(vector)


@add_solver
class SuperluNaturalFactorized(SparseSolver):
"""SuperLU+NATURAL LU factorized solve."""
class _SuperluFactorizedBase(SparseSolver):
"""SuperLU factorized solver base class."""

permc_spec = None
diag_pivot_thresh = None
relax = None
panel_size = None
options = {}
trans = "N"

def __init__(self, matrix, solver=None):
self.LU = spla.splu(matrix.tocsc(), permc_spec='NATURAL')
if self.trans == "T":
matrix = matrix.T
elif self.trans == "H":
matrix = matrix.H
self.LU = spla.splu(matrix.tocsc(),
permc_spec=self.permc_spec,
diag_pivot_thresh=self.diag_pivot_thresh,
relax=self.relax,
panel_size=self.panel_size,
options=self.options)

def solve(self, vector):
return self.LU.solve(vector)
return self.LU.solve(vector, trans=self.trans)


@add_solver
class SuperluNaturalFactorizedTranspose(SparseSolver):
"""SuperLU+NATURAL LU factorized solve."""
class SuperluNaturalFactorized(_SuperluFactorizedBase):
"""SuperLU factorized solve with 'NATURAL' column permutation."""
permc_spec = "NATURAL"

def __init__(self, matrix, solver=None):
self.LU = spla.splu(matrix.T.tocsc(), permc_spec='NATURAL')

def solve(self, vector):
return self.LU.solve(vector, trans='T')
@add_solver
class SuperluNaturalFactorizedTranspose(_SuperluFactorizedBase):
"""SuperLU factorized solve with 'NATURAL' row permutation."""
permc_spec = "NATURAL"
trans = "T"


@add_solver
class SuperluColamdFactorized(SparseSolver):
"""SuperLU+COLAMD LU factorized solve."""
class SuperluColamdFactorized(_SuperluFactorizedBase):
"""SuperLU factorized solve with 'COLAMD' column permutation."""
permc_spec = "COLAMD"

def __init__(self, matrix, solver=None):
self.LU = spla.splu(matrix.tocsc(), permc_spec='COLAMD')

def solve(self, vector):
return self.LU.solve(vector)
@add_solver
class SuperluColamdFactorizedTranspose(_SuperluFactorizedBase):
"""SuperLU factorized solve with 'COLAMD' row permutation."""
permc_spec = "COLAMD"
trans = "T"


@add_solver
Expand Down Expand Up @@ -230,6 +259,7 @@ def _solve_block(self, vector):
def _solve_diag(self, vector):
return self.matrix_inv_diagonal * vector


@add_solver
class ScipyDenseLU(DenseSolver):
"""Scipy dense LU factorized solve."""
Expand Down
4 changes: 2 additions & 2 deletions dedalus/tools/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,10 @@ def csr_matvecs(A_csr, x_vec, out_vec):
raise ValueError("Matrix must be in CSR format.")
# Check shapes
M, N = A_csr.shape
n, kx = x_vec.shape
m, ko = out_vec.shape
if x_vec.ndim != 2 or out_vec.ndim != 2:
raise ValueError("Only matrices allowed for input and output.")
n, kx = x_vec.shape
m, ko = out_vec.shape
if M != m or N != n:
raise ValueError(f"Matrix shape {(M,N)} does not match input {(n,)} and output {(m,)} shapes.")
if kx != ko:
Expand Down

0 comments on commit 5ca7577

Please sign in to comment.