diff --git a/plumpy/ports.py b/plumpy/ports.py index 7e5955ab..08ec3932 100644 --- a/plumpy/ports.py +++ b/plumpy/ports.py @@ -7,12 +7,13 @@ import logging import six +from plumpy.utils import is_mutable_property, type_check + if six.PY2: import collections else: import collections.abc as collections -from plumpy.utils import is_mutable_property _LOGGER = logging.getLogger(__name__) UNSPECIFIED = () @@ -525,7 +526,7 @@ def create_port_namespace(self, name, **kwargs): else: return self[port_name] - def absorb(self, port_namespace, exclude=(), include=None, namespace_options={}): + def absorb(self, port_namespace, exclude=None, include=None, namespace_options=None): """Absorb another PortNamespace instance into oneself, including all its mutable properties and ports. Mutable properties of self will be overwritten with those of the port namespace that is to be absorbed. @@ -543,6 +544,16 @@ def absorb(self, port_namespace, exclude=(), include=None, namespace_options={}) if not isinstance(port_namespace, PortNamespace): raise ValueError('port_namespace has to be an instance of PortNamespace') + if exclude is not None and include is not None: + raise ValueError('exclude and include are mutually exclusive') + elif exclude is not None: + type_check(exclude, (list, tuple)) + elif include is not None: + type_check(include, (list, tuple)) + + if namespace_options is None: + namespace_options = {} + # Overload mutable attributes of PortNamespace unless overridden by value in namespace_options for attr in dir(port_namespace): if is_mutable_property(PortNamespace, attr): @@ -557,22 +568,34 @@ def absorb(self, port_namespace, exclude=(), include=None, namespace_options={}) absorbed_ports = [] - for port_name, port in self._filter_ports(list(port_namespace.items()), exclude=exclude, include=include): + for port_name, port in port_namespace.items(): + + # If the current port name occurs in the exclude list, simply skip it entirely, there is no need to consider + # any of the nested ports it might have, even if it is a port namespace + if exclude and port_name in exclude: + continue if isinstance(port, PortNamespace): - # Strip the namespace's name from the exclude and include rules - stripped_exclude = self.strip_namespace(port_name, self.NAMESPACE_SEPARATOR, exclude) - stripped_include = self.strip_namespace(port_name, self.NAMESPACE_SEPARATOR, include) + # If the name does not appear at the start of any of the include rules we continue: + if include and not any([rule.startswith(port_name) for rule in include]): + continue - # Create a new namespace at `port_name` and absorb its ports into it, with the stripped exclude/include. - # Note that we copy the port namespace itself such that we keep all its mutable properties, but then - # reset its ports, because not all ports need to be included depending on the exclude/include rules. - # Instead the copying of the ports is taken care of by the recursive `absorb` call. + # Determine the sub exclude and include rules for this specific namespace + sub_exclude = self.strip_namespace(port_name, self.NAMESPACE_SEPARATOR, exclude) + sub_include = self.strip_namespace(port_name, self.NAMESPACE_SEPARATOR, include) + + # Create a new namespace at `port_name` and copy the original port namespace itself such that we keep + # all its mutable properties, but reset its ports, since those will be taken care of by the recursive + # absorb call that will properly consider the include and exclude rules self[port_name] = copy.copy(port) self[port_name]._ports = {} - self[port_name].absorb(port, stripped_exclude, stripped_include) + self[port_name].absorb(port, sub_exclude, sub_include) else: + # If include rules are specified but the port name does not appear, simply skip it + if include and port_name not in include: + continue + self[port_name] = copy.deepcopy(port) absorbed_ports.append(port_name) @@ -716,15 +739,15 @@ def validate_dynamic_ports(self, port_values, breadcrumbs=()): @staticmethod def strip_namespace(namespace, separator, rules=None): - """Strip the namespace from the given tuple of exclude/include rules. + """Filter given exclude/include rules staring with namespace and strip the first level. For example if the namespace is `base` and the rules are:: - ('base.a', 'relax.base.c', 'd') + ('base.a', 'base.sub.b','relax.base.c', 'd') the function will return:: - ('a', 'relax.base.c', 'd') + ('a', 'sub.c') If the rules are `None`, that is what is returned as well. @@ -733,7 +756,7 @@ def strip_namespace(namespace, separator, rules=None): :param rules: the list or tuple of exclude or include rules to strip :return: `None` if `rules=None` or the list of stripped rules """ - if not rules: + if rules is None: return rules stripped = [] @@ -743,40 +766,12 @@ def strip_namespace(namespace, separator, rules=None): for rule in rules: if rule.startswith(prefix): stripped.append(rule[len(prefix):]) - else: - stripped.append(rule) return stripped - @staticmethod - def _filter_ports(items, exclude, include): - """ - Convenience generator that will filter the items based on its keys and the exclude/include tuples. - The exclude and include tuples are mutually exclusive and only one should be defined. A key in items - will only be yielded if it appears in include or does not appear in exclude, otherwise it will be skipped - - :param items: a mapping of port names and Ports - :param exclude: a tuple of port names that are to be skipped - :param include: a tuple of port names that are the only ones to be yielded - :returns: tuple of port name and Port - """ - if exclude and include is not None: - raise ValueError('exclude and include are mutually exclusive') - - for name, port in items: - if include is not None: - if name not in include: - continue - else: - if name in exclude: - continue - - yield name, port - def breadcrumbs_to_port(breadcrumbs): - """ - Convert breadcrumbs to a string representing the port + """Convert breadcrumbs to a string representing the port :param breadcrumbs: a tuple of the path to the port :type breadcrumbs: typing.Tuple[str] diff --git a/plumpy/process_spec.py b/plumpy/process_spec.py index 196a42ea..cd9e0fd7 100644 --- a/plumpy/process_spec.py +++ b/plumpy/process_spec.py @@ -171,7 +171,7 @@ def has_output(self, name): """ return name in self.outputs - def expose_inputs(self, process_class, namespace=None, exclude=(), include=None, namespace_options={}): + def expose_inputs(self, process_class, namespace=None, exclude=None, include=None, namespace_options={}): """ This method allows one to automatically add the inputs from another Process to this ProcessSpec. The optional namespace argument can be used to group the exposed inputs in a separated PortNamespace. @@ -195,7 +195,7 @@ def expose_inputs(self, process_class, namespace=None, exclude=(), include=None, namespace_options=namespace_options, ) - def expose_outputs(self, process_class, namespace=None, exclude=(), include=None, namespace_options={}): + def expose_outputs(self, process_class, namespace=None, exclude=None, include=None, namespace_options={}): """ This method allows one to automatically add the ouputs from another Process to this ProcessSpec. The optional namespace argument can be used to group the exposed outputs in a separated PortNamespace. diff --git a/test/test_expose.py b/test/test_expose.py index 4a539e39..19118bc6 100644 --- a/test/test_expose.py +++ b/test/test_expose.py @@ -1,8 +1,10 @@ from __future__ import absolute_import -from . import utils + from plumpy.ports import PortNamespace +from plumpy.processes import Process from plumpy.process_spec import ProcessSpec from plumpy.test_utils import NewLoopProcess +from . import utils class TestExposeProcess(utils.TestCaseWithLoop): @@ -10,6 +12,20 @@ class TestExposeProcess(utils.TestCaseWithLoop): def setUp(self): super(TestExposeProcess, self).setUp() + def validator_function(input): + pass + + class BaseNamespaceProcess(NewLoopProcess): + + @classmethod + def define(cls, spec): + super(BaseNamespaceProcess, cls).define(spec) + spec.input('top') + spec.input('namespace.sub_one') + spec.input('namespace.sub_two') + spec.inputs['namespace'].valid_type = (int, float) + spec.inputs['namespace'].validator = validator_function + class BaseProcess(NewLoopProcess): @classmethod @@ -31,13 +47,40 @@ def define(cls, spec): spec.inputs.dynamic = True spec.inputs.valid_type = int + self.BaseNamespaceProcess = BaseNamespaceProcess self.BaseProcess = BaseProcess self.ExposeProcess = ExposeProcess + def check_ports(self, process, namespace, expected_port_names): + """Check the port namespace of a given process inputs spec for existence of set of expected port names.""" + port_namespace = process.spec().inputs + + if namespace is not None: + port_namespace = process.spec().inputs.get_port(namespace) + + self.assertEqual(set(port_namespace.keys()), set(expected_port_names)) + + def check_namespace_properties(self, process_left, namespace_left, process_right, namespace_right): + """Check that all properties, with exception of ports, of two port namespaces are equal.""" + if not issubclass(process_left, Process) or not issubclass(process_right, Process): + raise TypeError('`process_left` and `process_right` should be processes') + + port_namespace_left = process_left.spec().inputs.get_port(namespace_left) + port_namespace_right = process_right.spec().inputs.get_port(namespace_right) + + # Pop the ports in stored in the `_ports` attribute + port_namespace_left.__dict__.pop('_ports', None) + port_namespace_right.__dict__.pop('_ports', None) + + # The `_value_spec` is a nested dictionary so should be compared explicitly separately + value_spec_left = port_namespace_left._value_spec + value_spec_right = port_namespace_right._value_spec + + self.assertEqual(port_namespace_left.__dict__, port_namespace_right.__dict__) + self.assertEqual(value_spec_left.__dict__, value_spec_right.__dict__) + def test_expose_nested_namespace(self): - """ - Test that expose_inputs can create nested namespaces while maintaining own ports - """ + """Test that expose_inputs can create nested namespaces while maintaining own ports.""" inputs = self.ExposeProcess.spec().inputs # Verify that the nested namespaces are present @@ -57,9 +100,7 @@ def test_expose_nested_namespace(self): self.assertEqual(inputs['d'].default, 2) def test_expose_ports(self): - """ - Test that the exposed ports are present and properly deepcopied - """ + """Test that the exposed ports are present and properly deepcopied.""" exposed_inputs = self.ExposeProcess.spec().inputs.get_port('base.name.space') self.assertEqual(len(exposed_inputs), 2) @@ -74,9 +115,7 @@ def test_expose_ports(self): self.assertEqual(exposed_inputs['a'].default, 'a') def test_expose_attributes(self): - """ - Test that the attributes of the exposed PortNamespace are maintained and properly deepcopied - """ + """Test that the attributes of the exposed PortNamespace are maintained and properly deepcopied.""" inputs = self.ExposeProcess.spec().inputs exposed_inputs = self.ExposeProcess.spec().inputs.get_port('base.name.space') @@ -92,9 +131,7 @@ def test_expose_attributes(self): self.assertEqual(inputs.valid_type, int) def test_expose_exclude(self): - """ - Test that the exclude argument of exposed_inputs works correctly and excludes ports from being absorbed - """ + """Test that the exclude argument of exposed_inputs works correctly and excludes ports from being absorbed.""" BaseProcess = self.BaseProcess class ExcludeProcess(NewLoopProcess): @@ -112,9 +149,7 @@ def define(cls, spec): self.assertTrue('a' not in inputs) def test_expose_include(self): - """ - Test that the include argument of exposed_inputs works correctly and includes only specified ports - """ + """Test that the include argument of exposed_inputs works correctly and includes only specified ports.""" BaseProcess = self.BaseProcess class ExcludeProcess(NewLoopProcess): @@ -132,9 +167,7 @@ def define(cls, spec): self.assertTrue('a' not in inputs) def test_expose_exclude_include_mutually_exclusive(self): - """ - Test that passing both exclude and include raises - """ + """Test that passing both exclude and include raises.""" BaseProcess = self.BaseProcess class ExcludeProcess(NewLoopProcess): @@ -149,62 +182,6 @@ def define(cls, spec): with self.assertRaises(ValueError): ExcludeProcess.spec() - def test_expose_nested_exclude(self): - """Test the exclude rules can be nested and are properly unwrapped.""" - BaseProcess = self.BaseProcess - - def test_validator(self): - pass - - class BaseProcess(NewLoopProcess): - - @classmethod - def define(cls, spec): - super(BaseProcess, cls).define(spec) - spec.input('a', valid_type=str, default='a') - spec.input('b', valid_type=str, default='b') - spec.inputs.dynamic = True - spec.inputs.valid_type = str - spec.inputs.help = 'Base Process' - - class SubProcess(NewLoopProcess): - - @classmethod - def define(cls, spec): - super(SubProcess, cls).define(spec) - spec.expose_inputs(BaseProcess, namespace='base') - spec.input('c', valid_type=str, default='c') - spec.input('d', valid_type=str, default='d') - spec.inputs.valid_type = int - spec.inputs.help = 'Sub Process' - spec.inputs.validator = test_validator - - class ExcludeProcess(NewLoopProcess): - - @classmethod - def define(cls, spec): - super(ExcludeProcess, cls).define(spec) - spec.expose_inputs(SubProcess, exclude=('base.a', 'c')) - - inputs = ExcludeProcess.spec().inputs - - # Check that port `base.b` is present but `base.a` is not, as it was excluded - self.assertTrue('a' not in inputs['base']) - self.assertTrue('b' in inputs['base']) - self.assertTrue('c' not in inputs) - self.assertTrue('d' in inputs) - - # Properties of the exposed sub namespaces should have been preserved - self.assertEqual(inputs['base'].dynamic, True) - self.assertEqual(inputs['base'].valid_type, str) - self.assertEqual(inputs['base'].help, 'Base Process') - - # Properties of the top level should match that of the `SubProcess` because it was not exposed in a namespace - self.assertEqual(inputs.dynamic, True) - self.assertEqual(inputs.valid_type, int) - self.assertEqual(inputs.help, 'Sub Process') - self.assertEqual(inputs.validator, test_validator) - def test_expose_ports_top_level(self): """ Verify that exposing a sub process in top level correctly overrides the parent's namespace @@ -371,3 +348,127 @@ def test_expose_ports_namespace_options_non_existent(self): namespace_options={ 'non_existent': None, }) + + def test_expose_nested_include_top_level(self): + """Test the include rules can be nested and are properly unwrapped.""" + BaseNamespaceProcess = self.BaseNamespaceProcess + + class ExposeProcess(NewLoopProcess): + + @classmethod + def define(cls, spec): + super(ExposeProcess, cls).define(spec) + spec.expose_inputs(BaseNamespaceProcess, namespace='base', include=('top',)) + + self.check_ports(ExposeProcess, None, ['base']) + self.check_ports(ExposeProcess, 'base', ['top']) + + def test_expose_nested_include_namespace(self): + """Test the include rules can be nested and are properly unwrapped.""" + BaseNamespaceProcess = self.BaseNamespaceProcess + + class ExposeProcess(NewLoopProcess): + + @classmethod + def define(cls, spec): + super(ExposeProcess, cls).define(spec) + spec.expose_inputs(BaseNamespaceProcess, namespace='base', include=('namespace',)) + + self.check_ports(ExposeProcess, None, ['base']) + self.check_ports(ExposeProcess, 'base', ['namespace']) + self.check_ports(ExposeProcess, 'base.namespace', ['sub_one', 'sub_two']) + self.check_namespace_properties(BaseNamespaceProcess, 'namespace', ExposeProcess, 'base.namespace') + + def test_expose_nested_include_namespace_sub(self): + """Test the include rules can be nested and are properly unwrapped.""" + BaseNamespaceProcess = self.BaseNamespaceProcess + + class ExposeProcess(NewLoopProcess): + + @classmethod + def define(cls, spec): + super(ExposeProcess, cls).define(spec) + spec.expose_inputs(BaseNamespaceProcess, namespace='base', include=('namespace.sub_two',)) + + self.check_ports(ExposeProcess, None, ['base']) + self.check_ports(ExposeProcess, 'base', ['namespace']) + self.check_ports(ExposeProcess, 'base.namespace', ['sub_two']) + self.check_namespace_properties(BaseNamespaceProcess, 'namespace', ExposeProcess, 'base.namespace') + + def test_expose_nested_include_combination(self): + """Test the include rules can be nested and are properly unwrapped.""" + BaseNamespaceProcess = self.BaseNamespaceProcess + + class ExposeProcess(NewLoopProcess): + + @classmethod + def define(cls, spec): + super(ExposeProcess, cls).define(spec) + spec.expose_inputs(BaseNamespaceProcess, namespace='base', include=('namespace.sub_two', 'top')) + + self.check_ports(ExposeProcess, None, ['base']) + self.check_ports(ExposeProcess, 'base', ['namespace', 'top']) + self.check_ports(ExposeProcess, 'base.namespace', ['sub_two']) + self.check_namespace_properties(BaseNamespaceProcess, 'namespace', ExposeProcess, 'base.namespace') + + def test_expose_nested_exclude_top_level(self): + """Test the exclude rules can be nested and are properly unwrapped.""" + BaseNamespaceProcess = self.BaseNamespaceProcess + + class ExposeProcess(NewLoopProcess): + + @classmethod + def define(cls, spec): + super(ExposeProcess, cls).define(spec) + spec.expose_inputs(BaseNamespaceProcess, namespace='base', exclude=('top',)) + + self.check_ports(ExposeProcess, None, ['base']) + self.check_ports(ExposeProcess, 'base', ['namespace']) + self.check_ports(ExposeProcess, 'base.namespace', ['sub_one', 'sub_two']) + self.check_namespace_properties(BaseNamespaceProcess, 'namespace', ExposeProcess, 'base.namespace') + + def test_expose_nested_exclude_namespace(self): + """Test the exclude rules can be nested and are properly unwrapped.""" + BaseNamespaceProcess = self.BaseNamespaceProcess + + class ExposeProcess(NewLoopProcess): + + @classmethod + def define(cls, spec): + super(ExposeProcess, cls).define(spec) + spec.expose_inputs(BaseNamespaceProcess, namespace='base', exclude=('namespace',)) + + self.check_ports(ExposeProcess, None, ['base']) + self.check_ports(ExposeProcess, 'base', ['top']) + + def test_expose_nested_exclude_namespace_sub(self): + """Test the exclude rules can be nested and are properly unwrapped.""" + BaseNamespaceProcess = self.BaseNamespaceProcess + + class ExposeProcess(NewLoopProcess): + + @classmethod + def define(cls, spec): + super(ExposeProcess, cls).define(spec) + spec.expose_inputs(BaseNamespaceProcess, namespace='base', exclude=('namespace.sub_two',)) + + self.check_ports(ExposeProcess, None, ['base']) + self.check_ports(ExposeProcess, 'base', ['top', 'namespace']) + self.check_ports(ExposeProcess, 'base.namespace', ['sub_one']) + self.check_namespace_properties(BaseNamespaceProcess, 'namespace', ExposeProcess, 'base.namespace') + + def test_expose_nested_exclude_combination(self): + """Test the exclude rules can be nested and are properly unwrapped.""" + BaseNamespaceProcess = self.BaseNamespaceProcess + + class ExposeProcess(NewLoopProcess): + + @classmethod + def define(cls, spec): + super(ExposeProcess, cls).define(spec) + spec.expose_inputs(BaseNamespaceProcess, namespace='base', exclude=('namespace.sub_two', 'top')) + + self.check_ports(ExposeProcess, None, ['base']) + self.check_ports(ExposeProcess, 'base', ['namespace']) + self.check_ports(ExposeProcess, 'base.namespace', ['sub_one']) + self.check_namespace_properties(BaseNamespaceProcess, 'namespace', ExposeProcess, 'base.namespace')