Skip to content

Commit

Permalink
Fix: fn_kwarg for conc yield
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanHeng committed May 25, 2024
1 parent 26c8fda commit 2a8211f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup, find_packages

VERSION = '0.41.5'
VERSION = '0.41.6'
DESCRIPTION = 'Machine Learning project startup utilities'
LONG_DESCRIPTION = 'My commonly used utilities for machine learning projects'

Expand All @@ -13,7 +13,7 @@
description=DESCRIPTION,
long_description=LONG_DESCRIPTION,
url='https://github.com/StefanHeng/stef-util',
download_url='https://github.com/StefanHeng/stef-util/archive/refs/tags/v0.41.5.tar.gz',
download_url='https://github.com/StefanHeng/stef-util/archive/refs/tags/v0.41.6.tar.gz',
packages=find_packages(),
include_package_data=True,
install_requires=[
Expand Down
33 changes: 23 additions & 10 deletions stefutil/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __call__(self, args: Iterable[T]) -> List[K]:


def conc_yield(
fn: MapFn, args: Iterable[T], fn_keyword: str = None, with_tqdm: Union[bool, Dict] = False,
fn: MapFn, args: Iterable[T], fn_kwarg: str = None, with_tqdm: Union[bool, Dict] = False,
n_worker: int = os.cpu_count()-1, # since the calling script consumes one process
mode: str = 'thread', batch_size: Union[int, bool] = None,
process_chunk_size: int = None, process_chunk_multiplier: int = None
Expand All @@ -193,7 +193,7 @@ def conc_yield(
:param fn: A function
:param args: A list of elements as input to the function
:param fn_keyword: If given, the keyword argument to pass to `fn`
:param fn_kwarg: If given, the keyword argument to pass to `fn`
:param with_tqdm: If true, progress bar is shown
If dict, treated as `tqdm` concurrent kwargs
note `chunksize` is helpful
Expand Down Expand Up @@ -223,7 +223,7 @@ def conc_yield(
if batch_size:
batch_size = 32 if isinstance(batch_size, bool) else batch_size
# pbar doesn't work w/ pickle hence multiprocessing
fn = BatchedFn(fn=fn, fn_keyword=fn_keyword, pbar=pbar if is_thread else None)
fn = BatchedFn(fn=fn, fn_keyword=fn_kwarg, pbar=pbar if is_thread else None)
if process_chunk_size is not None:
chunk = process_chunk_size
elif process_chunk_multiplier is not None:
Expand Down Expand Up @@ -255,7 +255,10 @@ def conc_yield(
pbar.update(len(res))
yield from res
else:
futures = set(executor.submit(fn, a) for a in args)
if fn_kwarg:
futures = set(executor.submit(fn, **{fn_kwarg: a}) for a in args)
else:
futures = set(executor.submit(fn, a) for a in args)
for f in concurrent.futures.as_completed(futures): # TODO: process chunking
res = f.result()
futures.remove(f)
Expand Down Expand Up @@ -285,9 +288,10 @@ def _work(task_idx):

def _dummy_fn(xx: int):
x = xx
t = round(random.uniform(0.2, 1), 3)
logger.info(f'Calling dummy_fn w/ arg {s.i(x)}, will sleep for {s.i(t)}s')
time.sleep(t)
t_ms = random.randint(500, 3000)

logger.info(f'Calling dummy_fn w/ arg {s.i(x)}, will sleep for {s.i(t_ms)}ms')
time.sleep(t_ms / 1000)

return x, [random.random() for _ in range(int(1e6))]

Expand Down Expand Up @@ -354,12 +358,21 @@ def test_conc_mem_use():
# test_conc_mem_use()

def check_conc_process_chunk():
n = 200
# n = 200
n = 100
# chunk_mult = 4
bsz, n_worker = 4, 4
# bsz, n_worker = 4, 4
bsz = 8
# bsz = False
# args = dict(n_worker=4, mode='process', batch_size=bsz, process_chunk_size=chunk_mult * bsz * n_worker)
args = dict(n_worker=4, mode='process', batch_size=bsz, process_chunk_multiplier=4)
for i in conc_yield(fn=_dummy_fn, fn_keyword='xx', args=range(n), with_tqdm=dict(total=n), **args):

# for i in conc_yield(fn=_dummy_fn, fn_kwarg='xx', args=range(n), with_tqdm=dict(total=n), **args):
# i, _ = i
# logger.info(f'Process {s.i(i)} terminated')

from stefutil.prettier import rich_progress
for i in rich_progress(conc_yield(fn=_dummy_fn, fn_kwarg='xx', args=range(n), **args), total=n):
i, _ = i
logger.info(f'Process {s.i(i)} terminated')
check_conc_process_chunk()

0 comments on commit 2a8211f

Please sign in to comment.