Skip to content

Commit

Permalink
WIP : support more strange numpy corner-cases
Browse files Browse the repository at this point in the history
  • Loading branch information
tacaswell committed Dec 7, 2014
1 parent 0d2668f commit fa7290f
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 51 deletions.
6 changes: 4 additions & 2 deletions vttools/tests/test_wrap_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def test_enum_type():
'np.array', 'np.ndarray', '(N, M, P) array',
'(..., K) array',
'(..., M, N) array_like', '(N, M, P) ndarray',
'(M,) array_like', '(M) array_like', 'MxN array')
'(M,) array_like', '(M) array_like', 'MxN array',
'array_like, shape (M, N)', 'ndarray, float')
matrix_type_strings = (tuple('{}matrix'.format(p)
for p in ('np.', 'numpy.', '')) +
('(N, M) matrix', ))
Expand All @@ -163,7 +164,8 @@ def test_enum_type():

tuple_type_strings = ('tuple'),
seq_type_strings = ('sequence',)
dtype_type_strings = ('dtype', 'dtype like', 'np.dtype', 'numpy.dtype')
dtype_type_strings = ('dtype', 'dtype like', 'np.dtype', 'numpy.dtype',
'data-type')
bool_type_strings = ('bool', 'boolean')
file_type_strings = ('file',)
scalar_type_strings = ('scalar', )
Expand Down
100 changes: 51 additions & 49 deletions vttools/wrap_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,16 @@
import inspect
import importlib
import time
import sys
import logging
import re
from collections import OrderedDict
from numpydoc.docscrape import FunctionDoc, ClassDoc
from numpydoc.docscrape import FunctionDoc, ClassDoc, NumpyDocString
from vistrails.core.modules.vistrails_module import (Module, ModuleSettings,
ModuleError)
from vistrails.core.modules.config import IPort, OPort

from skxray.core import verbosedict

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -165,12 +166,10 @@ def docstring_func(pyobj):
Taken from:
https://github.com/numpy/numpydoc/blob/master/numpydoc/docscrape.py#L94
"""
if inspect.isfunction(pyobj) or inspect.ismethod(pyobj):
try:
return FunctionDoc(pyobj)
else:
raise ValueError("The pyobj input parameter is not a function."
"Your parameter returned {0} from "
"type(pyobj)".format(type(pyobj)))
except TypeError:
return NumpyDocString(pyobj.__doc__)


def _default_vals(pyobj):
Expand Down Expand Up @@ -232,16 +231,17 @@ def _type_optional(type_str):

_OR_REGEX = re.compile(r'\bor\b')
_OF_REGEX = re.compile(r'\bof\b')
_COMMA_REGEX = re.compile(r'\b, ?\b')
_ENUM_RE = re.compile('\{(.*)\}')
_RE_DICT = {
"object": re.compile('^(?i)(any|object)$'),
"array": re.compile('^(?i)(\(?((([A-Z0-9.]+(,|x)? *)+)|, ?)\)?)? *(((np|numpy)\.)?(nd)?array(_|-| )?(like)?)$'), # noqa,
"array": re.compile('^(?i)(\(?((([A-Z0-9.]+(,|x)? *)+)|, ?)\)?)? *(((np|numpy)\.)?(nd)?array(_|-| )?(like)?)(, shape \(([a-zA-Z],? *)+\))?$'), # noqa,
"matrix": re.compile('^(?i)(\((([A-Z0-9.]+,? *){2} ?)\))? *(((np|numpy)\.)?matrix(_|-| )?(like)?)$'), # noqa,
# note these three do not match end so 'list of ... ' matches
"list": re.compile('^(?i)list(-|_| )?(like)?'),
"tuple": re.compile('^(?i)tuple(-|_| )?(like)?'),
"seq": re.compile('^(?i)sequence(-|_| )?(like)?'),
"dtype": re.compile('^(?i)((np|numpy)\.)?dtype(-|_| )?(like)?$'),
"dtype": re.compile('^(?i)((np|numpy)\.)?d(ata-)?type(-|_| )?(like)?$'),
"bool": re.compile('^(?i)bool(ean)?$'),
"file": re.compile('^(?i)file?$'),
"scalar": re.compile('^(?i)scalar?$'),
Expand All @@ -253,7 +253,7 @@ def _type_optional(type_str):
'callable': re.compile('^(?i)(func(tion)?|callable)$'),
}

sig_map = {
sig_map = verbosedict({
'object': 'basic:Variant',
'array': 'basic:Variant',
'matrix': 'basic:Variant',
Expand All @@ -270,7 +270,7 @@ def _type_optional(type_str):
'dict': 'basic:Dictionary',
'str': 'basic:String',
'callable': 'basic:Variant'
}
})


precedence_list = ('list',
Expand Down Expand Up @@ -435,6 +435,10 @@ def _normalize_type(the_type):
if bool(_RE_DICT[n_type].search(the_type)):
return n_type

if _COMMA_REGEX.search(the_type):
left, right = the_type.split(',', 1)
return _type_precedence(left, right)

# of no patterns matched, return None to signal
# failure and let down-stream sort it out.
return None
Expand Down Expand Up @@ -472,14 +476,6 @@ def _type_precedence(left, right):
return left if left_i < right_i else right


def _generate_port_dicts(doc_struct, func):
"""
Process the docstring structure to format the
dictionaries need to
"""
pass


def _enums_equal(left, right):
"""
Compare two lists of enumn and determine if they are equivalent.
Expand All @@ -499,7 +495,7 @@ def _enums_equal(left, right):
'to use.')


def define_input_ports(docstring, func):
def define_input_ports(docstring, func, short_description_word_count=4):
"""Turn the 'Parameters' fields into VisTrails input ports
Parameters
Expand All @@ -517,18 +513,20 @@ def define_input_ports(docstring, func):
List of input_ports (Vistrails type IPort)
"""
input_ports = []
short_description_word_count = 4

default_dict = _default_vals(func)

for (the_name, the_type, the_description) in docstring['Parameters']:
# skip in-place returns
if the_name == 'output':
continue
# parse and normalize
the_type, is_optional = _type_optional(the_type)
the_type, is_enum, enum_list = _enum_type(the_type)
the_type = _normalize_type(the_type)
if the_type is None:
raise AutowrapError("")

# Trim parameter descriptions for incorporation into vistrails
short_description = _truncate_description(the_description,
short_description_word_count)
Expand Down Expand Up @@ -570,12 +568,10 @@ def define_input_ports(docstring, func):
logger.debug('port_param_dict: {0}'.format(pdict))
input_ports.append(IPort(**pdict))

if len(input_ports) == 0:
logger.debug('dir of input_ports[0]: {0}'.format(dir(input_ports[0])))
return input_ports


def define_output_ports(docstring):
def define_output_ports(docstring, short_description_word_count=4):
"""
Turn the 'Returns' fields into VisTrails output ports
Expand All @@ -593,36 +589,42 @@ def define_output_ports(docstring):

output_ports = []

# If the 'Returns' section is included, but does not have any
# parameters listed, then check the 'Parameters' section to see
# whether the output is actually included as an optional input
for (the_name, the_type, the_description) in docstring['Parameters']:
if the_name.lower() == 'output':
the_type = _normalize_type(the_type)
if the_type is None:
# TODO dillify
raise AutowrapError("Malformed type")
output_ports.append(OPort(name=the_name,
signature=sig_map[the_type]))

# now look at the return Returns section
for (the_name, the_type, the_description) in docstring['Returns']:
the_type, is_optional = _type_optional(the_type)
if is_optional:
raise AutowrapError("Returns can not be optional")
the_type = _normalize_type(the_type)

# Trim parameter descriptions for incorporation into vistrails
short_description = _truncate_description(the_description,
short_description_word_count)

if the_type is None:
# TODO dillify
raise AutowrapError("Malformed type")

logger.debug("the_name is {0}. \n\tthe_type is {1}. "
"\n\tthe_description is {2}"
"".format(the_name, the_type, the_description))
try:
output_ports.append(OPort(name=the_name,
signature=sig_map[the_type]))
except ValueError as ve:
logger.error('ValueError raised for Returns parameter with '
'name: {0}\n\ttype: {1}\n\tdescription: {2}'
''.format(the_name, the_type, the_description))
six.reraise(ValueError, ve, sys.exc_info()[2])
for port_name in (_.strip() for _ in the_name.split(',')):
if not port_name:
raise AutowrapError("A Port with no name")
pdict = {'name': port_name,
# 'label': short_description,
'signature': sig_map[the_type]}

output_ports.append(OPort(**pdict))

# some numpy functions lack a Returns section and have and 'output'
# optional input (mostly for in-place operations)
if len(output_ports) < 1:
for (the_name, the_type, the_description) in docstring['Parameters']:
if the_name.lower() == 'output':
the_type, _ = _type_optional(the_type)
the_type = _normalize_type(the_type)
if the_type is None:
# TODO dillify
raise AutowrapError("Malformed type")
output_ports.append(OPort(name=the_name,
signature=sig_map[the_type]))

return output_ports


Expand Down Expand Up @@ -763,7 +765,7 @@ def wrap_function(func_name, module_path,
# if we can get the source, use the whole thing as the
# docstring in vistrails
doc_string = obj_src(func)
except IOError:
except (IOError, TypeError):
# if we can't, just use the docstring
doc_string = func.__doc__
# create the VisTrails input ports
Expand Down

0 comments on commit fa7290f

Please sign in to comment.