diff --git a/pySDC/implementations/problem_classes/generic_spectral.py b/pySDC/implementations/problem_classes/generic_spectral.py index a8c4bc260b..0410cb3ad9 100644 --- a/pySDC/implementations/problem_classes/generic_spectral.py +++ b/pySDC/implementations/problem_classes/generic_spectral.py @@ -59,6 +59,7 @@ def __init__( left_preconditioner=True, solver_type='cached_direct', solver_args=None, + preconditioner_args=None, useGPU=False, max_cached_factorizations=12, spectral_space=True, @@ -83,11 +84,15 @@ def __init__( debug (bool): Make additional tests at extra computational cost """ solver_args = {} if solver_args is None else solver_args + preconditioner_args = {} if preconditioner_args is None else preconditioner_args + preconditioner_args['drop_tol'] = preconditioner_args.get('drop_tol', 1e-3) + preconditioner_args['fill_factor'] = preconditioner_args.get('fill_factor', 100) self._makeAttributeAndRegister( 'max_cached_factorizations', 'useGPU', 'solver_type', 'solver_args', + 'preconditioner_args', 'left_preconditioner', 'Dirichlet_recombination', 'comm', @@ -229,10 +234,14 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs) rhs_hat = rhs.copy() if u0 is not None: u0_hat = self.Pr.T @ u0.copy().flatten() + else: + u0_hat = None else: rhs_hat = self.spectral.transform(rhs) if u0 is not None: u0_hat = self.Pr.T @ self.spectral.transform(u0).flatten() + else: + u0_hat = None if self.useGPU: self.xp.cuda.Device().synchronize() @@ -257,6 +266,23 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs) # plt.colorbar(im) # plt.show() + if 'ilu' in self.solver_type.lower(): + if dt not in self.cached_factorizations.keys(): + if len(self.cached_factorizations) >= self.max_cached_factorizations: + to_evict = list(self.cached_factorizations.keys())[0] + self.cached_factorizations.pop(to_evict) + self.logger.debug(f'Evicted matrix factorization for {to_evict=:.6f} from cache') + iLU = self.linalg.spilu( + A, **{**self.preconditioner_args, 'drop_tol': dt * self.preconditioner_args['drop_tol']} + ) + self.cached_factorizations[dt] = self.linalg.LinearOperator(A.shape, iLU.solve) + self.logger.debug(f'Cached incomplete LU factorization for {dt=:.6f}') + self.work_counters['factorizations']() + M = self.cached_factorizations[dt] + else: + M = None + info = 0 + if self.solver_type.lower() == 'cached_direct': if dt not in self.cached_factorizations.keys(): if len(self.cached_factorizations) >= self.max_cached_factorizations: @@ -271,15 +297,7 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs) elif self.solver_type.lower() == 'direct': _sol_hat = sp.linalg.spsolve(A, rhs_hat) - elif self.solver_type.lower() == 'lsqr': - lsqr = sp.linalg.lsqr( - A, - rhs_hat, - x0=u0_hat, - **self.solver_args, - ) - _sol_hat = lsqr[0] - elif self.solver_type.lower() == 'gmres': + elif 'gmres' in self.solver_type.lower(): _sol_hat, _ = sp.linalg.gmres( A, rhs_hat, @@ -287,36 +305,27 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs) **self.solver_args, callback=self.work_counters[self.solver_type], callback_type='pr_norm', + M=M, ) - elif self.solver_type.lower() == 'gmres+ilu': - linalg = self.spectral.linalg - - if dt not in self.cached_factorizations.keys(): - if len(self.cached_factorizations) >= self.max_cached_factorizations: - to_evict = list(self.cached_factorizations.keys())[0] - self.cached_factorizations.pop(to_evict) - self.logger.debug(f'Evicted matrix factorization for {to_evict=:.6f} from cache') - iLU = linalg.spilu(A, drop_tol=dt * 1e-4, fill_factor=100) - self.cached_factorizations[dt] = linalg.LinearOperator(A.shape, iLU.solve) - self.logger.debug(f'Cached matrix factorization for {dt=:.6f}') - self.work_counters['factorizations']() - - _sol_hat, _ = linalg.gmres( + elif self.solver_type.lower() == 'cg': + _sol_hat, info = sp.linalg.cg( + A, rhs_hat, x0=u0_hat, **self.solver_args, callback=self.work_counters[self.solver_type] + ) + elif 'bicgstab' in self.solver_type.lower(): + _sol_hat, info = self.linalg.bicgstab( A, rhs_hat, x0=u0_hat, **self.solver_args, callback=self.work_counters[self.solver_type], - callback_type='pr_norm', - M=self.cached_factorizations[dt], - ) - elif self.solver_type.lower() == 'cg': - _sol_hat, _ = sp.linalg.cg( - A, rhs_hat, x0=u0_hat, **self.solver_args, callback=self.work_counters[self.solver_type] + M=M, ) else: raise NotImplementedError(f'Solver {self.solver_type=} not implemented in {type(self).__name__}!') + if info != 0: + self.logger.warn(f'{self.solver_type} not converged! {info=}') + sol_hat = self.spectral.u_init_forward sol_hat[...] = (self.Pr @ _sol_hat).reshape(sol_hat.shape) diff --git a/pySDC/tests/test_problems/test_heat_chebychev.py b/pySDC/tests/test_problems/test_heat_chebychev.py index 8a5f150d7b..4c4a0ca3ce 100644 --- a/pySDC/tests/test_problems/test_heat_chebychev.py +++ b/pySDC/tests/test_problems/test_heat_chebychev.py @@ -8,7 +8,8 @@ @pytest.mark.parametrize('noise', [0, 1e-3]) @pytest.mark.parametrize('use_ultraspherical', [True, False]) @pytest.mark.parametrize('spectral_space', [True, False]) -def test_heat1d_chebychev(a, b, f, noise, use_ultraspherical, spectral_space, nvars=2**4): +@pytest.mark.parametrize('solver_type', ['cached_direct', 'direct', 'gmres', 'bicgstab', 'gmres+ilu', 'bicgstab+ilu']) +def test_heat1d_chebychev(a, b, f, noise, use_ultraspherical, spectral_space, solver_type, nvars=2**4): import numpy as np if use_ultraspherical: @@ -25,6 +26,8 @@ def test_heat1d_chebychev(a, b, f, noise, use_ultraspherical, spectral_space, nv left_preconditioner=False, debug=True, spectral_space=spectral_space, + solver_type=solver_type, + solver_args={'rtol': 1e-12}, ) u0 = P.u_exact(0, noise=noise)