Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 147 additions & 1 deletion hypernets_processor/data_io/dataset_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from hypernets_processor.version import __version__
import string
from xarray import Variable, DataArray
from xarray import Variable, DataArray, Dataset
import numpy as np


Expand Down Expand Up @@ -220,6 +220,152 @@ def get_default_fill_value(dtype):
elif dtype == np.float64:
return np.float64(9.969209968386869E36)

@staticmethod
def _get_flag_encoding(da):
"""
Returns flag encoding for flag type data array

:type da: xarray.DataArray
:param da: data array

:return: flag meanings
:rtype: list

:return: flag masks
:rtype: list
"""

try:
flag_meanings = da.attrs["flag_meanings"].split()
flag_masks = [int(fm) for fm in da.attrs["flag_masks"].split(",")]
except KeyError:
raise KeyError(da.name + " not a flag variable")

return flag_meanings, flag_masks

@staticmethod
def unpack_flags(da):
"""
Breaks down flag data array into dataset of boolean masks for each flag

:type da: xarray.DataArray
:param da: dataset

:return: flag masks
:rtype: xarray.Dataset
"""

flag_meanings, flag_masks = DatasetUtil._get_flag_encoding(da)

ds = Dataset()
for flag_meaning, flag_mask in zip(flag_meanings, flag_masks):
ds[flag_meaning] = DatasetUtil.create_variable(list(da.shape), bool, dim_names=list(da.dims))
ds[flag_meaning] = (da & flag_mask).astype(bool)

return ds

@staticmethod
def set_flag(da, flag_name, error_if_set=False):
"""
Sets named flag for elements in data array

:type da: xarray.DataArray
:param da: dataset

:type flag_name: str
:param flag_name: name of flag to set

:type error_if_set: bool
:param error_if_set: raises error if chosen flag is already set for any element
"""

set_flags = DatasetUtil.unpack_flags(da)[flag_name]

if np.any(set_flags == True) and error_if_set:
raise ValueError("Flag " + flag_name + " already set for variable " + da.name)

# Find flag mask
flag_meanings, flag_masks = DatasetUtil._get_flag_encoding(da)
flag_bit = flag_meanings.index(flag_name)
flag_mask = flag_masks[flag_bit]

return da | flag_mask

@staticmethod
def unset_flag(da, flag_name, error_if_unset=False):
"""
Unsets named flag for specified index of dataset variable

:type da: xarray.DataArray
:param da: data array

:type flag_name: str
:param flag_name: name of flag to unset

:type error_if_unset: bool
:param error_if_unset: raises error if chosen flag is already set at specified index
"""

set_flags = DatasetUtil.unpack_flags(da)[flag_name]

if np.any(set_flags == False) and error_if_unset:
raise ValueError("Flag " + flag_name + " already set for variable " + da.name)

# Find flag mask
flag_meanings, flag_masks = DatasetUtil._get_flag_encoding(da)
flag_bit = flag_meanings.index(flag_name)
flag_mask = flag_masks[flag_bit]

return da & ~flag_mask

@staticmethod
def get_set_flags(da):
"""
Return list of set flags for single element data array

:type da: xarray.DataArray
:param da: single element data array

:return: set flags
:rtype: list
"""

if da.shape != ():
raise ValueError("Must pass single element data array")

flag_meanings, flag_masks = DatasetUtil._get_flag_encoding(da)

set_flags = []
for flag_meaning, flag_mask in zip(flag_meanings, flag_masks):
if (da & flag_mask):
set_flags.append(flag_meaning)

return set_flags

@staticmethod
def check_flag_set(da, flag_name):
"""
Returns if flag for single element data array

:type da: xarray.DataArray
:param da: single element data array

:type flag_name: str
:param flag_name: name of flag to set

:return: set flags
:rtype: list
"""

if da.shape != ():
raise ValueError("Must pass single element data array")

set_flags = DatasetUtil.get_set_flags(da)

if flag_name in set_flags:
return True
return False


if __name__ == "__main__":
pass
155 changes: 154 additions & 1 deletion hypernets_processor/data_io/tests/test_dataset_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import unittest
import numpy as np
from xarray import DataArray, Variable
from xarray import DataArray, Variable, Dataset
from hypernets_processor.data_io.dataset_util import DatasetUtil
from hypernets_processor.version import __version__

Expand Down Expand Up @@ -241,6 +241,159 @@ def test_get_default_fill_value(self):
self.assertEqual(np.float32(9.96921E36), DatasetUtil.get_default_fill_value(np.float32))
self.assertEqual(9.969209968386869E36, DatasetUtil.get_default_fill_value(np.float64))

def test__get_flag_encoding(self):

ds = Dataset()
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
masks = [1, 2, 4, 8, 16, 32, 64, 128]
flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"],
attributes={"standard_name": "std"})

ds["flags"] = flags_vector_variable

meanings_out, masks_out = DatasetUtil._get_flag_encoding(ds["flags"])

self.assertCountEqual(meanings, meanings_out)
self.assertCountEqual(masks, masks_out)

def test__get_flag_encoding_not_flag_var(self):
ds = Dataset()
ds["array_variable"] = DatasetUtil.create_variable([7, 8, 3], np.int8, attributes={"standard_name": "std"})

self.assertRaises(KeyError, DatasetUtil._get_flag_encoding, ds["array_variable"])

def test_unpack_flags(self):

ds = Dataset()
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
masks = [1, 2, 4, 8, 16, 32, 64, 128]
flags_vector_variable = DatasetUtil.create_flags_variable([2,3], meanings, dim_names=["dim1", "dim2"],
attributes={"standard_name": "std"})

ds["flags"] = flags_vector_variable
ds["flags"][0, 0] = ds["flags"][0, 0] | 8

empty = np.zeros((2, 3), bool)
flag4 = np.zeros((2, 3), bool)
flag4[0,0] = True

flags = DatasetUtil.unpack_flags(ds["flags"])

self.assertTrue((flags["flag1"].data == empty).all())
self.assertTrue((flags["flag2"].data == empty).all())
self.assertTrue((flags["flag3"].data == empty).all())
self.assertTrue((flags["flag4"].data == flag4).all())
self.assertTrue((flags["flag5"].data == empty).all())
self.assertTrue((flags["flag6"].data == empty).all())
self.assertTrue((flags["flag7"].data == empty).all())
self.assertTrue((flags["flag8"].data == empty).all())

def test_get_set_flags(self):

ds = Dataset()
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"],
attributes={"standard_name": "std"})
ds["flags"] = flags_vector_variable
ds["flags"][3] = ds["flags"][3] | 8
ds["flags"][3] = ds["flags"][3] | 32

set_flags = DatasetUtil.get_set_flags(ds["flags"][3])

self.assertCountEqual(set_flags, ["flag4", "flag6"])

def test_get_set_flags_2d(self):

ds = Dataset()
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"],
attributes={"standard_name": "std"})
ds["flags"] = flags_vector_variable

self.assertRaises(ValueError, DatasetUtil.get_set_flags, ds["flags"])

def test_check_flag_set_true(self):
ds = Dataset()
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"],
attributes={"standard_name": "std"})
ds["flags"] = flags_vector_variable
ds["flags"][3] = ds["flags"][3] | 8
ds["flags"][3] = ds["flags"][3] | 32

flag_set = DatasetUtil.check_flag_set(ds["flags"][3], "flag6")

self.assertTrue(flag_set)

def test_check_flag_set_false(self):
ds = Dataset()
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"],
attributes={"standard_name": "std"})
ds["flags"] = flags_vector_variable
ds["flags"][3] = ds["flags"][3] | 8

flag_set = DatasetUtil.check_flag_set(ds["flags"][3], "flag6")

self.assertFalse(flag_set)

def test_check_flag_set_2d(self):
ds = Dataset()
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"],
attributes={"standard_name": "std"})
ds["flags"] = flags_vector_variable

self.assertRaises(ValueError, DatasetUtil.check_flag_set, ds["flags"], "flag6")

def test_set_flag(self):

ds = Dataset()
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
flags_vector_variable = DatasetUtil.create_flags_variable([5, 4], meanings, dim_names=["dim1", "dim2"],
attributes={"standard_name": "std"})
ds["flags"] = flags_vector_variable

ds["flags"] = DatasetUtil.set_flag(ds["flags"], "flag4")

flags = np.full(ds["flags"].shape, 0|8)

self.assertTrue((ds["flags"].data == flags).all())

def test_set_flag_error_if_set(self):
ds = Dataset()
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"],
attributes={"standard_name": "std"})
ds["flags"] = flags_vector_variable
ds["flags"][3] = ds["flags"][3] | 8

self.assertRaises(ValueError, DatasetUtil.set_flag, ds["flags"], "flag4", error_if_set=True)

def test_unset_flag(self):

ds = Dataset()
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"],
attributes={"standard_name": "std"})
ds["flags"] = flags_vector_variable
ds["flags"][:] = ds["flags"][:] | 8

ds["flags"] = DatasetUtil.unset_flag(ds["flags"], "flag4")

flags = np.zeros(ds["flags"].shape)

self.assertTrue((ds["flags"].data == flags).all())

def test_set_flag_error_if_unset(self):
ds = Dataset()
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
flags_vector_variable = DatasetUtil.create_flags_variable([5], meanings, dim_names=["dim1"],
attributes={"standard_name": "std"})
ds["flags"] = flags_vector_variable

self.assertRaises(ValueError, DatasetUtil.unset_flag, ds["flags"], "flag4", error_if_unset=True)


if __name__ == '__main__':
unittest.main()