Skip to content

Commit

Permalink
solve_bicgstab: cut use of s (#3629)
Browse files Browse the repository at this point in the history
## Summary

The MF named `s` seems to be an unnecessary usage of memory as the
residue `r` can fulfill its roles in the algorithm. This PR replaces `s`
with `r` and `LinComb` with `Saxpy` as appropriate.

## Additional background

This PR is part of a larger request to improve `solve_bicgstab` and
`solve_cg`.
  • Loading branch information
eebasso committed Nov 17, 2023
1 parent d93f344 commit 175b99d
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions Src/LinearSolvers/MLMG/AMReX_MLCGSolver.H
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
MF sorig = Lp.make(amrlev, mglev, nghost);
MF p = Lp.make(amrlev, mglev, nghost);
MF r = Lp.make(amrlev, mglev, nghost);
MF s = Lp.make(amrlev, mglev, nghost);
MF rh = Lp.make(amrlev, mglev, nghost);
MF v = Lp.make(amrlev, mglev, nghost);
MF t = Lp.make(amrlev, mglev, nghost);
Expand Down Expand Up @@ -166,9 +165,9 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
ret = 2; break;
}
MF::Saxpy(sol, alpha, ph, 0, 0, ncomp, nghost); // sol += alpha * ph
MF::LinComb(s, RT(1.0), r, 0, -alpha, v, 0, 0, ncomp, nghost); // s = r - alpha * v
MF::Saxpy(r, -alpha, v, 0, 0, ncomp, nghost); // r += -alpha * v

rnorm = norm_inf(s);
rnorm = norm_inf(r);

if ( verbose > 2 && ParallelDescriptor::IOProcessor() )
{
Expand All @@ -180,15 +179,15 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)

if ( rnorm < eps_rel*rnorm0 || rnorm < eps_abs ) { break; }

sh.LocalCopy(s,0,0,ncomp,nghost);
sh.LocalCopy(r,0,0,ncomp,nghost);
Lp.apply(amrlev, mglev, t, sh, MLLinOpT<MF>::BCMode::Homogeneous, MLLinOpT<MF>::StateMode::Correction);
Lp.normalize(amrlev, mglev, t);
//
// This is a little funky. I want to elide one of the reductions
// in the following two dotxy()s. We do that by calculating the "local"
// values and then reducing the two local values at the same time.
//
RT tvals[2] = { dotxy(t,t,true), dotxy(t,s,true) };
RT tvals[2] = { dotxy(t,t,true), dotxy(t,r,true) };

BL_PROFILE_VAR("MLCGSolver::ParallelAllReduce", blp_par);
ParallelAllReduce::Sum(tvals,2,Lp.BottomCommunicator());
Expand All @@ -203,7 +202,7 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
ret = 3; break;
}
MF::Saxpy(sol, omega, sh, 0, 0, ncomp, nghost); // sol += omega * sh
MF::LinComb(r, RT(1.0), s, 0, -omega, t, 0, 0, ncomp, nghost); // r = s - omega * t
MF::Saxpy(r, -omega, t, 0, 0, ncomp, nghost); // r += -omega * t

rnorm = norm_inf(r);

Expand Down

0 comments on commit 175b99d

Please sign in to comment.