Skip to content

Commit

Permalink
ArrayData: Allow defining array(s) on construction
Browse files Browse the repository at this point in the history
Currently, the constructor does not allow to define any arrays to set
when constructing a new node, so one is forced to multi line code:

    node = ArrayData()
    node.set_array('a', np.array([1, 2]))
    node.set_array('b', np.array([3, 4]))

This commit allows initialization upon construction simplifying the code
above to:

    node = ArrayData({'a': np.array([1, 2]), 'b': np.array([3, 4])})

Note that it is also possible to pass a single array to the constructor,
in which case the array name is taken from the `default_array_name`
class attribute.

For backwards compatibility, it remains possible to construct an
`ArrayData` without any arrays.
  • Loading branch information
sphuber committed Sep 5, 2023
1 parent c19b142 commit 35e669f
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
21 changes: 19 additions & 2 deletions aiida/orm/nodes/data/array/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,32 @@ class ArrayData(Data):
cache with the :py:meth:`.clear_internal_cache` method.
"""
array_prefix = 'array|'
default_array_name = 'default'

def __init__(self, **kwargs):
def __init__(self, arrays: 'ndarray' | dict[str, 'ndarray'] | None = None, **kwargs):
"""Construct a new instance and set one or multiple numpy arrays.
:param arrays: A single numpy array, or a dictionary of numpy arrays to store.
:param arrays: An optional single numpy array, or dictionary of numpy arrays to store.
"""
import numpy

super().__init__(**kwargs)
self._cached_arrays: dict[str, 'ndarray'] = {}

arrays = arrays if arrays is not None else {}

if isinstance(arrays, numpy.ndarray):
arrays = {self.default_array_name: arrays}

if (
not isinstance(arrays, dict) # type: ignore[redundant-expr]
or any(not isinstance(a, numpy.ndarray) for a in arrays.values())
):
raise TypeError(f'`arrays` should be a single numpy array or dictionary of numpy arrays but got: {arrays}')

for key, value in arrays.items():
self.set_array(key, value)

def initialize(self):
super().initialize()
self._cached_arrays = {}
Expand Down
17 changes: 17 additions & 0 deletions tests/orm/nodes/data/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,20 @@ def test_read_stored():

loaded = load_node(node.uuid)
assert numpy.array_equal(loaded.get_array('array'), array)


def test_constructor():
"""Test the various construction options."""
node = ArrayData()
assert node.get_arraynames() == []

arrays = numpy.array([1, 2])
node = ArrayData(arrays)
assert node.get_arraynames() == [ArrayData.default_array_name]
assert (node.get_array(ArrayData.default_array_name) == arrays).all()

arrays = {'a': numpy.array([1, 2]), 'b': numpy.array([3, 4])}
node = ArrayData(arrays)
assert sorted(node.get_arraynames()) == ['a', 'b']
assert (node.get_array('a') == arrays['a']).all()
assert (node.get_array('b') == arrays['b']).all()

0 comments on commit 35e669f

Please sign in to comment.