Skip to content

Commit

Permalink
Fix: conc yield enforce order
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanHeng committed May 25, 2024
1 parent 2a8211f commit bb293d5
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 38 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup, find_packages

VERSION = '0.41.6'
VERSION = '0.41.7'
DESCRIPTION = 'Machine Learning project startup utilities'
LONG_DESCRIPTION = 'My commonly used utilities for machine learning projects'

Expand All @@ -13,7 +13,7 @@
description=DESCRIPTION,
long_description=LONG_DESCRIPTION,
url='https://github.com/StefanHeng/stef-util',
download_url='https://github.com/StefanHeng/stef-util/archive/refs/tags/v0.41.6.tar.gz',
download_url='https://github.com/StefanHeng/stef-util/archive/refs/tags/v0.41.7.tar.gz',
packages=find_packages(),
include_package_data=True,
install_requires=[
Expand Down
135 changes: 99 additions & 36 deletions stefutil/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import os
import heapq
import concurrent.futures
from typing import List, Tuple, Dict, Iterable, Callable, TypeVar, Union

Expand All @@ -31,8 +32,8 @@ def _check_conc_mode(mode: str):


def conc_map(
fn: MapFn, args: Iterable[T], with_tqdm: Union[bool, Dict] = False, n_worker: int = os.cpu_count(),
mode: str = 'thread'
fn: MapFn, args: Iterable[T], with_tqdm: Union[bool, Dict] = False, n_worker: int = os.cpu_count() - 1,
mode: str = 'process'
) -> Iterable[K]:
"""
Wrapper for `concurrent.futures.map`
Expand Down Expand Up @@ -184,8 +185,8 @@ def __call__(self, args: Iterable[T]) -> List[K]:
def conc_yield(
fn: MapFn, args: Iterable[T], fn_kwarg: str = None, with_tqdm: Union[bool, Dict] = False,
n_worker: int = os.cpu_count()-1, # since the calling script consumes one process
mode: str = 'thread', batch_size: Union[int, bool] = None,
process_chunk_size: int = None, process_chunk_multiplier: int = None
mode: str = 'process', batch_size: Union[int, bool] = None,
process_chunk_size: int = None, process_chunk_multiplier: int = None, enforce_order: bool = False
) -> Iterable[K]:
"""
Wrapper for `concurrent.futures`, yielding results as they become available, irrelevant of order
Expand All @@ -207,6 +208,7 @@ def conc_yield(
:param process_chunk_multiplier: If mode is `process`,
elements are chunked into groups of `n_worker` x `batch_size` x `process_chunk_multiplier`
for a sequence of concurrent pools
:param enforce_order: If true, results are yielded in the order of args passed
:return: Iterator of `lst` elements mapped by `fn` with thread concurrency
"""
_check_conc_mode(mode=mode)
Expand Down Expand Up @@ -234,41 +236,88 @@ def conc_yield(
if not is_thread and chunk:
it_args = group_n(args, chunk)
for args_ in it_args:
futures = set(executor.submit(fn, args_) for args_ in group_n(args_, batch_size))
futures = {executor.submit(fn, args_): i for i, args_ in enumerate(group_n(args_, batch_size))}

pq, index = ([], 0) if enforce_order else (None, None)

for f in concurrent.futures.as_completed(futures):
res = f.result()
futures.remove(f)
if pbar is not None:
pbar.update(len(res))
yield from res

if enforce_order:
heapq.heappush(pq, (futures[f], res))
while pq and pq[0][0] == index:
_, res = heapq.heappop(pq)
yield from res
index += 1

if with_tqdm:
pbar.update(len(res))

else:
if with_tqdm:
pbar.update(len(res))
yield from res

del futures[f]
else:
futures = set(executor.submit(fn, args_) for args_ in group_n(args, batch_size))
futures = {executor.submit(fn, args_): i for i, args_ in enumerate(group_n(args, batch_size))}

pq, index = ([], 0) if enforce_order else (None, None)

for f in concurrent.futures.as_completed(futures):
if is_thread:
res = f.result()
futures.remove(f)
yield from res
res = f.result()

if enforce_order:
heapq.heappush(pq, (futures[f], res))
while pq and pq[0][0] == index:
_, res = heapq.heappop(pq)
yield from res
index += 1

if not is_thread and with_tqdm:
pbar.update(len(res))

else:
res = f.result()
futures.remove(f)
if pbar is not None:
if not is_thread and with_tqdm:
pbar.update(len(res))
yield from res

del futures[f]

else:
if fn_kwarg:
futures = set(executor.submit(fn, **{fn_kwarg: a}) for a in args)
futures = {executor.submit(fn, **{fn_kwarg: a}): i for i, a in enumerate(args)}
else:
futures = set(executor.submit(fn, a) for a in args)
futures = {executor.submit(fn, a): i for i, a in enumerate(args)}

if enforce_order:
pq = [] # priority queue to hold (index, result) pairs
index = 0 # the next index to yield
else:
pq, index = None, None

for f in concurrent.futures.as_completed(futures): # TODO: process chunking
res = f.result()
futures.remove(f)
if with_tqdm:
pbar.update(1)
yield res


DEV = True
# DEV = False
if enforce_order:
heapq.heappush(pq, (futures[f], res)) # prioritize by index
while pq and pq[0][0] == index:
_, res = heapq.heappop(pq)
yield res
index += 1

if with_tqdm:
pbar.update(1)
else:
if with_tqdm:
pbar.update(1)
yield res

del futures[f]


# DEV = True
DEV = False
if DEV:
import time
import random
Expand All @@ -277,26 +326,28 @@ def conc_yield(

logger = get_logger('Conc Dev')

def _work(task_idx):
t = round(random.uniform(1, 4), 3)
def _work(task_idx: int = None):
t = round(random.uniform(0.5, 3), 3)
print(f'Task {s.i(task_idx)} launched, will sleep for {s.i(t)} s')
time.sleep(t)

print(f'Task {s.i(task_idx)} is done')
return task_idx


def _dummy_fn(xx: int):
def _dummy_fn(xx: int = None):
x = xx
t_ms = random.randint(500, 3000)

logger.info(f'Calling dummy_fn w/ arg {s.i(x)}, will sleep for {s.i(t_ms)}ms')
# logger.info(f'Calling dummy_fn w/ arg {s.i(x)}, will sleep for {s.i(t_ms)}ms')
time.sleep(t_ms / 1000)

return x, [random.random() for _ in range(int(1e6))]


if __name__ == '__main__':
from stefutil.prettier import rich_progress, rcl

def try_concurrent_yield():
# this will gather all results and return
# with ThreadPoolExecutor() as executor:
Expand All @@ -311,6 +362,15 @@ def try_concurrent_yield():
sic(res)
# try_concurrent_yield()

def check_conc_map():
n = 10
# for res in conc_map(fn=_work, args=range(n), with_tqdm=dict(total=n), n_worker=4, mode='process'):
# sic(res)
from stefutil.prettier import rcl
for res in rich_progress(conc_map(fn=_work, args=range(n), mode='process'), total=n):
rcl(res)
# check_conc_map()

def check_conc_yield():
batch = False

Expand Down Expand Up @@ -359,20 +419,23 @@ def test_conc_mem_use():

def check_conc_process_chunk():
# n = 200
n = 100
# n = 100
n = 50
# chunk_mult = 4
# bsz, n_worker = 4, 4
bsz = 8
bsz = 4
# bsz = False
# args = dict(n_worker=4, mode='process', batch_size=bsz, process_chunk_size=chunk_mult * bsz * n_worker)
args = dict(n_worker=4, mode='process', batch_size=bsz, process_chunk_multiplier=4)
args = dict(n_worker=4, mode='process', batch_size=bsz, process_chunk_multiplier=None)

# for i in conc_yield(fn=_dummy_fn, fn_kwarg='xx', args=range(n), with_tqdm=dict(total=n), **args):
# i, _ = i
# logger.info(f'Process {s.i(i)} terminated')

from stefutil.prettier import rich_progress
for i in rich_progress(conc_yield(fn=_dummy_fn, fn_kwarg='xx', args=range(n), **args), total=n):
lst = []
for i in rich_progress(conc_yield(fn=_dummy_fn, fn_kwarg='xx', args=range(n), **args, enforce_order=True), total=n):
i, _ = i
logger.info(f'Process {s.i(i)} terminated')
# logger.info(f'Process {s.i(i)} terminated')
lst.append(i)
rcl(lst)
check_conc_process_chunk()

0 comments on commit bb293d5

Please sign in to comment.