-
Notifications
You must be signed in to change notification settings - Fork 521
/
wrappers.py
159 lines (139 loc) · 5.4 KB
/
wrappers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# -*- coding: utf-8 -*-
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""
# changing to temporary directories
>>> tmp = getfixture('tmpdir')
>>> old = tmp.chdir()
"""
from ... import logging
from ..base import (
traits,
DynamicTraitedSpec,
Undefined,
isdefined,
BaseInterfaceInputSpec,
)
from ..io import IOBase, add_traits
from ...utils.filemanip import ensure_list
from ...utils.functions import getsource, create_function_from_source
iflogger = logging.getLogger("nipype.interface")
class FunctionInputSpec(DynamicTraitedSpec, BaseInterfaceInputSpec):
function_str = traits.Str(mandatory=True, desc="code for function")
class Function(IOBase):
"""Runs arbitrary function as an interface
Examples
--------
>>> func = 'def func(arg1, arg2=5): return arg1 + arg2'
>>> fi = Function(input_names=['arg1', 'arg2'], output_names=['out'])
>>> fi.inputs.function_str = func
>>> res = fi.run(arg1=1)
>>> res.outputs.out
6
"""
input_spec = FunctionInputSpec
output_spec = DynamicTraitedSpec
def __init__(
self,
input_names=None,
output_names="out",
function=None,
imports=None,
**inputs
):
"""
Parameters
----------
input_names: single str or list or None
names corresponding to function inputs
if ``None``, derive input names from function argument names
output_names: single str or list
names corresponding to function outputs (default: 'out').
if list of length > 1, has to match the number of outputs
function : callable
callable python object. must be able to execute in an
isolated namespace (possibly in concert with the ``imports``
parameter)
imports : list of strings
list of import statements that allow the function to execute
in an otherwise empty namespace
"""
super(Function, self).__init__(**inputs)
if function:
if hasattr(function, "__call__"):
try:
self.inputs.function_str = getsource(function)
except IOError:
raise Exception(
"Interface Function does not accept "
"function objects defined interactively "
"in a python session"
)
else:
if input_names is None:
fninfo = function.__code__
elif isinstance(function, (str, bytes)):
self.inputs.function_str = function
if input_names is None:
fninfo = create_function_from_source(function, imports).__code__
else:
raise Exception("Unknown type of function")
if input_names is None:
input_names = fninfo.co_varnames[: fninfo.co_argcount]
self.inputs.on_trait_change(self._set_function_string, "function_str")
self._input_names = ensure_list(input_names)
self._output_names = ensure_list(output_names)
add_traits(self.inputs, [name for name in self._input_names])
self.imports = imports
self._out = {}
for name in self._output_names:
self._out[name] = None
def _set_function_string(self, obj, name, old, new):
if name == "function_str":
if hasattr(new, "__call__"):
function_source = getsource(new)
fninfo = new.__code__
elif isinstance(new, (str, bytes)):
function_source = new
fninfo = create_function_from_source(new, self.imports).__code__
self.inputs.trait_set(
trait_change_notify=False, **{"%s" % name: function_source}
)
# Update input traits
input_names = fninfo.co_varnames[: fninfo.co_argcount]
new_names = set(input_names) - set(self._input_names)
add_traits(self.inputs, list(new_names))
self._input_names.extend(new_names)
def _add_output_traits(self, base):
undefined_traits = {}
for key in self._output_names:
base.add_trait(key, traits.Any)
undefined_traits[key] = Undefined
base.trait_set(trait_change_notify=False, **undefined_traits)
return base
def _run_interface(self, runtime):
# Create function handle
function_handle = create_function_from_source(
self.inputs.function_str, self.imports
)
# Get function args
args = {}
for name in self._input_names:
value = getattr(self.inputs, name)
if isdefined(value):
args[name] = value
out = function_handle(**args)
if len(self._output_names) == 1:
self._out[self._output_names[0]] = out
else:
if isinstance(out, tuple) and (len(out) != len(self._output_names)):
raise RuntimeError("Mismatch in number of expected outputs")
else:
for idx, name in enumerate(self._output_names):
self._out[name] = out[idx]
return runtime
def _list_outputs(self):
outputs = self._outputs().get()
for key in self._output_names:
outputs[key] = self._out[key]
return outputs