Skip to content

Commit

Permalink
Tidying up code
Browse files Browse the repository at this point in the history
  • Loading branch information
astrojoni89 committed Dec 18, 2023
1 parent 04bfda6 commit 310d5b7
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions astrosaber/parallel_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
import multiprocessing
from typing import List, Tuple, Callable, Iterable
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm

from .training import saberTraining
from .prepare_training import saberPrepare
from .hisa import HisaExtraction
#from .utils.aslsq_fit import baseline_als_optimized
#from .utils.quality_checks import goodness_of_fit, get_max_consecutive_channels, determine_peaks, mask_channels
from tqdm import tqdm



Expand Down Expand Up @@ -76,7 +74,7 @@ def parallel_process(array : np.ndarray, function : Callable[[int], Tuple], n_jo
"""
# We run the first few iterations serially to catch bugs
if front_num > 0:
front = [function(**a) if use_kwargs else function(a) for a in array[:front_num]] #, lam1_updt=lam1_updt, p1_updt=p1_updt, lam2_updt=lam2_updt, p2_updt=p2_updt
front = [function(**a) if use_kwargs else function(a) for a in array[:front_num]]
# If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
if n_jobs == 1:
return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])]
Expand All @@ -86,7 +84,7 @@ def parallel_process(array : np.ndarray, function : Callable[[int], Tuple], n_jo
if use_kwargs:
futures = [pool.submit(function, **a) for a in array[front_num:]]
else:
futures = [pool.submit(function, a) for a in array[front_num:]] # , lam1_updt=lam1_updt, p1_updt=p1_updt, lam2_updt=lam2_updt, p2_updt=p2_updt
futures = [pool.submit(function, a) for a in array[front_num:]]
kwargs = {
'total': len(futures),
'unit': 'spec',
Expand All @@ -98,7 +96,7 @@ def parallel_process(array : np.ndarray, function : Callable[[int], Tuple], n_jo
pass
out = []
# Get the results from the futures.
for i, future in enumerate(futures): #tqdm(enumerate(futures)):
for i, future in enumerate(futures):
try:
out.append(future.result())
except Exception as e:
Expand Down Expand Up @@ -128,7 +126,7 @@ def parallel_process_wo_bar(array : np.ndarray, function : Callable[[int], Tuple
"""
# We run the first few iterations serially to catch bugs
if front_num > 0:
front = [function(**a) if use_kwargs else function(a) for a in array[:front_num]] #, lam1_updt=lam1_updt, p1_updt=p1_updt, lam2_updt=lam2_updt, p2_updt=p2_updt
front = [function(**a) if use_kwargs else function(a) for a in array[:front_num]]
# If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
if n_jobs == 1:
return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])]
Expand All @@ -138,7 +136,7 @@ def parallel_process_wo_bar(array : np.ndarray, function : Callable[[int], Tuple
if use_kwargs:
futures = [pool.submit(function, **a) for a in array[front_num:]]
else:
futures = [pool.submit(function, a) for a in array[front_num:]] # , lam1_updt=lam1_updt, p1_updt=p1_updt, lam2_updt=lam2_updt, p2_updt=p2_updt
futures = [pool.submit(function, a) for a in array[front_num:]]
kwargs = {
'total': len(futures),
'unit': 'it',
Expand All @@ -150,7 +148,7 @@ def parallel_process_wo_bar(array : np.ndarray, function : Callable[[int], Tuple
# pass
out = []
# Get the results from the futures.
for i, future in enumerate(futures): #tqdm(enumerate(futures)):
for i, future in enumerate(futures):
try:
out.append(future.result())
except Exception as e:
Expand Down

0 comments on commit 310d5b7

Please sign in to comment.