-
Notifications
You must be signed in to change notification settings - Fork 240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Case Recording Alpha #296
Case Recording Alpha #296
Changes from all commits
d588b95
6cafc36
43c3a38
7ce1fc5
1fba0fb
62e4a80
82f2da3
85b4a8a
6cf5c1d
19ef82c
0177cf8
b490841
f38b6ac
dc4cdd1
6508185
e3efe17
41b1409
83e7ef4
61d3cee
6fee221
e0fb8bd
710c0ad
f1f6866
d78aff3
b0e84d2
b95d268
2623db2
025de39
671df6a
e470947
6733675
9f95d10
4df4188
bec4575
11378d2
ce77ce6
975eda2
744b53c
00b86ca
1a005d0
4e034d7
b27d423
a5454cb
49f7326
11b6016
795e041
b2ca65e
d899f3c
14441f8
66817ea
50a19ec
1a76ddb
2d240d2
05b12e0
7a4b644
ed7feef
6e03358
2d50d73
27074ba
09fe794
00ae93a
106805d
2d2f7a9
7a08369
2a79c13
45e209b
5c2cd4e
92be28a
b458018
e18484d
164659e
f60b9e4
d237058
60fbbbc
e243cf2
1cd23c9
ec4e05a
4218ec7
660a64a
f7aa8d8
0a8dc07
3316d67
967d1a5
aece01a
9431dff
6c30250
2fe423d
cb58ae0
5987a59
093c913
3bf5247
e490bb6
706b996
788eff1
ff81a0c
6f5fb9c
a9a27fc
09f28de
afaec88
f663696
13bd26b
d439789
f4ce837
700a243
972ae35
0834eb4
7e57ca5
0c59aa1
a55324c
820f64d
234bbc4
3d8031e
ef2dead
f20d47a
d6527c2
f27535b
f4b5e66
4056313
0d24da2
7e7016b
0c98da5
9d18fda
46ee837
4a6c9cb
bc6511e
08084b1
71218ad
ea596f1
06ffa22
b38d55c
109357b
f2f5714
da1825f
c401f0b
336b57d
e0cbc45
06550b6
084cdf2
87b8a3b
7f3817b
237b996
9acc28d
9c2aef2
214874a
eaa5d78
2cae507
c3b12ac
8c767b0
13ada99
892790e
4d251e8
9606afb
b9244fc
e5f51f1
1134b35
93bc36e
c48ae20
24086ad
85676f0
5d2855a
facaff9
07a8b5c
fa66ad5
9646b0e
eab1c50
ea82069
29da032
0796f3a
797e798
6b2c3b5
54d525f
a7d4c49
4402b35
7ca9c49
5c6d60a
540ea8d
3c9dde8
5d6c0a0
60c0672
99f3b61
9418f83
9f3f3a7
2b47713
45a0c96
b1f43b5
5a2d2ca
bca720e
13a5042
f8a276a
6bfd5d5
d382c5e
c1dc504
0424f2f
f6b0595
1be7912
1d70138
8430b3a
1845dfc
70c03ea
d6590f3
4911653
56ba386
38786d6
281dc47
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,12 @@ | ||
"""Define a base class for all Drivers in OpenMDAO.""" | ||
|
||
from six import iteritems | ||
|
||
import numpy as np | ||
|
||
from openmdao.utils.record_util import create_local_meta | ||
from openmdao.utils.options_dictionary import OptionsDictionary | ||
from openmdao.recorders.recording_manager import RecordingManager | ||
from openmdao.recorders.recording_iteration_stack import Recording | ||
|
||
|
||
class Driver(object): | ||
|
@@ -15,6 +17,10 @@ class Driver(object): | |
---------- | ||
fail : bool | ||
Reports whether the driver ran successfully. | ||
iter_count : int | ||
Keep track of iterations for case recording. | ||
metadata : list | ||
List of metadata | ||
options : <OptionsDictionary> | ||
Dictionary with general pyoptsparse options. | ||
_problem : <Problem> | ||
|
@@ -29,12 +35,18 @@ class Driver(object): | |
Contains all objective info. | ||
_responses : dict | ||
Contains all response info. | ||
_rec_mgr : <RecordingManager> | ||
Object that manages all recorders added to this driver. | ||
_model_viewer_data : dict | ||
Structure of model, used to make n2 diagram. | ||
""" | ||
|
||
def __init__(self): | ||
""" | ||
Initialize the driver. | ||
""" | ||
self._rec_mgr = RecordingManager() | ||
|
||
self._problem = None | ||
self._designvars = None | ||
self._cons = None | ||
|
@@ -53,11 +65,32 @@ def __init__(self): | |
self.supports.declare('gradients', type_=bool, default=False) | ||
self.supports.declare('active_set', type_=bool, default=False) | ||
|
||
self.iter_count = 0 | ||
self.metadata = None | ||
self._model_viewer_data = None | ||
|
||
# TODO, support these in Openmdao blue | ||
self.supports.declare('integer_design_vars', type_=bool, default=False) | ||
|
||
self.fail = False | ||
|
||
def add_recorder(self, recorder): | ||
""" | ||
Add a recorder to the driver. | ||
|
||
Parameters | ||
---------- | ||
recorder : BaseRecorder | ||
A recorder instance. | ||
""" | ||
self._rec_mgr.append(recorder) | ||
|
||
def cleanup(self): | ||
""" | ||
Clean up resources prior to exit. | ||
""" | ||
self._rec_mgr.close() | ||
|
||
def _setup_driver(self, problem): | ||
""" | ||
Prepare the driver for execution. | ||
|
@@ -79,20 +112,42 @@ def _setup_driver(self, problem): | |
self._objs = model.get_objectives(recurse=True) | ||
self._cons = model.get_constraints(recurse=True) | ||
|
||
def get_design_var_values(self): | ||
self._rec_mgr.startup(self) | ||
if (self._rec_mgr._recorders): | ||
from openmdao.devtools.problem_viewer.problem_viewer import _get_viewer_data | ||
self._model_viewer_data = _get_viewer_data(problem) | ||
self._rec_mgr.record_metadata(self) | ||
|
||
def get_design_var_values(self, filter=None): | ||
""" | ||
Return the design variable values. | ||
|
||
This is called to gather the initial design variable state. | ||
|
||
Parameters | ||
---------- | ||
filter : list | ||
List of desvar names used by recorders. | ||
|
||
Returns | ||
------- | ||
dict | ||
Dictionary containing values of each design variable. | ||
""" | ||
designvars = {} | ||
|
||
if filter: | ||
# pull out designvars of those names into filtered dict. | ||
for inc in filter: | ||
designvars[inc] = self._designvars[inc] | ||
|
||
else: | ||
# use all the designvars | ||
designvars = self._designvars | ||
|
||
vec = self._problem.model._outputs._views_flat | ||
dv_dict = {} | ||
for name, meta in iteritems(self._designvars): | ||
for name, meta in iteritems(designvars): | ||
scaler = meta['scaler'] | ||
adder = meta['adder'] | ||
indices = meta['indices'] | ||
|
@@ -138,10 +193,15 @@ def set_design_var(self, name, value): | |
if adder is not None: | ||
desvar[indices] -= adder | ||
|
||
def get_response_values(self): | ||
def get_response_values(self, filter=None): | ||
""" | ||
Return response values. | ||
|
||
Parameters | ||
---------- | ||
filter : list | ||
List of response names used by recorders. | ||
|
||
Returns | ||
------- | ||
dict | ||
|
@@ -150,18 +210,34 @@ def get_response_values(self): | |
# TODO: finish this method when we have a driver that requires it. | ||
pass | ||
|
||
def get_objective_values(self): | ||
def get_objective_values(self, filter=None): | ||
""" | ||
Return objective values. | ||
|
||
Parameters | ||
---------- | ||
filter : list | ||
List of objective names used by recorders. | ||
|
||
Returns | ||
------- | ||
dict | ||
Dictionary containing values of each objective. | ||
""" | ||
objectives = {} | ||
|
||
if filter: | ||
# pull out objectives of those names into filtered dict. | ||
for inc in filter: | ||
objectives[inc] = self._objs[inc] | ||
|
||
else: | ||
# use all the objectives | ||
objectives = self._objs | ||
|
||
vec = self._problem.model._outputs._views_flat | ||
obj_dict = {} | ||
for name, meta in iteritems(self._objs): | ||
for name, meta in iteritems(objectives): | ||
scaler = meta['scaler'] | ||
adder = meta['adder'] | ||
indices = meta['indices'] | ||
|
@@ -180,7 +256,7 @@ def get_objective_values(self): | |
|
||
return obj_dict | ||
|
||
def get_constraint_values(self, ctype='all', lintype='all'): | ||
def get_constraint_values(self, ctype='all', lintype='all', filter=None): | ||
""" | ||
Return constraint values. | ||
|
||
|
@@ -194,15 +270,29 @@ def get_constraint_values(self, ctype='all', lintype='all'): | |
Default is 'all'. Optionally return just the linear constraints | ||
with 'linear' or the nonlinear constraints with 'nonlinear'. | ||
|
||
filter : list | ||
List of objective names used by recorders. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For this method in the doctstring, objective --> constraint There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ARGGGH, stupid copy/paste errors! Fixed locally. |
||
|
||
Returns | ||
------- | ||
dict | ||
Dictionary containing values of each constraint. | ||
""" | ||
constraints = {} | ||
|
||
if filter is not None: | ||
# pull out objectives of those names into filtered dict. | ||
for inc in filter: | ||
constraints[inc] = self._cons[inc] | ||
|
||
else: | ||
# use all the objectives | ||
constraints = self._cons | ||
|
||
vec = self._problem.model._outputs._views_flat | ||
con_dict = {} | ||
|
||
for name, meta in iteritems(self._cons): | ||
for name, meta in iteritems(constraints): | ||
|
||
if lintype == 'linear' and meta['linear'] is False: | ||
continue | ||
|
@@ -234,7 +324,6 @@ def get_constraint_values(self, ctype='all', lintype='all'): | |
# TODO: Need to get the allgathered values? Like: | ||
# cons[name] = self._get_distrib_var(name, meta, 'constraint') | ||
con_dict[name] = val | ||
|
||
return con_dict | ||
|
||
def run(self): | ||
|
@@ -249,7 +338,11 @@ def run(self): | |
boolean | ||
Failure flag; True if failed to converge, False is successful. | ||
""" | ||
return self._problem.model._solve_nonlinear() | ||
with Recording(self._get_name(), self.iter_count, self) as rec: | ||
failure_flag = self._problem.model._solve_nonlinear() | ||
|
||
self.iter_count += 1 | ||
return failure_flag | ||
|
||
def _compute_total_derivs(self, of=None, wrt=None, return_format='flat_dict', | ||
global_names=True): | ||
|
@@ -376,3 +469,21 @@ def get_req_procs(self, model): | |
max_procs can be None, indicating all available procs can be used. | ||
""" | ||
return model.get_req_procs() | ||
|
||
def record_iteration(self): | ||
""" | ||
Record an iteration of the current Driver. | ||
""" | ||
metadata = create_local_meta(self._get_name()) | ||
self._rec_mgr.record_iteration(self, metadata) | ||
|
||
def _get_name(self): | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why a method instead of an attribute? I guess it's no big deal though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can switch if you'd prefer. |
||
Get name of current Driver. | ||
|
||
Returns | ||
------- | ||
str | ||
Name of current Driver. | ||
""" | ||
return "Driver" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, the driver recorder uses this method? I guess that is okay, as long as the user is aware that this is the driver-scaled value. Still, that is probalby appropriate for recording on the driver. (same for cons and objs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, before we record, we're allowing the users to filter design vars (cons and objs too), using include/excludes. So before we record, we grab the values of only the ones that they want. Do you feel like we need to do something with the scaling here before we record? (We record scaling factors in driver metadata, for what that's worth.)