Skip to content

Commit

Permalink
Added missing attributes from Modulesettings/CIPort/COPort to specs
Browse files Browse the repository at this point in the history
  • Loading branch information
rexissimus committed Feb 10, 2015
1 parent a8c1335 commit 4753921
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 56 deletions.
5 changes: 2 additions & 3 deletions vistrails/packages/vtk/generate/parse.py
Expand Up @@ -127,9 +127,8 @@ def is_abstract():
is_algorithm = issubclass(node.klass, vtk.vtkAlgorithm)
module_spec = ModuleSpec(node.name, base_cls_name, "vtk.%s" % node.name,
node.klass.__doc__.decode('latin-1'),
input_ports, output_ports,
cacheable=cacheable,
is_algorithm=is_algorithm)
cacheable, input_ports, output_ports,
is_algorithm)

# FIXME deal with fix_classes, signatureCallable

Expand Down
152 changes: 114 additions & 38 deletions vistrails/packages/vtk/generate/specs.py
Expand Up @@ -78,36 +78,70 @@ class ModuleSpec(object):
This mirrors how the module will look in the vistrails registry
"""

attrs = ["name", "superklass", "docstring", "cacheable"]
def __init__(self, name, superklass, code_ref=None, docstring="",
port_specs=None, output_port_specs=None, cacheable=True):
if port_specs is None:
port_specs = []
# From Modulesettings. See core.modules.config._documentation
ms_attrs = ['name',
'configure_widget',
'constant_widget',
'constant_widgets',
'signature',
'constant_signature',
'color',
'fringe',
'left_fringe',
'right_fringe',
'abstract',
'namespace',
'package_version',
'hide_descriptor']
attrs = ['module_name', # Name of module (can be overridden by modulesettings)
'superklass', # class to inherit from
'code_ref', # reference to wrapped class/method
'docstring', # module __doc__
'cacheable'] # should this module be cached
attrs.extend(ms_attrs)

def __init__(self, module_name, superklass=None, code_ref=None, docstring="",
cacheable=True, input_port_specs=None, output_port_specs=None,
**kwargs):
if input_port_specs is None:
input_port_specs = []
if output_port_specs is None:
output_port_specs = []
self.name = name
self.superklass = superklass # parent module to subclass from
self.code_ref = code_ref # reference to wrapped method/class

self.module_name = module_name
self.superklass = superklass
self.code_ref = code_ref
self.docstring = docstring
self.port_specs = port_specs
self.output_port_specs = output_port_specs
self.cacheable = cacheable

self.input_port_specs = input_port_specs
self.output_port_specs = output_port_specs

for attr in self.ms_attrs:
setattr(self, attr, kwargs.get(attr, None))

self._mixin_class = None
self._mixin_functions = None

def to_xml(self, elt=None):
if elt is None:
elt = ET.Element("moduleSpec")
elt.set("name", self.name)
elt.set("module_name", self.module_name)
elt.set("superclass", self.superklass)
elt.set("code_ref", self.code_ref)
if self.cacheable is False:
elt.set("cacheable", unicode(self.cacheable))
subelt = ET.Element("docstring")
subelt.text = unicode(self.docstring)
elt.append(subelt)
for port_spec in self.port_specs:
subelt = port_spec.to_xml()
if self.cacheable is False:
elt.set("cacheable", unicode(self.cacheable))

for attr in self.ms_attrs:
value = getattr(self, attr)
if value is not None:
elt.set(attr, repr(value))

for input_port_spec in self.input_port_specs:
subelt = input_port_spec.to_xml()
elt.append(subelt)
for port_spec in self.output_port_specs:
subelt = port_spec.to_xml()
Expand All @@ -116,23 +150,30 @@ def to_xml(self, elt=None):

@classmethod
def from_xml(cls, elt):
name = elt.get("name", "")
module_name = elt.get("module_name", "")
superklass = elt.get("superclass", "")
code_ref = elt.get("code_ref", "")
cacheable = ast.literal_eval(elt.get("cacheable", "True"))

kwargs = {}
for attr in cls.ms_attrs:
value = elt.get(attr, None)
if value is not None:
kwargs[attr] = ast.literal_eval(value)

docstring = ""
port_specs = []
input_port_specs = []
output_port_specs = []
for child in elt.getchildren():
if child.tag == "inputPortSpec":
port_specs.append(InputPortSpec.from_xml(child))
input_port_specs.append(InputPortSpec.from_xml(child))
elif child.tag == "outputPortSpec":
output_port_specs.append(OutputPortSpec.from_xml(child))
elif child.tag == "docstring":
if child.text:
docstring = child.text
return cls(name, superklass, code_ref, docstring, port_specs,
output_port_specs, cacheable)
return cls(module_name, superklass, code_ref, docstring, cacheable,
input_port_specs, output_port_specs, **kwargs)

def get_output_port_spec(self, compute_name):
for ps in self.output_port_specs:
Expand All @@ -141,7 +182,7 @@ def get_output_port_spec(self, compute_name):
return None

def get_mixin_name(self):
return self.name + "Mixin"
return self.module_name + "Mixin"

def has_mixin(self):
if self._mixin_class is None:
Expand Down Expand Up @@ -175,21 +216,29 @@ def get_compute_after(self):
def get_init(self):
return self.get_mixin_function("__init__")

def get_module_settings(self):
""" Returns modulesettings dict
"""
attrs = {}
for attr in self.ms_attrs:
value = getattr(self, attr)
if value is not None:
attrs[attr] = value
return attrs

class VTKModuleSpec(ModuleSpec):
""" Represents specification of a vtk module
Adds attribute is_algorithm
"""

attrs = ["superklass"]
attrs.extend(ModuleSpec.attrs)

def __init__(self, name, superklass, code_ref, docstring="", port_specs=None,
output_port_specs=None, cacheable=True,
is_algorithm=False):
ModuleSpec.__init__(self, name, superklass, code_ref, docstring,
port_specs, output_port_specs,
cacheable)
def __init__(self, module_name, superklass=None, code_ref=None, docstring="",
cacheable=True, input_port_specs=None, output_port_specs=None,
is_algorithm=False, **kwargs):
ModuleSpec.__init__(self, module_name, superklass, code_ref, docstring,
cacheable, input_port_specs, output_port_specs,
**kwargs)
self.is_algorithm = is_algorithm

def to_xml(self, elt=None):
Expand All @@ -210,15 +259,20 @@ class PortSpec(object):
"""
xml_name = "portSpec"
# attrs tuple means (default value, [is subelement, [run eval]])
# Subelement: ?
# eval: serialize as string and use eval to get value back
# FIXME: subelement/eval not needed if using json
attrs = {"name": "", # port name
"method_name": "", # method/attribute name
"port_type": None, # type class in vistrails
"port_type": None, # type signature in vistrails
"docstring": ("", True), # documentation
"min_conns": (0, False, True), # set min_conns (1=required)
"max_conns": (-1, False, True), # Set max_conns (default -1)
"min_conns": (0, False, True), # set min_conns (1=required)
"max_conns": (-1, False, True), # Set max_conns (default -1)
"show_port": (False, False, True), # Set not optional (use connection)
"hide": (False, False, True), # hides/disables port (is this needed?)
"sort_key": (-1, False, True), # sort_key
"shape": (None, False, True), # physical shape
"depth": (0, False, True), # expected list depth
"other_params": (None, True, True)} # prepended params used with indexed methods

def __init__(self, arg, **kwargs):
Expand Down Expand Up @@ -386,8 +440,9 @@ class InputPortSpec(PortSpec):
xml_name = "inputPortSpec"
attrs = {"entry_types": (None, True, True),# custom entry type (like enum)
"values": (None, True, True), # values for enums
"labels": (None, True, True), # custom labels on enum values
"defaults": (None, True, True), # default value list
"translations": (None, True, True), # value translating method specified in the mako
"translations": (None, True, True), # value translating method
}
attrs.update(PortSpec.attrs)

Expand All @@ -399,8 +454,20 @@ def get_port_attrs(self):
"""
attrs = {}
if self.name:
attrs["name"] = self.name
if self.port_type:
attrs["signature"] = self.port_type
if self.sort_key != -1:
attrs["sort_key"] = self.sort_key
if self.shape:
attrs["shape"] = self.shape
if self.depth:
attrs["depth"] = self.depth
if self.values:
attrs["values"] = unicode(self.values)
if self.labels:
attrs["labels"] = unicode(self.labels)
if self.entry_types:
attrs["entry_types"] = unicode(self.entry_types)
if self.defaults:
Expand All @@ -409,11 +476,11 @@ def get_port_attrs(self):
attrs["docstring"] = self.docstring
if self.min_conns:
attrs["min_conns"] = self.min_conns
if self.max_conns:
if self.max_conns != -1:
attrs["max_conns"] = self.max_conns
if not self.show_port:
attrs["optional"] = True
return unicode(attrs)
return attrs

class OutputPortSpec(PortSpec):
xml_name = "outputPortSpec"
Expand All @@ -437,15 +504,24 @@ def get_port_attrs(self):
"""
attrs = {}
attrs["name"] = self.name
if self.port_type:
attrs["signature"] = self.port_type
if self.sort_key != -1:
attrs["sort_key"] = self.sort_key
if self.shape:
attrs["shape"] = self.shape
if self.depth:
attrs["depth"] = self.depth
if self.docstring:
attrs["docstring"] = self.docstring
if self.min_conns:
attrs["min_conns"] = self.min_conns
if self.max_conns:
if self.max_conns != -1:
attrs["max_conns"] = self.max_conns
if not self.show_port:
attrs["optional"] = True
return unicode(attrs)
return attrs

#def run():
# specs = SpecList.read_from_xml("mpl_plots_raw.xml")
Expand Down
34 changes: 19 additions & 15 deletions vistrails/packages/vtk/generate/vtk_template.py.mako
Expand Up @@ -4,6 +4,7 @@ import vtk

from vistrails.core import debug
from vistrails.core.modules.vistrails_module import Module, ModuleError
from vistrails.core.modules.config import CIPort, COPort, ModuleSettings

from bases import vtkObjectBase
<%def name="get_translate(t_spec, t_ps)">\
Expand All @@ -23,38 +24,41 @@ def translate_file(f):
return f.name

% for spec in specs.module_specs:
% for ps in spec.port_specs:
% for ps in spec.input_port_specs:
% if ps.translations and type(ps.translations) == dict:
def translate_${spec.name}_${ps.name}(val):
def translate_${spec.module_name}_${ps.name}(val):
translate_dict = ${ps.translations}
return translate_dict[val]
% endif
% endfor
% endfor

% for spec in specs.module_specs:
class ${spec.name}(${spec.superklass}):
class ${spec.module_name}(${spec.superklass}):
"""${spec.docstring}
"""

% if len(spec.get_module_settings()):
_module_settings = ModuleSettings(**${unicode(spec.get_module_settings())})
%endif

_input_ports = [
% for ps in spec.port_specs:
% for ps in spec.input_port_specs:
% if not ps.hide:
("${ps.name}", "${ps.get_port_type()}",
${ps.get_port_attrs()}),
CIPort(**${unicode(ps.get_port_attrs())}),
% endif
% endfor
]

_output_ports = [
("self", "(${spec.name})"),
("self", "(${spec.module_name})"),
% for ps in spec.output_port_specs:
("${ps.name}", "${ps.get_port_type()}",
${ps.get_port_attrs()}),
COPort(**${unicode(ps.get_port_attrs())}),
% endfor
]

set_method_table = {
% for ps in spec.port_specs:
% for ps in spec.input_port_specs:
"${ps.name}": ("${ps.method_name}", ${ps.get_port_shape()}, ${ps.get_other_params()}, ${get_translate(spec, ps)}),
% endfor
}
Expand All @@ -71,14 +75,14 @@ class ${spec.name}(${spec.superklass}):

@staticmethod
def get_set_method_info(port_name):
if port_name in ${spec.name}.set_method_table:
return ${spec.name}.set_method_table[port_name]
if port_name in ${spec.module_name}.set_method_table:
return ${spec.module_name}.set_method_table[port_name]
return ${spec.superklass}.get_set_method_info(port_name)

@staticmethod
def get_get_method_info(port_name):
if port_name in ${spec.name}.get_method_table:
return ${spec.name}.get_method_table[port_name]
if port_name in ${spec.module_name}.get_method_table:
return ${spec.module_name}.get_method_table[port_name]
return ${spec.superklass}.get_get_method_info(port_name)

def compute(self):
Expand Down Expand Up @@ -143,6 +147,6 @@ class ${spec.name}(${spec.superklass}):

_modules = [
% for spec in specs.module_specs:
${spec.name},
${spec.module_name},
% endfor
]

0 comments on commit 4753921

Please sign in to comment.