Skip to content

Commit

Permalink
change default pbar behavior (for multiple runs) (#480)
Browse files Browse the repository at this point in the history
* change default pbar behavoir (for multiple runs)

* fix imports

* update the docstring

* disable for st.make
  • Loading branch information
JoranAngevaare committed Jul 8, 2021
1 parent a0240de commit 546ab58
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 20 deletions.
18 changes: 7 additions & 11 deletions strax/context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import collections
import datetime
import logging
import fnmatch
Expand All @@ -11,19 +10,15 @@
import numpy as np
import pandas as pd
import strax
import sys
if any('jupyter' in arg for arg in sys.argv):
# In some cases we are not using any notebooks,
# Taken from 44952863 on stack overflow thanks!
from tqdm.notebook import tqdm
else:
from tqdm import tqdm

export, __all__ = strax.exporter()
__all__ += ['RUN_DEFAULTS_KEY']

RUN_DEFAULTS_KEY = 'strax_defaults'

# use tqdm as loaded in utils (from tqdm.notebook when in a juypyter env)
tqdm = strax.utils.tqdm


@strax.takes_config(
strax.Option(name='storage_converter', default=False,
Expand Down Expand Up @@ -1095,11 +1090,13 @@ def apply_selection(x,

def make(self, run_id: ty.Union[str, tuple, list],
targets, save=tuple(), max_workers=None,
progress_bar=False, _skip_if_built=True,
_skip_if_built=True,
**kwargs) -> None:
"""Compute target for run_id. Returns nothing (None).
{get_docs}
"""
kwargs.setdefault('progress_bar', False)

# Multi-run support
run_ids = strax.to_str_tuple(run_id)
if len(run_ids) == 0:
Expand All @@ -1108,14 +1105,12 @@ def make(self, run_id: ty.Union[str, tuple, list],
return strax.multi_run(
self.get_array, run_ids, targets=targets,
throw_away_result=True,
progress_bar=progress_bar,
save=save, max_workers=max_workers, **kwargs)

if _skip_if_built and self.is_stored(run_id, targets):
return

for _ in self.get_iter(run_ids[0], targets,
progress_bar=progress_bar,
save=save, max_workers=max_workers, **kwargs):
pass

Expand Down Expand Up @@ -1547,6 +1542,7 @@ def add_method(cls, f):
- skip: Do not select a time range, even if other arguments say so
:param _chunk_number: For internal use: return data from one chunk.
:param progress_bar: Display a progress bar if metedata exists.
:param multi_run_progress_bar: Display a progress bar for loading multiple runs
"""

get_docs = """
Expand Down
31 changes: 22 additions & 9 deletions strax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import dill
import numba
import numpy as np
from tqdm import tqdm
import pandas as pd

# Change numba's caching backend from pickle to dill
Expand All @@ -26,6 +25,13 @@
# Numba < 0.49
numba.caching.pickle = dill

if any('jupyter' in arg for arg in sys.argv):
# In some cases we are not using any notebooks,
# Taken from 44952863 on stack overflow thanks!
from tqdm.notebook import tqdm
else:
from tqdm import tqdm


def exporter(export_self=False):
"""Export utility modified from https://stackoverflow.com/a/41895194
Expand Down Expand Up @@ -422,34 +428,41 @@ def dict_to_rec(x, dtype=None):


@export
def multi_run(fun, run_ids, *args, max_workers=None,
def multi_run(exec_function, run_ids, *args,
max_workers=None,
throw_away_result=False,
multi_run_progress_bar=True,
**kwargs):
"""Execute f(run_id, **kwargs) over multiple runs,
"""Execute exec_function(run_id, *args, **kwargs) over multiple runs,
then return list of result arrays, each with a run_id column added.
:param fun: Function to run
:param run_ids: list/tuple of runids
:param exec_function: Function to run
:param run_ids: list/tuple of run_ids
:param max_workers: number of worker threads/processes to spawn.
If set to None, defaults to 1.
If set to None, defaults to 1.
:param throw_away_result: instead of collecting result, return None.
:param multi_run_progress_bar: show a tqdm progressbar for multiple runs.
Other (kw)args will be passed to f
Other (kw)args will be passed to the exec_function.
"""
if max_workers is None:
max_workers = 1

# This will autocast all run ids to Unicode fixed-width
run_id_numpy = np.array(run_ids)

# Generally we don't want a per run pbar because of multi_run_progress_bar
kwargs.setdefault('progress_bar', False)

# Probably we'll want to use dask for this in the future,
# to enable cut history tracking and multiprocessing.
# For some reason the ProcessPoolExecutor doesn't work??
with ThreadPoolExecutor(max_workers=max_workers) as exc:
futures = [exc.submit(fun, r, *args, **kwargs)
futures = [exc.submit(exec_function, r, *args, **kwargs)
for r in run_ids]
for _ in tqdm(as_completed(futures),
desc="Loading %d runs" % len(run_ids)):
desc="Loading %d runs" % len(run_ids),
disable=not multi_run_progress_bar):
pass

result = []
Expand Down

0 comments on commit 546ab58

Please sign in to comment.