Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes a relevance related bug in the krylov solvers that was introduced in a recent relevance PR #3165

Merged
merged 4 commits into from Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 5 additions & 3 deletions openmdao/core/group.py
Expand Up @@ -3437,7 +3437,8 @@ def _solve_nonlinear(self):
name = self.pathname if self.pathname else 'root'

with Recording(name + '._solve_nonlinear', self.iter_count, self):
self._nonlinear_solver._solve_with_cache_check()
with self._relevance.active(self._nonlinear_solver.use_relevance()):
self._nonlinear_solver._solve_with_cache_check()

# Iteration counter is incremented in the Recording context manager at exit.

Expand Down Expand Up @@ -3638,7 +3639,8 @@ def _solve_linear(self, mode, scope_out=_UNDEFINED, scope_in=_UNDEFINED):
d_residuals *= -1.0
else:
self._linear_solver._set_matvec_scope(scope_out, scope_in)
self._linear_solver.solve(mode, None)
with self._relevance.active(self._linear_solver.use_relevance()):
self._linear_solver.solve(mode, None)

def _linearize(self, jac, sub_do_ln=True):
"""
Expand Down Expand Up @@ -3677,7 +3679,7 @@ def _linearize(self, jac, sub_do_ln=True):
jac = self._assembled_jac

relevance = self._relevance
with relevance.active(self.linear_solver.use_relevance()):
with relevance.active(self._linear_solver.use_relevance()):
subs = list(relevance.filter(self._subsystems_myproc))

# Only linearize subsystems if we aren't approximating the derivs at this level.
Expand Down
4 changes: 2 additions & 2 deletions openmdao/solvers/linear/direct.py
Expand Up @@ -228,12 +228,12 @@ def _linearize_children(self):

def use_relevance(self):
"""
Return True if relevance is should be active.
Return True if relevance should be active.

Returns
-------
bool
True if relevance is should be active.
True if relevance should be active.
"""
return False

Expand Down
11 changes: 11 additions & 0 deletions openmdao/solvers/linear/petsc_ksp.py
Expand Up @@ -228,6 +228,17 @@ def _assembled_jac_solver_iter(self):
for s in self.precon._assembled_jac_solver_iter():
yield s

def use_relevance(self):
"""
Return True if relevance should be active.

Returns
-------
bool
True if relevance should be active.
"""
return False

def _setup_solvers(self, system, depth):
"""
Assign system instance, set depth, and optionally perform setup.
Expand Down
11 changes: 11 additions & 0 deletions openmdao/solvers/linear/scipy_iter_solver.py
Expand Up @@ -277,3 +277,14 @@ def _apply_precon(self, in_vec):

# return resulting value of x vector
return x_vec.asarray(copy=True)

def use_relevance(self):
"""
Return True if relevance should be active.

Returns
-------
bool
True if relevance should be active.
"""
return False
4 changes: 2 additions & 2 deletions openmdao/solvers/nonlinear/newton.py
Expand Up @@ -281,11 +281,11 @@ def cleanup(self):

def use_relevance(self):
"""
Return True if relevance is should be active.
Return True if relevance should be active.

Returns
-------
bool
True if relevance is should be active.
True if relevance should be active.
"""
return False
166 changes: 82 additions & 84 deletions openmdao/solvers/solver.py
Expand Up @@ -550,12 +550,12 @@ def get_reports_dir(self):

def use_relevance(self):
"""
Return True if relevance is should be active.
Return True if relevance should be active.

Returns
-------
bool
True if relevance is should be active.
True if relevance should be active.
"""
return True

Expand Down Expand Up @@ -698,83 +698,82 @@ def _solve(self):
"""
system = self._system()

with system._relevance.active(self.use_relevance()):
maxiter = self.options['maxiter']
atol = self.options['atol']
rtol = self.options['rtol']
iprint = self.options['iprint']
stall_limit = self.options['stall_limit']
stall_tol = self.options['stall_tol']
stall_tol_type = self.options['stall_tol_type']
maxiter = self.options['maxiter']
atol = self.options['atol']
rtol = self.options['rtol']
iprint = self.options['iprint']
stall_limit = self.options['stall_limit']
stall_tol = self.options['stall_tol']
stall_tol_type = self.options['stall_tol_type']

self._mpi_print_header()
self._mpi_print_header()

self._iter_count = 0
norm0, norm = self._iter_initialize()
self._iter_count = 0
norm0, norm = self._iter_initialize()

self._norm0 = norm0
self._norm0 = norm0

self._mpi_print(self._iter_count, norm, norm / norm0)
self._mpi_print(self._iter_count, norm, norm / norm0)

stalled = False
stall_count = 0
if stall_limit > 0:
stall_norm = norm0

stalled = False
stall_count = 0
if stall_limit > 0:
stall_norm = norm0
force_one_iteration = system.under_complex_step

force_one_iteration = system.under_complex_step
while ((self._iter_count < maxiter and norm > atol and norm / norm0 > rtol and
not stalled) or force_one_iteration):

while ((self._iter_count < maxiter and norm > atol and norm / norm0 > rtol and
not stalled) or force_one_iteration):
if system.under_complex_step:
force_one_iteration = False

if system.under_complex_step:
force_one_iteration = False
with Recording(type(self).__name__, self._iter_count, self) as rec:
ls = self.linesearch
if stall_count == 3 and ls and not ls.options['print_bound_enforce']:

with Recording(type(self).__name__, self._iter_count, self) as rec:
ls = self.linesearch
if stall_count == 3 and ls and not ls.options['print_bound_enforce']:
self.linesearch.options['print_bound_enforce'] = True

self.linesearch.options['print_bound_enforce'] = True
if self._system().pathname:
pathname = f"{self._system().pathname}."
else:
pathname = ""

if self._system().pathname:
pathname = f"{self._system().pathname}."
else:
pathname = ""
msg = ("Your model has stalled three times and may be violating the bounds."
" In the future, turn on print_bound_enforce in your solver options "
f"here: \n{pathname}nonlinear_solver.linesearch.options"
"['print_bound_enforce']=True. \nThe bound(s) being violated now "
"are:\n")
issue_warning(msg, category=SolverWarning)

msg = ("Your model has stalled three times and may be violating the bounds."
" In the future, turn on print_bound_enforce in your solver options "
f"here: \n{pathname}nonlinear_solver.linesearch.options"
"['print_bound_enforce']=True. \nThe bound(s) being violated now "
"are:\n")
issue_warning(msg, category=SolverWarning)
self._single_iteration()
self.linesearch.options['print_bound_enforce'] = False
else:
self._single_iteration()

self._single_iteration()
self.linesearch.options['print_bound_enforce'] = False
self._iter_count += 1
self._run_apply()
norm = self._iter_get_norm()

# Save the norm values in the context manager so they can also be recorded.
rec.abs = norm
if norm0 == 0:
norm0 = 1
rec.rel = norm / norm0

# Check if convergence is stalled.
if stall_limit > 0:
norm_for_stall = rec.rel if stall_tol_type == 'rel' else rec.abs
norm_diff = np.abs(stall_norm - norm_for_stall)
if norm_diff <= stall_tol:
stall_count += 1
if stall_count >= stall_limit:
stalled = True
else:
self._single_iteration()

self._iter_count += 1
self._run_apply()
norm = self._iter_get_norm()

# Save the norm values in the context manager so they can also be recorded.
rec.abs = norm
if norm0 == 0:
norm0 = 1
rec.rel = norm / norm0

# Check if convergence is stalled.
if stall_limit > 0:
norm_for_stall = rec.rel if stall_tol_type == 'rel' else rec.abs
norm_diff = np.abs(stall_norm - norm_for_stall)
if norm_diff <= stall_tol:
stall_count += 1
if stall_count >= stall_limit:
stalled = True
else:
stall_count = 0
stall_norm = norm_for_stall

self._mpi_print(self._iter_count, norm, norm / norm0)
stall_count = 0
stall_norm = norm_for_stall

self._mpi_print(self._iter_count, norm, norm / norm0)

# flag for the print statements. we only print on root if USE_PROC_FILES is not set to True
print_flag = system.comm.rank == 0 or os.environ.get('USE_PROC_FILES')
Expand Down Expand Up @@ -1018,33 +1017,32 @@ def _solve(self):
rtol = self.options['rtol']
iprint = self.options['iprint']

with self._system()._relevance.active(self.use_relevance()):
self._mpi_print_header()
self._mpi_print_header()

self._iter_count = 0
norm0, norm = self._iter_initialize()
self._iter_count = 0
norm0, norm = self._iter_initialize()

self._norm0 = norm0
self._norm0 = norm0

system = self._system()
system = self._system()

self._mpi_print(self._iter_count, norm, norm / norm0)
self._mpi_print(self._iter_count, norm, norm / norm0)

while self._iter_count < maxiter and norm > atol and norm / norm0 > rtol:
while self._iter_count < maxiter and norm > atol and norm / norm0 > rtol:

with Recording(type(self).__name__, self._iter_count, self) as rec:
self._single_iteration()
self._iter_count += 1
self._run_apply()
norm = self._iter_get_norm()
with Recording(type(self).__name__, self._iter_count, self) as rec:
self._single_iteration()
self._iter_count += 1
self._run_apply()
norm = self._iter_get_norm()

# Save the norm values in the context manager so they can also be recorded.
rec.abs = norm
if norm0 == 0:
norm0 = 1
rec.rel = norm / norm0
# Save the norm values in the context manager so they can also be recorded.
rec.abs = norm
if norm0 == 0:
norm0 = 1
rec.rel = norm / norm0

self._mpi_print(self._iter_count, norm, norm / norm0)
self._mpi_print(self._iter_count, norm, norm / norm0)

# flag for the print statements. we only print on root if USE_PROC_FILES is not set to True
print_flag = system.comm.rank == 0 or os.environ.get('USE_PROC_FILES')
Expand Down