Skip to content

Commit

Permalink
Add type hinting for aiida.orm.nodes.data.array.array
Browse files Browse the repository at this point in the history
  • Loading branch information
sphuber committed Sep 5, 2023
1 parent e1a5bd1 commit c19b142
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 18 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Expand Up @@ -123,7 +123,6 @@ repos:
aiida/orm/implementation/storage_backend.py|
aiida/orm/nodes/caching.py|
aiida/orm/nodes/comments.py|
aiida/orm/nodes/data/array/array.py|
aiida/orm/nodes/data/array/bands.py|
aiida/orm/nodes/data/array/trajectory.py|
aiida/orm/nodes/data/cif.py|
Expand Down
48 changes: 31 additions & 17 deletions aiida/orm/nodes/data/array/array.py
Expand Up @@ -10,8 +10,15 @@
"""
AiiDA ORM data class storing (numpy) arrays
"""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Iterator

from ..data import Data

if TYPE_CHECKING:
from numpy import ndarray

__all__ = ('ArrayData',)


Expand All @@ -32,13 +39,20 @@ class ArrayData(Data):
cache with the :py:meth:`.clear_internal_cache` method.
"""
array_prefix = 'array|'
_cached_arrays = None

def __init__(self, **kwargs):
"""Construct a new instance and set one or multiple numpy arrays.
:param arrays: A single numpy array, or a dictionary of numpy arrays to store.
"""
super().__init__(**kwargs)
self._cached_arrays: dict[str, 'ndarray'] = {}

def initialize(self):
super().initialize()
self._cached_arrays = {}

def delete_array(self, name):
def delete_array(self, name: str) -> None:
"""
Delete an array from the node. Can only be called before storing.
Expand All @@ -56,7 +70,7 @@ def delete_array(self, name):
# Should not happen, but do not crash if for some reason the property was not set.
pass

def get_arraynames(self):
def get_arraynames(self) -> list[str]:
"""
Return a list of all arrays stored in the node, listing the files (and
not relying on the properties).
Expand All @@ -66,21 +80,21 @@ def get_arraynames(self):
"""
return self._arraynames_from_properties()

def _arraynames_from_files(self):
def _arraynames_from_files(self) -> list[str]:
"""
Return a list of all arrays stored in the node, listing the files (and
not relying on the properties).
"""
return [i[:-4] for i in self.base.repository.list_object_names() if i.endswith('.npy')]

def _arraynames_from_properties(self):
def _arraynames_from_properties(self) -> list[str]:
"""
Return a list of all arrays stored in the node, listing the attributes
starting with the correct prefix.
"""
return [i[len(self.array_prefix):] for i in self.base.attributes.keys() if i.startswith(self.array_prefix)]

def get_shape(self, name):
def get_shape(self, name: str) -> tuple[int, ...]:
"""
Return the shape of an array (read from the value cached in the
properties for efficiency reasons).
Expand All @@ -89,7 +103,7 @@ def get_shape(self, name):
"""
return tuple(self.base.attributes.get(f'{self.array_prefix}{name}'))

def get_iterarrays(self):
def get_iterarrays(self) -> Iterator[tuple[str, 'ndarray']]:
"""
Iterator that returns tuples (name, array) for each array stored in the node.
Expand All @@ -99,15 +113,15 @@ def get_iterarrays(self):
for name in self.get_arraynames():
yield (name, self.get_array(name))

def get_array(self, name):
def get_array(self, name: str) -> 'ndarray':
"""
Return an array stored in the node
:param name: The name of the array to return.
"""
import numpy

def get_array_from_file(self, name):
def get_array_from_file(self, name: str) -> 'ndarray':
"""Return the array stored in a .npy file"""
filename = f'{name}.npy'

Expand All @@ -127,7 +141,7 @@ def get_array_from_file(self, name):

return self._cached_arrays[name]

def clear_internal_cache(self):
def clear_internal_cache(self) -> None:
"""
Clear the internal memory cache where the arrays are stored after being
read from disk (used in order to reduce at minimum the readings from
Expand All @@ -137,7 +151,7 @@ def clear_internal_cache(self):
"""
self._cached_arrays = {}

def set_array(self, name, array):
def set_array(self, name: str, array: 'ndarray') -> None:
"""
Store a new numpy array inside the node. Possibly overwrite the array
if it already existed.
Expand Down Expand Up @@ -171,12 +185,12 @@ def set_array(self, name, array):
handle.seek(0)

# Write the numpy array to the repository, keeping the byte representation
self.base.repository.put_object_from_filelike(handle, f'{name}.npy')
self.base.repository.put_object_from_filelike(handle, f'{name}.npy') # type: ignore[arg-type]

# Store the array name and shape for querying purposes
self.base.attributes.set(f'{self.array_prefix}{name}', list(array.shape))

def _validate(self):
def _validate(self) -> bool:
"""
Check if the list of .npy files stored inside the node and the
list of properties match. Just a name check, no check on the size
Expand All @@ -192,9 +206,9 @@ def _validate(self):
raise ValidationError(
f'Mismatch of files and properties for ArrayData node (pk= {self.pk}): {files} vs. {properties}'
)
super()._validate()
return super()._validate()

def _get_array_entries(self):
def _get_array_entries(self) -> dict[str, Any]:
"""Return a dictionary with the different array entries.
The idea is that this dictionary contains the array name as a key and
Expand All @@ -208,7 +222,7 @@ def _get_array_entries(self):
array_dict[key] = clean_array(val)
return array_dict

def _prepare_json(self, main_file_name='', comments=True): # pylint: disable=unused-argument
def _prepare_json(self, main_file_name='', comments=True) -> tuple[bytes, dict]: # pylint: disable=unused-argument
"""Dump the content of the arrays stored in this node into JSON format.
:param comments: if True, includes comments (if it makes sense for the given format)
Expand All @@ -226,7 +240,7 @@ def _prepare_json(self, main_file_name='', comments=True): # pylint: disable=un
return json.dumps(json_dict).encode('utf-8'), {}


def clean_array(array):
def clean_array(array: 'ndarray') -> list:
"""
Replacing np.nan and np.inf/-np.inf for Nones.
Expand Down

0 comments on commit c19b142

Please sign in to comment.