Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow apply_mask to handle multi-channel Sv dataset #1010

Merged
merged 17 commits into from
Mar 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 106 additions & 38 deletions echopype/mask/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,24 @@
}


def _validate_source_ds(source_ds, storage_options_ds):
"""
Validate the input ``source_ds`` and the associated ``storage_options_mask``.
"""
# Validate the source_ds type or path (if it is provided)
source_ds, file_type = validate_source_ds_da(source_ds, storage_options_ds)

if isinstance(source_ds, str):
# open up Dataset using source_ds path
source_ds = xr.open_dataset(source_ds, engine=file_type, chunks={}, **storage_options_ds)

# Check source_ds coordinates
if "ping_time" not in source_ds or "range_sample" not in source_ds:
raise ValueError("'source_ds' must have coordinates 'ping_time' and 'range_sample'!")

return source_ds


def _validate_and_collect_mask_input(
mask: Union[
Union[xr.DataArray, str, pathlib.Path], List[Union[xr.DataArray, str, pathlib.Path]]
Expand Down Expand Up @@ -82,6 +100,15 @@ def _validate_and_collect_mask_input(
mask_val, engine=file_type, chunks={}, **storage_options_mask[mask_ind]
)

# check mask coordinates
# the coordinate sequence matters, so fix the tuple form
allowed_dims = [
("ping_time", "range_sample"),
("channel", "ping_time", "range_sample"),
]
if mask[mask_ind].dims not in allowed_dims:
raise ValueError("All masks must have dimensions ('ping_time', 'range_sample')!")

else:
if not isinstance(storage_options_mask, dict):
raise ValueError(
Expand All @@ -101,7 +128,7 @@ def _validate_and_collect_mask_input(

def _check_var_name_fill_value(
source_ds: xr.Dataset, var_name: str, fill_value: Union[int, float, np.ndarray, xr.DataArray]
) -> None:
) -> Union[int, float, np.ndarray, xr.DataArray]:
"""
Ensures that the inputs ``var_name`` and ``fill_value`` for the function
``apply_mask`` were appropriately provided.
Expand All @@ -115,6 +142,11 @@ def _check_var_name_fill_value(
fill_value: int or float or np.ndarray or xr.DataArray
Specifies the value(s) at false indices

Returns
-------
fill_value: int or float or np.ndarray or xr.DataArray
fill_value with sanitized dimensions

Raises
------
TypeError
Expand All @@ -139,11 +171,25 @@ def _check_var_name_fill_value(
"The input fill_value must be of type int or " "float or np.ndarray or xr.DataArray!"
)

# make sure that fill_values is the same shape as var_name, if it is an array
if isinstance(fill_value, (np.ndarray, xr.DataArray)) and (
fill_value.shape != source_ds[var_name].shape
):
raise ValueError("If fill_value is an array is must be of the same shape as var_name!")
# make sure that fill_values is the same shape as var_name
if isinstance(fill_value, (np.ndarray, xr.DataArray)):
if isinstance(fill_value, xr.DataArray):
fill_value = fill_value.data.squeeze() # squeeze out length=1 channel dimension
elif isinstance(fill_value, np.ndarray):
fill_value = fill_value.squeeze() # squeeze out length=1 channel dimension

source_ds_shape = (
source_ds[var_name].isel(channel=0).shape
if "channel" in source_ds[var_name].coords
else source_ds[var_name].shape
)

if fill_value.shape != source_ds_shape:
raise ValueError(
f"If fill_value is an array it must be of the same shape as {var_name}!"
)

return fill_value


def _variable_prov_attrs(
Expand Down Expand Up @@ -214,13 +260,24 @@ def apply_mask(
source_ds: xr.Dataset, str, or pathlib.Path
Points to a Dataset that contains the variable the mask should be applied to
mask: xr.DataArray, str, pathlib.Path, or a list of these datatypes
The mask(s) to be applied. Can be a single input or list that corresponds to
a DataArray or a path. If a path is provided this should point to a zarr or
netcdf file with only one data variable in it.
The mask(s) to be applied.
Can be a single input or list that corresponds to a DataArray or a path.
Each entry in the list must have dimensions ``('ping_time', 'range_sample')``.
Multi-channel masks are not currently supported.
If a path is provided this should point to a zarr or netcdf file with only
one data variable in it.
If the input ``mask`` is a list, a logical AND will be used to produce the final
mask that will be applied to ``var_name``.
var_name: str, default="Sv"
The Sv variable name in ``source_ds`` that the mask should be applied to
The Sv variable name in ``source_ds`` that the mask should be applied to.
This variable needs to have coordinates ``ping_time`` and ``range_sample``,
and can optionally also have coordinate ``channel``.
In the case of a multi-channel Sv data variable, the ``mask`` will be broadcast
to all channels.
fill_value: int, float, np.ndarray, or xr.DataArray, default=np.nan
Value(s) at masked indices
Value(s) at masked indices.
If ``fill_value`` is of type ``np.ndarray`` or ``xr.DataArray``,
it must have the same shape as each entry of ``mask``.
storage_options_ds: dict, default={}
Any additional parameters for the storage backend, corresponding to the
path provided for ``source_ds``
Expand All @@ -234,32 +291,18 @@ def apply_mask(
-------
xr.Dataset
A Dataset with the same format of ``source_ds`` with the mask(s) applied to ``var_name``

Notes
-----
If the input ``mask`` is a list, then a logical AND will be used to produce the final
mask that will be applied to ``var_name``.
"""

# validate the source_ds type or path (if it is provided)
source_ds, file_type = validate_source_ds_da(source_ds, storage_options_ds)
# Validate the source_ds
source_ds = _validate_source_ds(source_ds, storage_options_ds)

if isinstance(source_ds, str):
# open up Dataset using source_ds path
source_ds = xr.open_dataset(source_ds, engine=file_type, chunks={}, **storage_options_ds)

# validate and form the mask input to be used downstream
# Validate and form the mask input to be used downstream
mask = _validate_and_collect_mask_input(mask, storage_options_mask)

# ensure that var_name and fill_value were correctly provided
_check_var_name_fill_value(source_ds, var_name, fill_value)

# select data only, if fill_value is a DataArray (necessary since
# xr.where(keep_attrs=True) is not functioning correctly)
if isinstance(fill_value, xr.DataArray):
fill_value = fill_value.data
# Check var_name and sanitize fill_value dimensions if an array
fill_value = _check_var_name_fill_value(source_ds, var_name, fill_value)

# obtain final mask to be applied to var_name
# Obtain final mask to be applied to var_name
if isinstance(mask, list):
# perform a logical AND element-wise operation across the masks
final_mask = np.logical_and.reduce(mask)
Expand All @@ -269,18 +312,43 @@ def apply_mask(
else:
final_mask = mask

# sanity check to make sure final_mask is the same shape as source_ds[var_name]
if final_mask.shape != source_ds[var_name].shape:
raise ValueError("Final constructed mask is not the same shape as source_ds[var_name]!")
# Sanity check: final_mask should be of the same shape as source_ds[var_name]
# along the ping_time and range_sample dimensions
def get_ch_shape(da):
return da.isel(channel=0).shape if "channel" in da.dims else da.shape

# Below operate on the actual data array to be masked
source_da = source_ds[var_name]

source_da_shape = get_ch_shape(source_da)
final_mask_shape = get_ch_shape(final_mask)

if final_mask_shape != source_da_shape:
raise ValueError(
f"The final constructed mask is not of the same shape as source_ds[{var_name}] "
"along the ping_time and range_sample dimensions!"
)

# final_mask is always an xr.DataArray with at most length=1 channel dimension
if "channel" in final_mask.dims:
final_mask = final_mask.isel(channel=0)

# Make sure fill_value and final_mask are expanded in dimensions
if "channel" in source_da.dims:
if isinstance(fill_value, np.ndarray):
fill_value = np.array([fill_value] * source_da["channel"].size)
final_mask = np.array([final_mask.data] * source_da["channel"].size)

# apply the mask to var_name
var_name_masked = xr.where(final_mask, x=source_ds[var_name], y=fill_value, keep_attrs=True)
# Apply the mask to var_name
# Somehow keep_attrs=True errors out here, so will attach later
var_name_masked = xr.where(final_mask, x=source_da, y=fill_value)

# obtain a shallow copy of source_ds
# Obtain a shallow copy of source_ds
output_ds = source_ds.copy(deep=False)

# replace var_name with var_name_masked
# Replace var_name with var_name_masked
output_ds[var_name] = var_name_masked
output_ds[var_name] = output_ds[var_name].assign_attrs(source_da.attrs)

# Add or modify variable and global (dataset) provenance attributes
output_ds[var_name] = output_ds[var_name].assign_attrs(
Expand Down
Loading