Skip to content

Commit

Permalink
Merge pull request #256 from SCM-NV/lazy-loading
Browse files Browse the repository at this point in the history
MAINT: Clean up the template and RDKit lazy loading
  • Loading branch information
BvB93 committed Oct 14, 2021
2 parents fe8bd15 + c0daa45 commit e4c4772
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 74 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def readme():
]

tests_require = [
'assertionlib>=2.2.0',
'assertionlib>=2.3.0',
'mypy',
'pytest>=5.4',
'pytest-cov',
Expand Down
85 changes: 39 additions & 46 deletions src/qmflows/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
"""QMFlows API."""

import sys
import types
import importlib as _importlib
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

from .__version__ import __version__

Expand All @@ -17,6 +15,13 @@
from . import templates
from .settings import Settings

try:
import rdkit
except ModuleNotFoundError as ex:
_RDKIT_EX: "None | ModuleNotFoundError" = ex
else:
_RDKIT_EX = None

__all__ = [
'__version__',
'logger',
Expand All @@ -26,57 +31,45 @@
'example_H2O2_TS', 'example_freqs', 'example_generic_constraints',
'example_partial_geometry_opt',
'freq', 'geometry', 'singlepoint', 'ts', 'md', 'cell_opt',
'find_first_job', 'select_max', 'select_min']
'find_first_job', 'select_max', 'select_min',
]

# Use `__getattr__` for loading (and copying) the templates in python >= 3.7
if TYPE_CHECKING or sys.version_info < (3, 7):
from .templates import freq, geometry, singlepoint, ts, md, cell_opt

# Use `__getattr__` to raise a more descriptive error if RDKit
# is not installed (requires python >= 3.7)
if TYPE_CHECKING or sys.version_info < (3, 7) or _RDKIT_EX is None:
from .components import (
Angle, Dihedral, Distance, find_first_job, select_max, select_min
Angle,
Dihedral,
Distance,
find_first_job,
select_max,
select_min,
)
from .examples import (
example_H2O2_TS, example_freqs, example_generic_constraints, example_partial_geometry_opt
example_H2O2_TS,
example_freqs,
example_generic_constraints,
example_partial_geometry_opt,
)
from . import components, examples
else:
_TEMPLATES = frozenset(templates.__all__)
_REQUIRES_RDKIT = types.MappingProxyType({
"components": "qmflows.components",
"Angle": "qmflows.components",
"Dihedral": "qmflows.components",
"Distance": "qmflows.components",
"find_first_job": "qmflows.components",
"select_max": "qmflows.components",
"select_min": "qmflows.components",
"examples": "qmflows.examples",
"example_H2O2_TS": "qmflows.examples",
"example_freqs": "qmflows.examples",
"example_generic_constraints": "qmflows.examples",
"example_partial_geometry_opt": "qmflows.examples",
})

_DIR_CACHE: "None | list[str]" = None

def __getattr__(name: str) -> Any:
"""Ensure that the qmflows templates are always copied before returning."""
if name in _TEMPLATES:
return getattr(templates, name).copy()

# Lazily load (and cache) the content of `qmflows.examples` and `
# qmflows.components` in order to avoid directly importing RDKit
module_name = _REQUIRES_RDKIT.get(name)
if module_name is not None:
globals()[module_name] = module = _importlib.import_module(module_name)
globals()[name] = ret = getattr(module, name, module)
return ret
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

def __dir__() -> "list[str]":
"""Manually insert the qmflows templates into :func:`dir`."""
global _DIR_CACHE
if _DIR_CACHE is None:
_DIR_CACHE = list(globals()) + templates.__all__ + list(_REQUIRES_RDKIT)
_DIR_CACHE.sort()
return _DIR_CACHE
if sys.version_info >= (3, 7):
from ._init_utils import (
getattr_method as __getattr__,
dir_method as __dir__,
RDKIT_SET,
)
if _RDKIT_EX is not None:
__all__ = [name for name in __all__ if name not in RDKIT_SET]
del RDKIT_SET
else:
# Initalize the sub-module such that `_RDKIT_EX` can enter its namespace
from . import _init_utils
del _init_utils

# Clean up the namespace
del sys, types, TYPE_CHECKING, Any
del sys, TYPE_CHECKING, _RDKIT_EX
74 changes: 74 additions & 0 deletions src/qmflows/_init_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""``__getattr__`` and ``__dir__`` implementations for the main QMFlows namespace."""

import sys
import types
import importlib
from typing import Any

import qmflows

__all__ = ["dir_method", "getattr_method", "TEMPLATE_DICT", "RDKIT_SET", "RDKIT_EX"]

RDKIT_EX = qmflows._RDKIT_EX

# Map template names to the (cached) template
TEMPLATE_DICT = types.MappingProxyType({
k: getattr(qmflows.templates, k) for k in qmflows.templates.__all__
})

# Map RDKit-requiring objects to their namespace
RDKIT_SET = frozenset({
"components",
"Angle",
"Dihedral",
"Distance",
"find_first_job",
"select_max",
"select_min",
"examples",
"example_H2O2_TS",
"example_freqs",
"example_generic_constraints",
"example_partial_geometry_opt",
})


def __getattr__(self: types.ModuleType, name: str) -> Any:
"""Ensure that templates are always copied and the RDKit functions are loaded lazilly."""
# Always return a copy of the template, as inplace operations will otherwise
# modify the original template in the qmflows namespace
try:
return TEMPLATE_DICT[name].copy()
except KeyError:
pass

if name in RDKIT_SET:
raise ImportError(f"{name!r} requires the optional RDKit package") from RDKIT_EX
raise AttributeError(f"module {self.__name__!r} has no attribute {name!r}")


def __dir__(self: types.ModuleType) -> "list[str]":
"""Manually insert the qmflows templates and RDKit functions into :func:`dir`."""
try:
return self._DIR_CACHE.copy()
except AttributeError:
pass

cache_set = set(object.__dir__(qmflows)) | TEMPLATE_DICT.keys()
if RDKIT_EX is None:
cache_set |= RDKIT_SET

cache = sorted(cache_set)
setattr(self, "_DIR_CACHE", cache)
return cache.copy()


# Alias the functions under a different names such that they don't
# trigger module-level `getattr`/`dir` calls
_getattr_func = __getattr__
_dir_func = __dir__
del __getattr__
del __dir__

getattr_method = types.MethodType(_getattr_func, qmflows)
dir_method = types.MethodType(_dir_func, qmflows)
45 changes: 18 additions & 27 deletions test/test_rdkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,46 +6,37 @@
import pytest
import qmflows
from qmflows.test_utils import HAS_RDKIT
from qmflows._init_utils import RDKIT_SET
from assertionlib import assertion

NAMES = frozenset({
'components',
'Angle',
'Dihedral',
'Distance',
'find_first_job',
'select_max',
'select_min',
'examples',
'example_H2O2_TS',
'example_freqs',
'example_generic_constraints',
'example_partial_geometry_opt',
})


@pytest.mark.parametrize("name", sorted(NAMES))

@pytest.mark.parametrize("name", sorted(RDKIT_SET))
def test_sub_module(name: str) -> None:
"""Test :func:`getattr` operations on rdkit-requiring objects."""
if HAS_RDKIT:
assert getattr(qmflows, name)
else:
with pytest.raises(ImportError):
match = f"{name!r} requires the optional RDKit package"
with pytest.raises(ImportError, match=match):
getattr(qmflows, name)


@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python 3.7")
def test_requires_rdkit() -> None:
"""Test that ``NAMES`` and ``qmflows._REQUIRES_RDKIT`` are synced."""
assertion.eq(set(NAMES), qmflows._REQUIRES_RDKIT.keys())


@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python 3.7")
@pytest.mark.skipif(not HAS_RDKIT, reason="requires RDKit")
def test_namespace() -> None:
"""Test that ``qmflows._REQUIRES_RDKIT`` and the sub-modules' ``__all__`` are synced."""
def test_rdkit() -> None:
"""Test that ``qmflows._init_utils.RDKIT_DICT`` and the sub-modules' ``__all__`` are synced."""
name_set = {"components", "examples"}
name_set.update(qmflows.components.__all__)
name_set.update(qmflows.examples.__all__)
assertion.eq(name_set, RDKIT_SET)

assertion.eq(name_set, qmflows._REQUIRES_RDKIT.keys())

def test_dir() -> None:
"""Test that RDKit functions are in-/excluded from ``dir``."""
all_names = RDKIT_SET - {"components", "examples"}
if HAS_RDKIT:
assertion.issubset(RDKIT_SET, dir(qmflows))
assertion.issubset(all_names, qmflows.__all__)
else:
assertion.isdisjoint(RDKIT_SET, dir(qmflows))
assertion.isdisjoint(all_names, qmflows.__all__)
4 changes: 4 additions & 0 deletions test/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,7 @@ def test_id(name: str) -> None:
s2 = getattr(qmflows, name)
assertion.eq(s1, s2)
assertion.is_not(s1, s2)


def test_namespace() -> None:
assertion.issubset(templates.__all__, dir(qmflows))

0 comments on commit e4c4772

Please sign in to comment.