Skip to content

Commit

Permalink
Performance: Cache the lookup of entry points (#6124)
Browse files Browse the repository at this point in the history
Entry points are looked up using the `entry_points` callable of the
`importlib_metadata` module. It is wrapped by the `eps` function in the
`aiida.plugins.entry_point` module. This call, and the `.select()` filter
that is used on it to find a specific entry point can be quite expensive
as it involves a double loop in the `importlib_metadata` code. Since it
is used throughout the `aiida-core` source code whenever an entry point
is looked up, this causes a significant slowdown of module imports.

The `eps` function now pre-sorts the entry points based on the group.
This guarantees that the entry points of groups starting with `aiida.`
come first in the lookup, giving a small performance boost. The result
is then cached so the sorting is performed just once, which takes on the
order of ~30 µs.

The most expensive part is still the looping over all entry points when
`eps().select()` is called. To alleviate this, the `eps_select` function
is added which simply calls through to `eps().select()`, but which allows
the calls to be cached.

In order to implement the changes, the `importlib_metadata` package,
which provides a backport implementation of the `importlib.metadata`
module of the standard lib, was updated to v6.0.
  • Loading branch information
danielhollas committed Oct 10, 2023
1 parent 2ea5087 commit 12cc930
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 56 deletions.
4 changes: 2 additions & 2 deletions aiida/manage/configuration/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
###########################################################################
"""Module that defines the configuration file of an AiiDA instance and functions to create and load it."""
import codecs
from functools import lru_cache
from functools import cache
from importlib.resources import files
import json
import os
Expand All @@ -28,7 +28,7 @@
SCHEMA_FILE = 'config-v9.schema.json'


@lru_cache(1)
@cache
def config_schema() -> Dict[str, Any]:
"""Return the configuration schema."""
return json.loads(files(schema_module).joinpath(SCHEMA_FILE).read_text(encoding='utf8'))
Expand Down
51 changes: 24 additions & 27 deletions aiida/manage/tests/pytest_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import asyncio
import contextlib
import copy
import inspect
import io
import os
Expand All @@ -33,6 +32,7 @@
import uuid
import warnings

from importlib_metadata import EntryPoints
import plumpy
import pytest
import wrapt
Expand Down Expand Up @@ -759,9 +759,16 @@ def suppress_deprecations(wrapped, _, args, kwargs):
class EntryPointManager:
"""Manager to temporarily add or remove entry points."""

@staticmethod
def eps():
return plugins.entry_point.eps()
def __init__(self, entry_points: EntryPoints):
self.entry_points = entry_points

def eps(self) -> EntryPoints:
return self.entry_points

def eps_select(self, group, name=None) -> EntryPoints:
if name is None:
return self.eps().select(group=group)
return self.eps().select(group=group, name=name)

@staticmethod
def _validate_entry_point(entry_point_string: str | None, group: str | None, name: str | None) -> tuple[str, str]:
Expand Down Expand Up @@ -791,7 +798,6 @@ def _validate_entry_point(entry_point_string: str | None, group: str | None, nam

return group, name

@suppress_deprecations
def add(
self,
value: type | str,
Expand All @@ -817,9 +823,8 @@ def add(

group, name = self._validate_entry_point(entry_point_string, group, name)
entry_point = plugins.entry_point.EntryPoint(name, value, group)
self.eps()[group].append(entry_point)
self.entry_points = EntryPoints(self.entry_points + (entry_point,))

@suppress_deprecations
def remove(
self, entry_point_string: str | None = None, *, name: str | None = None, group: str | None = None
) -> None:
Expand All @@ -835,31 +840,23 @@ def remove(
:raises ValueError: If `entry_point_string` is not a complete entry point string with group and name.
"""
group, name = self._validate_entry_point(entry_point_string, group, name)

for entry_point in self.eps()[group]:
if entry_point.name == name:
self.eps()[group].remove(entry_point)
break
else:
try:
self.entry_points[name]
except KeyError:
raise KeyError(f'entry point `{name}` does not exist in group `{group}`.')
self.entry_points = EntryPoints((ep for ep in self.entry_points if not (ep.name == name and ep.group == group)))


@pytest.fixture
def entry_points(monkeypatch) -> EntryPointManager:
"""Return an instance of the ``EntryPointManager`` which allows to temporarily add or remove entry points.
This fixture creates a deep copy of the entry point cache returned by the :func:`aiida.plugins.entry_point.eps`
method and then monkey patches that function to return the deepcopy. This ensures that the changes on the entry
point cache performed during the test through the manager are undone at the end of the function scope.
.. note:: This fixture does not use the ``suppress_deprecations`` decorator on purpose, but instead adds it manually
inside the fixture's body. The reason is that otherwise all deprecations would be suppressed for the entire
scope of the fixture, including those raised by the code run in the test using the fixture, which is not
desirable.
This fixture monkey patches the entry point caches returned by
the :func:`aiida.plugins.entry_point.eps` and :func:`aiida.plugins.entry_point.eps_select` functions
to class methods of the ``EntryPointManager`` so that we can dynamically add / remove entry points.
Note that we do not need a deepcopy here as ``eps()`` returns an immutable ``EntryPoints`` tuple type.
"""
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
eps_copy = copy.deepcopy(plugins.entry_point.eps())
monkeypatch.setattr(plugins.entry_point, 'eps', lambda: eps_copy)
yield EntryPointManager()
epm = EntryPointManager(plugins.entry_point.eps())
monkeypatch.setattr(plugins.entry_point, 'eps', epm.eps)
monkeypatch.setattr(plugins.entry_point, 'eps_select', epm.eps_select)
yield epm
2 changes: 1 addition & 1 deletion aiida/orm/autogroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def validate(strings: list[str] | None):
"""Validate the list of strings passed to set_include and set_exclude."""
if strings is None:
return
valid_prefixes = set(['aiida.node', 'aiida.calculations', 'aiida.workflows', 'aiida.data'])
valid_prefixes = {'aiida.node', 'aiida.calculations', 'aiida.workflows', 'aiida.data'}
for string in strings:
pieces = string.split(':')
if len(pieces) != 2:
Expand Down
48 changes: 32 additions & 16 deletions aiida/plugins/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Module to manage loading entrypoints."""
from __future__ import annotations

import enum
import functools
import traceback
Expand All @@ -30,9 +32,30 @@
ENTRY_POINT_STRING_SEPARATOR = ':'


@functools.lru_cache(maxsize=1)
def eps():
return _eps()
@functools.cache
def eps() -> EntryPoints:
"""Cache around entry_points()
This call takes around 50ms!
NOTE: For faster lookups, we sort the ``EntryPoints`` alphabetically
by the group name so that 'aiida.' groups come up first.
Unfortunately, this does not help with the entry_points.select() filter,
which will always iterate over all entry points since it looks for
possible duplicate entries.
"""
entry_points = _eps()
return EntryPoints(sorted(entry_points, key=lambda x: x.group))


@functools.lru_cache(maxsize=100)
def eps_select(group: str, name: str | None = None) -> EntryPoints:
"""
A thin wrapper around entry_points.select() calls, which are
expensive so we want to cache them.
"""
if name is None:
return eps().select(group=group)
return eps().select(group=group, name=name)


class EntryPointFormat(enum.Enum):
Expand Down Expand Up @@ -254,8 +277,7 @@ def get_entry_point_groups() -> Set[str]:

def get_entry_point_names(group: str, sort: bool = True) -> List[str]:
"""Return the entry points within a group."""
all_eps = eps()
group_names = list(all_eps.select(group=group).names)
group_names = list(get_entry_points(group).names)
if sort:
return sorted(group_names)
return group_names
Expand All @@ -268,7 +290,7 @@ def get_entry_points(group: str) -> EntryPoints:
:param group: the entry point group
:return: a list of entry points
"""
return eps().select(group=group)
return eps_select(group=group)


def get_entry_point(group: str, name: str) -> EntryPoint:
Expand All @@ -283,7 +305,7 @@ def get_entry_point(group: str, name: str) -> EntryPoint:
"""
# The next line should be removed for ``aiida-core==3.0`` when the old deprecated entry points are fully removed.
name = convert_potentially_deprecated_entry_point(group, name)
found = eps().select(group=group, name=name)
found = eps_select(group=group, name=name)
if name not in found.names:
raise MissingEntryPointError(f"Entry point '{name}' not found in group '{group}'")
# If multiple entry points are found and they have different values we raise, otherwise if they all
Expand Down Expand Up @@ -326,15 +348,9 @@ def get_entry_point_from_class(class_module: str, class_name: str) -> Tuple[Opti
:param class_name: name of the class
:return: a tuple of the corresponding group and entry point or None if not found
"""
for group in get_entry_point_groups():
for entry_point in get_entry_points(group):

if entry_point.module != class_module:
continue

if entry_point.attr == class_name:
return group, entry_point

for entry_point in eps():
if entry_point.module == class_module and entry_point.attr == class_name:
return entry_point.group, entry_point
return None, None


Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies:
- jinja2~=3.0
- jsonschema~=3.0
- kiwipy[rmq]~=0.7.7
- importlib-metadata~=4.13
- importlib-metadata~=6.0
- numpy~=1.21
- paramiko>=2.7.2,~=2.7
- plumpy~=0.21.6
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ dependencies = [
"jinja2~=3.0",
"jsonschema~=3.0",
"kiwipy[rmq]~=0.7.7",
"importlib-metadata~=4.13",
"importlib-metadata~=6.0",
"numpy~=1.21",
"paramiko~=2.7,>=2.7.2",
"plumpy~=0.21.6",
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-py-3.10.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ graphviz==0.20.1
greenlet==2.0.2
idna==3.4
imagesize==1.4.1
importlib-metadata==4.13.0
importlib-metadata==6.8.0
iniconfig==2.0.0
ipykernel==6.23.2
ipython==8.14.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-py-3.11.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ graphviz==0.20.1
greenlet==2.0.2
idna==3.4
imagesize==1.4.1
importlib-metadata==4.13.0
importlib-metadata==6.8.0
iniconfig==2.0.0
ipykernel==6.23.2
ipython==8.14.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-py-3.9.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ graphviz==0.20.1
greenlet==2.0.2
idna==3.4
imagesize==1.4.1
importlib-metadata==4.13.0
importlib-metadata==6.8.0
importlib-resources==5.12.0
iniconfig==2.0.0
ipykernel==6.23.2
Expand Down
9 changes: 4 additions & 5 deletions tests/plugins/test_entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def select(group, name): # pylint: disable=unused-argument

@pytest.mark.parametrize(
'eps, name, exception', (
((EP(name='ep', group='gr', value='x'),), None, None),
((EP(name='ep', group='gr', value='x'),), 'ep', None),
((EP(name='ep', group='gr', value='x'),), 'non-existing', MissingEntryPointError),
((EP(name='ep', group='gr', value='x'), EP(name='ep', group='gr', value='y')), None, MultipleEntryPointError),
((EP(name='ep', group='gr', value='x'), EP(name='ep', group='gr', value='x')), None, None),
((EP(name='ep', group='gr', value='x'), EP(name='ep', group='gr', value='y')), 'ep', MultipleEntryPointError),
((EP(name='ep', group='gr', value='x'), EP(name='ep', group='gr', value='x')), 'ep', None),
),
indirect=['eps']
)
Expand All @@ -91,8 +91,7 @@ def test_get_entry_point(eps, name, exception, monkeypatch):
"""
monkeypatch.setattr(entry_point, 'eps', eps)

name = name or 'ep' # Try to load the entry point with name ``ep`` unless the fixture provides one
entry_point.eps_select.cache_clear()

if exception:
with pytest.raises(exception):
Expand Down

0 comments on commit 12cc930

Please sign in to comment.