Skip to content

Commit

Permalink
Merge pull request #1693 from naylor-b/refactor
Browse files Browse the repository at this point in the history
Refactor of internal data structures and some cleanups.
  • Loading branch information
swryan committed Sep 21, 2020
2 parents 59521bf + 061e976 commit e18df7c
Show file tree
Hide file tree
Showing 48 changed files with 1,153 additions and 1,138 deletions.
29 changes: 18 additions & 11 deletions openmdao/approximation_schemes/approximation_scheme.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Base class used to define the interface for derivative approximation schemes."""
from collections import defaultdict

from itertools import chain
from scipy.sparse import coo_matrix
import numpy as np

Expand Down Expand Up @@ -146,7 +146,6 @@ def _init_colored_approximations(self, system):

outputs = system._outputs
inputs = system._inputs
abs2meta = system._var_allprocs_abs2meta
prom2abs_out = system._var_allprocs_prom2abs_list['output']
prom2abs_in = system._var_allprocs_prom2abs_list['input']
approx_wrt_idx = system._owns_approx_wrt_idx
Expand All @@ -171,8 +170,8 @@ def _init_colored_approximations(self, system):

if is_total and system.pathname == '': # top level approx totals
of_names = system._owns_approx_of
full_wrts = system._var_allprocs_abs_names['output'] + \
system._var_allprocs_abs_names['input']
full_wrts = list(chain(system._var_allprocs_abs2meta['output'],
system._var_allprocs_abs2meta['input']))
wrt_names = system._owns_approx_wrt
else:
of_names, wrt_names = system._get_partials_varlists()
Expand All @@ -190,7 +189,7 @@ def _init_colored_approximations(self, system):

# FIXME: need to deal with mix of local/remote indices

len_full_ofs = len(system._var_allprocs_abs_names['output'])
len_full_ofs = len(system._var_allprocs_abs2meta['output'])

full_idxs = []
approx_of_idx = system._owns_approx_of_idx
Expand All @@ -211,7 +210,10 @@ def _init_colored_approximations(self, system):

if len(full_wrts) != len(wrt_matches) or approx_wrt_idx:
if is_total and system.pathname == '': # top level approx totals
full_wrt_sizes = [abs2meta[wrt]['size'] for wrt in full_wrts]
a2mi = system._var_allprocs_abs2meta['input']
a2mo = system._var_allprocs_abs2meta['output']
full_wrt_sizes = [a2mi[wrt]['size'] if wrt in a2mi else a2mo[wrt]['size']
for wrt in full_wrts]
else:
_, full_wrt_sizes = system._get_partials_var_sizes()

Expand Down Expand Up @@ -274,7 +276,10 @@ def _init_approximations(self, system):
in_idx += slices[wrt].start
else:
if arr is None:
in_idx = range(abs2meta[wrt]['size'])
if wrt in abs2meta['input']:
in_idx = range(abs2meta['input'][wrt]['size'])
else:
in_idx = range(abs2meta['output'][wrt]['size'])
else:
in_idx = range(slices[wrt].start, slices[wrt].stop)

Expand Down Expand Up @@ -525,7 +530,8 @@ def _get_wrt_subjacs(system, approxs):
each subjac.
"""
abs2idx = system._var_allprocs_abs2idx['nonlinear']
abs2meta = system._var_allprocs_abs2meta
abs2meta_in = system._var_allprocs_abs2meta['input']
abs2meta_out = system._var_allprocs_abs2meta['output']
approx_of_idx = system._owns_approx_of_idx
approx_wrt_idx = system._owns_approx_wrt_idx
approx_of = system._owns_approx_of
Expand All @@ -534,7 +540,7 @@ def _get_wrt_subjacs(system, approxs):
ofdict = {}
nondense = {}
slicedict = system._outputs.get_slice_dict()
abs_out_names = [n for n in system._var_allprocs_abs_names['output'] if n in slicedict]
abs_out_names = [n for n in system._var_allprocs_abs2meta['output'] if n in slicedict]

for key, options in approxs:
of, wrt = key
Expand All @@ -550,7 +556,7 @@ def _get_wrt_subjacs(system, approxs):
out_idx = approx_of_idx[of]
out_size = len(out_idx)
else:
out_size = abs2meta[of]['size']
out_size = abs2meta_out[of]['size']
out_idx = _full_slice
ofdict[of] = (out_size, out_idx)
J[wrt]['tot_rows'] += out_size
Expand All @@ -566,7 +572,8 @@ def _get_wrt_subjacs(system, approxs):
elif wrt_idx is not _full_slice:
J[wrt]['data'] = arr = np.zeros((J[wrt]['tot_rows'], len(wrt_idx)))
else:
J[wrt]['data'] = arr = np.zeros((J[wrt]['tot_rows'], abs2meta[wrt]['size']))
sz = abs2meta_in[wrt]['size'] if wrt in abs2meta_in else abs2meta_out[wrt]['size']
J[wrt]['data'] = arr = np.zeros((J[wrt]['tot_rows'], sz))

# sort ofs into the proper order to match outputs/resids vecs
start = end = 0
Expand Down
13 changes: 4 additions & 9 deletions openmdao/components/exec_comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,18 +574,13 @@ def __init__(self, outputs, inputs):
self._inputs = inputs

def __getitem__(self, name):
if name in self._outputs:
return self._outputs[name]
else:
try:
return self._inputs[name]
except KeyError:
return self._outputs[name]

def __setitem__(self, name, value):
if name in self._outputs:
self._outputs[name] = value
elif name in self._inputs:
self._inputs[name] = value
else:
self._outputs[name] = value # will raise KeyError
self._outputs[name] = value

def __contains__(self, name):
return name in self._outputs or name in self._inputs
Expand Down
6 changes: 3 additions & 3 deletions openmdao/components/tests/test_balance_comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ def test_balance_comp_with_units_kwarg_and_eq_units(self):

prob.run_model()
meta = prob.model._var_abs2meta
self.assertEqual(meta['balance.x']['units'], 'm')
self.assertEqual(meta['balance.rhs:x']['units'], 'm')
self.assertEqual(meta['balance.lhs:x']['units'], 'm')
self.assertEqual(meta['output']['balance.x']['units'], 'm')
self.assertEqual(meta['input']['balance.rhs:x']['units'], 'm')
self.assertEqual(meta['input']['balance.lhs:x']['units'], 'm')

def test_create_on_init(self):

Expand Down
Loading

0 comments on commit e18df7c

Please sign in to comment.