diff --git a/hypernets_processor/data_io/dataset_util.py b/hypernets_processor/data_io/dataset_util.py index 126f2a09..0159f798 100644 --- a/hypernets_processor/data_io/dataset_util.py +++ b/hypernets_processor/data_io/dataset_util.py @@ -53,6 +53,8 @@ def create_default_array(dim_sizes, dtype, dim_names=None, fill_value=None): if dim_names is not None: default_array = DataArray(empty_array, dims=dim_names) + elif (dim_names is None) and (dim_sizes == []): + default_array = DataArray(empty_array) else: default_array = DataArray(empty_array, dims=DEFAULT_DIM_NAMES[-len(dim_sizes):]) @@ -259,6 +261,50 @@ def unpack_flags(da): return ds + @staticmethod + def get_flags_mask_or(da, flags=None): + """ + Returns boolean mask for set of flags, defined as logical or of flags + + :type da: xarray.DataArray + :param da: dataset + + :type flags: list + :param flags: list of flags (if unset all data flags selected) + + :return: flag masks + :rtype: numpy.ndarray + """ + + flags_ds = DatasetUtil.unpack_flags(da) + + flags = flags if flags is not None else flags_ds.variables + mask_flags = [flags_ds[flag].values for flag in flags] + + return np.logical_or.reduce(mask_flags) + + @staticmethod + def get_flags_mask_and(da, flags=None): + """ + Returns boolean mask for set of flags, defined as logical and of flags + + :type da: xarray.DataArray + :param da: dataset + + :type flags: list + :param flags: list of flags (if unset all data flags selected) + + :return: flag masks + :rtype: numpy.ndarray + """ + + flags_ds = DatasetUtil.unpack_flags(da) + + flags = flags if flags is not None else flags_ds.variables + mask_flags = [flags_ds[flag].values for flag in flags] + + return np.logical_and.reduce(mask_flags) + @staticmethod def set_flag(da, flag_name, error_if_set=False): """ @@ -281,7 +327,9 @@ def set_flag(da, flag_name, error_if_set=False): flag_bit = flag_meanings.index(flag_name) flag_mask = flag_masks[flag_bit] - return da | flag_mask + da.values = da.values | flag_mask + + return da @staticmethod def unset_flag(da, flag_name, error_if_unset=False): @@ -305,7 +353,9 @@ def unset_flag(da, flag_name, error_if_unset=False): flag_bit = flag_meanings.index(flag_name) flag_mask = flag_masks[flag_bit] - return da & ~flag_mask + da.values = da.values & ~flag_mask + + return da @staticmethod def get_set_flags(da): diff --git a/hypernets_processor/data_io/tests/test_dataset_util.py b/hypernets_processor/data_io/tests/test_dataset_util.py index 23b43a38..a82ca198 100644 --- a/hypernets_processor/data_io/tests/test_dataset_util.py +++ b/hypernets_processor/data_io/tests/test_dataset_util.py @@ -288,6 +288,62 @@ def test_unpack_flags(self): self.assertTrue((flags["flag7"].data == empty).all()) self.assertTrue((flags["flag8"].data == empty).all()) + def test_get_flags_mask_or(self): + + ds = Dataset() + meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"] + flags_vector_variable = DatasetUtil.create_flags_variable([2,3], meanings, dim_names=["dim1", "dim2"], + attributes={"standard_name": "std"}) + + ds["flags"] = flags_vector_variable + ds["flags"] = DatasetUtil.set_flag(ds["flags"], "flag4") + ds["flags"][0, 1] = DatasetUtil.set_flag(ds["flags"][0, 1], "flag5") + ds["flags"][1, 1] = DatasetUtil.set_flag(ds["flags"][1, 1], "flag2") + ds["flags"][1, 2] = DatasetUtil.set_flag(ds["flags"][1, 2], "flag7") + + flags_mask = DatasetUtil.get_flags_mask_or(ds["flags"], flags=["flag5", "flag2", "flag7"]) + + expected_flags_mask = np.array([[False, True, False], [False, True, True]], dtype=bool) + + np.testing.assert_array_almost_equal(flags_mask, expected_flags_mask) + + def test_get_flags_mask_all(self): + ds = Dataset() + meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"] + flags_vector_variable = DatasetUtil.create_flags_variable([2, 3], meanings, dim_names=["dim1", "dim2"], + attributes={"standard_name": "std"}) + + ds["flags"] = flags_vector_variable + ds["flags"] = DatasetUtil.set_flag(ds["flags"], "flag4") + ds["flags"][0, 1] = DatasetUtil.set_flag(ds["flags"][0, 1], "flag5") + ds["flags"][1, 1] = DatasetUtil.set_flag(ds["flags"][1, 1], "flag2") + ds["flags"][1, 2] = DatasetUtil.set_flag(ds["flags"][1, 2], "flag7") + + flags_mask = DatasetUtil.get_flags_mask_or(ds["flags"]) + + expected_flags_mask = np.array([[True, True, True], [True, True, True]], dtype=bool) + + np.testing.assert_array_almost_equal(flags_mask, expected_flags_mask) + + def test_get_flags_mask_and(self): + + ds = Dataset() + meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"] + flags_vector_variable = DatasetUtil.create_flags_variable([2,3], meanings, dim_names=["dim1", "dim2"], + attributes={"standard_name": "std"}) + + ds["flags"] = flags_vector_variable + ds["flags"] = DatasetUtil.set_flag(ds["flags"], "flag4") + ds["flags"][0, 1] = DatasetUtil.set_flag(ds["flags"][0, 1], "flag5") + ds["flags"][1, 1] = DatasetUtil.set_flag(ds["flags"][1, 1], "flag2") + ds["flags"][1, 1] = DatasetUtil.set_flag(ds["flags"][1, 1], "flag7") + + flags_mask = DatasetUtil.get_flags_mask_and(ds["flags"], flags=["flag2", "flag7"]) + + expected_flags_mask = np.array([[False, False, False], [False, True, False]], dtype=bool) + + np.testing.assert_array_almost_equal(flags_mask, expected_flags_mask) + def test_get_set_flags(self): ds = Dataset()