Skip to content

Commit

Permalink
bugfix in indexing in as_support_regions + SlurmOverSSH collection ru…
Browse files Browse the repository at this point in the history
…ns with srun
  • Loading branch information
francoislaurent committed Jan 7, 2021
1 parent 732187d commit 88a2913
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 71 deletions.
139 changes: 91 additions & 48 deletions tramway/analyzer/attribute/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,69 +120,112 @@ def null_index(i):


__all__.append('indexer')
def indexer(i, it, return_index=False):
def indexer(i, it, return_index=False, has_keys=False):
if i is None:
if return_index:
yield from enumerate(it)
if has_keys:
for k in it:
yield k, it[k]
else:
yield from enumerate(it)
else:
yield from it
if has_keys:
for k in it:
yield it[k]
else:
yield from it
elif callable(i):
for j, item in enumerate(it):
if i(j):
if return_index:
yield j, item
else:
yield item
if has_keys:
if return_index:
for j in it:
if i(j):
yield j, item
else:
for j in it:
if i(j):
yield item
else:
if return_index:
for j, item in enumerate(it):
if i(j):
yield j, item
else:
for j, item in enumerate(it):
if i(j):
yield item
elif isinstance(i, (Sequence, np.ndarray)):
# no need for `__getitem__`, but iteration follows the ordering in `i`
postponed = dict()
j, it = -1, iter(it)
for k in i:
try:
while j<k:
if j in i:
postponed[j] = item
j, item = j+1, next(it)
except StopIteration:
raise IndexError('index is out of bounds: {}'.format(k))
if j != k:
try:
item = postponed.pop(k)
except KeyError:
if k < 0:
raise IndexError('negative values are not supported in a sequence of indices')
else:
raise IndexError('duplicate index: {}'.format(k))
if has_keys:
if return_index:
yield k, item
for k in i:
yield k, it[k]
else:
yield item
elif isinstance(i, Set):
i = set(i) # copy and make mutable
for j, item in enumerate(it):
if j in i:
for k in i:
yield it[k]
else:
# no need for `__getitem__`, but iteration follows the ordering in `i`
postponed = dict()
j, it = -1, iter(it)
for k in i:
try:
while j<k:
if j in i:
postponed[j] = item
j, item = j+1, next(it)
except StopIteration:
raise IndexError('index is out of bounds: {}'.format(k))
if j != k:
try:
item = postponed.pop(k)
except KeyError:
if k < 0:
raise IndexError('negative values are not supported in a sequence of indices')
else:
raise IndexError('duplicate index: {}'.format(k))
if return_index:
yield j, item
yield k, item
else:
yield item
i.remove(j)
elif isinstance(i, Set):
i = set(i) # copy and make mutable
if has_keys:
for j in it:
if j in i:
if return_index:
yield j, it[j]
else:
yield it[j]
i.remove(j)
else:
for j, item in enumerate(it):
if j in i:
if return_index:
yield j, item
else:
yield item
i.remove(j)
if i:
raise IndexError(('some indices are out of bounds: '+', '.join(['{}'])).format(*tuple(i)))
elif np.isscalar(i):
it = iter(it)
if i == -1:
if has_keys:
try:
while True:
item = next(it)
except StopIteration:
pass
item = it[i]
except KeyError:
raise IndexError('index is out of bounds: {}'.format(i)) from None
else:
j = -1
try:
while j<i:
j, item = j+1, next(it)
except StopIteration:
raise IndexError('index is out of bounds: {}'.format(i))
it = iter(it)
if i == -1:
try:
while True:
item = next(it)
except StopIteration:
pass
else:
j = -1
try:
while j<i:
j, item = j+1, next(it)
except StopIteration:
raise IndexError('index is out of bounds: {}'.format(i))
if return_index:
yield i, item
else:
Expand Down
26 changes: 25 additions & 1 deletion tramway/analyzer/env/environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,8 @@ def _combine_analyses(cls, wd, data_location, logger, *args, directory_mapping={
end_result_files = []
for source in analyses:
logger.info('for source file: {}...'.format(source))
# TODO: if isinstance(spt_data, (StandaloneRWAFile, RWAFiles))
# pass the input rwa file paths to _combine_analyses
rwa_file = os.path.splitext(os.path.normpath(source))[0]+'.rwa'
#logger.info((rwa_file, os.path.isabs(rwa_file), directory_mapping))
if os.path.isabs(rwa_file):
Expand Down Expand Up @@ -1017,6 +1019,9 @@ def remote_dependencies(self):
def remote_dependencies(self, deps):
self._remote_dependencies = deps
@property
def collection_interpreter(self):
return self.interpreter
@property
def wd_is_available(self):
return self.worker_side
def make_working_directory(self):
Expand Down Expand Up @@ -1133,7 +1138,7 @@ def collect_results(self, _log_pattern, stage_index=None, _parent_cls='Env'):
self.logger.debug(attrs)
cmd = '{}{} {}; rm {}'.format(
'' if self.remote_dependencies is None else self.remote_dependencies+'; ',
self.interpreter, remote_script, remote_script)
self.collection_interpreter, remote_script, remote_script)
out, err = self.ssh.exec(cmd, shell=True, logger=self.logger)
if err:
self.logger.error(err.rstrip())
Expand Down Expand Up @@ -1377,6 +1382,25 @@ def wait_for_job_completion(self):
self.logger.info('killing jobs with: scancel '+self.job_id)
self.ssh.exec('scancel '+self.job_id, shell=True)
raise
@property
def srun_options(self):
return ('p', 'partition', 'q', 'qos')
@property
def collection_interpreter(self):
cmd = ['srun']
for option in self.sbatch_options:
if option not in self.srun_options:
continue
value = self.sbatch_options[option]
if option[1:]:
fmt = '--{}={}'
else:
fmt = '-{} {}'
if isinstance(value, str) and ' ' in value:
value = '"{}"'.format(value)
cmd.append(fmt.format(option, value))
cmd.append(self.interpreter)
return ' '.join(cmd)
def collect_results(self, stage_index=None):
RemoteHost.collect_results(self, '*.out', stage_index)
collect_results.__doc__ = RemoteHost.collect_results.__doc__
Expand Down
21 changes: 13 additions & 8 deletions tramway/analyzer/roi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ def crop(self, df=None):
_min,_max = self._bounding_box
if df is None:
df = self._spt_data.dataframe
n_space_cols = len([ col for col in 'xyz' if col in df.columns ])
if n_space_cols < _min.size:
assert _min.size == n_space_cols + 1
df = df[(_min[-1] <= df['t']) & (df['t'] <= _max[-1])]
_min, _max = _min[:-1], _max[:-1]
df = crop(df, np.r_[_min, _max-_min])
return df
@property
Expand Down Expand Up @@ -487,28 +492,25 @@ def self_update(self, op):
def as_support_regions(self, index=None, source=None, return_index=False):
if return_index:
def bear_child(cls, r, *args):
i, r = r
return i, self._bear_child(cls, r, *args)
kwargs = dict(return_index=return_index)
return r, self._bear_child(cls, r, *args)
else:
bear_child = self._bear_child
kwargs = {}
try:
spt_data = self._parent.spt_data
except AttributeError:
# decentralized roi (single source)
if source is not None:
warnings.warn('ignoring argument `source`', helper.IgnoredInputWarning)
spt_data = self._parent
for r in indexer(index, self._collections.regions, **kwargs):
for r, _ in indexer(index, self._collections.regions, has_keys=True, return_index=True):
yield bear_child( SupportRegion, r, self._collections.regions, spt_data )
else:
# roi manager (one set of regions, multiple sources)
if isinstance(spt_data, Initializer):
raise RuntimeError('cannot iterate not-initialized SPT data')
if source is None:
for d in spt_data:
for r in indexer(index, self._collections.regions, **kwargs):
for r, _ in indexer(index, self._collections.regions, has_keys=True, return_index=True):
yield bear_child( SupportRegion, r, self._collections.regions, d )
else:
if callable(source):
Expand All @@ -517,7 +519,7 @@ def bear_child(cls, r, *args):
sfilter = lambda s: s==source
for d in spt_data:
if sfilter(d.source):
for r in indexer(index, self._collections.regions, **kwargs):
for r, _ in indexer(index, self._collections.regions, has_keys=True, return_index=True):
yield bear_child( SupportRegion, r, self._collections.regions, d )
as_support_regions.__doc__ = ROI.as_support_regions.__doc__
def __iter__(self):
Expand All @@ -538,7 +540,10 @@ def get_support_region(self, index, collection=None):
exception.
"""
return single(self.as_support_regions(index=self._collections.regions.unit_to_region(index, collection)))
sr_index = self._collections.regions.unit_to_region(index, collection)
sr = single(self.as_support_regions(index=sr_index))
assert sr._sr_index == sr_index
return sr


class BoundingBoxes(SpecializedROI):
Expand Down
50 changes: 37 additions & 13 deletions tramway/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,15 +671,16 @@ def group(self, ngroups=None, max_cell_count=None, cell_centers=None, \

def run(self, function, *args, **kwargs):
"""
Apply a function to the groups (:class:`FiniteElements`) of terminal cells.
Apply a function that takes a group (:class:`FiniteElements`) of terminal cells
as input argument, plus args and kwargs, and must return a `pandas.DataFrame` array
with the indices referring to cells.
The results are merged into a single :class:`~pandas.DataFrame` array,
handling adjacency margins if any.
`function` is called for each group of terminal cells, adjacency margins
are removed if any, and the resulting `DataFrame` are merged into a single
`DataFrame`.
Although this method was designed for `FiniteElements` of `FiniteElements`, its usage is
advised to call any function that returns a DataFrame with cell indices as indices.
Multiples processes may be spawned.
`function` may instead be applied to `self` in the case `self.cells` has been
overloaded to exhibit the output features as attributes.
Arguments:
Expand All @@ -693,39 +694,62 @@ def run(self, function, *args, **kwargs):
positional arguments for `function` after the first one.
kwargs (dict):
keyword arguments for `function` from which are removed the ones below.
keyword arguments for `function`;
the following arguments are popped out of `kwargs`:
returns (list):
attributes to be collected from all the individual cells as return values;
if defined, the values returned by `function` are ignored.
worker_count (int):
number of simultaneously working processing units.
number of simultaneously working processing units,
if the `self.cells` .
profile (bool or str or tuple):
profile each child job if any;
if `str`, dump the output stats into *.prof* files;
if `tuple`, print a report with :func:`~pstats.Stats.print_stats` and
tuple elements as input arguments.
function_worker_count (int):
if `worker_count` is a reserved keyword argument for `function`,
`function_worker_count` can be specified and is passed to `function`
with name/keyword `worker_count`;
in this case, `worker_count` must be a multiple of `function_worker_count`.
Returns:
pandas.DataFrame:
single merged array of maps.
If `function` returns two output arguments, :meth:`run` also
returns a second merged array of posteriors.
returns a second merged array of posterior probabilities(?).
"""
# clear the caches
self.clear_caches()

returns = kwargs.pop('returns', None)

if all(isinstance(cell, FiniteElements) for cell in self.cells.values()):
# parallel for-loop over the subsets of cells
# if `worker_count` is `None`, `Pool` will use `multiprocessing.cpu_count()`
parallel = all(isinstance(cell, FiniteElements) for cell in self.cells.values())
if parallel:
worker_count = kwargs.pop('worker_count', None)
profile = kwargs.pop('profile', False)

for arg in ('returns', 'worker_count', 'profile'):
try:
val = kwargs.pop('function_'+arg)
except KeyError:
pass
else:
kwargs[arg] = val
if arg == 'worker_count' and \
isinstance(worker_count, int) and \
isinstance(val, int) and 0<val:
worker_count /= val

if parallel:
# parallel for-loop over the subsets of cells
# if `worker_count` is `None`, `Pool` will use `multiprocessing.cpu_count()`
pool = Pool(worker_count)
fargs = (function, args, kwargs)
if profile:
Expand Down
2 changes: 1 addition & 1 deletion tramway/tessellation/gwr/gas.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def batch_train(self, sample, eta_square=None, radius=None, grab=None, max_frame
dist_min = sqrt(dist2_min)
except ValueError:
precision = int(np.dtype(dist2_min.dtype).str[-1])
if (precision==8 and -1e-12 < dist2_min) or (precision==4 and -1e-3 < dist_min):
if (precision==8 and -1e-10 < dist2_min) or (precision==4 and -1e-3 < dist_min):
dist_min = 0
import warnings
warnings.warn('Rounding error: negative distance', RuntimeWarning)
Expand Down

0 comments on commit 88a2913

Please sign in to comment.