Skip to content

Commit

Permalink
Update _dataset_from_files to allow merging conflicting null and non-…
Browse files Browse the repository at this point in the history
…null values. Raise a warning if conflicting non-null values are merged, but supress the MergeError. Move null-masking into dataset creation util functions to ensure that nulls are set BEFORE merging to allow the above to work as designed. (#29)
  • Loading branch information
aazuspan committed Oct 5, 2021
1 parent 334b9f7 commit 94a951b
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 16 deletions.
4 changes: 2 additions & 2 deletions test/test_utils.py
Expand Up @@ -150,14 +150,14 @@ def test_parse_invalid_time_warns():
def test_dataarray_from_file():
"""Test that an xarray.DataArray can be created from a valid GeoTIFF."""
file_path = TEST_IMAGE_PATHS[0]
da = wxee.utils._dataarray_from_file(file_path)
da = wxee.utils._dataarray_from_file(file_path, masked=True, nodata=0)

assert da.name == "pr"


def test_dataset_from_files():
"""Test than an xarray.Dataset can be created from a list of valid GeoTIFFs."""
ds = wxee.utils._dataset_from_files(TEST_IMAGE_PATHS)
ds = wxee.utils._dataset_from_files(TEST_IMAGE_PATHS, masked=True, nodata=0)

assert ds.time.size == 3
assert all([var in ds.variables for var in ["pr", "rmax"]])
Expand Down
6 changes: 1 addition & 5 deletions wxee/collection.py
Expand Up @@ -131,11 +131,7 @@ def to_xarray(
max_attempts=max_attempts,
)

ds = _dataset_from_files(files)

# Mask the nodata values. This will convert int datasets to float.
if masked:
ds = ds.where(ds != nodata)
ds = _dataset_from_files(files, masked, nodata)

if path:
ds.to_netcdf(path, mode="w")
Expand Down
6 changes: 1 addition & 5 deletions wxee/image.py
Expand Up @@ -91,11 +91,7 @@ def to_xarray(
progress=progress,
)

ds = _dataset_from_files(files)

# Mask the nodata values. This will convert int datasets to float.
if masked:
ds = ds.where(ds != nodata)
ds = _dataset_from_files(files, masked, nodata)

if path:
ds.to_netcdf(path, mode="w")
Expand Down
22 changes: 18 additions & 4 deletions wxee/utils.py
Expand Up @@ -124,14 +124,24 @@ def _create_retry_session(max_attempts: int) -> requests.Session:
return session


def _dataset_from_files(files: List[str]) -> xr.Dataset:
def _dataset_from_files(files: List[str], masked: bool, nodata: int) -> xr.Dataset:
"""Create an xarray.Dataset from a list of raster files."""
das = [_dataarray_from_file(file) for file in files]
das = [_dataarray_from_file(file, masked, nodata) for file in files]

return xr.merge(das)
try:
# Allow conflicting values if one is null, take the non-null value
merged = xr.merge(das, compat="no_conflicts")
except xr.core.merge.MergeError:
# If non-null conflicting values occur, take the first value and warn the user
merged = xr.merge(das, compat="override")
warnings.warn(
"Different non-null values were encountered for the same variable at the same time coordinate. The first value was taken."
)

return merged

def _dataarray_from_file(file: str) -> xr.DataArray:

def _dataarray_from_file(file: str, masked: bool, nodata: int) -> xr.DataArray:
"""Create an xarray.DataArray from a single file by parsing datetimes and variables from the file name.
The file name must follow the format "{dimension}.{coordinate}.{variable}.{extension}".
Expand All @@ -141,6 +151,10 @@ def _dataarray_from_file(file: str) -> xr.DataArray:

da = da.expand_dims({dim: [coord]}).rename(var).squeeze("band").drop_vars("band")

# Mask the nodata values. This will convert int datasets to float.
if masked:
da = da.where(da != nodata)

return da


Expand Down

0 comments on commit 94a951b

Please sign in to comment.