Skip to content

Commit

Permalink
Test against different dask schedulers (JiaweiZhuang#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
andersy005 committed Nov 5, 2020
1 parent 001d6e9 commit 9e76f18
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
shell: bash -l {0}
run: |
conda activate xesmf
python -m pytest --cov=./ --cov-report=xml --verbose
python -m pytest --setup-show --cov=./ --cov-report=xml --verbose
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@ PET0.ESMF_LogFile
.cache
.coverage
coverage.xml
dask-worker-space/
27 changes: 27 additions & 0 deletions xesmf/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import dask
import pytest


@pytest.fixture(scope="function")
def threaded_scheduler():
with dask.config.set(scheduler="threads"):
yield


@pytest.fixture(scope="function")
def processes_scheduler():
with dask.config.set(scheduler="processes"):
yield


@pytest.fixture(scope="module")
def distributed_scheduler():
from dask.distributed import Client, LocalCluster

cluster = LocalCluster(threads_per_worker=1, n_workers=2, processes=True)
client = Client(cluster)
yield
client.close()
del client
cluster.close()
del cluster
50 changes: 34 additions & 16 deletions xesmf/tests/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import xarray as xr
import xesmf as xe
from xesmf.frontend import as_2d_mesh

import dask
from numpy.testing import assert_equal, assert_almost_equal
import pytest

dask_schedulers = ['threaded_scheduler', 'processes_scheduler', 'distributed_scheduler']


# same test data as test_backend.py, but here we can use xarray DataSet
ds_in = xe.util.grid_global(20, 12)
ds_out = xe.util.grid_global(15, 9)
Expand Down Expand Up @@ -268,14 +271,17 @@ def test_regrid_dataarray_from_locstream():
regridder = xe.Regridder(ds_locs, ds_in, 'conservative', locstream_in=True)


def test_regrid_dask():
@pytest.mark.parametrize('scheduler', dask_schedulers)
def test_regrid_dask(request, scheduler):
# chunked dask array (no xarray metadata)

scheduler = request.getfixturevalue(scheduler)
regridder = xe.Regridder(ds_in, ds_out, 'conservative')

indata = ds_in_chunked['data4D'].data
outdata = regridder(indata)

assert dask.is_dask_collection(outdata)

# lazy dask arrays have incorrect shape attribute due to last chunk
assert outdata.compute().shape == indata.shape[:-2] + horiz_shape_out
assert outdata.chunksize == indata.chunksize[:-2] + horiz_shape_out
Expand All @@ -285,30 +291,38 @@ def test_regrid_dask():
assert np.max(np.abs(rel_err)) < 0.05


def test_regrid_dask_to_locstream():
@pytest.mark.parametrize('scheduler', dask_schedulers)
def test_regrid_dask_to_locstream(request, scheduler):
# chunked dask array (no xarray metadata)


scheduler = request.getfixturevalue(scheduler)
regridder = xe.Regridder(ds_in, ds_locs, 'bilinear', locstream_out=True)

indata = ds_in_chunked['data4D'].data
outdata = regridder(indata)
assert dask.is_dask_collection(outdata)


def test_regrid_dask_from_locstream():
@pytest.mark.parametrize('scheduler', dask_schedulers)
def test_regrid_dask_from_locstream(request, scheduler):
# chunked dask array (no xarray metadata)


scheduler = request.getfixturevalue(scheduler)
regridder = xe.Regridder(ds_locs, ds_in, 'nearest_s2d', locstream_in=True)

outdata = regridder(ds_locs['lat'].data)
outdata = regridder(ds_locs.chunk()['lat'].data)
assert dask.is_dask_collection(outdata)


def test_regrid_dataarray_dask():
@pytest.mark.parametrize('scheduler', dask_schedulers)
def test_regrid_dataarray_dask(request, scheduler):
# xarray.DataArray containing chunked dask array

scheduler = request.getfixturevalue(scheduler)
regridder = xe.Regridder(ds_in, ds_out, 'conservative')

dr_in = ds_in_chunked['data4D']
dr_out = regridder(dr_in)
assert dask.is_dask_collection(dr_out)

assert dr_out.data.shape == dr_in.data.shape[:-2] + horiz_shape_out
assert dr_out.data.chunksize == dr_in.data.chunksize[:-2] + horiz_shape_out
Expand All @@ -324,22 +338,26 @@ def test_regrid_dataarray_dask():
assert_equal(dr_out['lat'].values, ds_out['lat'].values)
assert_equal(dr_out['lon'].values, ds_out['lon'].values)


def test_regrid_dataarray_dask_to_locstream():
@pytest.mark.parametrize('scheduler', dask_schedulers)
def test_regrid_dataarray_dask_to_locstream(request, scheduler):
# xarray.DataArray containing chunked dask array

scheduler = request.getfixturevalue(scheduler)
regridder = xe.Regridder(ds_in, ds_locs, 'bilinear', locstream_out=True)

dr_in = ds_in_chunked['data4D']
dr_out = regridder(dr_in)
assert dask.is_dask_collection(dr_out)


def test_regrid_dataarray_dask_from_locstream():
@pytest.mark.parametrize('scheduler', dask_schedulers)
def test_regrid_dataarray_dask_from_locstream(request, scheduler):
# xarray.DataArray containing chunked dask array


scheduler = request.getfixturevalue(scheduler)
regridder = xe.Regridder(ds_locs, ds_in, 'nearest_s2d', locstream_in=True)

outdata = regridder(ds_locs['lat'])
outdata = regridder(ds_locs.chunk()['lat'])
assert dask.is_dask_collection(outdata)


def test_regrid_dataset():
Expand Down

0 comments on commit 9e76f18

Please sign in to comment.