Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,16 @@ def determine_restart(self, controller, S, **kwargs):
if self.get_convergence(controller, S, **kwargs):
self.res_last_iter = np.inf

if self.params.restart_at_maxiter and S.levels[0].status.residual > S.levels[0].params.restol:
L = S.levels[0]
e_tol_converged = (
L.status.increment < L.params.e_tol if (L.params.get('e_tol') and L.status.get('increment')) else False
)

if (
self.params.restart_at_maxiter
and S.levels[0].status.residual > S.levels[0].params.restol
and not e_tol_converged
):
self.trigger_restart_upon_nonconvergence(S)
elif self.get_local_error_estimate(controller, S, **kwargs) > self.params.e_tol:
S.status.restart = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ def check_convergence(S, self=None):
iter_converged = S.status.iter >= S.params.maxiter
res_converged = L.status.residual <= L.params.restol
e_tol_converged = (
L.status.error_embedded_estimate < L.params.e_tol
if (L.params.get('e_tol') and L.status.get('error_embedded_estimate'))
else False
L.status.increment < L.params.e_tol if (L.params.get('e_tol') and L.status.get('increment')) else False
)
converged = (
iter_converged or res_converged or e_tol_converged or S.status.force_done
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,13 @@ def estimate_embedded_error_serial(self, L):

def setup_status_variables(self, controller, **kwargs):
"""
Add the embedded error variable to the error function.
Add the embedded error to the level status

Args:
controller (pySDC.Controller): The controller
"""
self.add_status_variable_to_level('error_embedded_estimate')
self.add_status_variable_to_level('increment')

def post_iteration_processing(self, controller, S, **kwargs):
"""
Expand All @@ -134,6 +135,7 @@ def post_iteration_processing(self, controller, S, **kwargs):
if S.status.iter > 0 or self.params.sweeper_type == "RK":
for L in S.levels:
L.status.error_embedded_estimate = max([self.estimate_embedded_error_serial(L), np.finfo(float).eps])
L.status.increment = L.status.error_embedded_estimate * 1
self.debug(f'L.status.error_embedded_estimate={L.status.error_embedded_estimate:.5e}', S)

return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def post_iteration_processing(self, controller, S, **kwargs):
if self.comm:
buf = np.array(abs(u_inter - high_order_sol) if self.comm.rank == rank else 0.0)
self.comm.Bcast(buf, root=rank)
L.status.error_embedded_estimate = buf
L.status.error_embedded_estimate = float(buf)
else:
L.status.error_embedded_estimate = abs(u_inter - high_order_sol)

Expand Down
Loading