Skip to content

Commit

Permalink
ENH : special case ufuncs
Browse files Browse the repository at this point in the history
ufuncs (np.sin, np.log, ...) are not actually functions, they are
instances of the ufunc class which are callable.  They do not support
unpacking dictionaries in to keywords so we need to collect (in the
correct order) the args and unpack them from a tuple.
  • Loading branch information
tacaswell committed Dec 9, 2014
1 parent fa7290f commit 2665627
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 16 deletions.
22 changes: 19 additions & 3 deletions vt_config/NSLS-II/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,16 @@
import six
import logging
import sys
import yaml

import importlib
import collections
import os

from vttools import wrap_lib
from vttools.vtmods.import_lists import load_config
from vttools.wrap_lib import AutowrapError

import numpy

logger = logging.getLogger(__name__)

# get modules to import
Expand Down Expand Up @@ -84,7 +87,20 @@ def get_modules():

vtmods = [vtmod for mod in pymods for vtmod in mod.vistrails_modules()]

all_mods = vtmods + vtfuncs # + vtclasses
funcs_to_wrap = list(set([atr.__name__ for atr in
(getattr(numpy, atr_name) for atr_name in dir(numpy)
if not atr_name.startswith('_'))
if callable(atr) and type(atr) is not type]))

numpy_mods = []
for ftw in funcs_to_wrap:
try:
tmp = wrap_lib.wrap_function(ftw, 'numpy')
numpy_mods.append(tmp)
except Exception as e:
print(e)

all_mods = vtmods + vtfuncs + numpy_mods # + vtclasses
if len(all_mods) != len(set(all_mods)):
raise ValueError('Some modules have been imported multiple times.\n'
'Full list: {0}'
Expand Down
97 changes: 84 additions & 13 deletions vttools/wrap_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from vistrails.core.modules.config import IPort, OPort

from skxray.core import verbosedict

import numpy
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -596,18 +596,13 @@ def define_output_ports(docstring, short_description_word_count=4):
raise AutowrapError("Returns can not be optional")
the_type = _normalize_type(the_type)

# Trim parameter descriptions for incorporation into vistrails
short_description = _truncate_description(the_description,
short_description_word_count)

if the_type is None:
raise AutowrapError("Malformed type")

for port_name in (_.strip() for _ in the_name.split(',')):
if not port_name:
raise AutowrapError("A Port with no name")
pdict = {'name': port_name,
# 'label': short_description,
'signature': sig_map[the_type]}

output_ports.append(OPort(**pdict))
Expand All @@ -631,6 +626,27 @@ def define_output_ports(docstring, short_description_word_count=4):
def gen_module(input_ports, output_ports, docstring,
module_name, library_func, module_namespace,
dict_port=None):
"""
Parameters
----------
input_ports : list
List of input ports
output_ports : list
List of output ports
docstring : ?
module_name : str
The name of the module (as displayed in vistrails
library_func : callable
The callable object to be wrapped for VisTrails
module_namespace : str
Vistrails namespace to use
dict_port : ?
"""

mandatory = []
optional = []
Expand Down Expand Up @@ -683,7 +699,8 @@ def compute(self):
if hasattr(val, 'value'):
print('name [{0}] has attribute value [{1}]'.format(name, val))
params_dict[name] = val.value

print(library_func.__name__)
print(params_dict)
ret = library_func(**params_dict)
if len(output_ports) == 1:
self.set_output(output_ports[0].name, ret)
Expand All @@ -704,6 +721,52 @@ def compute(self):
return new_class


def gen_module_ufunc(input_ports, output_ports, docstring,
module_name, library_func, module_namespace):

# can't unpack dicts into ufuncs, assume all are
# mandatory
mandatory = input_ports
arg_names = [m.name for m in mandatory]
if len(mandatory) != library_func.nin:
raise ValueError("wrap {} : \n".format(library_func.__name__) +
"the docstring parsing went wrong " +
"ufunc should have {} args".format(library_func.nin) +
" parsing docstring has {}".format(len(mandatory)))

if len(output_ports) != library_func.nout:
raise ValueError("wrap {} : \n".format(library_func.__name__) +
"the docstring parsing went wrong" +
"ufunc should have {} out".format(library_func.nout) +
" parsing docstring has {}".format(len(output_ports)))

def compute(self):
args = list()
for arg_name in arg_names:
args.append(self.get_input(arg_name))

print(library_func.__name__)

ret = library_func(*args)
if len(output_ports) == 1:
self.set_output(output_ports[0].name, ret)
else:
for (out_port, ret_val) in zip(output_ports, ret):
self.set_output(out_port.name, ret_val)

_settings = ModuleSettings(namespace=module_namespace)

new_class = type(str(module_name),
(Module,), {'compute': compute,
'__module__': __name__,
'_settings': _settings,
'__doc__': docstring,
'__name__': module_name,
'_input_ports': input_ports,
'_output_ports': output_ports})
return new_class


def wrap_function(func_name, module_path,
add_input_dict=False, namespace=None):
"""Perform the wrapping of functions into VisTrails modules
Expand Down Expand Up @@ -783,12 +846,20 @@ def wrap_function(func_name, module_path,
dict_port = None

# actually create the VisTrail module
generated_module = gen_module(input_ports=input_ports,
output_ports=output_ports,
docstring=doc_string, module_name=func_name,
module_namespace=namespace,
library_func=func,
dict_port=dict_port)
if not isinstance(func, numpy.ufunc):
generated_module = gen_module(input_ports=input_ports,
output_ports=output_ports,
docstring=doc_string, module_name=func_name,
module_namespace=namespace,
library_func=func,
dict_port=dict_port)

else:
generated_module = gen_module_ufunc(input_ports=input_ports,
output_ports=output_ports,
docstring=doc_string, module_name=func_name,
module_namespace=namespace,
library_func=func)

logger.info('func_name {0}, module_name {1}. Time: {2}'
''.format(func_name, module_path, format(time.time() - t1)))
Expand Down

0 comments on commit 2665627

Please sign in to comment.