Skip to content

Commit

Permalink
Maestro environment + lazyness fix for data loading in spt_data
Browse files Browse the repository at this point in the history
  • Loading branch information
francoislaurent committed Jan 6, 2021
1 parent 04e0a88 commit 732187d
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 29 deletions.
43 changes: 40 additions & 3 deletions tramway/analyzer/env/environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1432,7 +1432,7 @@ class Tars(SlurmOverSSH):
Designed for server *tars.pasteur.fr*.
By default, makes singularity container *tramway-hpc-200928.sif* run on the remote host.
See also `available_images.rst <https://github.com/DecBayComp/TRamWAy/blob/slurmoverssh/containers/available_images.rst>`_.
See also `available_images.rst <https://github.com/DecBayComp/TRamWAy/blob/master/containers/available_images.rst>`_.
"""
def __init__(self, **kwargs):
SlurmOverSSH.__init__(self, **kwargs)
Expand Down Expand Up @@ -1469,7 +1469,7 @@ class GPULab(SlurmOverSSH):
Designed for server *adm.inception.hubbioit.pasteur.fr*.
By default, makes singularity container *tramway-hpc-200928.sif* run on the remote host.
See also `available_images.rst <https://github.com/DecBayComp/TRamWAy/blob/slurmoverssh/containers/available_images.rst>`_.
See also `available_images.rst <https://github.com/DecBayComp/TRamWAy/blob/master/containers/available_images.rst>`_.
"""
def __init__(self, **kwargs):
SlurmOverSSH.__init__(self, **kwargs)
Expand Down Expand Up @@ -1500,5 +1500,42 @@ def setup(self, *argv):
self.ssh.download_if_missing(self.container, self.container_url, self.logger)


__all__ = ['Environment', 'LocalHost', 'SlurmOverSSH', 'Tars', 'GPULab']
class Maestro(SlurmOverSSH):
"""
Designed for server *maestro.pasteur.fr*.
By default, makes singularity container *tramway-hpc-200928.sif* run on the remote host.
See also `available_images.rst <https://github.com/DecBayComp/TRamWAy/blob/master/containers/available_images.rst>`_.
"""
def __init__(self, **kwargs):
SlurmOverSSH.__init__(self, **kwargs)
self.interpreter = 'singularity exec -H $HOME -B /pasteur tramway-hpc-200928.sif python3.6 -s'
self.remote_dependencies = 'module load singularity'
@property
def username(self):
return None if self.ssh.host is None else self.ssh.host.split('@')[0]
@username.setter
def username(self, name):
self.ssh.host = None if name is None else name+'@maestro.pasteur.fr'
if self.wd is None:
self.wd = '/pasteur/sonic/scratch/users/'+name
@property
def container(self):
parts = self.interpreter.split()
return parts[parts.index('python3.6')-1]
@container.setter
def container(self, path):
parts = self.interpreter.split()
p = parts.index('python3.6')
self.interpreter = ' '.join(parts[:p-1]+[path]+parts[p:])
@property
def container_url(self):
return 'http://dl.pasteur.fr/fop/VsJYgkxP/tramway-hpc-200928.sif'
def setup(self, *argv):
SlurmOverSSH.setup(self, *argv)
if self.submit_side:
self.ssh.download_if_missing(self.container, self.container_url, self.logger)


__all__ = ['Environment', 'LocalHost', 'SlurmOverSSH', 'Tars', 'GPULab', 'Maestro']

41 changes: 16 additions & 25 deletions tramway/analyzer/roi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def set_contiguous_time_support_by_count(points, space_bounds, time_window, min_
units = roi.regions.region_to_units(r)

region_weight = sum([ len(u) for u in units.values() ]) # TODO: relative surface area instead
if not group_overlapping_roi:
assert region_weight == 1
threshold = start_stop_min_points * region_weight

times = roi.regions.crop(r, points)['t'].values
Expand Down Expand Up @@ -170,33 +172,22 @@ def set_contiguous_time_support_by_count(points, space_bounds, time_window, min_
start_times, end_times = [], []
for first_t, last_t in zip(first_ts, last_ts):

start_time = times[first_t]-duration
if start_time<times[0]:
start_time = None
min_start_time = times[first_t] - duration
max_start_time = max(times[0], min_start_time)

if last_t+1==times.size:
end_time = None
min_end_time = times[last_t]
if last_t+1 == times.size:
max_end_time = times[-max(1, threshold)] + duration
else:
end_time = times[last_t]

nsegments = None
if start_time is None:
if end_time is None:
start_time = times[0]
end_time = times[-1]
else:
nsegments = np.floor((end_time - times[first_t]) / shift) + 1
start_time = end_time - duration - (nsegments - 1) * shift
elif end_time is None:
end_time = times[-1]
else:
nsegments = np.floor((end_time - start_time - duration) / shift) + 1
total_duration = duration + (nsegments - 1) * shift
time_margin = .5 * (end_time - start_time - total_duration)
start_time += time_margin
end_time -= time_margin
if nsegments is None:
nsegments = np.floor((end_time - start_time - duration) / shift) + 1
max_end_time = min_end_time

max_total_duration = max_end_time - min_start_time
nsegments = np.floor((max_total_duration - duration) / shift) + 1
total_duration = duration + (nsegments - 1) * shift
time_margin = .5 * (max_total_duration - total_duration)
start_time = min(min_start_time + time_margin, max_start_time)
end_time = max_end_time# - time_margin # no need to discard the trailing data

if nsegments<min_segments:
continue

Expand Down
5 changes: 4 additions & 1 deletion tramway/analyzer/spt_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,9 @@ def reload_from_rwa_files(self, skip_missing=False):
The rwa file that corresponds to an SPT file should be available at the
same path with the *.rwa* extension instead of the SPT file's extension.
This method is known to fail with a :class:`TypeError` exception in cases
where not any matching *.rwa* file can be found.
.. note::
As this operation modifies the SPT data `source` and `filepath` attributes,
Expand Down Expand Up @@ -476,7 +479,7 @@ def __init__(self, df=None, **kwargs):
self._frame_interval_cache = self._localization_error_cache = None
HasROI.__init__(self, **kwargs)
self.analyses = Analyses(df, standard_metadata(), autosave=True)
self.analyses.hooks.append(lambda _: self.commit_cache(autoload=True))
self._analyses.hooks.append(lambda _: self.commit_cache(autoload=True))
@property
def _frame_interval(self):
return self._frame_interval_cache
Expand Down

0 comments on commit 732187d

Please sign in to comment.