Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions hypernets_processor/data_io/dataset_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def create_default_array(dim_sizes, dtype, dim_names=None, fill_value=None):

if dim_names is not None:
default_array = DataArray(empty_array, dims=dim_names)
elif (dim_names is None) and (dim_sizes == []):
default_array = DataArray(empty_array)
else:
default_array = DataArray(empty_array, dims=DEFAULT_DIM_NAMES[-len(dim_sizes):])

Expand Down Expand Up @@ -259,6 +261,50 @@ def unpack_flags(da):

return ds

@staticmethod
def get_flags_mask_or(da, flags=None):
"""
Returns boolean mask for set of flags, defined as logical or of flags

:type da: xarray.DataArray
:param da: dataset

:type flags: list
:param flags: list of flags (if unset all data flags selected)

:return: flag masks
:rtype: numpy.ndarray
"""

flags_ds = DatasetUtil.unpack_flags(da)

flags = flags if flags is not None else flags_ds.variables
mask_flags = [flags_ds[flag].values for flag in flags]

return np.logical_or.reduce(mask_flags)

@staticmethod
def get_flags_mask_and(da, flags=None):
"""
Returns boolean mask for set of flags, defined as logical and of flags

:type da: xarray.DataArray
:param da: dataset

:type flags: list
:param flags: list of flags (if unset all data flags selected)

:return: flag masks
:rtype: numpy.ndarray
"""

flags_ds = DatasetUtil.unpack_flags(da)

flags = flags if flags is not None else flags_ds.variables
mask_flags = [flags_ds[flag].values for flag in flags]

return np.logical_and.reduce(mask_flags)

@staticmethod
def set_flag(da, flag_name, error_if_set=False):
"""
Expand All @@ -281,7 +327,9 @@ def set_flag(da, flag_name, error_if_set=False):
flag_bit = flag_meanings.index(flag_name)
flag_mask = flag_masks[flag_bit]

return da | flag_mask
da.values = da.values | flag_mask

return da

@staticmethod
def unset_flag(da, flag_name, error_if_unset=False):
Expand All @@ -305,7 +353,9 @@ def unset_flag(da, flag_name, error_if_unset=False):
flag_bit = flag_meanings.index(flag_name)
flag_mask = flag_masks[flag_bit]

return da & ~flag_mask
da.values = da.values & ~flag_mask

return da

@staticmethod
def get_set_flags(da):
Expand Down
56 changes: 56 additions & 0 deletions hypernets_processor/data_io/tests/test_dataset_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,62 @@ def test_unpack_flags(self):
self.assertTrue((flags["flag7"].data == empty).all())
self.assertTrue((flags["flag8"].data == empty).all())

def test_get_flags_mask_or(self):

ds = Dataset()
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
flags_vector_variable = DatasetUtil.create_flags_variable([2,3], meanings, dim_names=["dim1", "dim2"],
attributes={"standard_name": "std"})

ds["flags"] = flags_vector_variable
ds["flags"] = DatasetUtil.set_flag(ds["flags"], "flag4")
ds["flags"][0, 1] = DatasetUtil.set_flag(ds["flags"][0, 1], "flag5")
ds["flags"][1, 1] = DatasetUtil.set_flag(ds["flags"][1, 1], "flag2")
ds["flags"][1, 2] = DatasetUtil.set_flag(ds["flags"][1, 2], "flag7")

flags_mask = DatasetUtil.get_flags_mask_or(ds["flags"], flags=["flag5", "flag2", "flag7"])

expected_flags_mask = np.array([[False, True, False], [False, True, True]], dtype=bool)

np.testing.assert_array_almost_equal(flags_mask, expected_flags_mask)

def test_get_flags_mask_all(self):
ds = Dataset()
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
flags_vector_variable = DatasetUtil.create_flags_variable([2, 3], meanings, dim_names=["dim1", "dim2"],
attributes={"standard_name": "std"})

ds["flags"] = flags_vector_variable
ds["flags"] = DatasetUtil.set_flag(ds["flags"], "flag4")
ds["flags"][0, 1] = DatasetUtil.set_flag(ds["flags"][0, 1], "flag5")
ds["flags"][1, 1] = DatasetUtil.set_flag(ds["flags"][1, 1], "flag2")
ds["flags"][1, 2] = DatasetUtil.set_flag(ds["flags"][1, 2], "flag7")

flags_mask = DatasetUtil.get_flags_mask_or(ds["flags"])

expected_flags_mask = np.array([[True, True, True], [True, True, True]], dtype=bool)

np.testing.assert_array_almost_equal(flags_mask, expected_flags_mask)

def test_get_flags_mask_and(self):

ds = Dataset()
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
flags_vector_variable = DatasetUtil.create_flags_variable([2,3], meanings, dim_names=["dim1", "dim2"],
attributes={"standard_name": "std"})

ds["flags"] = flags_vector_variable
ds["flags"] = DatasetUtil.set_flag(ds["flags"], "flag4")
ds["flags"][0, 1] = DatasetUtil.set_flag(ds["flags"][0, 1], "flag5")
ds["flags"][1, 1] = DatasetUtil.set_flag(ds["flags"][1, 1], "flag2")
ds["flags"][1, 1] = DatasetUtil.set_flag(ds["flags"][1, 1], "flag7")

flags_mask = DatasetUtil.get_flags_mask_and(ds["flags"], flags=["flag2", "flag7"])

expected_flags_mask = np.array([[False, False, False], [False, True, False]], dtype=bool)

np.testing.assert_array_almost_equal(flags_mask, expected_flags_mask)

def test_get_set_flags(self):

ds = Dataset()
Expand Down