From eb8f72b6388aa99cd9187974dc66aa4c35d274c1 Mon Sep 17 00:00:00 2001 From: shunt16 Date: Fri, 16 Oct 2020 15:34:47 +0100 Subject: [PATCH] add flag handling methods --- hypernets_processor/data_io/dataset_util.py | 148 ++++++++++++++++- .../data_io/tests/test_dataset_util.py | 155 +++++++++++++++++- 2 files changed, 301 insertions(+), 2 deletions(-) diff --git a/hypernets_processor/data_io/dataset_util.py b/hypernets_processor/data_io/dataset_util.py index fbcdef70..4feb2836 100644 --- a/hypernets_processor/data_io/dataset_util.py +++ b/hypernets_processor/data_io/dataset_util.py @@ -4,7 +4,7 @@ from hypernets_processor.version import __version__ import string -from xarray import Variable, DataArray +from xarray import Variable, DataArray, Dataset import numpy as np @@ -220,6 +220,152 @@ def get_default_fill_value(dtype): elif dtype == np.float64: return np.float64(9.969209968386869E36) + @staticmethod + def _get_flag_encoding(da): + """ + Returns flag encoding for flag type data array + + :type da: xarray.DataArray + :param da: data array + + :return: flag meanings + :rtype: list + + :return: flag masks + :rtype: list + """ + + try: + flag_meanings = da.attrs["flag_meanings"].split() + flag_masks = [int(fm) for fm in da.attrs["flag_masks"].split(",")] + except KeyError: + raise KeyError(da.name + " not a flag variable") + + return flag_meanings, flag_masks + + @staticmethod + def unpack_flags(da): + """ + Breaks down flag data array into dataset of boolean masks for each flag + + :type da: xarray.DataArray + :param da: dataset + + :return: flag masks + :rtype: xarray.Dataset + """ + + flag_meanings, flag_masks = DatasetUtil._get_flag_encoding(da) + + ds = Dataset() + for flag_meaning, flag_mask in zip(flag_meanings, flag_masks): + ds[flag_meaning] = DatasetUtil.create_variable(list(da.shape), bool, dim_names=list(da.dims)) + ds[flag_meaning] = (da & flag_mask).astype(bool) + + return ds + + @staticmethod + def set_flag(da, flag_name, error_if_set=False): + """ + Sets named flag for elements in data array + + :type da: xarray.DataArray + :param da: dataset + + :type flag_name: str + :param flag_name: name of flag to set + + :type error_if_set: bool + :param error_if_set: raises error if chosen flag is already set for any element + """ + + set_flags = DatasetUtil.unpack_flags(da)[flag_name] + + if np.any(set_flags == True) and error_if_set: + raise ValueError("Flag " + flag_name + " already set for variable " + da.name) + + # Find flag mask + flag_meanings, flag_masks = DatasetUtil._get_flag_encoding(da) + flag_bit = flag_meanings.index(flag_name) + flag_mask = flag_masks[flag_bit] + + return da | flag_mask + + @staticmethod + def unset_flag(da, flag_name, error_if_unset=False): + """ + Unsets named flag for specified index of dataset variable + + :type da: xarray.DataArray + :param da: data array + + :type flag_name: str + :param flag_name: name of flag to unset + + :type error_if_unset: bool + :param error_if_unset: raises error if chosen flag is already set at specified index + """ + + set_flags = DatasetUtil.unpack_flags(da)[flag_name] + + if np.any(set_flags == False) and error_if_unset: + raise ValueError("Flag " + flag_name + " already set for variable " + da.name) + + # Find flag mask + flag_meanings, flag_masks = DatasetUtil._get_flag_encoding(da) + flag_bit = flag_meanings.index(flag_name) + flag_mask = flag_masks[flag_bit] + + return da & ~flag_mask + + @staticmethod + def get_set_flags(da): + """ + Return list of set flags for single element data array + + :type da: xarray.DataArray + :param da: single element data array + + :return: set flags + :rtype: list + """ + + if da.shape != (): + raise ValueError("Must pass single element data array") + + flag_meanings, flag_masks = DatasetUtil._get_flag_encoding(da) + + set_flags = [] + for flag_meaning, flag_mask in zip(flag_meanings, flag_masks): + if (da & flag_mask): + set_flags.append(flag_meaning) + + return set_flags + + @staticmethod + def check_flag_set(da, flag_name): + """ + Returns if flag for single element data array + + :type da: xarray.DataArray + :param da: single element data array + + :type flag_name: str + :param flag_name: name of flag to set + + :return: set flags + :rtype: list + """ + + if da.shape != (): + raise ValueError("Must pass single element data array") + + set_flags = DatasetUtil.get_set_flags(da) + + if flag_name in set_flags: + return True + return False + if __name__ == "__main__": pass diff --git a/hypernets_processor/data_io/tests/test_dataset_util.py b/hypernets_processor/data_io/tests/test_dataset_util.py index 65ca025a..23b43a38 100644 --- a/hypernets_processor/data_io/tests/test_dataset_util.py +++ b/hypernets_processor/data_io/tests/test_dataset_util.py @@ -4,7 +4,7 @@ import unittest import numpy as np -from xarray import DataArray, Variable +from xarray import DataArray, Variable, Dataset from hypernets_processor.data_io.dataset_util import DatasetUtil from hypernets_processor.version import __version__ @@ -241,6 +241,159 @@ def test_get_default_fill_value(self): self.assertEqual(np.float32(9.96921E36), DatasetUtil.get_default_fill_value(np.float32)) self.assertEqual(9.969209968386869E36, DatasetUtil.get_default_fill_value(np.float64)) + def test__get_flag_encoding(self): + + ds = Dataset() + meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"] + masks = [1, 2, 4, 8, 16, 32, 64, 128] + flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"], + attributes={"standard_name": "std"}) + + ds["flags"] = flags_vector_variable + + meanings_out, masks_out = DatasetUtil._get_flag_encoding(ds["flags"]) + + self.assertCountEqual(meanings, meanings_out) + self.assertCountEqual(masks, masks_out) + + def test__get_flag_encoding_not_flag_var(self): + ds = Dataset() + ds["array_variable"] = DatasetUtil.create_variable([7, 8, 3], np.int8, attributes={"standard_name": "std"}) + + self.assertRaises(KeyError, DatasetUtil._get_flag_encoding, ds["array_variable"]) + + def test_unpack_flags(self): + + ds = Dataset() + meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"] + masks = [1, 2, 4, 8, 16, 32, 64, 128] + flags_vector_variable = DatasetUtil.create_flags_variable([2,3], meanings, dim_names=["dim1", "dim2"], + attributes={"standard_name": "std"}) + + ds["flags"] = flags_vector_variable + ds["flags"][0, 0] = ds["flags"][0, 0] | 8 + + empty = np.zeros((2, 3), bool) + flag4 = np.zeros((2, 3), bool) + flag4[0,0] = True + + flags = DatasetUtil.unpack_flags(ds["flags"]) + + self.assertTrue((flags["flag1"].data == empty).all()) + self.assertTrue((flags["flag2"].data == empty).all()) + self.assertTrue((flags["flag3"].data == empty).all()) + self.assertTrue((flags["flag4"].data == flag4).all()) + self.assertTrue((flags["flag5"].data == empty).all()) + self.assertTrue((flags["flag6"].data == empty).all()) + self.assertTrue((flags["flag7"].data == empty).all()) + self.assertTrue((flags["flag8"].data == empty).all()) + + def test_get_set_flags(self): + + ds = Dataset() + meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"] + flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"], + attributes={"standard_name": "std"}) + ds["flags"] = flags_vector_variable + ds["flags"][3] = ds["flags"][3] | 8 + ds["flags"][3] = ds["flags"][3] | 32 + + set_flags = DatasetUtil.get_set_flags(ds["flags"][3]) + + self.assertCountEqual(set_flags, ["flag4", "flag6"]) + + def test_get_set_flags_2d(self): + + ds = Dataset() + meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"] + flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"], + attributes={"standard_name": "std"}) + ds["flags"] = flags_vector_variable + + self.assertRaises(ValueError, DatasetUtil.get_set_flags, ds["flags"]) + + def test_check_flag_set_true(self): + ds = Dataset() + meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"] + flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"], + attributes={"standard_name": "std"}) + ds["flags"] = flags_vector_variable + ds["flags"][3] = ds["flags"][3] | 8 + ds["flags"][3] = ds["flags"][3] | 32 + + flag_set = DatasetUtil.check_flag_set(ds["flags"][3], "flag6") + + self.assertTrue(flag_set) + + def test_check_flag_set_false(self): + ds = Dataset() + meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"] + flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"], + attributes={"standard_name": "std"}) + ds["flags"] = flags_vector_variable + ds["flags"][3] = ds["flags"][3] | 8 + + flag_set = DatasetUtil.check_flag_set(ds["flags"][3], "flag6") + + self.assertFalse(flag_set) + + def test_check_flag_set_2d(self): + ds = Dataset() + meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"] + flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"], + attributes={"standard_name": "std"}) + ds["flags"] = flags_vector_variable + + self.assertRaises(ValueError, DatasetUtil.check_flag_set, ds["flags"], "flag6") + + def test_set_flag(self): + + ds = Dataset() + meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"] + flags_vector_variable = DatasetUtil.create_flags_variable([5, 4], meanings, dim_names=["dim1", "dim2"], + attributes={"standard_name": "std"}) + ds["flags"] = flags_vector_variable + + ds["flags"] = DatasetUtil.set_flag(ds["flags"], "flag4") + + flags = np.full(ds["flags"].shape, 0|8) + + self.assertTrue((ds["flags"].data == flags).all()) + + def test_set_flag_error_if_set(self): + ds = Dataset() + meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"] + flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"], + attributes={"standard_name": "std"}) + ds["flags"] = flags_vector_variable + ds["flags"][3] = ds["flags"][3] | 8 + + self.assertRaises(ValueError, DatasetUtil.set_flag, ds["flags"], "flag4", error_if_set=True) + + def test_unset_flag(self): + + ds = Dataset() + meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"] + flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"], + attributes={"standard_name": "std"}) + ds["flags"] = flags_vector_variable + ds["flags"][:] = ds["flags"][:] | 8 + + ds["flags"] = DatasetUtil.unset_flag(ds["flags"], "flag4") + + flags = np.zeros(ds["flags"].shape) + + self.assertTrue((ds["flags"].data == flags).all()) + + def test_set_flag_error_if_unset(self): + ds = Dataset() + meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"] + flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"], + attributes={"standard_name": "std"}) + ds["flags"] = flags_vector_variable + + self.assertRaises(ValueError, DatasetUtil.unset_flag, ds["flags"], "flag4", error_if_unset=True) + if __name__ == '__main__': unittest.main()