Skip to content

Commit

Permalink
Added class to handle various setup statuses
Browse files Browse the repository at this point in the history
  • Loading branch information
DKilkenny committed Aug 10, 2020
1 parent 39618c2 commit 8449a37
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 31 deletions.
32 changes: 23 additions & 9 deletions openmdao/core/constants.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
"""File to define different setup status constants."""

# PRE_SETUP: Newly initialized problem or newly added model.
# POST_CONFIGURE: Configure has been called.
# POST_SETUP: The `setup` method has been called, but vectors not initialized.
# POST_FINAL_SETUP: The `final_setup` has been run, everything ready to run.

PRE_SETUP = 0
POST_CONFIGURE = 1
POST_SETUP = 2
POST_FINAL_SETUP = 3
from enum import IntEnum


class _SetupStatus(IntEnum):
"""
Class used to define different states of the setup status.
Attributes
----------
PRE_SETUP : int
Newly initialized problem or newly added model.
POST_CONFIGURE : int
Configure has been called.
POST_SETUP : int
The `setup` method has been called, but vectors not initialized.
POST_FINAL_SETUP : int
The `final_setup` has been run, everything ready to run.
"""

PRE_SETUP = 0
POST_CONFIGURE = 1
POST_SETUP = 2
POST_FINAL_SETUP = 3
9 changes: 5 additions & 4 deletions openmdao/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from openmdao.utils.coloring import Coloring, _STD_COLORING_FNAME
import openmdao.utils.coloring as coloring_mod
from openmdao.utils.options_dictionary import _undefined
from openmdao.core.constants import PRE_SETUP, POST_CONFIGURE
from openmdao.core.constants import _SetupStatus

# regex to check for valid names.
import re
Expand Down Expand Up @@ -345,7 +345,7 @@ def _configure(self):
if subsys.matrix_free:
self.matrix_free = True

self._problem_meta['_setup_status'] = POST_CONFIGURE
self._problem_meta['_setup_status'] = _SetupStatus.POST_CONFIGURE
self.configure()

def _setup_procs(self, pathname, comm, mode, prob_meta):
Expand Down Expand Up @@ -1902,7 +1902,8 @@ def set_order(self, new_order):
new_order : list of str
List of system names in desired new execution order.
"""
if self._problem_meta is not None and self._problem_meta['_setup_status'] == POST_CONFIGURE:
if self._problem_meta is not None and \
self._problem_meta['_setup_status'] == _SetupStatus.POST_CONFIGURE:
raise RuntimeError("%s: Cannot call set_order in the configure method" % (self.msginfo))

# Make sure the new_order is valid. It must contain all subsystems
Expand Down Expand Up @@ -1940,7 +1941,7 @@ def set_order(self, new_order):

self._order_set = True
if self._problem_meta is not None:
self._problem_meta['_setup_status'] = PRE_SETUP
self._problem_meta['_setup_status'] = _SetupStatus.PRE_SETUP

def _get_subsystem(self, name):
"""
Expand Down
32 changes: 16 additions & 16 deletions openmdao/core/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from openmdao.utils.options_dictionary import OptionsDictionary, _undefined
from openmdao.utils.units import convert_units
from openmdao.utils import coloring as coloring_mod
from openmdao.core.constants import PRE_SETUP, POST_CONFIGURE, POST_SETUP, POST_FINAL_SETUP
from openmdao.core.constants import _SetupStatus
from openmdao.utils.name_maps import abs_key2rel_key
from openmdao.vectors.vector import _full_slice, INT_DTYPE
from openmdao.vectors.default_vector import DefaultVector
Expand Down Expand Up @@ -367,7 +367,7 @@ def get_val(self, name, units=None, indices=None, get_remote=False):
object
The value of the requested output/input variable.
"""
if self._metadata['_setup_status'] == POST_SETUP:
if self._metadata['_setup_status'] == _SetupStatus.POST_SETUP:
val = self._get_cached_val(name, get_remote=get_remote)
if val is not _undefined:
if indices is not None:
Expand Down Expand Up @@ -467,7 +467,7 @@ def set_val(self, name, value, units=None, indices=None):

if units is None:
# avoids double unit conversion
if self._metadata['_setup_status'] > POST_SETUP:
if self._metadata['_setup_status'] > _SetupStatus.POST_SETUP:
ivalue = value
if sunits is not None:
if gunits is not None and gunits != tunits:
Expand All @@ -479,7 +479,7 @@ def set_val(self, name, value, units=None, indices=None):
ivalue = model.convert_from_units(abs_name, value, units)
else:
ivalue = model.convert_units(name, value, units, gunits)
if self._metadata['_setup_status'] == POST_SETUP:
if self._metadata['_setup_status'] == _SetupStatus.POST_SETUP:
value = ivalue
else:
value = model.convert_from_units(src, value, units)
Expand All @@ -489,7 +489,7 @@ def set_val(self, name, value, units=None, indices=None):
value = model.convert_from_units(abs_name, value, units)

# Caching only needed if vectors aren't allocated yet.
if self._metadata['_setup_status'] == POST_SETUP:
if self._metadata['_setup_status'] == _SetupStatus.POST_SETUP:
if indices is not None:
self._get_cached_val(name)
try:
Expand Down Expand Up @@ -752,7 +752,7 @@ def record(self, case_name):
case_name : str
Name used to identify this Problem case.
"""
if self._metadata['_setup_status'] < POST_FINAL_SETUP:
if self._metadata['_setup_status'] < _SetupStatus.POST_FINAL_SETUP:
raise RuntimeError(f"{self.msginfo}: Problem.record() cannot be called before "
"`Problem.run_model()`, `Problem.run_driver()`, or "
"`Problem.final_setup()`.")
Expand Down Expand Up @@ -864,15 +864,15 @@ def setup(self, check=False, logger=None, mode='auto', force_alloc_complex=False
'remote_systems': {},
'remote_vars': {}, # does not include distrib vars
'prom2abs': {'input': {}, 'output': {}}, # includes ALL promotes including buried ones
'_setup_status': PRE_SETUP
'_setup_status': _SetupStatus.PRE_SETUP
}
model._setup(model_comm, mode, self._metadata)

# Cache all args for final setup.
self._check = check
self._logger = logger

self._metadata['_setup_status'] = POST_SETUP
self._metadata['_setup_status'] = _SetupStatus.POST_SETUP

return self

Expand All @@ -897,7 +897,7 @@ def final_setup(self):
else:
mode = self._orig_mode

if self._metadata['_setup_status'] < POST_FINAL_SETUP:
if self._metadata['_setup_status'] < _SetupStatus.POST_FINAL_SETUP:
self.model._final_setup(self.comm)

driver._setup_driver(self)
Expand All @@ -924,20 +924,20 @@ def final_setup(self):
"(objectives and nonlinear constraints)." %
(mode, desvar_size, response_size), RuntimeWarning)

if self._metadata['_setup_status'] == PRE_SETUP and \
if self._metadata['_setup_status'] == _SetupStatus.PRE_SETUP and \
hasattr(self.model, '_order_set') and self.model._order_set:
raise RuntimeError("%s: Cannot call set_order without calling "
"setup after" % (self.msginfo))

# we only want to set up recording once, after problem setup
if self._metadata['_setup_status'] == POST_SETUP:
if self._metadata['_setup_status'] == _SetupStatus.POST_SETUP:
driver._setup_recording()
self._setup_recording()
record_viewer_data(self)
record_system_options(self)

if self._metadata['_setup_status'] < POST_FINAL_SETUP:
self._metadata['_setup_status'] = POST_FINAL_SETUP
if self._metadata['_setup_status'] < _SetupStatus.POST_FINAL_SETUP:
self._metadata['_setup_status'] = _SetupStatus.POST_FINAL_SETUP
self._set_initial_conditions()

if self._check:
Expand Down Expand Up @@ -1009,7 +1009,7 @@ def check_partials(self, out_stream=_DEFAULT_OUT_STREAM, includes=None, excludes
For 'J_fd', 'J_fwd', 'J_rev' the value is: A numpy array representing the computed
Jacobian for the three different methods of computation.
"""
if self._metadata['_setup_status'] < POST_FINAL_SETUP:
if self._metadata['_setup_status'] < _SetupStatus.POST_FINAL_SETUP:
self.final_setup()

model = self.model
Expand Down Expand Up @@ -1427,7 +1427,7 @@ def check_totals(self, of=None, wrt=None, out_stream=_DEFAULT_OUT_STREAM, compac
For 'rel error', 'abs error', 'magnitude' the value is: A tuple containing norms for
forward - fd, adjoint - fd, forward - adjoint.
"""
if self._metadata['_setup_status'] < POST_FINAL_SETUP:
if self._metadata['_setup_status'] < _SetupStatus.POST_FINAL_SETUP:
raise RuntimeError(self.msginfo + ": run_model must be called before total "
"derivatives can be checked.")

Expand Down Expand Up @@ -1533,7 +1533,7 @@ def compute_totals(self, of=None, wrt=None, return_format='flat_dict', debug_pri
derivs : object
Derivatives in form requested by 'return_format'.
"""
if self._metadata['_setup_status'] < POST_FINAL_SETUP:
if self._metadata['_setup_status'] < _SetupStatus.POST_FINAL_SETUP:
self.final_setup()

if self.model._owns_approx_jac:
Expand Down
4 changes: 2 additions & 2 deletions openmdao/devtools/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from openmdao.utils.mpi import MPI
from openmdao.utils.name_maps import abs_key2rel_key, rel_key2abs_key
from openmdao.utils.general_utils import simple_warning
from openmdao.core.constants import POST_FINAL_SETUP
from openmdao.core.constants import _SetupStatus

# an object used to detect when a named value isn't found
_notfound = object()
Expand Down Expand Up @@ -235,7 +235,7 @@ def config_summary(problem, stream=sys.stdout):
if s.nonlinear_solver is not None]

max_depth = max([len(name.split('.')) for name in sysnames])
setup_done = model._problem_meta['_setup_status'] == POST_FINAL_SETUP
setup_done = model._problem_meta['_setup_status'] == _SetupStatus.POST_FINAL_SETUP

if problem.comm.size > 1:
local_max = np.array([max_depth])
Expand Down

0 comments on commit 8449a37

Please sign in to comment.