Skip to content

Commit

Permalink
Caching: Add the strict argument configuration validation
Browse files Browse the repository at this point in the history
So far, the caching configuration validation only considered whether the
defined identifiers were valid syntactically. This made it possible for
a user to specify a valid identifier but that didn't actually match a
class that can be imported or an entry point that cannot be loaded. If
this is due to a typo, the user may be confused why the caching config
seems to be ignored.

The caching control functionality adds the `strict` argument, which when
set to `True`, besides checking the syntax validity of an identifier,
will also try to import/load it and raise a `ValueError` if it fails. By
default it is set to `False` to maintain backwards compatibility.
  • Loading branch information
sphuber committed Sep 4, 2023
1 parent 2c56fc2 commit f272e19
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 41 deletions.
109 changes: 70 additions & 39 deletions aiida/manage/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Definition of caching mechanism and configuration for calculations."""
from __future__ import annotations

from collections import namedtuple
from contextlib import contextmanager, suppress
from enum import Enum
Expand Down Expand Up @@ -48,18 +50,22 @@ def enable_all(self):
def disable_all(self):
self._default_all = 'disable'

def enable(self, identifier):
def enable(self, identifier: str):
self._enable.append(identifier)
with suppress(ValueError):
self._disable.remove(identifier)

def disable(self, identifier):
def disable(self, identifier: str):
self._disable.append(identifier)
with suppress(ValueError):
self._enable.remove(identifier)

def get_options(self):
"""Return the options, applying any context overrides."""
def get_options(self, strict: bool = False):
"""Return the options, applying any context overrides.
:param strict: When set to ``True``, the function will actually try to resolve the identifier by loading it and
if it fails, an exception is raised.
"""

if self._default_all == 'disable':
return False, [], []
Expand All @@ -84,7 +90,7 @@ def get_options(self):
# Check validity of enabled and disabled entries
try:
for identifier in enabled + disabled:
_validate_identifier_pattern(identifier=identifier)
_validate_identifier_pattern(identifier=identifier, strict=strict)
except ValueError as exc:
raise exceptions.ConfigurationError('Invalid identifier pattern in enable or disable list.') from exc

Expand All @@ -95,59 +101,64 @@ def get_options(self):


@contextmanager
def enable_caching(*, identifier=None):
def enable_caching(*, identifier: str | None = None, strict: bool = False):
"""Context manager to enable caching, either for a specific node class, or globally.
.. warning:: this does not affect the behavior of the daemon, only the local Python interpreter.
:param identifier: Process type string of the node, or a pattern with '*' wildcard that matches it.
If not provided, caching is enabled for all classes.
:param strict: When set to ``True``, the function will actually try to resolve the identifier by loading it and if
it fails, an exception is raised.
:type identifier: str
"""
type_check(identifier, str, allow_none=True)

if identifier is None:
_CONTEXT_CACHE.enable_all()
else:
_validate_identifier_pattern(identifier=identifier)
_validate_identifier_pattern(identifier=identifier, strict=strict)
_CONTEXT_CACHE.enable(identifier)
yield
_CONTEXT_CACHE.clear()


@contextmanager
def disable_caching(*, identifier=None):
def disable_caching(*, identifier: str | None = None, strict: bool = False):
"""Context manager to disable caching, either for a specific node class, or globally.
.. warning:: this does not affect the behavior of the daemon, only the local Python interpreter.
:param identifier: Process type string of the node, or a pattern with '*' wildcard that matches it.
If not provided, caching is disabled for all classes.
:param strict: When set to ``True``, the function will actually try to resolve the identifier by loading it and if
it fails, an exception is raised.
:type identifier: str
"""
type_check(identifier, str, allow_none=True)

if identifier is None:
_CONTEXT_CACHE.disable_all()
else:
_validate_identifier_pattern(identifier=identifier)
_validate_identifier_pattern(identifier=identifier, strict=strict)
_CONTEXT_CACHE.disable(identifier)
yield
_CONTEXT_CACHE.clear()


def get_use_cache(*, identifier=None):
def get_use_cache(*, identifier: str | None = None, strict: bool = False) -> bool:
"""Return whether the caching mechanism should be used for the given process type according to the configuration.
:param identifier: Process type string of the node
:type identifier: str
:param strict: When set to ``True``, the function will actually try to resolve the identifier by loading it and if
it fails, an exception is raised.
:return: boolean, True if caching is enabled, False otherwise
:raises: `~aiida.common.exceptions.ConfigurationError` if the configuration is invalid, either due to a general
configuration error, or by defining the class both enabled and disabled
"""
type_check(identifier, str, allow_none=True)

default, enabled, disabled = _CONTEXT_CACHE.get_options()
default, enabled, disabled = _CONTEXT_CACHE.get_options(strict=strict)

if identifier is not None:
type_check(identifier, str)
Expand Down Expand Up @@ -176,12 +187,13 @@ def get_use_cache(*, identifier=None):

if len(most_specific) > 1:
raise exceptions.ConfigurationError(
f'Invalid configuration: multiple matches for identifier {identifier}'
f', but the most specific identifier is not unique. Candidates: {[match.pattern for match in most_specific]}'
f'Invalid configuration: multiple matches for identifier `{identifier}`, but the most specific '
f'identifier is not unique. Candidates: {[match.pattern for match in most_specific]}'
)
if not most_specific:
raise exceptions.ConfigurationError(
f'Invalid configuration: multiple matches for identifier {identifier}, but none of them is most specific.'
f'Invalid configuration: multiple matches for identifier `{identifier}`, but none of them is most '
'specific.'
)
return most_specific[0].use_cache
if enable_matches:
Expand All @@ -191,17 +203,20 @@ def get_use_cache(*, identifier=None):
return default


def _match_wildcard(*, string, pattern):
"""
Helper function to check whether a given name matches a pattern
which can contain '*' wildcards.
def _match_wildcard(*, string: str, pattern: str) -> bool:
"""Return whether a given name matches a pattern which can contain '*' wildcards.
:param string: The string to check.
:param pattern: The patter to match for.
:returns: ``True`` if ``string`` matches the ``pattern``, ``False`` otherwise.
"""
regexp = '.*'.join(re.escape(part) for part in pattern.split('*'))
return re.fullmatch(pattern=regexp, string=string) is not None


def _validate_identifier_pattern(*, identifier):
"""
def _validate_identifier_pattern(*, identifier: str, strict: bool = False):
"""Validate an caching identifier pattern.
The identifier (without wildcards) can have one of two forms:
1. <group_name><ENTRY_POINT_STRING_SEPARATOR><tail>
Expand All @@ -214,20 +229,27 @@ def _validate_identifier_pattern(*, identifier):
this is a colon-separated string, where each part satisfies
`part.isidentifier() and not keyword.iskeyword(part)`
This function checks if an identifier _with_ wildcards can possibly
match one of these two forms. If it can not, a `ValueError` is raised.
This function checks if an identifier _with_ wildcards can possibly match one of these two forms. If it can not,
a ``ValueError`` is raised.
:param identifier: Process type string, or a pattern with '*' wildcard that matches it.
:type identifier: str
:param strict: When set to ``True``, the function will actually try to resolve the identifier by loading it and if
it fails, an exception is raised.
:raises ValueError: If the identifier is an invalid identifier.
:raises ValueError: If ``strict=True`` and the identifier cannot be successfully loaded.
"""
# pylint: disable=too-many-branches
import importlib

common_error_msg = f"Invalid identifier pattern '{identifier}': "
from aiida.common.exceptions import EntryPointError
from aiida.plugins.entry_point import load_entry_point_from_string

common_error_msg = f'Invalid identifier pattern `{identifier}`: '
assert ENTRY_POINT_STRING_SEPARATOR not in '.*' # The logic of this function depends on this
# Check if it can be an entry point string
if identifier.count(ENTRY_POINT_STRING_SEPARATOR) > 1:
raise ValueError(
f"{common_error_msg}Can contain at most one entry point string separator '{ENTRY_POINT_STRING_SEPARATOR}'"
f'{common_error_msg}Can contain at most one entry point string separator `{ENTRY_POINT_STRING_SEPARATOR}`'
)
# If there is one separator, it must be an entry point string.
# Check if the left hand side is a matching pattern
Expand All @@ -239,11 +261,18 @@ def _validate_identifier_pattern(*, identifier):
):
raise ValueError(
common_error_msg +
f"Group name pattern '{group_pattern}' does not match any of the AiiDA entry point group names."
f'Group name pattern `{group_pattern}` does not match any of the AiiDA entry point group names.'
)
# The group name pattern matches, and there are no further
# entry point string separators in the identifier, hence it is
# a valid pattern.

# If strict mode is enabled and the identifier is explicit, i.e., doesn't contain a wildcard, try to load it.
if strict and '*' not in identifier:
try:
load_entry_point_from_string(identifier)
except EntryPointError as exception:
raise ValueError(common_error_msg + f'`{identifier}` cannot be loaded.') from exception

# The group name pattern matches, and there are no further entry point string separators in the identifier,
# hence it is a valid pattern.
return

# The separator might be swallowed in a wildcard, for example
Expand All @@ -252,6 +281,7 @@ def _validate_identifier_pattern(*, identifier):
group_part, _ = identifier.split('*', 1)
if any(group_name.startswith(group_part) for group_name in ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP):
return

# Finally, check if it could be a fully qualified Python name
for identifier_part in identifier.split('.'):
# If it contains a wildcard, we can not check for keywords.
Expand All @@ -262,21 +292,22 @@ def _validate_identifier_pattern(*, identifier):
if not identifier_part.replace('*', 'a').isidentifier():
raise ValueError(
common_error_msg +
f"Identifier part '{identifier_part}' can not match a fully qualified Python name."
f'Identifier part `{identifier_part}` can not match a fully qualified Python name.'
)
else:
if not identifier_part.isidentifier():
raise ValueError(f"{common_error_msg}'{identifier_part}' is not a valid Python identifier.")
raise ValueError(f'{common_error_msg}`{identifier_part}` is not a valid Python identifier.')
if keyword.iskeyword(identifier_part):
raise ValueError(f"{common_error_msg}'{identifier_part}' is a reserved Python keyword.")
raise ValueError(f'{common_error_msg}`{identifier_part}` is a reserved Python keyword.')

if not strict:
return

# If there is no separator, it must be a fully qualified Python name.
try:
module_name = '.'.join(identifier.split('.')[:-1])
class_name = identifier.split('.')[-1]

spec = importlib.util.find_spec(module_name)
module = importlib.util.module_from_spec(spec)
cls = getattr(module, class_name)
except (ModuleNotFoundError, AttributeError) as exc:
raise ValueError(common_error_msg + f'{identifier} can not be imported.') from exc
module = importlib.import_module(module_name)
getattr(module, class_name)
except (ModuleNotFoundError, AttributeError, IndexError) as exc:
raise ValueError(common_error_msg + f'`{identifier}` cannot be imported.') from exc
12 changes: 12 additions & 0 deletions docs/source/howto/run_codes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,18 @@ Caching can be enabled or disabled on a case-by-case basis by using the :class:`
This affects only the current Python interpreter and won't change the behavior of the daemon workers.
This means that this technique is only useful when using :py:class:`~aiida.engine.run`, and **not** with :py:class:`~aiida.engine.submit`.

By default, the ``enable_caching`` context manager will just validate that the identifier is syntactically valid.
It *does not* validate that the identifier points to a class or entry point that actually exists and can be imported or loaded.
To make sure that the specified identifier is known to AiiDA, pass the ``strict=True`` keyword argument:

.. code-block:: python
from aiida.engine import run
from aiida.manage.caching import enable_caching
with enable_caching(identifier='aiida.calculations:core.templatereplacer', strict=True):
run(...)
When ``strict`` is set to ``True``, the function will raise a ``ValueError`` if the specified class or entry point cannot be imported or loaded.

Besides controlling which process classes are cached, it may be useful or necessary to control what already *stored* nodes are used as caching *sources*.
Section :ref:`topics:provenance:caching:control-caching` provides details how AiiDA decides which stored nodes are equivalent to the node being stored and which are considered valid caching sources.
Expand Down
37 changes: 35 additions & 2 deletions tests/manage/test_caching_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import yaml

from aiida.common import exceptions
from aiida.manage.caching import disable_caching, enable_caching, get_use_cache
from aiida.manage.caching import _validate_identifier_pattern, disable_caching, enable_caching, get_use_cache


@pytest.fixture
Expand Down Expand Up @@ -267,7 +267,7 @@ def test_disable_caching_global(configure_caching):
@pytest.mark.parametrize(
'identifier', [
'aiida.spam:Ni', 'aiida.calculations:With:second_separator', 'aiida.sp*:Ni', 'aiida.sp*!bar',
'startswith.number.2bad', 'some.thing.in.this.is.a.keyword', 'invalid_module.AClass'
'startswith.number.2bad', 'some.thing.in.this.is.a.keyword'
]
)
def test_enable_disable_invalid(identifier):
Expand All @@ -281,3 +281,36 @@ def test_enable_disable_invalid(identifier):
with pytest.raises(ValueError):
with disable_caching(identifier=identifier):
pass


@pytest.mark.parametrize(
'strict, identifier, matches', (
(False, 'aiida.calculations:core.arithmetic.add', None),
(False, 'aiida.calculations.arithmetic.add.ArithmeticAddCalculation', None),
(False, 'aiida.calculations:core.non_existent', None),
(False, 'aiida.calculations.arithmetic.non_existent.ArithmeticAddCalculation', None),
(False, 'aiida.spam:Ni', r'does not match any of the AiiDA entry point group names\.'),
(False, 'aiida.calculations:With:second_separator', r'Can contain at most one entry point string separator.*'),
(False, 'aiida.sp*:Ni', r'does not match any of the AiiDA entry point group names\.'),
(False, 'aiida.sp*!bar', r'Identifier part `sp\*!bar` can not match a fully qualified Python name.'),
(False, 'startswith.number.2bad', r'is not a valid Python identifier\.'),
(False, 'some.thing.in.this.is.a.keyword', r'is a reserved Python keyword\.'),
(True, 'aiida.calculations:core.arithmetic.add', None),
(True, 'aiida.calculations.arithmetic.add.ArithmeticAddCalculation', None),
(True, 'aiida.calculations:core.non_existent', r'cannot be loaded\.'),
(True, 'aiida.calculations.arithmetic.non_existent.ArithmeticAddCalculation', r'cannot be imported\.'),
(True, 'aiida.spam:Ni', r'does not match any of the AiiDA entry point group names\.'),
(True, 'aiida.calculations:With:second_separator', r'Can contain at most one entry point string separator.*'),
(True, 'aiida.sp*:Ni', r'does not match any of the AiiDA entry point group names\.'),
(True, 'aiida.sp*!bar', r'Identifier part `sp\*!bar` can not match a fully qualified Python name.'),
(True, 'startswith.number.2bad', r'is not a valid Python identifier\.'),
(True, 'some.thing.in.this.is.a.keyword', r'is a reserved Python keyword\.'),
)
)
def test_validate_identifier_pattern(strict, identifier, matches):
"""Test :func:`aiida.manage.caching._validate_identifier_pattern`."""
if matches:
with pytest.raises(ValueError, match=matches):
_validate_identifier_pattern(identifier=identifier, strict=strict)
else:
_validate_identifier_pattern(identifier=identifier, strict=strict)

0 comments on commit f272e19

Please sign in to comment.