Skip to content

Commit

Permalink
added test
Browse files Browse the repository at this point in the history
  • Loading branch information
naylor-b committed Jan 12, 2024
1 parent 41e1d16 commit 5f8b9cd
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 93 deletions.
2 changes: 0 additions & 2 deletions openmdao/core/group.py
Expand Up @@ -5110,8 +5110,6 @@ def _setup_iteration_lists(self):
grad_groups = self.comm.allgather(grad_groups)
grad_groups = {g for grp in grad_groups for g in grp}

print(f"GRAD GROUPS:", sorted(grad_groups))

if grad_groups:
remaining = set(grad_groups)
for name in sorted(grad_groups, key=lambda x: x.count('.')):
Expand Down
211 changes: 120 additions & 91 deletions openmdao/core/tests/test_pre_post_iter.py
Expand Up @@ -6,6 +6,12 @@
from openmdao.test_suite.components.exec_comp_for_test import ExecComp4Test
from openmdao.test_suite.components.sellar import SellarDis1withDerivatives, SellarDis2withDerivatives
from openmdao.utils.testing_utils import use_tempdirs
from openmdao.utils.mpi import MPI

try:
from openmdao.parallel_api import PETScVector
except ImportError:
PETScVector = None


class MissingPartialsComp(om.ExplicitComponent):
Expand All @@ -28,97 +34,109 @@ def compute_partials(self, inputs, partials, discrete_inputs=None):
partials['z', 'b'] = 3.0


@use_tempdirs
class TestPrePostIter(unittest.TestCase):
def setup_problem(do_pre_post_opt, mode, use_ivc=False, coloring=False, size=3, group=False,
force=(), approx=False, force_complex=False, recording=False, parallel=False):
prob = om.Problem()
prob.options['group_by_pre_opt_post'] = do_pre_post_opt

def setup_problem(self, do_pre_post_opt, mode, use_ivc=False, coloring=False, size=3, group=False,
force=(), approx=False, force_complex=False, recording=False):
prob = om.Problem()
prob.options['group_by_pre_opt_post'] = do_pre_post_opt
prob.driver = om.ScipyOptimizeDriver(optimizer='SLSQP', disp=False)
prob.set_solver_print(level=0)

prob.driver = om.ScipyOptimizeDriver(optimizer='SLSQP', disp=False)
prob.set_solver_print(level=0)
model = prob.model

model = prob.model
if approx:
model.approx_totals()

if approx:
model.approx_totals()

if use_ivc:
model.add_subsystem('ivc', om.IndepVarComp('x', np.ones(size)))
if use_ivc:
model.add_subsystem('ivc', om.IndepVarComp('x', np.ones(size)))

if group:
G1 = model.add_subsystem('G1', om.Group(), promotes=['*'])
G2 = model.add_subsystem('G2', om.Group(), promotes=['*'])
if parallel:
par = model.add_subsystem('par', om.ParallelGroup(), promotes=['*'])
par.nonlinear_solver = om.NonlinearBlockJac()
parent = par
else:
parent = model

if group:
G1 = parent.add_subsystem('G1', om.Group(), promotes=['*'])
G2 = parent.add_subsystem('G2', om.Group(), promotes=['*'])
if parallel:
G2.nonlinear_solver = om.NewtonSolver(solve_subsystems=False)
G2.linear_solver = om.DirectSolver(assemble_jac=True)
# this guy wouldn't be included in the iteration loop were it not under Newton
G1.add_subsystem('sub_post_comp', ExecComp4Test('y=.5*x', x=np.ones(size), y=np.zeros(size)))
else:
G1 = parent
G2 = parent

comps = {
'pre1': G1.add_subsystem('pre1', ExecComp4Test('y=2.*x', x=np.ones(size), y=np.zeros(size))),
'pre2': G1.add_subsystem('pre2', ExecComp4Test('y=3.*x - 7.*xx', x=np.ones(size), xx=np.ones(size), y=np.zeros(size))),

'iter1': G1.add_subsystem('iter1', ExecComp4Test('y=x1 + x2*4. + x3',
x1=np.ones(size), x2=np.ones(size),
x3=np.ones(size), y=np.zeros(size))),
'iter2': G1.add_subsystem('iter2', ExecComp4Test('y=.5*x', x=np.ones(size), y=np.zeros(size))),
'iter4': G2.add_subsystem('iter4', ExecComp4Test('y=7.*x', x=np.ones(size), y=np.zeros(size))),
'iter3': G2.add_subsystem('iter3', ExecComp4Test('y=6.*x', x=np.ones(size), y=np.zeros(size))),

'post1': G2.add_subsystem('post1', ExecComp4Test('y=8.*x', x=np.ones(size), y=np.zeros(size))),
'post2': G2.add_subsystem('post2', ExecComp4Test('y=x1*9. + x2*5. + x3*3.', x1=np.ones(size),
x2=np.ones(size), x3=np.zeros(size),
y=np.zeros(size))),
}

for name in force:
if name in comps:
comps[name].options['always_opt'] = True
else:
G1 = model
G2 = model
raise RuntimeError(f'"{name}" not in comps')

comps = {
'pre1': G1.add_subsystem('pre1', ExecComp4Test('y=2.*x', x=np.ones(size), y=np.zeros(size))),
'pre2': G1.add_subsystem('pre2', ExecComp4Test('y=3.*x - 7.*xx', x=np.ones(size), xx=np.ones(size), y=np.zeros(size))),
if use_ivc:
model.connect('ivc.x', 'iter1.x3')

'iter1': G1.add_subsystem('iter1', ExecComp4Test('y=x1 + x2*4. + x3',
x1=np.ones(size), x2=np.ones(size),
x3=np.ones(size), y=np.zeros(size))),
'iter2': G1.add_subsystem('iter2', ExecComp4Test('y=.5*x', x=np.ones(size), y=np.zeros(size))),
'iter4': G2.add_subsystem('iter4', ExecComp4Test('y=7.*x', x=np.ones(size), y=np.zeros(size))),
'iter3': G2.add_subsystem('iter3', ExecComp4Test('y=6.*x', x=np.ones(size), y=np.zeros(size))),
model.connect('pre1.y', ['iter1.x1', 'post2.x1', 'pre2.xx'])
model.connect('pre2.y', 'iter1.x2')
model.connect('iter1.y', ['iter2.x', 'iter4.x'])
model.connect('iter2.y', 'post2.x2')
model.connect('iter3.y', 'post1.x')
model.connect('iter4.y', 'iter3.x')
model.connect('post1.y', 'post2.x3')

'post1': G2.add_subsystem('post1', ExecComp4Test('y=8.*x', x=np.ones(size), y=np.zeros(size))),
'post2': G2.add_subsystem('post2', ExecComp4Test('y=x1*9. + x2*5. + x3*3.', x1=np.ones(size),
x2=np.ones(size), x3=np.zeros(size),
y=np.zeros(size))),
}
prob.model.add_design_var('iter1.x3', lower=-10, upper=10)
prob.model.add_constraint('iter2.y', upper=10.)
prob.model.add_objective('iter3.y', index=0)

for name in force:
if name in comps:
comps[name].options['always_opt'] = True
else:
raise RuntimeError(f'"{name}" not in comps')
if coloring:
prob.driver.declare_coloring()

if use_ivc:
model.connect('ivc.x', 'iter1.x3')
if recording:
model.recording_options['record_inputs'] = True
model.recording_options['record_outputs'] = True
model.recording_options['record_residuals'] = True

model.connect('pre1.y', ['iter1.x1', 'post2.x1', 'pre2.xx'])
model.connect('pre2.y', 'iter1.x2')
model.connect('iter1.y', ['iter2.x', 'iter4.x'])
model.connect('iter2.y', 'post2.x2')
model.connect('iter3.y', 'post1.x')
model.connect('iter4.y', 'iter3.x')
model.connect('post1.y', 'post2.x3')
recorder = om.SqliteRecorder("sqlite_test_pre_post", record_viewer_data=False)

prob.model.add_design_var('iter1.x3', lower=-10, upper=10)
prob.model.add_constraint('iter2.y', upper=10.)
prob.model.add_objective('iter3.y', index=0)
model.add_recorder(recorder)
prob.driver.add_recorder(recorder)

if coloring:
prob.driver.declare_coloring()
for comp in comps.values():
comp.add_recorder(recorder)

if recording:
model.recording_options['record_inputs'] = True
model.recording_options['record_outputs'] = True
model.recording_options['record_residuals'] = True
prob.setup(mode=mode, force_alloc_complex=force_complex)

recorder = om.SqliteRecorder("sqlite_test_pre_post", record_viewer_data=False)
# we don't want ExecComps to be colored because it makes the iter counting more complicated.
for comp in model.system_iter(recurse=True, typ=ExecComp4Test):
comp.options['do_coloring'] = False
comp.options['has_diag_partials'] = True

model.add_recorder(recorder)
prob.driver.add_recorder(recorder)

for comp in comps.values():
comp.add_recorder(recorder)

prob.setup(mode=mode, force_alloc_complex=force_complex)

# we don't want ExecComps to be colored because it makes the iter counting more complicated.
for comp in model.system_iter(recurse=True, typ=ExecComp4Test):
comp.options['do_coloring'] = False
comp.options['has_diag_partials'] = True
return prob

return prob
@use_tempdirs
class TestPrePostIter(unittest.TestCase):

def test_pre_post_iter_rev(self):
prob = self.setup_problem(do_pre_post_opt=True, mode='rev')
prob = setup_problem(do_pre_post_opt=True, mode='rev')
prob.run_driver()

self.assertEqual(prob.model.pre1.num_nl_solves, 1)
Expand All @@ -136,7 +154,7 @@ def test_pre_post_iter_rev(self):
assert_check_totals(data)

def test_pre_post_iter_rev_grouped(self):
prob = self.setup_problem(do_pre_post_opt=True, group=True, mode='rev')
prob = setup_problem(do_pre_post_opt=True, group=True, mode='rev')
prob.run_driver()

self.assertEqual(prob.model.G1.pre1.num_nl_solves, 1)
Expand All @@ -154,7 +172,7 @@ def test_pre_post_iter_rev_grouped(self):
assert_check_totals(data)

def test_pre_post_iter_rev_coloring(self):
prob = self.setup_problem(do_pre_post_opt=True, coloring=True, mode='rev')
prob = setup_problem(do_pre_post_opt=True, coloring=True, mode='rev')
prob.run_driver()

self.assertEqual(prob.model.pre1.num_nl_solves, 1)
Expand All @@ -172,7 +190,7 @@ def test_pre_post_iter_rev_coloring(self):
assert_check_totals(data)

def test_pre_post_iter_rev_coloring_grouped(self):
prob = self.setup_problem(do_pre_post_opt=True, coloring=True, group=True, mode='rev')
prob = setup_problem(do_pre_post_opt=True, coloring=True, group=True, mode='rev')
prob.run_driver()

self.assertEqual(prob.model.G1.pre1.num_nl_solves, 1)
Expand All @@ -190,7 +208,7 @@ def test_pre_post_iter_rev_coloring_grouped(self):
assert_check_totals(data)

def test_pre_post_iter_rev_ivc(self):
prob = self.setup_problem(do_pre_post_opt=True, use_ivc=True, mode='rev')
prob = setup_problem(do_pre_post_opt=True, use_ivc=True, mode='rev')
prob.run_driver()

self.assertEqual(prob.model.pre1.num_nl_solves, 1)
Expand All @@ -208,7 +226,7 @@ def test_pre_post_iter_rev_ivc(self):
assert_check_totals(data)

def test_pre_post_iter_rev_ivc_grouped(self):
prob = self.setup_problem(do_pre_post_opt=True, use_ivc=True, group=True, mode='rev')
prob = setup_problem(do_pre_post_opt=True, use_ivc=True, group=True, mode='rev')
prob.run_driver()

self.assertEqual(prob.model.G1.pre1.num_nl_solves, 1)
Expand All @@ -226,7 +244,7 @@ def test_pre_post_iter_rev_ivc_grouped(self):
assert_check_totals(data)

def test_pre_post_iter_rev_ivc_coloring(self):
prob = self.setup_problem(do_pre_post_opt=True, use_ivc=True, coloring=True, mode='rev')
prob = setup_problem(do_pre_post_opt=True, use_ivc=True, coloring=True, mode='rev')
prob.run_driver()

self.assertEqual(prob.model.pre1.num_nl_solves, 1)
Expand All @@ -244,7 +262,7 @@ def test_pre_post_iter_rev_ivc_coloring(self):
assert_check_totals(data)

def test_pre_post_iter_fwd(self):
prob = self.setup_problem(do_pre_post_opt=True, mode='fwd')
prob = setup_problem(do_pre_post_opt=True, mode='fwd')
prob.run_driver()

self.assertEqual(prob.model.pre1.num_nl_solves, 1)
Expand All @@ -262,7 +280,7 @@ def test_pre_post_iter_fwd(self):
assert_check_totals(data)

def test_pre_post_iter_fwd_grouped(self):
prob = self.setup_problem(do_pre_post_opt=True, group=True, mode='fwd')
prob = setup_problem(do_pre_post_opt=True, group=True, mode='fwd')
prob.run_driver()

self.assertEqual(prob.model.G1.pre1.num_nl_solves, 1)
Expand All @@ -280,7 +298,7 @@ def test_pre_post_iter_fwd_grouped(self):
assert_check_totals(data)

def test_pre_post_iter_fwd_coloring(self):
prob = self.setup_problem(do_pre_post_opt=True, coloring=True, mode='fwd')
prob = setup_problem(do_pre_post_opt=True, coloring=True, mode='fwd')
prob.run_driver()

self.assertEqual(prob.model.pre1.num_nl_solves, 1)
Expand All @@ -298,7 +316,7 @@ def test_pre_post_iter_fwd_coloring(self):
assert_check_totals(data)

def test_pre_post_iter_fwd_coloring_grouped(self):
prob = self.setup_problem(do_pre_post_opt=True, coloring=True, group=True, mode='fwd')
prob = setup_problem(do_pre_post_opt=True, coloring=True, group=True, mode='fwd')
prob.run_driver()

self.assertEqual(prob.model.G1.pre1.num_nl_solves, 1)
Expand All @@ -316,7 +334,7 @@ def test_pre_post_iter_fwd_coloring_grouped(self):
assert_check_totals(data)

def test_pre_post_iter_fwd_coloring_grouped_force_post(self):
prob = self.setup_problem(do_pre_post_opt=True, coloring=True, group=True,
prob = setup_problem(do_pre_post_opt=True, coloring=True, group=True,
force=['post2'], mode='fwd')
prob.run_driver()

Expand All @@ -338,7 +356,7 @@ def test_pre_post_iter_fwd_coloring_grouped_force_post(self):
assert_check_totals(data)

def test_pre_post_iter_fwd_coloring_grouped_force_pre(self):
prob = self.setup_problem(do_pre_post_opt=True, coloring=True, group=True,
prob = setup_problem(do_pre_post_opt=True, coloring=True, group=True,
force=['pre1'], mode='fwd')
prob.run_driver()

Expand All @@ -360,7 +378,7 @@ def test_pre_post_iter_fwd_coloring_grouped_force_pre(self):
assert_check_totals(data)

def test_pre_post_iter_fwd_ivc(self):
prob = self.setup_problem(do_pre_post_opt=True, use_ivc=True, mode='fwd')
prob = setup_problem(do_pre_post_opt=True, use_ivc=True, mode='fwd')
prob.run_driver()

self.assertEqual(prob.model.pre1.num_nl_solves, 1)
Expand All @@ -378,7 +396,7 @@ def test_pre_post_iter_fwd_ivc(self):
assert_check_totals(data)

def test_pre_post_iter_fwd_ivc_coloring(self):
prob = self.setup_problem(do_pre_post_opt=True, use_ivc=True, coloring=True, mode='fwd')
prob = setup_problem(do_pre_post_opt=True, use_ivc=True, coloring=True, mode='fwd')
prob.run_driver()

self.assertEqual(prob.model.pre1.num_nl_solves, 1)
Expand All @@ -396,7 +414,7 @@ def test_pre_post_iter_fwd_ivc_coloring(self):
assert_check_totals(data)

def test_pre_post_iter_approx(self):
prob = self.setup_problem(do_pre_post_opt=True, mode='fwd', approx=True, force_complex=True)
prob = setup_problem(do_pre_post_opt=True, mode='fwd', approx=True, force_complex=True)

prob.run_driver()

Expand All @@ -415,7 +433,7 @@ def test_pre_post_iter_approx(self):
assert_check_totals(data)

def test_pre_post_iter_approx_grouped(self):
prob = self.setup_problem(do_pre_post_opt=True, group=True, approx=True, force_complex=True, mode='fwd')
prob = setup_problem(do_pre_post_opt=True, group=True, approx=True, force_complex=True, mode='fwd')
prob.run_driver()

self.assertEqual(prob.model.G1.pre1.num_nl_solves, 1)
Expand All @@ -433,7 +451,7 @@ def test_pre_post_iter_approx_grouped(self):
assert_check_totals(data)

def test_pre_post_iter_approx_coloring(self):
prob = self.setup_problem(do_pre_post_opt=True, coloring=True, approx=True, force_complex=True, mode='fwd')
prob = setup_problem(do_pre_post_opt=True, coloring=True, approx=True, force_complex=True, mode='fwd')
prob.run_driver()

self.assertEqual(prob.model.pre1.num_nl_solves, 1)
Expand All @@ -451,7 +469,7 @@ def test_pre_post_iter_approx_coloring(self):
assert_check_totals(data)

def test_pre_post_iter_approx_ivc(self):
prob = self.setup_problem(do_pre_post_opt=True, use_ivc=True, approx=True, force_complex=True, mode='fwd')
prob = setup_problem(do_pre_post_opt=True, use_ivc=True, approx=True, force_complex=True, mode='fwd')
prob.run_driver()

self.assertEqual(prob.model.pre1.num_nl_solves, 1)
Expand All @@ -469,7 +487,7 @@ def test_pre_post_iter_approx_ivc(self):
assert_check_totals(data)

def test_pre_post_iter_approx_ivc_coloring(self):
prob = self.setup_problem(do_pre_post_opt=True, use_ivc=True, coloring=True, approx=True, force_complex=True, mode='fwd')
prob = setup_problem(do_pre_post_opt=True, use_ivc=True, coloring=True, approx=True, force_complex=True, mode='fwd')
prob.run_driver()

self.assertEqual(prob.model.pre1.num_nl_solves, 1)
Expand Down Expand Up @@ -528,7 +546,7 @@ def test_newton_with_densejac_under_full_model_fd(self):
assert_near_equal(J[('obj_cmp.obj', 'pz.z')], np.array([[9.62568658, 1.78576699]]), .00001)

def test_reading_system_cases_pre_opt_post(self):
prob = self.setup_problem(do_pre_post_opt=True, mode='fwd', recording=True)
prob = setup_problem(do_pre_post_opt=True, mode='fwd', recording=True)
prob.run_driver()
prob.cleanup()

Expand Down Expand Up @@ -646,5 +664,16 @@ def test_comp_multiple_iter_lists(self):
self.assertEqual(C4._run_on_opt, [False, True, False])
self.assertEqual(C5._run_on_opt, [False, False, True])


@unittest.skipUnless(MPI and PETScVector, "MPI and PETSc are required.")
class PrePostMPITestCase(unittest.TestCase):
N_PROCS = 2

def test_newton_on_one_rank(self):
prob = setup_problem(do_pre_post_opt=True, mode='fwd', parallel=True, group=True)

prob.run_driver()


if __name__ == "__main__":
unittest.main()

0 comments on commit 5f8b9cd

Please sign in to comment.