Skip to content

Commit

Permalink
Merge branch 'slurmoverssh'
Browse files Browse the repository at this point in the history
  • Loading branch information
francoislaurent committed Nov 2, 2020
2 parents 9aa35cf + 0538485 commit 27b3bd8
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 21 deletions.
2 changes: 1 addition & 1 deletion scripts/tramway-browse
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def main():
parser = argparse.ArgumentParser(prog='tramway-browse',
description='Browse TRamWAy-generated .rwa files')
parser.add_argument('files', nargs='*', help='for example: *.rwa or */*.rwa')
parser.add_argument('--browser', default='Firefox', choices=['Firefox','Chrome','Edge','Ie','Opera','Safari','WebKitGTK','PhantomJS'])
parser.add_argument('--browser', default='Firefox', choices=['Firefox','Chrome','Edge','Ie','Opera','Safari','WebKitGTK'])
print(browse(**parser.parse_args().__dict__))

if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
extras_require = {
'animate': ['opencv-python', 'scikit-image', 'tqdm'],
'roi': ['polytope', 'cvxopt', 'tqdm'],
'webui': ['bokeh', 'selenium']}
'webui': ['bokeh>=2.0.2', 'selenium']}
setup_requires = ['pytest-runner']
tests_require = ['pytest']

Expand All @@ -22,7 +22,7 @@

setup(
name = 'tramway',
version = '0.5-beta2',
version = '0.5-beta3',
description = 'TRamWAy',
long_description = long_description,
url = 'https://github.com/DecBayComp/TRamWAy',
Expand Down
74 changes: 69 additions & 5 deletions tramway/analyzer/env/environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,18 +1017,43 @@ def submit_jobs(self):
self.pending_jobs = []
def wait_for_job_completion(self):
try:
cmd = 'squeue -j {} -h -o "%.2t %.10M %R"'.format(self.job_id)
cmd = 'squeue -j {} -h -o "%K %t %M %R"'.format(self.job_id)
while True:
time.sleep(self.refresh_interval)
out, err = self.ssh.exec(cmd, shell=True)
if err:
err = err.rstrip()
if err == 'slurm_load_jobs error: Invalid job id specified':
# complete
break
self.logger.error(err.rstrip())
elif out:
out = out.splitlines()[0]
out = out.split()
status, time_used, reason = out[0], out[1], ' '.join(out[2:])
self.logger.info('status: {} time used: {} reason: {}'.format(status, time_used, reason))
# parse and print progress info
out = out.splitlines()
try:
start, stop = out[0].split()[0].split('-')
except ValueError:
raise RuntimeError('unexpected squeue message: \n'+'\n'.join(out))
stop = stop.split('%')[0]
start, stop = int(start), int(stop)
total = stop
pending = stop - start
running = 0
other = 0
for out in out[1:]:
out = out.split()
array_ix, status, time_used = int(out[0]), out[1], out[2]
reason = ' '.join(out[3:])
if status == 'R':
running += 1
else:
other += 1
#self.logger.debug(task: {:d} status: {} time used: {} reason: {}'.format(array_ix, status, time_used, reason))
self.logger.info('tasks:\t{} done,\t{} running,\t{} pending{}'.format(
total-pending-running-other, running, pending,
',\t{} in abnormal state'.format(other) if other else ''))
else:
# complete
break
except:
self.logger.info('killing jobs with: scancel '+self.job_id)
Expand All @@ -1039,6 +1064,45 @@ def collect_results(self, stage_index=None):
def delete_temporary_data(self):
RemoteHost.delete_temporary_data(self)
Slurm.delete_temporary_data(self)
def resume(self, log=None, wd=None, stage_index=None, job_id=None):
"""
Parses log output of the disconnected instance, looks for the current stage index,
and tries to collect the resulting files.
This completes the current stage only. Further stages are not run.
"""
if wd is None or stage_index is None or job_id is None:
if log is None:
log = input('please copy-paste below the log output of the disconnected instance\n(job progress information can be omitted):\n')
log = log.splitlines()
#job_id = wd = stage_index = None
for line in log[::-1]:
if wd:
try:
opt = line.index(' --stage-index=')
except ValueError:
pass
else:
stage_index = line[opt+16:].split()[0]
stage_index = [ int(s) for s in stage_index.split(',') ]
break
elif job_id:
assert line.startswith('running: sbatch ')
script = line[16:].rstrip()
wd = '/'.join(wd.split('/')[:-1])
elif line.startswith('Submitted batch job '):
job_id = line[20:].rstrip()
if stage_index:
self.setup(sys.executable)
assert self.submit_side
self.delete_temporary_data() # undo wd creation during setup
self.working_directory = wd
self.job_id = job_id
self.logger.info('trying to complete stage(s): '+', '.join(stage_index))
self.wait_for_job_completion()
self.collect_results(stage_index=stage_index)
else:
self.logger.info('cannot identify an execution point where to resume from')


Environment.register(SlurmOverSSH)
Expand Down
17 changes: 17 additions & 0 deletions tramway/analyzer/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,23 @@ def run(self):
def add_collectible(self, filepath):
self.env.collectibles.add(filepath)

def resume(self, **kwargs):
"""
Looks for orphaned remote jobs and collect the generated files.
Works as a replacement for the :meth:`run` method to recover
after connection loss.
Recovery procedures featured by the `env` backend may fail or recover
some of the generated files only.
See also the *resume* method of the `env` attribute, if available.
"""
try:
self.env.resume(**kwargs)
except AttributeError:
self.logger.error('no recovery procedure available')


__all__ = ['Pipeline']

6 changes: 5 additions & 1 deletion tramway/analyzer/spt_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tramway.core.xyt import load_xyt, load_mat, discard_static_trajectories
from tramway.core.analyses.auto import Analyses, AutosaveCapable
from tramway.core.hdf5.store import load_rwa
from tramway.core.exceptions import RWAFileException
from math import sqrt
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -646,7 +647,10 @@ def set_analyses(self, tree):
self._analyses = tree
def load(self):
# ~ expansion is no longer necessary from rwa-python==0.8.4
self.analyses = load_rwa(os.path.expanduser(self.filepath), lazy=True)
try:
self.analyses = load_rwa(os.path.expanduser(self.filepath), lazy=True)
except KeyError as e:
raise RWAFileException(self.filepath, e) from None
self._trigger_discard_static_trajectories()
self._trigger_reset_origin()

Expand Down
10 changes: 10 additions & 0 deletions tramway/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,13 @@ def __str__(self):
class MisplacedAttributeWarning(UserWarning):
pass

class RWAFileException(IOError):
def __init__(self, filepath=None, exc=None):
self.filepath = filepath
self.exc = exc
def __str__(self):
if self.filepath is None:
return 'cannot find any analysis tree'
else:
return 'cannot find any analysis tree in file: '+self.filepath

107 changes: 97 additions & 10 deletions tramway/plot/bokeh/analyzer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@

from tramway.core.analyses.browser import AnalysisBrowser
from tramway.core.exceptions import RWAFileException
from tramway.tessellation.time import TimeLattice
from .map import *
import sys
import os
import numpy as np
import pandas as pd
import time
import traceback
Expand Down Expand Up @@ -67,6 +70,19 @@ def select_mapping(self, mapping_label):
self.release_mapping()
self.current_analyses.select_child(mapping_label)
self.current_mapping = self.current_analyses.artefact
if self.has_time_segments():
t = self.current_sampling.tessellation
m = self.current_mapping
self.clim = {}
for feature in m.features:
_m = m[feature].values
if _m.shape[1:]:
if 1<_m.shape[1]:
_m = np.sqrt(np.sum(_m*_m, axis=1))
else:
_m = _m.flatten()
self.clim[feature] = [_m.min(), _m.max()]
self.current_mapping = { feature: t.split_segments(m[feature]) for feature in m.features }
return self.current_mapping
def release_mapping(self, release_feature=True):
if self.current_feature is not None and release_feature:
Expand All @@ -75,11 +91,21 @@ def release_mapping(self, release_feature=True):
self.current_mapping = None
@property
def features(self):
return list(self.current_mapping.features)
assert self.current_mapping is not None
if self.has_time_segments():
return list(self.current_mapping.keys())
else:
return list(self.current_mapping.features)
def select_feature(self, feature):
self.current_feature = feature
def release_feature(self):
self.current_feature = None
def has_time_segments(self):
assert self.current_sampling is not None
return isinstance(self.current_sampling.tessellation, TimeLattice)
def n_time_segments(self):
assert self.has_time_segments()
return len(self.current_sampling.tessellation.time_lattice)

class Controller(object):
"""
Expand Down Expand Up @@ -120,7 +146,15 @@ def load_source(self, source_name):
if source_name == 'None':
#self.unload_source()
return
self.model.select_spt_data(source_name)
try:
self.model.select_spt_data(source_name)
except RWAFileException:
import traceback
traceback.print_exc()
self.source_dropdown.options = [ src for src in self.source_dropdown.options if src != source_name ]
if not (self.source_dropdown.options and (self.source_dropdown.options[1:] or self.source_dropdown.options[0] != 'None')):
self.source_dropdown.disabled = True
return
self.model.sampling_labels = list(self.model.current_analyses.labels())
options = self.model.sampling_labels
self.sampling_dropdown.options = ['None'] + options if options[1:] else options
Expand All @@ -146,6 +180,10 @@ def load_sampling(self, sampling_label):
options = self.model.mapping_labels
self.mapping_dropdown.options = ['None'] + options if options[1:] else options
self.mapping_dropdown.disabled = False
# time segmentation support
if self.model.has_time_segments():
self.time_slider.update(start=1, end=self.model.n_time_segments())
#
if not self.model.mapping_labels[1:]:
self.load_mapping(self.model.mapping_labels[0])
def unload_sampling(self):
Expand Down Expand Up @@ -181,10 +219,10 @@ def load_feature(self, feature):
return
self.unset_export_status('figure')
self.model.select_feature(feature)
if self.model.has_time_segments():
self.enable_time_view()
self.draw_map(feature)
self.draw_trajectories()
if 1 < self.model.analyzer.time.n_time_segments(self.model.current_sampling):
self.enable_time_view()
self.enable_space_view()
self.enable_side_panel()
finally:
Expand All @@ -200,6 +238,18 @@ def unload_feature(self):
print(self.colorbar_figure.renderers)
self.colorbar_figure.renderers = []
self.feature_dropdown.value = 'None'
def refresh_map(self):
feature = self.model.current_feature
if feature is None or feature == 'None':
return
_curdoc = curdoc()
_curdoc.hold()
try:
self.unset_export_status('figure')
self.draw_map(feature)
self.draw_trajectories()
finally:
_curdoc.unhold()
def make_main_view(self):
"""
Makes the main view `browse_maps` adds as document root.
Expand All @@ -214,6 +264,7 @@ def make_main_view(self):
return column(main_view)
def make_time_view(self):
self.time_slider = Slider(disabled=True, start=0, end=1, step=1)
self.time_slider.on_change('value_throttled', lambda attr, old, new: self.refresh_map())
return self.time_slider
def make_space_view(self):
self.main_figure = f = figure(disabled=True, toolbar_location=None, active_drag=None,
Expand All @@ -223,7 +274,7 @@ def make_space_view(self):
min_border=0, outline_line_color=None, title_location='right', plot_width=112)
f.background_fill_color = f.border_fill_color = None
f.title.align = 'center'
f.visible = False
#f.visible = False
self.overlaying_markers = CheckboxGroup(disabled=True, labels=['Localizations','Trajectories'], active=[])
def _update(attr, old, new):
if 0 in old and 0 not in new:
Expand All @@ -242,16 +293,17 @@ def disable_space_view(self):
self.colorbar_figure.disabled = True
self.overlaying_markers.disabled = True
def enable_space_view(self):
self.colorbar_figure.visible = True
#self.colorbar_figure.visible = True
self.main_figure.disabled = False
self.colorbar_figure.disabled = False
self.overlaying_markers.disabled = False
def disable_time_view(self):
self.time_slider.disabled = True
def enable_time_view(self):
self.time_slider.disabled = False
assert 0<self.time_slider.start
self.time_slider.value = 1
def draw_map(self, feature):
# TODO: support for time segments
kwargs = self.map_kwargs
if kwargs.get('unit', None) == 'std':
kwargs = dict(kwargs)
Expand All @@ -268,16 +320,37 @@ def draw_map(self, feature):
kwargs['unit'] = unit.get(feature, None)
if self.main_figure.renderers:
self.main_figure.renderers = []
_cells = self.model.current_sampling
_map = self.model.current_mapping[feature]
scalar_map_2d(self.model.current_sampling, _map,
if self.model.has_time_segments():
_cells = _cells.tessellation.split_segments(_cells)
_seg = self.time_slider.value
if _seg is None:
import warnings
warnings.warn('could not read time slider value', RuntimeWarning)
_seg = 0
else:
_seg -= 1
_cells = _cells[_seg]
_map = _map[_seg]
kwargs['clim'] = self.model.clim[feature]
scalar_map_2d(_cells, _map,
figure=self.main_figure, colorbar_figure=self.colorbar_figure, **kwargs)
if _map.shape[1] == 2:
field_map_2d(self.model.current_sampling, _map,
figure=self.main_figure, inferencemap=True)
elif _map.shape[1] != 1:
raise NotImplementedError('neither a scalar map nor a 2D-vector map')
def draw_trajectories(self):
traj_handles = plot_trajectories(self.model.current_sampling.points,
sampling = self.model.current_sampling
if self.model.has_time_segments():
try:
seg = self.time_slider.value-1
except TypeError:
pass
else:
sampling = sampling.tessellation.split_segments(sampling)[seg]
traj_handles = plot_trajectories(sampling.points,
figure=self.main_figure, **self.trajectories_kwargs)
self.trajectory_handles = traj_handles[0::2]
self.location_handles = traj_handles[1::2]
Expand Down Expand Up @@ -354,7 +427,21 @@ def export_figure(self, output_file):
"""
export_kwargs = {}
if self.selenium_webdriver is not None:
export_kwargs['webdriver'] = self.selenium_webdriver()
try:
from importlib import import_module
options = import_module(self.selenium_webdriver.__module__[:-9]+'options')
options = options.Options()
options.headless = True
webdriver = self.selenium_webdriver(options=options)
except (ImportError, AttributeError):
import selenium
if self.selenium_webdriver in (selenium.webdriver.Safari, selenium.webdriver.Edge):
pass
else:
import warnings, traceback
warnings.warn('could not access the webdriver''s options:\n'+traceback.format_exc(), ImportWarning)
webdriver = self.selenium_driver()
export_kwargs['webdriver'] = webdriver
if self.figure_export_width is not None:
export_kwargs['width'] = self.figure_export_width
if self.figure_export_height is not None:
Expand Down

0 comments on commit 27b3bd8

Please sign in to comment.