Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

relative_to port on File module #1057

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
110 changes: 96 additions & 14 deletions vistrails/core/modules/basic_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
pipelines."""
from __future__ import division

import vistrails.core.cache.hasher
from vistrails.core.cache.hasher import Hasher
from vistrails.core.debug import format_exception
from vistrails.core.modules.module_registry import get_module_registry
from vistrails.core.modules.vistrails_module import Module, new_module, \
Expand Down Expand Up @@ -69,7 +69,7 @@

###############################################################################

version = '2.1.1'
version = '2.1.2'
name = 'Basic Modules'
identifier = 'org.vistrails.vistrails.basic'
old_identifiers = ['edu.utah.sci.vistrails.basic']
Expand Down Expand Up @@ -372,7 +372,10 @@ class Path(Constant):
_settings = ModuleSettings(constant_widget=("%s:PathChooserWidget" %
constant_config_path))
_input_ports = [IPort("value", "Path"),
IPort("name", "String", optional=True)]
IPort("name", "String", optional=True),
IPort("relative_to", "String", optional=True,
default='absolute', entry_type='enum',
values=['absolute', 'vtfile', 'vistrails', 'cwd'])]
_output_ports = [OPort("value", "Path")]

@staticmethod
Expand All @@ -390,10 +393,34 @@ def validate(v):
def get_name(self):
n = None
if self.has_input("value"):
n = self.get_input("value").name
n = os.path.abspath(self.get_input("value").name)
if n is None:
self.check_input("name")
n = self.get_input("name")

relative_to = self.get_input('relative_to')
if os.path.isabs(n):
pass
elif relative_to == 'absolute':
raise ModuleError(
self,
"Path is not absolute")
elif relative_to == 'vtfile':
locator = self.moduleInfo['locator']
if not locator.to_url().startswith('file://'):
raise ModuleError(self, "Locator does not refer to a file")
n = os.path.join(os.path.dirname(locator.name), n)
elif relative_to == 'cwd':
n = os.path.abspath(n)
elif relative_to == 'vistrails':
n = os.path.abspath(os.path.join(
vistrails.core.system.vistrails_root_directory(),
n))
else:
raise ModuleError(
self,
"Invalid value for 'relative_to': %s" % relative_to)

return n

def set_results(self, n):
Expand All @@ -407,6 +434,7 @@ def compute(self):
Path.default_value = PathObject('')

def path_parameter_hasher(p):
# Here 'p' is a constant parameter, therefore relative_to is 'absolute'
def get_mtime(path):
t = int(os.path.getmtime(path))
if os.path.isdir(path):
Expand All @@ -416,7 +444,7 @@ def get_mtime(path):
t = max(t, get_mtime(subpath))
return t

h = vistrails.core.cache.hasher.Hasher.parameter_signature(p)
h = Hasher.parameter_signature(p)
try:
# FIXME: This will break with aliases - I don't really care that much
t = get_mtime(p.strValue)
Expand All @@ -427,11 +455,47 @@ def get_mtime(path):
hasher.update(str(t))
return hasher.digest()

def path_module_hasher(pipeline, module, constant_hasher_map):
# If 'value' is set, just hash normally (it's an absolute path)
# If not, look at relative_to: if it's 'absolute' or not set, hash normally
# If it's one of the other values, hash the relevant info. If it's an input
# connection, hash everything just in case, and emit a warning
if ('value' in module.connected_input_ports or
any(f.name == 'value' for f in module.functions)):
return Hasher.module_signature(module, constant_hasher_map)
else:
rel_to = None
for function in module.functions:
if function.name == 'relative_to':
rel_to = function.parameters[0].strValue
break
hasher = sha_hash()
hasher.update(Hasher.module_signature(module,
constant_hasher_map))
if rel_to == 'absolute':
pass
elif rel_to == 'vtfile':
pass # TODO: where do I get this from, at the hasher level!?
elif rel_to == 'cwd':
hasher.update('\0')
hasher.update(os.getcwd())
elif rel_to == 'vistrails':
hasher.update('\0')
hasher.update(vistrails.core.system.vistrails_root_directory())
elif 'relative_to' in module.connected_input_ports:
# No choice, hash everything
hasher.update('\0')
hasher.update(vistrails.core.system.vistrails_root_directory())
hasher.update('\0')
hasher.update(os.getcwd())
return hasher.digest()

class File(Path):
"""File is a VisTrails Module that represents a file stored on a
file system local to the machine where VisTrails is running."""

_settings = ModuleSettings(constant_signature=path_parameter_hasher,
signature=path_module_hasher,
constant_widget=("%s:FileChooserWidget" %
constant_config_path))
_input_ports = [IPort("value", "File"),
Expand All @@ -451,6 +515,7 @@ def compute(self):
class Directory(Path):

_settings = ModuleSettings(constant_signature=path_parameter_hasher,
signature=path_module_hasher,
constant_widget=("%s:DirectoryChooserWidget" %
constant_config_path))
_input_ports = [IPort("value", "Directory"),
Expand Down Expand Up @@ -1383,8 +1448,13 @@ def outputName_remap(old_conn, new_module):
ops.append(('add', new_conn_2))
return ops

def add_cwd(fname, module):
new_function = controller.create_function(module, 'relative_to',
['cwd'])
return [('add', new_function, 'module', module.id)]

module_remap = {'FileSink':
[(None, '1.6', None,
[(None, '1.6', None,
{'dst_port_remap':
{'overrideFile': 'overwrite',
'outputName': outputName_remap},
Expand Down Expand Up @@ -1416,6 +1486,18 @@ def outputName_remap(old_conn, new_module):
[(None, '2.1.1', None, {})],
'Converter':
[(None, '2.1.1', None, {})],
'Path':
[(None, '2.1.2', None,
{'dst_port_remap': {None: add_cwd}})],
'File':
[(None, '2.1.2', None,
{'dst_port_remap': {None: add_cwd}})],
'Directory':
[(None, '2.1.2', None,
{'dst_port_remap': {None: add_cwd}})],
'OutputPath':
[(None, '2.1.2', None,
{'dst_port_remap': {None: add_cwd}})],
}

return UpgradeWorkflowHandler.remap_module(controller, module_id, pipeline,
Expand Down Expand Up @@ -1686,10 +1768,10 @@ def test_full(self):
class TestUnzip(unittest.TestCase):
def test_unzip_file(self):
from vistrails.tests.utils import execute, intercept_result
from vistrails.core.system import vistrails_root_directory
zipfile = os.path.join(vistrails_root_directory(),
'tests', 'resources',
'test_archive.zip')
zipfile = os.path.join(
vistrails.core.system.vistrails_root_directory(),
'tests', 'resources',
'test_archive.zip')
with intercept_result(Unzip, 'file') as outfiles:
self.assertFalse(execute([
('Unzip', 'org.vistrails.vistrails.basic', [
Expand All @@ -1703,10 +1785,10 @@ def test_unzip_file(self):

def test_unzip_all(self):
from vistrails.tests.utils import execute, intercept_result
from vistrails.core.system import vistrails_root_directory
zipfile = os.path.join(vistrails_root_directory(),
'tests', 'resources',
'test_archive.zip')
zipfile = os.path.join(
vistrails.core.system.vistrails_root_directory(),
'tests', 'resources',
'test_archive.zip')
with intercept_result(UnzipDirectory, 'directory') as outdir:
self.assertFalse(execute([
('UnzipDirectory', 'org.vistrails.vistrails.basic', [
Expand Down