Skip to content

Commit

Permalink
Merge pull request #1456 from Kenneth-T-Moore/master
Browse files Browse the repository at this point in the history
rearranged a test for CI fail
  • Loading branch information
swryan committed Jun 9, 2020
2 parents b4598bf + 23723ed commit 8143d68
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions openmdao/drivers/tests/test_doe_driver.py
Expand Up @@ -1819,21 +1819,15 @@ class TestParallelDistribDOE(unittest.TestCase):

def setUp(self):
self.startdir = os.getcwd()
# Have to hard code the tempdir so that all procs run in the same place.
self.tempdir = 'TestParallelDistribDOE_one_dir_only'
try:
os.mkdir(self.tempdir)
except OSError:
pass
self.tempdir = tempfile.mkdtemp(prefix='TestDOEDriver-')
os.chdir(self.tempdir)

def tearDown(self):
os.chdir(self.startdir)
if MPI.COMM_WORLD.rank == 0:
try:
shutil.rmtree(self.tempdir)
except OSError:
pass
try:
shutil.rmtree(self.tempdir)
except OSError:
pass

def test_doe_distributed_var(self):
size = 3
Expand Down Expand Up @@ -1872,7 +1866,6 @@ def test_doe_distributed_var(self):
rank = prob.comm.rank
if rank == 0:
filename0 = "cases.sql_0"
filename1 = "cases.sql_1"
values = []

cr = om.CaseReader(filename0)
Expand All @@ -1881,16 +1874,28 @@ def test_doe_distributed_var(self):
outputs = cr.get_case(case).outputs
values.append(outputs)

cr = om.CaseReader(filename1)
# 2**6 cases, half on each rank
self.assertEqual(len(values), 32)
x_inputs = [list(val['x']) for val in values]
for n1 in [-50.]:
for n2 in [-50., 50.]:
for n3 in [-50., 50.]:
self.assertEqual(x_inputs.count([n1, n2, n3]), 8)

elif rank == 1:
filename0 = "cases.sql_1"
values = []

cr = om.CaseReader(filename0)
cases = cr.list_cases('driver')
for case in cases:
outputs = cr.get_case(case).outputs
values.append(outputs)

# 2**6 cases
self.assertEqual(len(values), 64)
# 2**6 cases, half on each rank
self.assertEqual(len(values), 32)
x_inputs = [list(val['x']) for val in values]
for n1 in [-50., 50.]:
for n1 in [50.]:
for n2 in [-50., 50.]:
for n3 in [-50., 50.]:
self.assertEqual(x_inputs.count([n1, n2, n3]), 8)
Expand Down

0 comments on commit 8143d68

Please sign in to comment.