diff --git a/src/mdio/builder/schemas/compressors.py b/src/mdio/builder/schemas/compressors.py index 7794277a..f661712b 100644 --- a/src/mdio/builder/schemas/compressors.py +++ b/src/mdio/builder/schemas/compressors.py @@ -71,11 +71,6 @@ class ZFP(CamelCaseStrictModel): description="Fixed precision in terms of number of uncompressed bits per value.", ) - write_header: bool = Field( - default=True, - description="Encode array shape, scalar type, and compression parameters.", - ) - @model_validator(mode="after") def check_requirements(self) -> ZFP: """Check if ZFP parameters make sense.""" diff --git a/src/mdio/builder/xarray_builder.py b/src/mdio/builder/xarray_builder.py index 18ff31ef..58501cba 100644 --- a/src/mdio/builder/xarray_builder.py +++ b/src/mdio/builder/xarray_builder.py @@ -14,9 +14,11 @@ try: # zfpy is an optional dependency for ZFP compression # It is not installed by default, so we check for its presence and import it only if available. - from zfpy import ZFPY as zfpy_ZFPY # noqa: N811 + from numcodecs import ZFPY as zfpy_ZFPY # noqa: N811 + from zarr.codecs.numcodecs import ZFPY as zarr_ZFPY # noqa: N811 except ImportError: zfpy_ZFPY = None # noqa: N816 + zarr_ZFPY = None # noqa: N816 from mdio.builder.schemas.compressors import ZFP as mdio_ZFP # noqa: N811 from mdio.builder.schemas.compressors import Blosc as mdio_Blosc @@ -121,33 +123,34 @@ def _get_zarr_chunks(var: Variable, all_named_dims: dict[str, NamedDimension]) - return _get_zarr_shape(var, all_named_dims=all_named_dims) -def _convert_compressor( +def _compressor_to_encoding( compressor: mdio_Blosc | mdio_ZFP | None, -) -> BloscCodec | Blosc | zfpy_ZFPY | None: +) -> dict[str, BloscCodec | Blosc | zfpy_ZFPY | zarr_ZFPY] | None: """Convert a compressor to a numcodecs compatible format.""" if compressor is None: return None + if not isinstance(compressor, (mdio_Blosc, mdio_ZFP)): + msg = f"Unsupported compressor model: {type(compressor)}" + raise TypeError(msg) + + is_v2 = zarr.config.get("default_zarr_format") == ZarrFormat.V2 + kwargs = compressor.model_dump(exclude={"name"}, mode="json") + if isinstance(compressor, mdio_Blosc): - blosc_kwargs = compressor.model_dump(exclude={"name"}, mode="json") - if zarr.config.get("default_zarr_format") == ZarrFormat.V2: - blosc_kwargs["shuffle"] = -1 if blosc_kwargs["shuffle"] is None else blosc_kwargs["shuffle"] - return Blosc(**blosc_kwargs) - return BloscCodec(**blosc_kwargs) - - if isinstance(compressor, mdio_ZFP): - if zfpy_ZFPY is None: - msg = "zfpy and numcodecs are required to use ZFP compression" - raise ImportError(msg) - return zfpy_ZFPY( - mode=compressor.mode.value, - tolerance=compressor.tolerance, - rate=compressor.rate, - precision=compressor.precision, - ) - - msg = f"Unsupported compressor model: {type(compressor)}" - raise TypeError(msg) + if is_v2 and kwargs["shuffle"] is None: + kwargs["shuffle"] = -1 + codec_cls = Blosc if is_v2 else BloscCodec + return {"compressors": codec_cls(**kwargs)} + + # must be ZFP beyond here + if zfpy_ZFPY is None: + msg = "zfpy and numcodecs are required to use ZFP compression" + raise ImportError(msg) + kwargs["mode"] = compressor.mode.int_code + if is_v2: + return {"compressors": zfpy_ZFPY(**kwargs)} + return {"serializer": zarr_ZFPY(**kwargs), "compressors": None} def _get_fill_value(data_type: ScalarType | StructuredType | str) -> any: @@ -222,10 +225,14 @@ def to_xarray_dataset(mdio_ds: Dataset) -> xr_Dataset: # noqa: PLR0912 encoding = { "chunks": original_chunks, - "compressors": _convert_compressor(v.compressor), fill_value_key: fill_value, } + compressor_encodings = _compressor_to_encoding(v.compressor) + + if compressor_encodings is not None: + encoding.update(compressor_encodings) + if zarr_format == ZarrFormat.V2: encoding["chunk_key_encoding"] = {"name": "v2", "configuration": {"separator": "/"}} diff --git a/tests/unit/v1/test_dataset_serializer.py b/tests/unit/v1/test_dataset_serializer.py index 45c19665..81cc5e83 100644 --- a/tests/unit/v1/test_dataset_serializer.py +++ b/tests/unit/v1/test_dataset_serializer.py @@ -19,7 +19,7 @@ from mdio.builder.schemas.v1.variable import Coordinate from mdio.builder.schemas.v1.variable import Variable from mdio.builder.schemas.v1.variable import VariableMetadata -from mdio.builder.xarray_builder import _convert_compressor +from mdio.builder.xarray_builder import _compressor_to_encoding from mdio.builder.xarray_builder import _get_all_named_dimensions from mdio.builder.xarray_builder import _get_coord_names from mdio.builder.xarray_builder import _get_dimension_names @@ -226,43 +226,49 @@ def test_get_fill_value() -> None: assert result_none_input is None -def test_convert_compressor() -> None: - """Simple test for _convert_compressor function covering basic scenarios.""" +def test_compressor_to_encoding() -> None: + """Simple test for _compressor_to_encoding function covering basic scenarios.""" # Test 1: None input - should return None - result_none = _convert_compressor(None) + result_none = _compressor_to_encoding(None) assert result_none is None # Test 2: mdio_Blosc compressor - should return nc_Blosc mdio_compressor = mdio_Blosc(cname=BloscCname.lz4, clevel=5, shuffle=BloscShuffle.bitshuffle, blocksize=1024) - result_blosc = _convert_compressor(mdio_compressor) + result_blosc = _compressor_to_encoding(mdio_compressor) - assert isinstance(result_blosc, BloscCodec) - assert result_blosc.cname == BloscCname.lz4 - assert result_blosc.clevel == 5 - assert result_blosc.shuffle == BloscShuffle.bitshuffle - assert result_blosc.blocksize == 1024 + assert isinstance(result_blosc, dict) + assert "compressors" in result_blosc + assert isinstance(result_blosc["compressors"], BloscCodec) + assert result_blosc["compressors"].cname == BloscCname.lz4 + assert result_blosc["compressors"].clevel == 5 + assert result_blosc["compressors"].shuffle == BloscShuffle.bitshuffle + assert result_blosc["compressors"].blocksize == 1024 # Test 3: mdio_ZFP compressor - should return zfpy_ZFPY if available zfp_compressor = MDIO_ZFP(mode=mdio_ZFPMode.FIXED_RATE, tolerance=0.01, rate=8.0, precision=16) + # TODO(BrianMichell): Update to also test zfp compression. + # https://github.com/TGSAI/mdio-python/issues/747 if HAS_ZFPY: # pragma: no cover - result_zfp = _convert_compressor(zfp_compressor) - assert isinstance(result_zfp, ZFPY) - assert result_zfp.mode == 1 # ZFPMode.FIXED_RATE.value = "fixed_rate" - assert result_zfp.tolerance == 0.01 - assert result_zfp.rate == 8.0 - assert result_zfp.precision == 16 + result_zfp = _compressor_to_encoding(zfp_compressor) + assert isinstance(result_zfp, dict) + assert "compressors" not in result_zfp + assert isinstance(result_zfp["serializer"], ZFPY) + assert result_zfp["serializer"].mode == 1 # ZFPMode.FIXED_RATE.value = "fixed_rate" + assert result_zfp["serializer"].tolerance == 0.01 + assert result_zfp["serializer"].rate == 8.0 + assert result_zfp["serializer"].precision == 16 else: # Test 5: mdio_ZFP without zfpy installed - should raise ImportError with pytest.raises(ImportError) as exc_info: - _convert_compressor(zfp_compressor) + _compressor_to_encoding(zfp_compressor) error_message = str(exc_info.value) assert "zfpy and numcodecs are required to use ZFP compression" in error_message # Test 6: Unsupported compressor type - should raise TypeError unsupported_compressor = "invalid_compressor" with pytest.raises(TypeError) as exc_info: - _convert_compressor(unsupported_compressor) + _compressor_to_encoding(unsupported_compressor) error_message = str(exc_info.value) assert "Unsupported compressor model" in error_message assert "" in error_message