Skip to content

Commit

Permalink
Allow the definition of None as default in process functions (#2582)
Browse files Browse the repository at this point in the history
It was impossible to define a process function with a keyword argument
that had `None` as default, because the dynamically created ports always
only specified `orm.Data` as valid types. So the validation of the
default value would fail during the spec definition. Here we detect if
`None` is passed as the default in the function signature, in which case
we define the tuple of `orm.Data` and `type(None)` to be a valid type.
Note that we need to use `type(None)` because the port validation later
on will call `isinstance(input, valid_types)` which will not work if one
of the values in `valid_types` is simply `None`.

The `Process._setup_inputs` had to be adjusted to skip values of `None`
because they cannot be linked to obviously but they can now potentially
be passed as inputs to a `Process`.
  • Loading branch information
sphuber committed Mar 7, 2019
1 parent 838c1dd commit e64a607
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 33 deletions.
88 changes: 56 additions & 32 deletions aiida/backends/tests/engine/test_process_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from __future__ import print_function
from __future__ import absolute_import

from aiida import orm
from aiida.backends.testbase import AiidaTestCase
from aiida.engine import run, run_get_node, submit, calcfunction, workfunction, Process, ExitCode
from aiida.orm import Int, Str, WorkFunctionNode, CalcFunctionNode
from aiida.orm.nodes.data.bool import get_true_node

DEFAULT_INT = 256
Expand All @@ -36,6 +36,8 @@ class TestProcessFunction(AiidaTestCase):
function would complain as the dummy node class is not recognized as a valid process node.
"""

# pylint: disable=too-many-public-methods

def setUp(self):
super(TestProcessFunction, self).setUp()
self.assertIsNone(Process.current())
Expand All @@ -53,9 +55,15 @@ def function_args(data_a):
return data_a

@workfunction
def function_args_with_default(data_a=Int(DEFAULT_INT)):
def function_args_with_default(data_a=orm.Int(DEFAULT_INT)):
return data_a

@calcfunction
def function_with_none_default(int_a, int_b, int_c=None):
if int_c is not None:
return orm.Int(int_a + int_b + int_c)
return orm.Int(int_a + int_b)

@workfunction
def function_kwargs(**kwargs):
return kwargs
Expand All @@ -67,12 +75,12 @@ def function_args_and_kwargs(data_a, **kwargs):
return result

@workfunction
def function_args_and_default(data_a, data_b=Int(DEFAULT_INT)):
def function_args_and_default(data_a, data_b=orm.Int(DEFAULT_INT)):
return {'data_a': data_a, 'data_b': data_b}

@workfunction
def function_defaults(
data_a=Int(DEFAULT_INT), metadata={
data_a=orm.Int(DEFAULT_INT), metadata={
'label': DEFAULT_LABEL,
'description': DEFAULT_DESCRIPTION
}): # pylint: disable=unused-argument,dangerous-default-value,missing-docstring
Expand All @@ -90,6 +98,7 @@ def function_excepts(exception):
self.function_return_true = function_return_true
self.function_args = function_args
self.function_args_with_default = function_args_with_default
self.function_with_none_default = function_with_none_default
self.function_kwargs = function_kwargs
self.function_args_and_kwargs = function_args_and_kwargs
self.function_args_and_default = function_args_and_default
Expand Down Expand Up @@ -125,9 +134,9 @@ def test_source_code_attributes(self):

@calcfunction
def test_process_function(data):
return {'result': Int(data.value + 1)}
return {'result': orm.Int(data.value + 1)}

_, node = test_process_function.run_get_node(data=Int(5))
_, node = test_process_function.run_get_node(data=orm.Int(5))

# Read the source file of the calculation function that should be stored in the repository
function_source_code = node.get_function_source_code().split('\n')
Expand Down Expand Up @@ -157,25 +166,39 @@ def test_function_args(self):
with self.assertRaises(ValueError):
result = self.function_args() # pylint: disable=no-value-for-parameter

result = self.function_args(data_a=Int(arg))
self.assertTrue(isinstance(result, Int))
result = self.function_args(data_a=orm.Int(arg))
self.assertTrue(isinstance(result, orm.Int))
self.assertEqual(result, arg)

def test_function_args_with_default(self):
"""Simple process function that defines a single argument with a default."""
arg = 1

result = self.function_args_with_default()
self.assertTrue(isinstance(result, Int))
self.assertEqual(result, Int(DEFAULT_INT))
self.assertTrue(isinstance(result, orm.Int))
self.assertEqual(result, orm.Int(DEFAULT_INT))

result = self.function_args_with_default(data_a=Int(arg))
self.assertTrue(isinstance(result, Int))
result = self.function_args_with_default(data_a=orm.Int(arg))
self.assertTrue(isinstance(result, orm.Int))
self.assertEqual(result, arg)

def test_function_with_none_default(self):
"""Simple process function that defines a keyword with `None` as default value."""
int_a = orm.Int(1)
int_b = orm.Int(2)
int_c = orm.Int(3)

result = self.function_with_none_default(int_a, int_b)
self.assertTrue(isinstance(result, orm.Int))
self.assertEqual(result, orm.Int(3))

result = self.function_with_none_default(int_a, int_b, int_c)
self.assertTrue(isinstance(result, orm.Int))
self.assertEqual(result, orm.Int(6))

def test_function_kwargs(self):
"""Simple process function that defines keyword arguments."""
kwargs = {'data_a': Int(DEFAULT_INT)}
kwargs = {'data_a': orm.Int(DEFAULT_INT)}

result = self.function_kwargs()
self.assertTrue(isinstance(result, dict))
Expand All @@ -188,8 +211,8 @@ def test_function_kwargs(self):
def test_function_args_and_kwargs(self):
"""Simple process function that defines a positional argument and keyword arguments."""
arg = 1
args = (Int(DEFAULT_INT),)
kwargs = {'data_b': Int(arg)}
args = (orm.Int(DEFAULT_INT),)
kwargs = {'data_b': orm.Int(arg)}

result = self.function_args_and_kwargs(*args)
self.assertTrue(isinstance(result, dict))
Expand All @@ -202,12 +225,12 @@ def test_function_args_and_kwargs(self):
def test_function_args_and_kwargs_default(self):
"""Simple process function that defines a positional argument and an argument with a default."""
arg = 1
args_input_default = (Int(DEFAULT_INT),)
args_input_explicit = (Int(DEFAULT_INT), Int(arg))
args_input_default = (orm.Int(DEFAULT_INT),)
args_input_explicit = (orm.Int(DEFAULT_INT), orm.Int(arg))

result = self.function_args_and_default(*args_input_default)
self.assertTrue(isinstance(result, dict))
self.assertEqual(result, {'data_a': args_input_default[0], 'data_b': Int(DEFAULT_INT)})
self.assertEqual(result, {'data_a': args_input_default[0], 'data_b': orm.Int(DEFAULT_INT)})

result = self.function_args_and_default(*args_input_explicit)
self.assertTrue(isinstance(result, dict))
Expand All @@ -218,13 +241,13 @@ def test_function_args_passing_kwargs(self):
arg = 1

with self.assertRaises(ValueError):
self.function_args(data_a=Int(arg), data_b=Int(arg)) # pylint: disable=unexpected-keyword-arg
self.function_args(data_a=orm.Int(arg), data_b=orm.Int(arg)) # pylint: disable=unexpected-keyword-arg

def test_function_set_label_description(self):
"""Verify that the label and description can be set for all process function variants."""
metadata = {'label': CUSTOM_LABEL, 'description': CUSTOM_DESCRIPTION}

_, node = self.function_args.run_get_node(data_a=Int(DEFAULT_INT), metadata=metadata)
_, node = self.function_args.run_get_node(data_a=orm.Int(DEFAULT_INT), metadata=metadata)
self.assertEqual(node.label, CUSTOM_LABEL)
self.assertEqual(node.description, CUSTOM_DESCRIPTION)

Expand All @@ -236,19 +259,19 @@ def test_function_set_label_description(self):
self.assertEqual(node.label, CUSTOM_LABEL)
self.assertEqual(node.description, CUSTOM_DESCRIPTION)

_, node = self.function_args_and_kwargs.run_get_node(data_a=Int(DEFAULT_INT), metadata=metadata)
_, node = self.function_args_and_kwargs.run_get_node(data_a=orm.Int(DEFAULT_INT), metadata=metadata)
self.assertEqual(node.label, CUSTOM_LABEL)
self.assertEqual(node.description, CUSTOM_DESCRIPTION)

_, node = self.function_args_and_default.run_get_node(data_a=Int(DEFAULT_INT), metadata=metadata)
_, node = self.function_args_and_default.run_get_node(data_a=orm.Int(DEFAULT_INT), metadata=metadata)
self.assertEqual(node.label, CUSTOM_LABEL)
self.assertEqual(node.description, CUSTOM_DESCRIPTION)

def test_function_defaults(self):
"""Verify that a process function can define a default label and description but can be overriden."""
metadata = {'label': CUSTOM_LABEL, 'description': CUSTOM_DESCRIPTION}

_, node = self.function_defaults.run_get_node(data_a=Int(DEFAULT_INT))
_, node = self.function_defaults.run_get_node(data_a=orm.Int(DEFAULT_INT))
self.assertEqual(node.label, DEFAULT_LABEL)
self.assertEqual(node.description, DEFAULT_DESCRIPTION)

Expand All @@ -264,7 +287,7 @@ def test_launchers(self):
result, node = run_get_node(self.function_return_true)
self.assertTrue(result)
self.assertEqual(result, get_true_node())
self.assertTrue(isinstance(node, CalcFunctionNode))
self.assertTrue(isinstance(node, orm.CalcFunctionNode))

with self.assertRaises(AssertionError):
submit(self.function_return_true)
Expand All @@ -276,7 +299,8 @@ def test_return_exit_code(self):
exit_status = 418
exit_message = 'I am a teapot'

_, node = self.function_exit_code.run_get_node(exit_status=Int(exit_status), exit_message=Str(exit_message))
message = orm.Str(exit_message)
_, node = self.function_exit_code.run_get_node(exit_status=orm.Int(exit_status), exit_message=message)

self.assertTrue(node.is_finished)
self.assertFalse(node.is_finished_ok)
Expand All @@ -288,7 +312,7 @@ def test_normal_exception(self):
exception = 'This process function excepted'

with self.assertRaises(RuntimeError):
_, node = self.function_excepts.run_get_node(exception=Str(exception))
_, node = self.function_excepts.run_get_node(exception=orm.Str(exception))
self.assertTrue(node.is_excepted)
self.assertEqual(node.exception, exception)

Expand All @@ -307,23 +331,23 @@ def mul(data_a, data_b):
def add_mul_wf(data_a, data_b, data_c):
return mul(add(data_a, data_b), data_c)

result, node = add_mul_wf.run_get_node(Int(3), Int(4), Int(5))
result, node = add_mul_wf.run_get_node(orm.Int(3), orm.Int(4), orm.Int(5))

self.assertEqual(result, (3 + 4) * 5)
self.assertIsInstance(node, WorkFunctionNode)
self.assertIsInstance(node, orm.WorkFunctionNode)

def test_hashes(self):
"""Test that the hashes generated for identical process functions with identical inputs are the same."""
_, node1 = self.function_return_input.run_get_node(data=Int(2))
_, node2 = self.function_return_input.run_get_node(data=Int(2))
_, node1 = self.function_return_input.run_get_node(data=orm.Int(2))
_, node2 = self.function_return_input.run_get_node(data=orm.Int(2))
self.assertEqual(node1.get_hash(), node1.get_extra('_aiida_hash'))
self.assertEqual(node2.get_hash(), node2.get_extra('_aiida_hash'))
self.assertEqual(node1.get_hash(), node2.get_hash())

def test_hashes_different(self):
"""Test that the hashes generated for identical process functions with different inputs are the different."""
_, node1 = self.function_return_input.run_get_node(data=Int(2))
_, node2 = self.function_return_input.run_get_node(data=Int(3))
_, node1 = self.function_return_input.run_get_node(data=orm.Int(2))
_, node2 = self.function_return_input.run_get_node(data=orm.Int(3))
self.assertEqual(node1.get_hash(), node1.get_extra('_aiida_hash'))
self.assertEqual(node2.get_hash(), node2.get_extra('_aiida_hash'))
self.assertNotEqual(node1.get_hash(), node2.get_hash())
11 changes: 10 additions & 1 deletion aiida/engine/processes/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,19 @@ def _define(cls, spec):
if i >= first_default_pos:
default = defaults[i - first_default_pos]

# If the keyword was already specified, simply override the default
if spec.has_input(arg):
spec.inputs[arg].default = default
else:
spec.input(arg, valid_type=orm.Data, default=default)
# If the default is `None` make sure that the port also accepts a `NoneType`
# Note that we cannot use `None` because the validation will call `isinstance` which does not work
# when passing `None`, but it does work with `NoneType` which is returned by calling `type(None)`
if default is None:
valid_type = (orm.Data, type(None))
else:
valid_type = (orm.Data,)

spec.input(arg, valid_type=valid_type, default=default)

# If the function support kwargs then allow dynamic inputs, otherwise disallow
spec.inputs.dynamic = keywords is not None
Expand Down
4 changes: 4 additions & 0 deletions aiida/engine/processes/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,10 @@ def _setup_inputs(self):

for name, node in self._flat_inputs().items():

# Certain processes allow to specify ports with `None` as acceptable values
if node is None:
continue

# Special exception: set computer if node is a remote Code and our node does not yet have a computer set
if isinstance(node, Code) and not node.is_local() and not self.node.computer:
self.node.computer = node.get_remote_computer()
Expand Down

0 comments on commit e64a607

Please sign in to comment.