Skip to content

Commit

Permalink
Caching: Try to import an identifier if it is a class path
Browse files Browse the repository at this point in the history
Raise a `ValueError` if the identifier cannot be imported, which will
help prevent accidental typos from appearing that the caching
configuration is being ignored.
  • Loading branch information
unkcpz authored and sphuber committed Sep 4, 2023
1 parent 4ef293a commit 2c56fc2
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
29 changes: 21 additions & 8 deletions aiida/manage/caching.py
Expand Up @@ -175,14 +175,13 @@ def get_use_cache(*, identifier=None):
most_specific.append(PatternWithResult(pattern=specific_pattern, use_cache=False))

if len(most_specific) > 1:
raise exceptions.ConfigurationError((
'Invalid configuration: multiple matches for identifier {}'
', but the most specific identifier is not unique. Candidates: {}'
).format(identifier, [match.pattern for match in most_specific]))
raise exceptions.ConfigurationError(
f'Invalid configuration: multiple matches for identifier {identifier}'
f', but the most specific identifier is not unique. Candidates: {[match.pattern for match in most_specific]}'
)
if not most_specific:
raise exceptions.ConfigurationError(
'Invalid configuration: multiple matches for identifier {}, but none of them is most specific.'.
format(identifier)
f'Invalid configuration: multiple matches for identifier {identifier}, but none of them is most specific.'
)
return most_specific[0].use_cache
if enable_matches:
Expand Down Expand Up @@ -221,6 +220,8 @@ def _validate_identifier_pattern(*, identifier):
:param identifier: Process type string, or a pattern with '*' wildcard that matches it.
:type identifier: str
"""
import importlib

common_error_msg = f"Invalid identifier pattern '{identifier}': "
assert ENTRY_POINT_STRING_SEPARATOR not in '.*' # The logic of this function depends on this
# Check if it can be an entry point string
Expand All @@ -237,13 +238,14 @@ def _validate_identifier_pattern(*, identifier):
for group_name in ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP
):
raise ValueError(
common_error_msg + "Group name pattern '{}' does not match any of the AiiDA entry point group names.".
format(group_pattern)
common_error_msg +
f"Group name pattern '{group_pattern}' does not match any of the AiiDA entry point group names."
)
# The group name pattern matches, and there are no further
# entry point string separators in the identifier, hence it is
# a valid pattern.
return

# The separator might be swallowed in a wildcard, for example
# aiida.* or aiida.calculations*
if '*' in identifier:
Expand All @@ -267,3 +269,14 @@ def _validate_identifier_pattern(*, identifier):
raise ValueError(f"{common_error_msg}'{identifier_part}' is not a valid Python identifier.")
if keyword.iskeyword(identifier_part):
raise ValueError(f"{common_error_msg}'{identifier_part}' is a reserved Python keyword.")

# If there is no separator, it must be a fully qualified Python name.
try:
module_name = '.'.join(identifier.split('.')[:-1])
class_name = identifier.split('.')[-1]

spec = importlib.util.find_spec(module_name)
module = importlib.util.module_from_spec(spec)
cls = getattr(module, class_name)
except (ModuleNotFoundError, AttributeError) as exc:
raise ValueError(common_error_msg + f'{identifier} can not be imported.') from exc
2 changes: 1 addition & 1 deletion tests/manage/test_caching_config.py
Expand Up @@ -267,7 +267,7 @@ def test_disable_caching_global(configure_caching):
@pytest.mark.parametrize(
'identifier', [
'aiida.spam:Ni', 'aiida.calculations:With:second_separator', 'aiida.sp*:Ni', 'aiida.sp*!bar',
'startswith.number.2bad', 'some.thing.in.this.is.a.keyword'
'startswith.number.2bad', 'some.thing.in.this.is.a.keyword', 'invalid_module.AClass'
]
)
def test_enable_disable_invalid(identifier):
Expand Down

0 comments on commit 2c56fc2

Please sign in to comment.