Skip to content

Commit

Permalink
Merge c846ca6 into 0a48bbc
Browse files Browse the repository at this point in the history
  • Loading branch information
aulemahal committed Sep 27, 2023
2 parents 0a48bbc + c846ca6 commit 1548222
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 21 deletions.
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Bug fixes
^^^^^^^^^
* Fixed an error in the `pytest` configuration that prevented copying of testing data to thread-safe caches of workers under certain conditions (this should always occur). (:pull:`1473`).
* Coincidentally, this also fixes an error that caused `pytest` to error-out when invoked without an active internet connection. Running `pytest` without network access is now supported (requires cached testing data). (:issue:`1468`).
* Calling a ``sdba.map_blocks``-wrapped function with data chunked along the reduced dimensions will raise an error. This forbids chunking the trained dataset along the distribution dimensions, for example. (:issue:`1481`, :pull:`1482`).

Breaking changes
^^^^^^^^^^^^^^^^
Expand Down
13 changes: 13 additions & 0 deletions tests/test_sdba/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,16 @@ def func(ds, *, dim):
).load()
assert set(data.data.dims) == {"dayofyear"}
assert "leftover" in data


def test_map_blocks_error(tas_series):
tas = tas_series(np.arange(366), start="2000-01-01")
tas = tas.expand_dims(lat=[1, 2, 3, 4]).chunk(lat=1)

# Test dim parsing
@map_blocks(reduces=["lat"], data=[])
def func(ds, *, group, lon=None):
return ds.tas.rename("data").to_dataset()

with pytest.raises(ValueError, match="cannot be chunked"):
func(xr.Dataset(dict(tas=tas)), group="time")
52 changes: 31 additions & 21 deletions xclim/sdba/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def duck_empty(dims, sizes, dtype="float64", chunks=None):
def _decode_cf_coords(ds):
"""Decode coords in-place."""
crds = xr.decode_cf(ds.coords.to_dataset())
for crdname in ds.coords.keys():
for crdname in list(ds.coords.keys()):
ds[crdname] = crds[crdname]
# decode_cf introduces an encoding key for the dtype, which can confuse the netCDF writer
dtype = ds[crdname].encoding.get("dtype")
Expand Down Expand Up @@ -557,26 +557,6 @@ def _map_blocks(ds, **kwargs):
) and group is None:
raise ValueError("Missing required `group` argument.")

if uses_dask(ds):
# Use dask if any of the input is dask-backed.
chunks = (
dict(ds.chunks)
if isinstance(ds, xr.Dataset)
else dict(zip(ds.dims, ds.chunks))
)
if group is not None:
badchunks = {
dim: chunks.get(dim)
for dim in group.add_dims + [group.dim]
if len(chunks.get(dim, [])) > 1
}
if badchunks:
raise ValueError(
f"The dimension(s) over which we group cannot be chunked ({badchunks})."
)
else:
chunks = None

# Make translation dict
if group is not None:
placeholders = {
Expand All @@ -602,6 +582,36 @@ def _map_blocks(ds, **kwargs):
f"Dimension {dim} is meant to be added by the computation but it is already on one of the inputs."
)

if uses_dask(ds):
# Use dask if any of the input is dask-backed.
chunks = (
dict(ds.chunks)
if isinstance(ds, xr.Dataset)
else dict(zip(ds.dims, ds.chunks))
)
badchunks = {}
if group is not None:
badchunks.update(
{
dim: chunks.get(dim)
for dim in group.add_dims + [group.dim]
if len(chunks.get(dim, [])) > 1
}
)
badchunks.update(
{
dim: chunks.get(dim)
for dim in reduced_dims
if len(chunks.get(dim)) > 1
}
)
if badchunks:
raise ValueError(
f"The dimension(s) over which we group, reduce or interpolate cannot be chunked ({badchunks})."
)
else:
chunks = None

# Dimensions untouched by the function.
base_dims = list(set(ds.dims) - set(new_dims) - set(reduced_dims))

Expand Down

0 comments on commit 1548222

Please sign in to comment.