Skip to content

Commit

Permalink
move all_array_storage/compression into config
Browse files Browse the repository at this point in the history
  • Loading branch information
braingram committed Mar 10, 2023
1 parent a612e90 commit 07cb824
Show file tree
Hide file tree
Showing 9 changed files with 238 additions and 37 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
------------------

The ASDF Standard is at v1.6.0
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

- Drop support for ASDF-in-FITS. [#1288]
- Add ``all_array_storage``, ``all_array_compression`` and
``all_array_compression_kwargs`` to ``asdf.config.AsdfConfig`` [#1468]

2.15.0 (unreleased)
-------------------
Expand Down
70 changes: 70 additions & 0 deletions asdf/_tests/test_array_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,26 @@ def test_update_expand_tree(tmp_path):
assert_array_equal(ff.tree["arrays"][1], my_array2)


def test_update_all_external(tmp_path):
fn = tmp_path / "test.asdf"

my_array = np.arange(64) * 1
my_array2 = np.arange(64) * 2
tree = {"arrays": [my_array, my_array2]}

af = asdf.AsdfFile(tree)
af.write_to(fn)

with asdf.config.config_context() as cfg:
cfg.array_inline_threshold = 10
cfg.all_array_storage = "external"
with asdf.open(fn, mode="rw") as af:
af.update()

assert "test0000.asdf" in os.listdir(tmp_path)
assert "test0001.asdf" in os.listdir(tmp_path)


def _get_update_tree():
return {"arrays": [np.arange(64) * 1, np.arange(64) * 2, np.arange(64) * 3]}

Expand Down Expand Up @@ -830,3 +850,53 @@ def test_block_allocation_on_validate():
assert len(list(af._blocks.blocks)) == 1
af.validate()
assert len(list(af._blocks.blocks)) == 1


@pytest.mark.parametrize("all_array_storage", ["internal", "external", "inline"])
@pytest.mark.parametrize("all_array_compression", [None, "", "zlib", "bzp2", "lz4", "input"])
@pytest.mark.parametrize("compression_kwargs", [None, {}])
def test_write_to_update_storage_options(tmp_path, all_array_storage, all_array_compression, compression_kwargs):
if all_array_compression == "bzp2" and compression_kwargs is not None:
compression_kwargs = {"compresslevel": 1}

def assert_result(ff, arr):
if all_array_storage == "external":
assert "test0000.asdf" in os.listdir(tmp_path)
else:
assert "test0000.asdf" not in os.listdir(tmp_path)
if all_array_storage == "internal":
assert len(ff._blocks._internal_blocks) == 1
else:
assert len(ff._blocks._internal_blocks) == 0
blk = ff._blocks[arr]

target_compression = all_array_compression or None
assert blk._output_compression == target_compression

target_compression_kwargs = compression_kwargs or {}
assert blk._output_compression_kwargs == target_compression_kwargs

arr1 = np.ones((8, 8))
tree = {"array": arr1}
fn = tmp_path / "test.asdf"

ff1 = asdf.AsdfFile(tree)
# first check write_to
ff1.write_to(
fn,
all_array_storage=all_array_storage,
all_array_compression=all_array_compression,
compression_kwargs=compression_kwargs,
)
assert_result(ff1, arr1)

# then reuse the file to check update
with asdf.open(fn, mode="rw") as ff2:
arr2 = np.ones((8, 8)) * 42
ff2["array"] = arr2
ff2.update(
all_array_storage=all_array_storage,
all_array_compression=all_array_compression,
compression_kwargs=compression_kwargs,
)
assert_result(ff2, arr2)
6 changes: 6 additions & 0 deletions asdf/_tests/test_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,15 @@ def compressors(self):
def test_compression_with_extension(tmp_path):
tree = _get_large_tree()

with pytest.raises(ValueError, match="Supported compression types are"), config_context() as cfg:
cfg.all_array_compression = "lzma"

with config_context() as config:
config.add_extension(LzmaExtension())

with config_context() as cfg:
cfg.all_array_compression = "lzma"

with pytest.raises(lzma.LZMAError, match=r"Invalid or unsupported options"):
_roundtrip(tmp_path, tree, "lzma", write_options={"compression_kwargs": {"preset": 9000}})
fn = _roundtrip(tmp_path, tree, "lzma", write_options={"compression_kwargs": {"preset": 6}})
Expand Down
33 changes: 33 additions & 0 deletions asdf/_tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,39 @@ def test_array_inline_threshold():
assert get_config().array_inline_threshold is None


def test_all_array_storage():
with asdf.config_context() as config:
assert config.all_array_storage == asdf.config.DEFAULT_ALL_ARRAY_STORAGE
config.all_array_storage = "internal"
assert get_config().all_array_storage == "internal"
config.all_array_storage = None
assert get_config().all_array_storage is None
with pytest.raises(ValueError, match=r"Invalid value for all_array_storage"):
config.all_array_storage = "foo"


def test_all_array_compression():
with asdf.config_context() as config:
assert config.all_array_compression == asdf.config.DEFAULT_ALL_ARRAY_COMPRESSION
config.all_array_compression = "zlib"
assert get_config().all_array_compression == "zlib"
config.all_array_compression = None
assert get_config().all_array_compression is None
with pytest.raises(ValueError, match=r"Supported compression types are"):
config.all_array_compression = "foo"


def test_all_array_compression_kwargs():
with asdf.config_context() as config:
assert config.all_array_compression_kwargs == asdf.config.DEFAULT_ALL_ARRAY_COMPRESSION_KWARGS
config.all_array_compression_kwargs = {}
assert get_config().all_array_compression_kwargs == {}
config.all_array_compression_kwargs = None
assert get_config().all_array_compression_kwargs is None
with pytest.raises(ValueError, match=r"Invalid value for all_array_compression_kwargs"):
config.all_array_compression_kwargs = "foo"


def test_resource_mappings():
with asdf.config_context() as config:
core_mappings = get_json_schema_resource_mappings() + asdf_standard.integration.get_resource_mappings()
Expand Down
52 changes: 26 additions & 26 deletions asdf/asdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,16 +1092,7 @@ def _tree_finalizer(tagged_tree):
padding = util.calculate_padding(fd.tell(), pad_blocks, fd.block_size)
fd.fast_forward(padding)

def _pre_write(self, fd, all_array_storage, all_array_compression, compression_kwargs=None):
if all_array_storage not in (None, "internal", "external", "inline"):
msg = f"Invalid value for all_array_storage: '{all_array_storage}'"
raise ValueError(msg)

self._all_array_storage = all_array_storage

self._all_array_compression = all_array_compression
self._all_array_compression_kwargs = compression_kwargs

def _pre_write(self, fd):
if len(self._tree):
self._run_hook("pre_write")

Expand Down Expand Up @@ -1132,12 +1123,6 @@ def _post_write(self, fd):

def update(
self,
all_array_storage=None,
all_array_compression="input",
pad_blocks=False,
include_block_index=True,
version=None,
compression_kwargs=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -1196,7 +1181,17 @@ def update(
in ``asdf.get_config().array_inline_threshold``.
"""

pad_blocks = kwargs.pop("pad_blocks", False)
include_block_index = kwargs.pop("include_block_index", True)
version = kwargs.pop("version", None)

with config_context() as config:
if "all_array_storage" in kwargs:
config.all_array_storage = kwargs.pop("all_array_storage")
if "all_array_compression" in kwargs:
config.all_array_compression = kwargs.pop("all_array_compression")
if "compression_kwargs" in kwargs:
config.all_array_compression_kwargs = kwargs.pop("compression_kwargs")
_handle_deprecated_kwargs(config, kwargs)

fd = self._fd
Expand All @@ -1216,10 +1211,10 @@ def update(
if version is not None:
self.version = version

if all_array_storage == "external":
if config.all_array_storage == "external":
# If the file is fully exploded, there's no benefit to
# update, so just use write_to()
self.write_to(fd, all_array_storage=all_array_storage)
self.write_to(fd)
fd.truncate()
return

Expand All @@ -1233,7 +1228,7 @@ def update(
if fd.can_memmap():
fd.flush_memmap()

self._pre_write(fd, all_array_storage, all_array_compression, compression_kwargs=compression_kwargs)
self._pre_write(fd)

try:
fd.seek(0)
Expand Down Expand Up @@ -1280,12 +1275,6 @@ def update(
def write_to(
self,
fd,
all_array_storage=None,
all_array_compression="input",
pad_blocks=False,
include_block_index=True,
version=None,
compression_kwargs=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -1355,7 +1344,18 @@ def write_to(
``asdf.get_config().array_inline_threshold``.
"""

pad_blocks = kwargs.pop("pad_blocks", False)
include_block_index = kwargs.pop("include_block_index", True)
version = kwargs.pop("version", None)

with config_context() as config:
if "all_array_storage" in kwargs:
config.all_array_storage = kwargs.pop("all_array_storage")
if "all_array_compression" in kwargs:
config.all_array_compression = kwargs.pop("all_array_compression")
if "compression_kwargs" in kwargs:
config.all_array_compression_kwargs = kwargs.pop("compression_kwargs")
_handle_deprecated_kwargs(config, kwargs)

if version is not None:
Expand All @@ -1367,7 +1367,7 @@ def write_to(
# attribute of the AsdfFile.
if self._uri is None:
self._uri = fd.uri
self._pre_write(fd, all_array_storage, all_array_compression, compression_kwargs=compression_kwargs)
self._pre_write(fd)

try:
self._serial_write(fd, pad_blocks, include_block_index)
Expand Down
19 changes: 10 additions & 9 deletions asdf/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def write_external_blocks(self, uri, pad_blocks=False):
blk._array_storage = "internal"
asdffile._blocks.add(blk)
blk._used = True
asdffile.write_to(subfd, pad_blocks=pad_blocks)
asdffile.write_to(subfd, pad_blocks=pad_blocks, all_array_storage="internal")

def write_block_index(self, fd, ctx):
"""
Expand Down Expand Up @@ -567,13 +567,14 @@ def _find_used_blocks(self, tree, ctx):
if getattr(block, "_used", 0) == 0 and block not in reserved_blocks:
self.remove(block)

def _handle_global_block_settings(self, ctx, block):
all_array_storage = getattr(ctx, "_all_array_storage", None)
def _handle_global_block_settings(self, block):
cfg = get_config()
all_array_storage = cfg.all_array_storage
if all_array_storage:
self.set_array_storage(block, all_array_storage)

all_array_compression = getattr(ctx, "_all_array_compression", "input")
all_array_compression_kwargs = getattr(ctx, "_all_array_compression_kwargs", {})
all_array_compression = cfg.all_array_compression
all_array_compression_kwargs = cfg.all_array_compression_kwargs
# Only override block compression algorithm if it wasn't explicitly set
# by AsdfFile.set_array_compression.
if all_array_compression != "input":
Expand Down Expand Up @@ -601,7 +602,7 @@ def finalize(self, ctx):
self._find_used_blocks(ctx.tree, ctx)

for block in list(self.blocks):
self._handle_global_block_settings(ctx, block)
self._handle_global_block_settings(block)

def get_block(self, source):
"""
Expand Down Expand Up @@ -714,7 +715,7 @@ def get_source(self, block):
msg = "block not found."
raise ValueError(msg)

def find_or_create_block_for_array(self, arr, ctx):
def find_or_create_block_for_array(self, arr):
"""
For a given array, looks for an existing block containing its
underlying data. If not found, adds a new block to the block
Expand Down Expand Up @@ -743,7 +744,7 @@ def find_or_create_block_for_array(self, arr, ctx):

block = Block(base)
self.add(block)
self._handle_global_block_settings(ctx, block)
self._handle_global_block_settings(block)

return block

Expand Down Expand Up @@ -787,7 +788,7 @@ def get_output_compression_extensions(self):
return ext

def __getitem__(self, arr):
return self.find_or_create_block_for_array(arr, object())
return self.find_or_create_block_for_array(arr)

def close(self):
for block in self.blocks:
Expand Down
Loading

0 comments on commit 07cb824

Please sign in to comment.