Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/mdio/builder/schemas/compressors.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,6 @@ class ZFP(CamelCaseStrictModel):
description="Fixed precision in terms of number of uncompressed bits per value.",
)

write_header: bool = Field(
default=True,
description="Encode array shape, scalar type, and compression parameters.",
)

@model_validator(mode="after")
def check_requirements(self) -> ZFP:
"""Check if ZFP parameters make sense."""
Expand Down
53 changes: 30 additions & 23 deletions src/mdio/builder/xarray_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
try:
# zfpy is an optional dependency for ZFP compression
# It is not installed by default, so we check for its presence and import it only if available.
from zfpy import ZFPY as zfpy_ZFPY # noqa: N811
from numcodecs import ZFPY as zfpy_ZFPY # noqa: N811
from zarr.codecs.numcodecs import ZFPY as zarr_ZFPY # noqa: N811
except ImportError:
zfpy_ZFPY = None # noqa: N816
zarr_ZFPY = None # noqa: N816

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was looking through test_dataset_serializer.py and I only found imports for zfpy.ZFPY. Are there tests for zarr.codecs.numcodecs.ZFPY?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#747 will help with this. Lossy compression has not been a high priority.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice. You're on it already.


from mdio.builder.schemas.compressors import ZFP as mdio_ZFP # noqa: N811
from mdio.builder.schemas.compressors import Blosc as mdio_Blosc
Expand Down Expand Up @@ -121,33 +123,34 @@ def _get_zarr_chunks(var: Variable, all_named_dims: dict[str, NamedDimension]) -
return _get_zarr_shape(var, all_named_dims=all_named_dims)


def _convert_compressor(
def _compressor_to_encoding(
compressor: mdio_Blosc | mdio_ZFP | None,
) -> BloscCodec | Blosc | zfpy_ZFPY | None:
) -> dict[str, BloscCodec | Blosc | zfpy_ZFPY | zarr_ZFPY] | None:
"""Convert a compressor to a numcodecs compatible format."""
if compressor is None:
return None

if not isinstance(compressor, (mdio_Blosc, mdio_ZFP)):
msg = f"Unsupported compressor model: {type(compressor)}"
raise TypeError(msg)

is_v2 = zarr.config.get("default_zarr_format") == ZarrFormat.V2
kwargs = compressor.model_dump(exclude={"name"}, mode="json")

if isinstance(compressor, mdio_Blosc):
blosc_kwargs = compressor.model_dump(exclude={"name"}, mode="json")
if zarr.config.get("default_zarr_format") == ZarrFormat.V2:
blosc_kwargs["shuffle"] = -1 if blosc_kwargs["shuffle"] is None else blosc_kwargs["shuffle"]
return Blosc(**blosc_kwargs)
return BloscCodec(**blosc_kwargs)

if isinstance(compressor, mdio_ZFP):
if zfpy_ZFPY is None:
msg = "zfpy and numcodecs are required to use ZFP compression"
raise ImportError(msg)
return zfpy_ZFPY(
mode=compressor.mode.value,
tolerance=compressor.tolerance,
rate=compressor.rate,
precision=compressor.precision,
)

msg = f"Unsupported compressor model: {type(compressor)}"
raise TypeError(msg)
if is_v2 and kwargs["shuffle"] is None:
kwargs["shuffle"] = -1
codec_cls = Blosc if is_v2 else BloscCodec
return {"compressors": codec_cls(**kwargs)}

# must be ZFP beyond here
if zfpy_ZFPY is None:
msg = "zfpy and numcodecs are required to use ZFP compression"
raise ImportError(msg)
kwargs["mode"] = compressor.mode.int_code
if is_v2:
return {"compressors": zfpy_ZFPY(**kwargs)}
return {"serializer": zarr_ZFPY(**kwargs), "compressors": None}


def _get_fill_value(data_type: ScalarType | StructuredType | str) -> any:
Expand Down Expand Up @@ -222,10 +225,14 @@ def to_xarray_dataset(mdio_ds: Dataset) -> xr_Dataset: # noqa: PLR0912

encoding = {
"chunks": original_chunks,
"compressors": _convert_compressor(v.compressor),
fill_value_key: fill_value,
}

compressor_encodings = _compressor_to_encoding(v.compressor)

if compressor_encodings is not None:
encoding.update(compressor_encodings)

if zarr_format == ZarrFormat.V2:
encoding["chunk_key_encoding"] = {"name": "v2", "configuration": {"separator": "/"}}

Expand Down
42 changes: 24 additions & 18 deletions tests/unit/v1/test_dataset_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from mdio.builder.schemas.v1.variable import Coordinate
from mdio.builder.schemas.v1.variable import Variable
from mdio.builder.schemas.v1.variable import VariableMetadata
from mdio.builder.xarray_builder import _convert_compressor
from mdio.builder.xarray_builder import _compressor_to_encoding
from mdio.builder.xarray_builder import _get_all_named_dimensions
from mdio.builder.xarray_builder import _get_coord_names
from mdio.builder.xarray_builder import _get_dimension_names
Expand Down Expand Up @@ -226,43 +226,49 @@ def test_get_fill_value() -> None:
assert result_none_input is None


def test_convert_compressor() -> None:
"""Simple test for _convert_compressor function covering basic scenarios."""
def test_compressor_to_encoding() -> None:
"""Simple test for _compressor_to_encoding function covering basic scenarios."""
# Test 1: None input - should return None
result_none = _convert_compressor(None)
result_none = _compressor_to_encoding(None)
assert result_none is None

# Test 2: mdio_Blosc compressor - should return nc_Blosc
mdio_compressor = mdio_Blosc(cname=BloscCname.lz4, clevel=5, shuffle=BloscShuffle.bitshuffle, blocksize=1024)
result_blosc = _convert_compressor(mdio_compressor)
result_blosc = _compressor_to_encoding(mdio_compressor)

assert isinstance(result_blosc, BloscCodec)
assert result_blosc.cname == BloscCname.lz4
assert result_blosc.clevel == 5
assert result_blosc.shuffle == BloscShuffle.bitshuffle
assert result_blosc.blocksize == 1024
assert isinstance(result_blosc, dict)
assert "compressors" in result_blosc
assert isinstance(result_blosc["compressors"], BloscCodec)
assert result_blosc["compressors"].cname == BloscCname.lz4
assert result_blosc["compressors"].clevel == 5
assert result_blosc["compressors"].shuffle == BloscShuffle.bitshuffle
assert result_blosc["compressors"].blocksize == 1024

# Test 3: mdio_ZFP compressor - should return zfpy_ZFPY if available
zfp_compressor = MDIO_ZFP(mode=mdio_ZFPMode.FIXED_RATE, tolerance=0.01, rate=8.0, precision=16)

# TODO(BrianMichell): Update to also test zfp compression.
# https://github.com/TGSAI/mdio-python/issues/747
if HAS_ZFPY: # pragma: no cover
result_zfp = _convert_compressor(zfp_compressor)
assert isinstance(result_zfp, ZFPY)
assert result_zfp.mode == 1 # ZFPMode.FIXED_RATE.value = "fixed_rate"
assert result_zfp.tolerance == 0.01
assert result_zfp.rate == 8.0
assert result_zfp.precision == 16
result_zfp = _compressor_to_encoding(zfp_compressor)
assert isinstance(result_zfp, dict)
assert "compressors" not in result_zfp
assert isinstance(result_zfp["serializer"], ZFPY)
assert result_zfp["serializer"].mode == 1 # ZFPMode.FIXED_RATE.value = "fixed_rate"
assert result_zfp["serializer"].tolerance == 0.01
assert result_zfp["serializer"].rate == 8.0
assert result_zfp["serializer"].precision == 16
else:
# Test 5: mdio_ZFP without zfpy installed - should raise ImportError
with pytest.raises(ImportError) as exc_info:
_convert_compressor(zfp_compressor)
_compressor_to_encoding(zfp_compressor)
error_message = str(exc_info.value)
assert "zfpy and numcodecs are required to use ZFP compression" in error_message

# Test 6: Unsupported compressor type - should raise TypeError
unsupported_compressor = "invalid_compressor"
with pytest.raises(TypeError) as exc_info:
_convert_compressor(unsupported_compressor)
_compressor_to_encoding(unsupported_compressor)
error_message = str(exc_info.value)
assert "Unsupported compressor model" in error_message
assert "<class 'str'>" in error_message
Expand Down