Skip to content

Commit

Permalink
Merge pull request #707 from hschilling/P158975747-scaling-case-reader
Browse files Browse the repository at this point in the history
P158975747 user can get scaled or unscaled values from driver cases using case reader
  • Loading branch information
Keith Marsteller committed Aug 1, 2018
2 parents b2b8275 + 6382fc9 commit 0ae224a
Show file tree
Hide file tree
Showing 11 changed files with 292 additions and 47 deletions.
6 changes: 3 additions & 3 deletions openmdao/core/driver.py
Expand Up @@ -782,17 +782,17 @@ def record_iteration(self):
filt = self._filtered_vars_to_record

if opts['record_desvars']:
des_vars = self.get_design_var_values()
des_vars = self.get_design_var_values(unscaled=True)
else:
des_vars = {}

if opts['record_objectives']:
obj_vars = self.get_objective_values()
obj_vars = self.get_objective_values(unscaled=True)
else:
obj_vars = {}

if opts['record_constraints']:
con_vars = self.get_constraint_values()
con_vars = self.get_constraint_values(unscaled=True)
else:
con_vars = {}

Expand Down
22 changes: 12 additions & 10 deletions openmdao/core/problem.py
Expand Up @@ -1563,18 +1563,20 @@ def load_case(self, case):
A Case from a CaseRecorder file.
"""
inputs = case.inputs._values if case.inputs is not None else None
for name, val in zip(inputs.dtype.names, inputs):
if name not in self.model._var_abs_names['input']:
raise KeyError("Input variable, '{}', recorded in the case is not "
"found in the model".format(name))
self[name] = val
if inputs:
for name, val in zip(inputs.dtype.names, inputs):
if name not in self.model._var_abs_names['input']:
raise KeyError("Input variable, '{}', recorded in the case is not "
"found in the model".format(name))
self[name] = val

outputs = case.outputs._values if case.outputs is not None else None
for name, val in zip(outputs.dtype.names, outputs):
if name not in self.model._var_abs_names['output']:
raise KeyError("Output variable, '{}', recorded in the case is not "
"found in the model".format(name))
self[name] = val
if outputs:
for name, val in zip(outputs.dtype.names, outputs):
if name not in self.model._var_abs_names['output']:
raise KeyError("Output variable, '{}', recorded in the case is not "
"found in the model".format(name))
self[name] = val

return

Expand Down
2 changes: 1 addition & 1 deletion openmdao/devtools/problem_viewer/problem_viewer.py
Expand Up @@ -96,7 +96,7 @@ def _get_viewer_data(problem_or_rootgroup_or_filename):
model_text = cur.fetchone()
from six import PY2, PY3
if row is not None:
if format_version == 3:
if format_version >= 3:
return json.loads(model_text[0])
elif format_version in (1, 2):
if PY2:
Expand Down
23 changes: 22 additions & 1 deletion openmdao/recorders/case.py
@@ -1,6 +1,7 @@
"""
A Case class.
"""
from six import iteritems


class Case(object):
Expand Down Expand Up @@ -161,10 +162,15 @@ def _get_variables_of_type(self, var_type):
class DriverCase(Case):
"""
Wrap data from a single iteration of a Driver recording to make it more easily accessible.
Attributes
----------
_var_settings : dict
Dictionary mapping absolute variable names to variable settings.
"""

def __init__(self, filename, counter, iteration_coordinate, timestamp, success,
msg, inputs, outputs, prom2abs, abs2prom, meta):
msg, inputs, outputs, prom2abs, abs2prom, meta, var_settings):
"""
Initialize.
Expand Down Expand Up @@ -192,10 +198,25 @@ def __init__(self, filename, counter, iteration_coordinate, timestamp, success,
Dictionary mapping absolute names to promoted names.
meta : dict
Dictionary mapping absolute variable names to variable metadata.
var_settings : dict
Dictionary mapping absolute variable names to variable settings.
"""
super(DriverCase, self).__init__(filename, counter, iteration_coordinate,
timestamp, success, msg, prom2abs,
abs2prom, meta, inputs, outputs)
self._var_settings = var_settings

def scale(self):
"""
Scale the outputs array using _var_settings.
"""
for name, val in zip(self.outputs._values.dtype.names, self.outputs._values):
if name in self._var_settings:
# physical to scaled
if self._var_settings[name]['adder'] is not None:
self.outputs._values[name] += self._var_settings[name]['adder']
if self._var_settings[name]['scaler'] is not None:
self.outputs._values[name] *= self._var_settings[name]['scaler']


class SystemCase(Case):
Expand Down
4 changes: 3 additions & 1 deletion openmdao/recorders/cases.py
Expand Up @@ -55,7 +55,7 @@ def __init__(self, filename, format_version, abs2prom, abs2meta, prom2abs):
self._cases = {}

@abstractmethod
def get_case(self, case_id):
def get_case(self, case_id, scaled=False):
"""
Get cases.
Expand All @@ -64,6 +64,8 @@ def get_case(self, case_id):
case_id : str or int
If int, the index of the case to be read in the case iterations.
If given as a string, it is the identifier of the case.
scaled : bool
If True, return the scaled values.
Returns
-------
Expand Down
98 changes: 76 additions & 22 deletions openmdao/recorders/sqlite_reader.py
Expand Up @@ -3,6 +3,7 @@
"""
from __future__ import print_function, absolute_import

from copy import deepcopy
import os
import re
import sys
Expand Down Expand Up @@ -56,6 +57,8 @@ class SqliteCaseReader(BaseCaseReader):
Dictionary mapping promoted names to absolute names.
_coordinate_split_re : RegularExpression
Regular expression used for splitting iteration coordinates.
_var_settings : dict
Dictionary mapping absolute variable names to variable settings.
"""

def __init__(self, filename):
Expand All @@ -81,14 +84,28 @@ def __init__(self, filename):

with sqlite3.connect(self.filename) as con:
cur = con.cursor()
cur.execute("SELECT format_version, abs2prom, prom2abs, abs2meta FROM metadata")

# need to see what columns are in the metadata table before we query it
cursor = cur.execute('select * from metadata')
names = [description[0] for description in cursor.description]
if "var_settings" in names:
cur.execute(
"SELECT format_version, abs2prom, prom2abs, abs2meta, var_settings "
"FROM metadata")
else:
cur.execute(
"SELECT format_version, abs2prom, prom2abs, abs2meta FROM metadata")
row = cur.fetchone()
self.format_version = row[0]
self._abs2prom = None
self._prom2abs = None
self._abs2meta = None
self._var_settings = None

if self.format_version >= 4:
self._var_settings = json.loads(row[4])

if self.format_version == 3:
if self.format_version >= 3:
self._abs2prom = json.loads(row[1])
self._prom2abs = json.loads(row[2])
self._abs2meta = json.loads(row[3])
Expand Down Expand Up @@ -141,7 +158,7 @@ def _load(self):
the individual cases/iterations from the recorded file.
"""
self.driver_cases = DriverCases(self.filename, self.format_version, self._abs2prom,
self._abs2meta, self._prom2abs)
self._abs2meta, self._prom2abs, self._var_settings)
self.driver_derivative_cases = DriverDerivativeCases(self.filename, self.format_version,
self._abs2prom, self._abs2meta,
self._prom2abs)
Expand Down Expand Up @@ -201,7 +218,7 @@ def _load(self):
cur.execute("SELECT model_viewer_data FROM driver_metadata")
row = cur.fetchone()
if row is not None:
if self.format_version == 3:
if self.format_version >= 3:
self.driver_metadata = json.loads(row[0])
elif self.format_version in (1, 2):
if PY2:
Expand Down Expand Up @@ -781,8 +798,35 @@ def _write_outputs(self, in_or_out, comp_type, outputs, hierarchical, print_arra
class DriverCases(BaseCases):
"""
Case specific to the entries that might be recorded in a Driver iteration.
Attributes
----------
_var_settings : dict
Dictionary mapping absolute variable names to variable settings.
"""

def __init__(self, filename, format_version, abs2prom, abs2meta, prom2abs, var_settings):
"""
Initialize.
Parameters
----------
filename : str
The name of the recording file from which to instantiate the case reader.
format_version : int
The version of the format assumed when loading the file.
abs2prom : {'input': dict, 'output': dict}
Dictionary mapping absolute names to promoted names.
abs2meta : dict
Dictionary mapping absolute variable names to variable metadata.
prom2abs : {'input': dict, 'output': dict}
Dictionary mapping promoted names to absolute names.
var_settings : dict
Dictionary mapping absolute variable names to variable settings.
"""
super(DriverCases, self).__init__(filename, format_version, abs2prom, abs2meta, prom2abs)
self._var_settings = var_settings

def _extract_case_from_row(self, row):
"""
Pull data out of a queried SQLite row.
Expand All @@ -800,7 +844,7 @@ def _extract_case_from_row(self, row):
idx, counter, iteration_coordinate, timestamp, success, msg, inputs_text, \
outputs_text, = row

if self.format_version == 3:
if self.format_version >= 3:
inputs_array = json_to_np_array(inputs_text)
outputs_array = json_to_np_array(outputs_text)
elif self.format_version in (1, 2):
Expand All @@ -809,7 +853,7 @@ def _extract_case_from_row(self, row):

case = DriverCase(self.filename, counter, iteration_coordinate, timestamp,
success, msg, inputs_array, outputs_array,
self._prom2abs, self._abs2prom, self._abs2meta)
self._prom2abs, self._abs2prom, self._abs2meta, self._var_settings)
return case

def load_cases(self):
Expand All @@ -824,14 +868,16 @@ def load_cases(self):
case = self._extract_case_from_row(row)
self._cases[case.iteration_coordinate] = case

def get_case(self, case_id):
def get_case(self, case_id, scaled=False):
"""
Get a case from the database.
Parameters
----------
case_id : int or str
The integer index or string-identifier of the case to be retrieved.
scaled : bool
If True, return variables scaled. Otherwise, return physical values.
Returns
-------
Expand All @@ -841,21 +887,29 @@ def get_case(self, case_id):
# check to see if we've already cached this case
iteration_coordinate = self.get_iteration_coordinate(case_id)
if iteration_coordinate in self._cases:
return self._cases[iteration_coordinate]
case = self._cases[iteration_coordinate]
else:
# Get an unscaled case if does not already exist in _cases
with sqlite3.connect(self.filename) as con:
cur = con.cursor()
cur.execute("SELECT * FROM driver_iterations WHERE "
"iteration_coordinate=:iteration_coordinate",
{"iteration_coordinate": iteration_coordinate})
# Initialize the Case object from the iterations data
row = cur.fetchone()
con.close()

with sqlite3.connect(self.filename) as con:
cur = con.cursor()
cur.execute("SELECT * FROM driver_iterations WHERE "
"iteration_coordinate=:iteration_coordinate",
{"iteration_coordinate": iteration_coordinate})
# Initialize the Case object from the iterations data
row = cur.fetchone()
con.close()
case = self._extract_case_from_row(row)

case = self._extract_case_from_row(row)
# save so we don't query again
self._cases[case.iteration_coordinate] = case

if scaled:
# We have to do some scaling first before we return it
# Need to make a copy, otherwise we modify the object in the cache
case = deepcopy(case)
case.scale()

# save so we don't query again
self._cases[case.iteration_coordinate] = case
return case


Expand Down Expand Up @@ -956,7 +1010,7 @@ def _extract_case_from_row(self, row):
idx, counter, case_name, timestamp, success, msg, \
outputs_text, = row

if self.format_version == 3:
if self.format_version >= 3:
outputs_array = json_to_np_array(outputs_text)
elif self.format_version in (1, 2):
outputs_array = blob_to_array(outputs_text)
Expand Down Expand Up @@ -1034,7 +1088,7 @@ def _extract_case_from_row(self, row):
idx, counter, iteration_coordinate, timestamp, success, msg, inputs_text,\
outputs_text, residuals_text = row

if self.format_version == 3:
if self.format_version >= 3:
inputs_array = json_to_np_array(inputs_text)
outputs_array = json_to_np_array(outputs_text)
residuals_array = json_to_np_array(residuals_text)
Expand Down Expand Up @@ -1118,7 +1172,7 @@ def _extract_case_from_row(self, row):
idx, counter, iteration_coordinate, timestamp, success, msg, abs_err, rel_err, \
input_text, output_text, residuals_text = row

if self.format_version == 3:
if self.format_version >= 3:
input_array = json_to_np_array(input_text)
output_array = json_to_np_array(output_text)
residuals_array = json_to_np_array(residuals_text)
Expand Down

0 comments on commit 0ae224a

Please sign in to comment.