diff --git a/openmdao/core/group.py b/openmdao/core/group.py index 69309c844d..8096754ccd 100644 --- a/openmdao/core/group.py +++ b/openmdao/core/group.py @@ -25,7 +25,7 @@ _flatten_src_indices from openmdao.utils.general_utils import ContainsAll, all_ancestors, simple_warning, \ common_subpath, conditional_error, _is_slicer_op, _slice_indices -from openmdao.utils.units import is_compatible, unit_conversion, _has_val_mismatch +from openmdao.utils.units import is_compatible, unit_conversion, _has_val_mismatch, _find_unit from openmdao.utils.mpi import MPI, check_mpi_exceptions, multi_proc_exception_check from openmdao.utils.coloring import Coloring, _STD_COLORING_FNAME import openmdao.utils.coloring as coloring_mod @@ -3196,8 +3196,9 @@ def _resolve_ambiguous_input_meta(self): tmeta = abs2meta[tgt] if tgt in abs2meta else all_abs2meta[tgt] tunits = tmeta['units'] if 'units' in tmeta else None if 'units' not in gmeta and sunits != tunits: - errs.add('units') - metadata.add('units') + if _find_unit(sunits) != _find_unit(tunits): + errs.add('units') + metadata.add('units') if 'value' not in gmeta: if tval.shape == sval.shape: if _has_val_mismatch(tunits, tval, sunits, sval): diff --git a/openmdao/core/tests/test_units.py b/openmdao/core/tests/test_units.py index 3462050ac1..854fa598cb 100644 --- a/openmdao/core/tests/test_units.py +++ b/openmdao/core/tests/test_units.py @@ -880,6 +880,44 @@ def setup(self): #self.assertTrue(iter_count < 20) #self.assertTrue(not np.isnan(prob['sub.cc2.y'])) + def test_promotes_equivalent_units(self): + # multiple Group.set_input_defaults calls at same tree level with conflicting units args + p = om.Problem() + + g1 = p.model.add_subsystem("G1", om.Group(), promotes_inputs=['x']) + g1.add_subsystem("C1", om.ExecComp("y = 2. * x * z", + x={'value': 5.0, 'units': 'm/s/s'}, + y={'value': 1.0, 'units': None}, + z={'value': 1.0, 'units': 'W'}), + promotes_inputs=['x', 'z']) + g1.add_subsystem("C2", om.ExecComp("y = 3. * x * z", + x={'value': 5.0, 'units': 'm/s**2'}, + y={'value': 1.0, 'units': None}, + z={'value': 1.0, 'units': 'J/s'}), + promotes_inputs=['x', 'z']) + # converting m/s/s to m/s**2 is allowed + p.setup() + + def test_promotes_non_equivalent_units(self): + # multiple Group.set_input_defaults calls at same tree level with conflicting units args + p = om.Problem() + + g1 = p.model.add_subsystem("G1", om.Group(), promotes_inputs=['x']) + g1.add_subsystem("C1", om.ExecComp("y = 2. * x * z", + x={'value': 5.0, 'units': 'J/s/s'}, + y={'value': 1.0, 'units': None}, + z={'value': 1.0, 'units': 'W'}), + promotes_inputs=['x', 'z']) + g1.add_subsystem("C2", om.ExecComp("y = 3. * x * z", + x={'value': 5.0, 'units': 'm/s**2'}, + y={'value': 1.0, 'units': None}, + z={'value': 1.0, 'units': 'J/s'}), + promotes_inputs=['x', 'z']) + # trying to convert J/s/s to m/s**2 should cause Incompatible units TypeError exception + with self.assertRaises(TypeError) as e: + p.setup() + self.assertEqual(str(e.exception), 'Incompatible units') + if __name__ == "__main__": unittest.main()