Skip to content

Commit

Permalink
Merge pull request #15 from FrescolinoGroup/add_numpy_special_type
Browse files Browse the repository at this point in the history
Add special type tag for numpy arrays
  • Loading branch information
mskoenz committed Mar 11, 2020
2 parents b92ada1 + dfc2e81 commit 34cb73b
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 5 deletions.
8 changes: 6 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ matrix:
env: TEST_TYPE="compliance"
install:
- pip install .
- pip install codecov
- pip install pytest-cov
script:
- if [ "$TEST_TYPE" == "compliance" ] ; then pip install .[dev]; pre-commit run --all-files ; fi
- if [ "$TEST_TYPE" == "test" ] ; then cd tests; py.test ; fi
- if [ "$TEST_TYPE" == "test_sympy" ] ; then pip install sympy; cd tests; py.test ; fi
- if [ "$TEST_TYPE" == "test" ] ; then pytest --cov=fsc.hdf5_io --cov-config=.coveragerc ; fi
- if [ "$TEST_TYPE" == "test_sympy" ] ; then pip install sympy; pytest --cov=fsc.hdf5_io --cov-config=.coveragerc ; fi
after_success:
- codecov
27 changes: 26 additions & 1 deletion fsc/hdf5_io/_special_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from functools import singledispatch
from collections.abc import Iterable, Mapping, Hashable

import numpy as np

from ._base_classes import Deserializable

from ._save_load import from_hdf5, to_hdf5, to_hdf5_singledispatch
Expand All @@ -25,6 +27,7 @@ class _SpecialTypeTags(SimpleNamespace):
NUMBER = 'builtins.number'
STR = 'builtins.str'
NONE = 'builtins.none'
NUMPY_ARRAY = 'numpy.ndarray'
SYMPY = 'sympy.object' # defined in _sympy_load.py and _sympy_save.py


Expand Down Expand Up @@ -74,6 +77,16 @@ def from_hdf5(cls, hdf5_handle):
return hdf5_handle['value'][()]


@subscribe_hdf5(_SpecialTypeTags.NUMPY_ARRAY)
class _NumpyArraryDeserializer(Deserializable):
"""Helper class to de-serialize numpy arrays."""
@classmethod
def from_hdf5(cls, hdf5_handle):
if 'value' in hdf5_handle:
return hdf5_handle['value'][()]
return np.array(_deserialize_iterable(hdf5_handle))


@subscribe_hdf5(_SpecialTypeTags.NONE)
class _NoneDeserializer(Deserializable):
"""Helper class to de-serialize ``None``."""
Expand Down Expand Up @@ -128,9 +141,10 @@ def _(obj, hdf5_handle):


@to_hdf5_singledispatch.register(str)
@to_hdf5_singledispatch.register(np.str_)
@add_type_tag(_SpecialTypeTags.STR)
def _(obj, hdf5_handle):
_value_serializer(obj, hdf5_handle)
_value_serializer(str(obj), hdf5_handle)


@to_hdf5_singledispatch.register(type(None))
Expand All @@ -139,6 +153,17 @@ def _(obj, hdf5_handle):
pass


@to_hdf5_singledispatch.register(np.ndarray)
@add_type_tag(_SpecialTypeTags.NUMPY_ARRAY)
def _(obj, hdf5_handle): # pylint: disable=missing-docstring
try:
_value_serializer(obj, hdf5_handle)
except TypeError:
# if the numpy dtype does not have a native HDF5 equivalent,
# treat it as an iterable instead
_serialize_iterable(obj, hdf5_handle)


def _value_serializer(obj, hdf5_handle):
hdf5_handle['value'] = obj

Expand Down
Binary file modified tests/samples/test_save_load_test_number[permanent].hdf5
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
20 changes: 18 additions & 2 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import h5py
import pytest
import numpy as np
from numpy.testing import assert_equal

from fsc.hdf5_io import save, load

Expand All @@ -22,7 +23,7 @@ def inner_tempfile(x):
with tempfile.NamedTemporaryFile() as named_file:
save(x, named_file.name)
y = load(named_file.name)
assert x == y
assert_equal(x, y)

def inner_permanent(x):
"""
Expand All @@ -31,7 +32,7 @@ def inner_permanent(x):
file_name = sample((test_name + '.hdf5').replace('/', '_'))
try:
y = load(file_name)
assert x == y
assert_equal(x, y)
except IOError:
save(x, file_name)
raise ValueError("Sample file did not exist")
Expand Down Expand Up @@ -159,3 +160,18 @@ def test_legacyclass_notag(sample):
x = LegacyClass.from_hdf5_file(sample('no_tag.hdf5'), y=1.2)
assert x.x == 10
assert x.y == 1.2


@pytest.mark.parametrize(
'obj', [
np.array([[1, 2, 3], [4, 5, 6]]),
np.array([1, 2., None, 'foo'], dtype=object),
np.array(['foo', 'bar', 'baz']), (np.array(['foo', 'bar', 'baz']), ),
np.array([[1, 2], [4, 5]], dtype=[('age', 'i4'), ('weight', 'f4')])
]
)
def test_numpy_array(check_save_load, obj): # pylint: disable=redefined-outer-name
"""
Check save / load for numpy arrays
"""
check_save_load(obj)

0 comments on commit 34cb73b

Please sign in to comment.