Skip to content

Commit

Permalink
Merge pull request #1627 from naylor-b/viewconn_bug
Browse files Browse the repository at this point in the history
fix for problem with view_connections when model has discrete variables
  • Loading branch information
swryan committed Aug 18, 2020
2 parents 51224f0 + 8640799 commit 8e1c8bf
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 14 deletions.
28 changes: 28 additions & 0 deletions openmdao/visualization/connection_viewer/tests/test_viewconns.py
@@ -1,4 +1,5 @@
import unittest
import openmdao.api as om
from openmdao.utils.testing_utils import use_tempdirs

@use_tempdirs
Expand All @@ -18,5 +19,32 @@ def test_feature_sellar(self):

om.view_connections(prob, outfile= "sellar_connections.html", show_browser=False)


class TestComp(om.ExplicitComponent):

def setup(self):
self.add_discrete_input('foo', val='4')
self.add_output('bar', val=0.)

def compute(self, inputs, outputs, discrete_inputs, discrete_outputs):

outputs['bar'] = float(discrete_inputs['foo'])


@use_tempdirs
class TestDiscreteViewConns(unittest.TestCase):
def test_discrete(self):
p = om.Problem()

ivc = p.model.add_subsystem('ivc', om.IndepVarComp(), promotes=['*'])
ivc.add_discrete_output('foo', val='3')

p.model.add_subsystem('test_comp', TestComp(), promotes=['*'])

p.setup()

om.view_connections(p, show_browser=False)


if __name__ == "__main__":
unittest.main()
25 changes: 11 additions & 14 deletions openmdao/visualization/connection_viewer/viewconns.py
Expand Up @@ -63,14 +63,10 @@ def view_connections(root, outfile='connections.html', show_browser=True,
else:
system = root

input_srcs = system._problem_meta['connections']

connections = {
tgt: src for tgt, src in input_srcs.items() if src is not None
}
connections = system._problem_meta['connections']

src2tgts = defaultdict(list)
units = {}
units = defaultdict(lambda: '')
for n, data in system._var_allprocs_abs2meta.items():
u = data.get('units', '')
if u is None:
Expand All @@ -80,10 +76,11 @@ def view_connections(root, outfile='connections.html', show_browser=True,
vals = {}

with printoptions(precision=precision, suppress=True, threshold=10000):

for t in system._var_abs_names['input']:
tmeta = system._var_abs2meta[t]
idxs = tmeta['src_indices']
for t in chain(system._var_abs_names['input'], system._var_abs_names_discrete['input']):
if t in system._var_abs2meta:
idxs = system._var_abs2meta[t]['src_indices']
else:
idxs = None

s = connections[t]
if show_values:
Expand All @@ -109,12 +106,12 @@ def view_connections(root, outfile='connections.html', show_browser=True,

src_systems = set()
tgt_systems = set()
for s in system._var_abs_names['output']:
for s in chain(system._var_abs_names['output'], system._var_abs_names_discrete['output']):
parts = s.split('.')
for i in range(len(parts)):
src_systems.add('.'.join(parts[:i]))

for t in system._var_abs_names['input']:
for t in chain(system._var_abs_names['input'], system._var_abs_names_discrete['input']):
parts = t.split('.')
for i in range(len(parts)):
tgt_systems.add('.'.join(parts[:i]))
Expand Down Expand Up @@ -146,10 +143,10 @@ def view_connections(root, outfile='connections.html', show_browser=True,
idx += 1

# add rows for unconnected sources
for src in system._var_abs_names['output']:
for src in chain(system._var_abs_names['output'], system._var_abs_names_discrete['output']):
if src not in src2tgts:
if show_values:
v = _val2str(system._outputs[src])
v = _val2str(system._abs_get_val(src))
else:
v = ''
row = {'id': idx, 'src': src, 'sprom': sprom[src], 'sunits': units[src],
Expand Down

0 comments on commit 8e1c8bf

Please sign in to comment.