## Determine the most efficient parallelization code. 

In [1]:
import xarray as xr 
import numpy as np 
from glob import glob
import os
from os.path import join
import multiprocessing as mp
import itertools
import joblib
from tqdm import tqdm  
from wofs_ml_severe.wofs_ml_severe.common.multiprocessing_utils import run_parallel, to_iterator

In [2]:
base_path = '/work/mflora/SummaryFiles/20210504/2200'
file_paths = glob(join(base_path, 'wofs_ENV_*'))
out_path = '/work/mflora/testdata/'
N_JOBS = 12

In [3]:
def worker(path):
    # Load the data into memory 
    ds = xr.load_dataset(path, decode_times=False)
    data = ds['mslp'].values
    base_name = os.path.basename(path)
    data*=1000
    
    data = {'mslp' : (('NE', 'NY', 'NX'), data)}
    
    new_ds = xr.Dataset(data)
    out_name = join(out_path, base_name)
    new_ds.to_netcdf(out_name)
    
    return out_name

In [5]:
%%time
run_parallel(worker, to_iterator(file_paths), N_JOBS)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 73/73 [00:13<00:00,  5.32it/s]

CPU times: user 74.4 ms, sys: 105 ms, total: 179 ms
Wall time: 13.7 s





In [11]:
import joblib
from tqdm.auto import tqdm

class ProgressParallel(joblib.Parallel):
    def __call__(self, *args, **kwargs):
        with tqdm() as self._pbar:
            return joblib.Parallel.__call__(self, *args, **kwargs)

    def print_progress(self):
        self._pbar.total = self.n_dispatched_tasks
        self._pbar.n = self.n_completed_tasks
        self._pbar.refresh()
        
        
class MPProgressParallel(mp.Pool):
    def __init__(self, n_jobs):
        self.pool = mp.Pool(n_jobs)
    
    def __call__(self, *args, **kwargs):
        with tqdm() as self._pbar:
            return self.pool.__call__(self, *args, **kwargs)

    def print_progress(self):
        self._pbar.total = self.n_dispatched_tasks
        self._pbar.n = self.n_completed_tasks
        self._pbar.refresh()        

In [12]:
%%time
# arallel processing with joblib
backend = 'loky'
iterator = file_paths
results = ProgressParallel(n_jobs = N_JOBS,
                backend=backend,
                verbose=0)(joblib.delayed(worker)(args,) for args in iterator)

0it [00:00, ?it/s]



CPU times: user 260 ms, sys: 209 ms, total: 469 ms
Wall time: 17 s


In [5]:
%%time
# Parallel processing with joblib
backend = 'multiprocessing'
iterator = file_paths
results = joblib.Parallel(n_jobs = N_JOBS,
                backend=backend,
                verbose=0)(joblib.delayed(worker)(args,) for args in iterator)

CPU times: user 106 ms, sys: 89.2 ms, total: 196 ms
Wall time: 13.3 s


In [14]:
%%time

pbar = tqdm(total=len(file_paths))
def update(*a):
    pbar.update()

# Parallel processing with multiprocessing
iterator = file_paths #itertools.zip_longest(*file_paths)
pool = mp.Pool(processes=N_JOBS)
for args in iterator:
    pool.apply_async(worker, args=(args,), callback=update)
pool.close()
pool.join()

  0%|          | 0/73 [00:00<?, ?it/s]

CPU times: user 120 ms, sys: 88.4 ms, total: 208 ms
Wall time: 13.2 s
