Skip to content

Commit

Permalink
Merge pull request #276 from DedalusProject/performance
Browse files Browse the repository at this point in the history
Linalg improvements
  • Loading branch information
kburns committed Dec 29, 2023
2 parents 527969c + 3b6412b commit fe693f0
Show file tree
Hide file tree
Showing 10 changed files with 501 additions and 64 deletions.
8 changes: 4 additions & 4 deletions dedalus/core/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5228,9 +5228,9 @@ def radial_matrix(self, spinindex_in, spinindex_out, m):
raise ValueError("This should never happen.")
n_size = self.input_basis.n_size(m)
if m == 0:
return sparse.identity(n_size)
return sparse.identity(n_size).tocsr()
else:
return sparse.csr_matrix((0, n_size), dtype=self.dtype)
return sparse.csr_matrix((0, n_size), dtype=self.dtype).tocsr()


class SphereAzimuthalAverage(AzimuthalAverage, operators.Average, operators.SpectralOperator):
Expand Down Expand Up @@ -5343,9 +5343,9 @@ def radial_matrix(self, regindex_in, regindex_out, ell):
raise ValueError("This should never happen.")
n_size = self.input_basis.n_size(ell)
if ell == 0:
return sparse.identity(n_size)
return sparse.identity(n_size).tocsr()
else:
return sparse.csr_matrix((0, n_size), dtype=self.dtype)
return sparse.csr_matrix((0, n_size), dtype=self.dtype).tocsr()


class IntegrateSpinBasis(operators.PolarMOperator):
Expand Down
2 changes: 1 addition & 1 deletion dedalus/core/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .evaluator import Evaluator
from ..libraries.matsolvers import matsolvers
from ..tools.config import config
from ..tools.array import csr_matvecs, scipy_sparse_eigs
from ..tools.array import scipy_sparse_eigs

import logging
logger = logging.getLogger(__name__.split('.')[-1])
Expand Down
24 changes: 10 additions & 14 deletions dedalus/core/subsystems.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from math import prod

from .domain import Domain
from ..tools.array import zeros_with_pattern, expand_pattern, sparse_block_diag, copyto, perm_matrix, drop_empty_rows, csr_matvecs, assert_sparse_pinv
from ..tools.array import zeros_with_pattern, expand_pattern, sparse_block_diag, copyto, perm_matrix, drop_empty_rows, apply_sparse, assert_sparse_pinv
from ..tools.cache import CachedAttribute, CachedMethod
from ..tools.general import replace, OrderedSet
from ..tools.progress import log_progress
Expand Down Expand Up @@ -329,7 +329,6 @@ def _output_field_views(self, fields):
views.append((buffer_view, field_view))
return tuple(views)

#@profile
def gather_inputs(self, fields, out=None):
"""Gather and precondition subproblem data from input-like field list."""
# Gather from fields
Expand All @@ -339,8 +338,7 @@ def gather_inputs(self, fields, out=None):
# Apply right preconditioner inverse to compress inputs
if out is None:
out = self._compressed_buffer
out.fill(0)
csr_matvecs(self.pre_right_pinv, self._input_buffer, out)
apply_sparse(self.pre_right_pinv, self._input_buffer, axis=0, out=out)
return out

def gather_outputs(self, fields, out=None):
Expand All @@ -352,15 +350,13 @@ def gather_outputs(self, fields, out=None):
# Apply left preconditioner to compress outputs
if out is None:
out = self._compressed_buffer
out.fill(0)
csr_matvecs(self.pre_left, self._output_buffer, out)
apply_sparse(self.pre_left, self._output_buffer, axis=0, out=out)
return out

def scatter_inputs(self, data, fields):
"""Precondition and scatter subproblem data out to input-like field list."""
# Undo right preconditioner inverse to expand inputs
self._input_buffer.fill(0)
csr_matvecs(self.pre_right, data, self._input_buffer)
apply_sparse(self.pre_right, data, axis=0, out=self._input_buffer)
# Scatter to fields
views = self._input_field_views(tuple(fields))
for buffer_view, field_view in views:
Expand All @@ -369,8 +365,7 @@ def scatter_inputs(self, data, fields):
def scatter_outputs(self, data, fields):
"""Precondition and scatter subproblem data out to output-like field list."""
# Undo left preconditioner to expand outputs
self._output_buffer.fill(0)
csr_matvecs(self.pre_left_pinv, data, self._output_buffer)
apply_sparse(self.pre_left_pinv, data, axis=0, out=self._output_buffer)
# Scatter to fields
views = self._output_field_views(tuple(fields))
for buffer_view, field_view in views:
Expand Down Expand Up @@ -554,10 +549,11 @@ def build_matrices(self, names):
right_perm = right_permutation(self, vars, tau_left=solver.tau_left, interleave_components=solver.interleave_components).tocsr()

# Preconditioners
self.pre_left = drop_empty_rows(left_perm @ valid_eqn).tocsr()
self.pre_left_pinv = self.pre_left.T.tocsr()
self.pre_right_pinv = drop_empty_rows(right_perm @ valid_var).tocsr()
self.pre_right = self.pre_right_pinv.T.tocsr()
# TODO: remove astype casting, requires dealing with used types in apply_sparse
self.pre_left = drop_empty_rows(left_perm @ valid_eqn).tocsr().astype(dtype)
self.pre_left_pinv = self.pre_left.T.tocsr().astype(dtype)
self.pre_right_pinv = drop_empty_rows(right_perm @ valid_var).tocsr().astype(dtype)
self.pre_right = self.pre_right_pinv.T.tocsr().astype(dtype)

# Check preconditioner pseudoinverses
assert_sparse_pinv(self.pre_left, self.pre_left_pinv)
Expand Down
19 changes: 6 additions & 13 deletions dedalus/core/timesteppers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from scipy.linalg import blas

from .system import CoeffSystem
from ..tools.array import csr_matvecs
from ..tools.array import apply_sparse

# Track implemented schemes
schemes = OrderedDict()
Expand Down Expand Up @@ -123,16 +123,13 @@ def step(self, dt, wall_time):
sp.LHS_solver = None

# Evaluate M.X0 and L.X0
MX0.data.fill(0)
LX0.data.fill(0)
evaluator.require_coeff_space(state_fields)
for sp in subproblems:
spX = sp.gather_inputs(state_fields)
csr_matvecs(sp.M_min, spX, MX0.get_subdata(sp))
csr_matvecs(sp.L_min, spX, LX0.get_subdata(sp))
apply_sparse(sp.M_min, spX, axis=0, out=MX0.get_subdata(sp))
apply_sparse(sp.L_min, spX, axis=0, out=LX0.get_subdata(sp))

# Evaluate F(X0)
F0.data.fill(0)
evaluator.evaluate_scheduled(iteration=iteration, wall_time=wall_time, sim_time=sim_time, timestep=dt)
evaluator.require_coeff_space(F_fields)
for sp in subproblems:
Expand Down Expand Up @@ -569,14 +566,12 @@ def step(self, dt, wall_time):
sp.LHS_solvers = [None] * (self.stages+1)

# Compute M.X(n,0) and L.X(n,0)
MX0.data.fill(0)
LX0.data.fill(0)
# Ensure coeff space before subsystem gathers
evaluator.require_coeff_space(state_fields)
for sp in subproblems:
spX = sp.gather_inputs(state_fields)
csr_matvecs(sp.M_min, spX, MX0.get_subdata(sp))
csr_matvecs(sp.L_min, spX, LX0.get_subdata(sp))
apply_sparse(sp.M_min, spX, axis=0, out=MX0.get_subdata(sp))
apply_sparse(sp.L_min, spX, axis=0, out=LX0.get_subdata(sp))

# Compute stages
# (M + k Hii L).X(n,i) = M.X(n,0) + k Aij F(n,j) - k Hij L.X(n,j)
Expand All @@ -585,20 +580,18 @@ def step(self, dt, wall_time):
# Compute L.X(n,i-1), already done for i=1
if i > 1:
LXi = LX[i-1]
LXi.data.fill(0)
# Ensure coeff space before subsystem gathers
evaluator.require_coeff_space(state_fields)
for sp in subproblems:
spX = sp.gather_inputs(state_fields)
csr_matvecs(sp.L_min, spX, LXi.get_subdata(sp))
apply_sparse(sp.L_min, spX, axis=0, out=LXi.get_subdata(sp))

# Compute F(n,i-1), only doing output on first evaluation
if i == 1:
evaluator.evaluate_scheduled(iteration=iteration, wall_time=wall_time, sim_time=solver.sim_time, timestep=dt)
else:
evaluator.evaluate_group('F')
Fi = F[i-1]
Fi.data.fill(0)
for sp in subproblems:
# F fields should be in coeff space from evaluator
sp.gather_outputs(F_fields, out=Fi.get_subdata(sp))
Expand Down

0 comments on commit fe693f0

Please sign in to comment.