Skip to content

Commit

Permalink
Merge pull request #1598 from DKilkenny/om_slicer
Browse files Browse the repository at this point in the history
Added check to allow connections to flat arrays
  • Loading branch information
swryan committed Aug 27, 2020
2 parents 142a754 + b941f6b commit 42f1d7d
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 15 deletions.
40 changes: 25 additions & 15 deletions openmdao/core/group.py
Expand Up @@ -1477,7 +1477,8 @@ def _setup_connections(self):

elif src_indices is not None:
shape = False
if _is_slicer_op(src_indices):
is_slice = _is_slicer_op(src_indices)
if is_slice:
global_size = self._var_allprocs_abs2meta[abs_out]['global_size']
global_shape = self._var_allprocs_abs2meta[abs_out]['global_shape']
src_indices = _slice_indices(src_indices, global_size, global_shape)
Expand All @@ -1488,22 +1489,31 @@ def _setup_connections(self):
if np.prod(src_indices.shape) == 0:
continue

flat_array_slice_check = not(is_slice and src_indices.size == np.prod(in_shape))

if any('flat_src_indices' in subsys._var_abs2meta[name]
for name in subsys._var_abs2meta):
msg = ("%s: flat_src_indices has no effect when using om_slicer to "
"slice array." % (self.msginfo))
simple_warning(msg)

# initial dimensions of indices shape must be same shape as target
for idx_d, inp_d in zip(src_indices.shape, in_shape):
if idx_d != inp_d:
msg = f"{self.msginfo}: The source indices " + \
f"{src_indices} do not specify a " + \
f"valid shape for the connection '{abs_out}' to " + \
f"'{abs_in}'. The target shape is " + \
f"{in_shape} but indices are {src_indices.shape}."
if self._raise_connection_errors:
raise ValueError(msg)
else:
simple_warning(msg)
continue
if flat_array_slice_check:
for idx_d, inp_d in zip(src_indices.shape, in_shape):
if idx_d != inp_d:
msg = f"{self.msginfo}: The source indices " + \
f"{src_indices} do not specify a " + \
f"valid shape for the connection '{abs_out}' to " + \
f"'{abs_in}'. The target shape is " + \
f"{in_shape} but indices are {src_indices.shape}."
if self._raise_connection_errors:
raise ValueError(msg)
else:
simple_warning(msg)
continue

# any remaining dimension of indices must match shape of source
if len(src_indices.shape) > len(in_shape):
if len(src_indices.shape) > len(in_shape) and flat_array_slice_check:
source_dimensions = src_indices.shape[len(in_shape)]
if source_dimensions != len(out_shape):
str_indices = str(src_indices).replace('\n', '')
Expand Down Expand Up @@ -1549,7 +1559,7 @@ def _setup_connections(self):
else:
abs2meta[abs_in]['src_indices'] = src_indices

if src_indices.shape != in_shape:
if src_indices.shape != in_shape and flat_array_slice_check:
msg = f"{self.msginfo}: src_indices shape " + \
f"{src_indices.shape} does not match {abs_in} shape " + \
f"{in_shape}."
Expand Down
75 changes: 75 additions & 0 deletions openmdao/core/tests/test_group.py
Expand Up @@ -555,6 +555,81 @@ def setup(self):
with assert_warning(UserWarning, msg):
p.setup()

def test_connect_to_flat_array_with_slice(self):
class SlicerComp(om.ExplicitComponent):
def setup(self):
self.add_input('x', np.ones((12,)))
self.add_output('y', 1.0)

def compute(self, inputs, outputs):
outputs['y'] = np.sum(inputs['x']) ** 2.0

p = om.Problem()

p.model.add_subsystem('indep', om.IndepVarComp('x', arr_large_4x4))
p.model.add_subsystem('row123_comp', SlicerComp())

idxs = np.array([0, 2, 3], dtype=int)

p.model.connect('indep.x', 'row123_comp.x', src_indices=om.slicer[idxs, ...])

p.setup()
p.run_model()

assert_near_equal(p['row123_comp.x'], arr_large_4x4[(0, 2, 3), ...].ravel())
assert_near_equal(p['row123_comp.y'], np.sum(arr_large_4x4[(0, 2, 3), ...]) ** 2.0)

def test_connect_to_flat_src_indices_with_slice_user_warning(self):
class SlicerComp(om.ExplicitComponent):
def setup(self):
self.add_input('x', np.ones((12,)))
self.add_output('y', 1.0)

def compute(self, inputs, outputs):
outputs['y'] = np.sum(inputs['x']) ** 2.0

p = om.Problem()

p.model.add_subsystem('indep', om.IndepVarComp('x', arr_large_4x4))
p.model.add_subsystem('row123_comp', SlicerComp())

idxs = np.array([0, 2, 3], dtype=int)

p.model.connect('indep.x', 'row123_comp.x', src_indices=om.slicer[idxs, ...],
flat_src_indices=True)

msg = "Group (<model>): flat_src_indices has no effect when using om_slicer to slice array."
with assert_warning(UserWarning, msg):
p.setup()
p.run_model()

assert_near_equal(p['row123_comp.x'], arr_large_4x4[(0, 2, 3), ...].ravel())
assert_near_equal(p['row123_comp.y'], np.sum(arr_large_4x4[(0, 2, 3), ...]) ** 2.0)

def test_connect_to_flat_array(self):
class SlicerComp(om.ExplicitComponent):
def setup(self):
self.add_input('x', np.ones((4,)))
self.add_output('y', 1.0)

def compute(self, inputs, outputs):
outputs['y'] = np.sum(inputs['x'])

p = om.Problem()

p.model.add_subsystem('indep', om.IndepVarComp('x', val=arr_large_4x4))
p.model.add_subsystem('trace_comp', SlicerComp())

idxs = np.array([0, 5, 10, 15], dtype=int)

p.model.connect('indep.x', 'trace_comp.x', src_indices=idxs, flat_src_indices=True)

p.setup()
p.run_model()

assert_near_equal(p['trace_comp.x'], np.diag(arr_large_4x4))
assert_near_equal(p['trace_comp.y'], np.sum(np.diag(arr_large_4x4)))

def test_om_slice_in_connect(self):

p = om.Problem()
Expand Down

0 comments on commit 42f1d7d

Please sign in to comment.