Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change default pbar behavior (for multiple runs) #480

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 7 additions & 11 deletions strax/context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import collections
import datetime
import logging
import fnmatch
Expand All @@ -11,19 +10,15 @@
import numpy as np
import pandas as pd
import strax
import sys
if any('jupyter' in arg for arg in sys.argv):
# In some cases we are not using any notebooks,
# Taken from 44952863 on stack overflow thanks!
from tqdm.notebook import tqdm
else:
from tqdm import tqdm

export, __all__ = strax.exporter()
__all__ += ['RUN_DEFAULTS_KEY']

RUN_DEFAULTS_KEY = 'strax_defaults'

# use tqdm as loaded in utils (from tqdm.notebook when in a juypyter env)
tqdm = strax.utils.tqdm


@strax.takes_config(
strax.Option(name='storage_converter', default=False,
Expand Down Expand Up @@ -1095,11 +1090,13 @@ def apply_selection(x,

def make(self, run_id: ty.Union[str, tuple, list],
targets, save=tuple(), max_workers=None,
progress_bar=False, _skip_if_built=True,
_skip_if_built=True,
**kwargs) -> None:
"""Compute target for run_id. Returns nothing (None).
{get_docs}
"""
kwargs.setdefault('progress_bar', False)

# Multi-run support
run_ids = strax.to_str_tuple(run_id)
if len(run_ids) == 0:
Expand All @@ -1108,14 +1105,12 @@ def make(self, run_id: ty.Union[str, tuple, list],
return strax.multi_run(
self.get_array, run_ids, targets=targets,
throw_away_result=True,
progress_bar=progress_bar,
save=save, max_workers=max_workers, **kwargs)

if _skip_if_built and self.is_stored(run_id, targets):
return

for _ in self.get_iter(run_ids[0], targets,
progress_bar=progress_bar,
save=save, max_workers=max_workers, **kwargs):
pass

Expand Down Expand Up @@ -1547,6 +1542,7 @@ def add_method(cls, f):
- skip: Do not select a time range, even if other arguments say so
:param _chunk_number: For internal use: return data from one chunk.
:param progress_bar: Display a progress bar if metedata exists.
:param multi_run_progress_bar: Display a progress bar for loading multiple runs
"""

get_docs = """
Expand Down
31 changes: 22 additions & 9 deletions strax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import dill
import numba
import numpy as np
from tqdm import tqdm
import pandas as pd

# Change numba's caching backend from pickle to dill
Expand All @@ -26,6 +25,13 @@
# Numba < 0.49
numba.caching.pickle = dill

if any('jupyter' in arg for arg in sys.argv):
# In some cases we are not using any notebooks,
# Taken from 44952863 on stack overflow thanks!
from tqdm.notebook import tqdm
else:
from tqdm import tqdm


def exporter(export_self=False):
"""Export utility modified from https://stackoverflow.com/a/41895194
Expand Down Expand Up @@ -422,34 +428,41 @@ def dict_to_rec(x, dtype=None):


@export
def multi_run(fun, run_ids, *args, max_workers=None,
def multi_run(exec_function, run_ids, *args,
max_workers=None,
throw_away_result=False,
multi_run_progress_bar=True,
**kwargs):
"""Execute f(run_id, **kwargs) over multiple runs,
"""Execute exec_function(run_id, *args, **kwargs) over multiple runs,
then return list of result arrays, each with a run_id column added.

:param fun: Function to run
:param run_ids: list/tuple of runids
:param exec_function: Function to run
:param run_ids: list/tuple of run_ids
:param max_workers: number of worker threads/processes to spawn.
If set to None, defaults to 1.
If set to None, defaults to 1.
:param throw_away_result: instead of collecting result, return None.
:param multi_run_progress_bar: show a tqdm progressbar for multiple runs.

Other (kw)args will be passed to f
Other (kw)args will be passed to the exec_function.
"""
if max_workers is None:
max_workers = 1

# This will autocast all run ids to Unicode fixed-width
run_id_numpy = np.array(run_ids)

# Generally we don't want a per run pbar because of multi_run_progress_bar
kwargs.setdefault('progress_bar', False)

# Probably we'll want to use dask for this in the future,
# to enable cut history tracking and multiprocessing.
# For some reason the ProcessPoolExecutor doesn't work??
with ThreadPoolExecutor(max_workers=max_workers) as exc:
futures = [exc.submit(fun, r, *args, **kwargs)
futures = [exc.submit(exec_function, r, *args, **kwargs)
for r in run_ids]
for _ in tqdm(as_completed(futures),
desc="Loading %d runs" % len(run_ids)):
desc="Loading %d runs" % len(run_ids),
disable=not multi_run_progress_bar):
pass

result = []
Expand Down