Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: permit class member functions as calcfunctions #4963

Merged
merged 2 commits into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aiida/engine/processes/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,10 @@ def _define(cls, spec): # pylint: disable=unused-argument
spec.outputs.valid_type = (Data, dict)

return type(
func.__name__, (FunctionProcess,), {
func.__qualname__, (FunctionProcess,), {
'__module__': func.__module__,
'__name__': func.__name__,
'__qualname__': func.__qualname__,
'_func': staticmethod(func),
Process.define.__name__: classmethod(_define),
'_func_args': args,
Expand Down
31 changes: 21 additions & 10 deletions aiida/orm/nodes/process/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,17 +219,28 @@ def process_class(self) -> Type['Process']:
except exceptions.EntryPointError as exception:
raise ValueError(
f'could not load process class for entry point `{self.process_type}` for Node<{self.pk}>: {exception}'
)
except ValueError:
try:
import importlib
module_name, class_name = self.process_type.rsplit('.', 1)
module = importlib.import_module(module_name)
process_class = getattr(module, class_name)
except (AttributeError, ValueError, ImportError) as exception:
) from exception
except ValueError as exception:
import importlib

def str_rsplit_iter(string, sep='.'):
components = string.split(sep)
for idx in range(1, len(components)):
yield sep.join(components[:-idx]), components[-idx:]

for module_name, class_names in str_rsplit_iter(self.process_type):
try:
module = importlib.import_module(module_name)
process_class = module
for objname in class_names:
process_class = getattr(process_class, objname)
break
except (AttributeError, ValueError, ImportError):
pass
else:
raise ValueError(
f'could not load process class from `{self.process_type}` for Node<{self.pk}>: {exception}'
)
f'could not load process class from `{self.process_type}` for Node<{self.pk}>'
) from exception

return process_class

Expand Down
4 changes: 2 additions & 2 deletions aiida/plugins/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ def parse_entry_point_string(entry_point_string: str) -> Tuple[str, str]:

try:
group, name = entry_point_string.split(ENTRY_POINT_STRING_SEPARATOR)
except ValueError:
raise ValueError('invalid entry_point_string format')
except ValueError as exc:
raise ValueError(f'invalid entry_point_string format: {entry_point_string}') from exc

return group, name

Expand Down
33 changes: 33 additions & 0 deletions docs/source/topics/processes/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,39 @@ The question you should ask yourself is whether a potential problem merits throw
Or maybe, as in the example above, the problem is easily foreseeable and classifiable with a well defined exit status, in which case it might make more sense to return the exit code.
At the end one should think which solution makes it easier for a workflow calling the function to respond based on the result and what makes it easier to query for these specific failure modes.

As class member methods
=======================

.. versionadded:: 2.3

Process functions can also be declared as class member methods, for example as part of a :class:`~aiida.engine.processes.workchains.workchain.WorkChain`:

.. code-block:: python

class CalcFunctionWorkChain(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.input('x')
spec.input('y')
spec.output('sum')
spec.outline(
cls.run_compute_sum,
)

@staticmethod
@calcfunction
def compute_sum(x, y):
return x + y

def run_compute_sum(self):
self.out('sum', self.compute_sum(self.inputs.x, self.inputs.y))

In this example, the work chain declares a class method called ``compute_sum`` which is decorated with the ``calcfunction`` decorator to turn it into a calculation function.
It is important that the method is also decorated with the ``staticmethod`` (see the `Python documentation <https://docs.python.org/3/library/functions.html#staticmethod>`_) such that the work chain instance is not passed when the method is invoked.
The calcfunction can be called from a work chain step like any other class method, as is shown in the last line.


Provenance
==========
Expand Down
10 changes: 10 additions & 0 deletions tests/engine/calcfunctions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
"""Definition of a calculation function used in ``test_calcfunctions.py``."""
from aiida.engine import calcfunction
from aiida.orm import Int


@calcfunction
def add_calcfunction(data):
"""Calcfunction mirroring a ``test_calcfunctions`` calcfunction but has a slightly different implementation."""
return Int(data.value + 2)
22 changes: 13 additions & 9 deletions tests/engine/test_calcfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,23 +100,27 @@ def test_calcfunction_caching(self):
assert cached.base.links.get_incoming().one().node.uuid == input_node.uuid

def test_calcfunction_caching_change_code(self):
"""Verify that changing the source codde of a calcfunction invalidates any existing cached nodes."""
result_original = self.test_calcfunction(self.default_int)
"""Verify that changing the source code of a calcfunction invalidates any existing cached nodes.

# Intentionally using the same name, to check that caching anyway
# distinguishes between the calcfunctions.
@calcfunction
def add_calcfunction(data): # pylint: disable=redefined-outer-name
"""This calcfunction has a different source code from the one created at the module level."""
return Int(data.value + 2)
The ``add_calcfunction`` of the ``calcfunctions`` module uses the exact same name as the one defined in this
test module, however, it has a slightly different implementation. Note that we have to define the duplicate in
a different module, because we cannot define it in the same module (as the name clashes, on purpose) and we
cannot inline the calcfunction in this test, since inlined process functions are not valid cache sources.
"""
from .calcfunctions import add_calcfunction # pylint: disable=redefined-outer-name

result_original = self.test_calcfunction(self.default_int)

with enable_caching(identifier='*.add_calcfunction'):
result_cached, cached = add_calcfunction.run_get_node(self.default_int)
assert result_original != result_cached
assert not cached.base.caching.is_created_from_cache
assert cached.is_valid_cache

# Test that the locally-created calcfunction can be cached in principle
result2_cached, cached2 = add_calcfunction.run_get_node(self.default_int)
assert result_original != result2_cached
assert result2_cached != result_original
assert result2_cached == result_cached
assert cached2.base.caching.is_created_from_cache

def test_calcfunction_do_not_store_provenance(self):
Expand Down
73 changes: 72 additions & 1 deletion tests/engine/test_work_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from aiida.common.utils import Capturing
from aiida.engine import ExitCode, Process, ToContext, WorkChain, append_, calcfunction, if_, launch, return_, while_
from aiida.engine.persistence import ObjectLoader
from aiida.manage import get_manager
from aiida.manage import enable_caching, get_manager
from aiida.orm import Bool, Float, Int, Str, load_node


Expand Down Expand Up @@ -146,6 +146,36 @@ def _set_finished(self, function_name):
self.finished_steps[function_name] = True


class CalcFunctionWorkChain(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.input('a')
spec.input('b')
spec.output('out_member')
spec.output('out_static')
spec.outline(
cls.run_add_member,
cls.run_add_static,
)

@calcfunction
def add_member(a, b): # pylint: disable=no-self-argument
return a + b

@staticmethod
@calcfunction
def add_static(a, b):
return a + b

def run_add_member(self):
self.out('out_member', CalcFunctionWorkChain.add_member(self.inputs.a, self.inputs.b))

def run_add_static(self):
self.out('out_static', self.add_static(self.inputs.a, self.inputs.b))


class PotentialFailureWorkChain(WorkChain):
"""Work chain that can finish with a non-zero exit code."""

Expand Down Expand Up @@ -1031,6 +1061,47 @@ def _run_with_checkpoints(wf_class, inputs=None):
proc = run_and_check_success(wf_class, **inputs)
return proc.finished_steps

def test_member_calcfunction(self):
"""Test defining a calcfunction as a ``WorkChain`` member method."""
results, node = launch.run.get_node(CalcFunctionWorkChain, a=Int(1), b=Int(2))
assert node.is_finished_ok
assert results['out_member'] == 3
assert results['out_static'] == 3

@pytest.mark.usefixtures('aiida_profile_clean')
def test_member_calcfunction_caching(self):
"""Test defining a calcfunction as a ``WorkChain`` member method with caching enabled."""
results, node = launch.run.get_node(CalcFunctionWorkChain, a=Int(1), b=Int(2))
assert node.is_finished_ok
assert results['out_member'] == 3
assert results['out_static'] == 3

with enable_caching():
results, cached_node = launch.run.get_node(CalcFunctionWorkChain, a=Int(1), b=Int(2))
assert cached_node.is_finished_ok
assert results['out_member'] == 3
assert results['out_static'] == 3

# Check that the calcfunctions called by the workchain have been cached
for called in cached_node.called:
assert called.base.caching.is_created_from_cache
assert called.base.caching.get_cache_source() in [n.uuid for n in node.called]

def test_member_calcfunction_daemon(self, entry_points, daemon_client, submit_and_await):
"""Test defining a calcfunction as a ``WorkChain`` member method submitted to the daemon."""
entry_points.add(CalcFunctionWorkChain, 'aiida.workflows:testing.calcfunction.workchain')

daemon_client.start_daemon()

builder = CalcFunctionWorkChain.get_builder()
builder.a = Int(1)
builder.b = Int(2)

node = submit_and_await(builder)
assert node.is_finished_ok
assert node.outputs.out_member == 3
assert node.outputs.out_static == 3


@pytest.mark.requires_rmq
class TestWorkChainAbort:
Expand Down