Skip to content

Commit

Permalink
Add tests and guard for parallel h5py
Browse files Browse the repository at this point in the history
  • Loading branch information
kburns committed Jun 2, 2023
1 parent 2f4bcbb commit fab25b9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
3 changes: 3 additions & 0 deletions dedalus/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,9 @@ class H5ParallelFileHandler(H5FileHandlerBase):

def __init__(self, *args, **kw):
super().__init__(*args, **kw)
# Fail if not using MPI
if not h5py.get_config().mpi:
raise ValueError("H5ParallelFileHandler requires parallel build of h5py.")
# Set HDF5 property list for collective writing
self._property_list = h5py.h5p.create(h5py.h5p.DATASET_XFER)
self._property_list.set_dxpl_mpio(h5py.h5fd.MPIO_COLLECTIVE)
Expand Down
12 changes: 10 additions & 2 deletions dedalus/tests/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,20 @@
from dedalus.tools.cache import CachedFunction


# Check if parallel h5py is available
handler_options = ['gather', 'virtual']
if h5py.get_config().mpi:
handler_options.append('mpio')
else:
handler_options.append(pytest.param('mpio', marks=pytest.mark.xfail(reason="parallel h5py not available")))


@pytest.mark.parametrize('dtype', [np.float64, np.complex128])
@pytest.mark.parametrize('dealias', [1, 3/2])
@pytest.mark.parametrize('output_scales', [1, 3/2, 2,
pytest.param(1/2, marks=pytest.mark.xfail(reason="evaluator not copying correctly for scales < 1"))])
@pytest.mark.parametrize('output_layout', ['g', 'c'])
@pytest.mark.parametrize('parallel', ['gather', 'mpio', 'virtual'])
@pytest.mark.parametrize('parallel', handler_options)
def test_cartesian_output(dtype, dealias, output_scales, output_layout, parallel):
Nx = Ny = Nz = 16
Lx = Ly = Lz = 2 * np.pi
Expand Down Expand Up @@ -87,7 +95,7 @@ def build_shell(Nphi, Ntheta, Nr, k, dealias, dtype):
@pytest.mark.parametrize('dtype', [np.float64, np.complex128])
@pytest.mark.parametrize('output_scales', [1, 3/2, 2,
pytest.param(1/2, marks=pytest.mark.xfail(reason="evaluator not copying correctly for scales < 1"))])
@pytest.mark.parametrize('parallel', ['gather', 'virtual'])
@pytest.mark.parametrize('parallel', handler_options)
def test_spherical_output(Nphi, Ntheta, Nr, k, dealias, dtype, basis, output_scales, parallel):
# Basis
c, d, b, phi, theta, r, x, y, z = basis(Nphi, Ntheta, Nr, k, dealias, dtype)
Expand Down

0 comments on commit fab25b9

Please sign in to comment.