From 7f6bc50f39cbdd98fa786acd989ca7447dfb8638 Mon Sep 17 00:00:00 2001 From: Charles Smith Date: Wed, 11 Oct 2023 00:08:56 -0400 Subject: [PATCH] better api --- README.md | 18 ++++++++++++++++-- src/tentacule/i_process_pool.py | 8 ++++++-- src/tentacule/process_pool.py | 8 ++++++-- tests/test_simple_functions.py | 25 +++++++------------------ 4 files changed, 35 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 61fae27..732283a 100644 --- a/README.md +++ b/README.md @@ -18,12 +18,26 @@ pool.start() def task(arg: int): return 2 + arg -task_id = pool.new_task(task, 3) -result = pool.get_result(task_id) # 5 +task_id = pool.submit(task, 3) +result = pool.fetch(task_id) # 5 pool.close() ``` +You can also do this in a single step: +```python +pool = ProcessPool(workers=3) +pool.start() + +def task(arg: int): + return 2 + arg + +result = pool.submit_and_fetch(task, 5) # 7 + +pool.close() +``` + + ## Restrictions - Anything [dill](https://github.com/uqfoundation/dill) can't pickle, can't be used as a task. - Arguments are pickled using the regular pickling for processes. diff --git a/src/tentacule/i_process_pool.py b/src/tentacule/i_process_pool.py index e32c578..7c8ac88 100644 --- a/src/tentacule/i_process_pool.py +++ b/src/tentacule/i_process_pool.py @@ -20,9 +20,13 @@ def close(self, force: bool): pass @abstractmethod - def new_task(self, task: Callable, *args, **kwargs) -> str: + def submit(self, task: Callable, *args, **kwargs) -> str: pass @abstractmethod - def get_result(self, task_id: str, timeout: int = 30) -> Any: + def fetch(self, task_id: str, timeout: int = 30) -> Any: + pass + + @abstractmethod + def submit_and_fetch(self, task: Callable, *args, timeout: int = 30, **kwargs): pass diff --git a/src/tentacule/process_pool.py b/src/tentacule/process_pool.py index d166086..a47b996 100644 --- a/src/tentacule/process_pool.py +++ b/src/tentacule/process_pool.py @@ -48,20 +48,24 @@ def close(self, force: bool = False): for process in self._pool: terminate_process_with_timeout(process.native_process, self.terminate_timeout) - def new_task(self, task: Callable, *args, **kwargs) -> str: + def submit(self, task: Callable, *args, **kwargs) -> str: task_id = generate_unique_id() self._task_queue.put((task_id, dill.dumps(task), args, kwargs)) self._result_events[task_id] = Event() return task_id - def get_result(self, task_id: str, timeout: int = 30) -> Any: + def fetch(self, task_id: str, timeout: int = 30) -> Any: try: self._result_events[task_id].wait(timeout) return self._results[task_id][0] finally: self._result_events.pop(task_id) + def submit_and_fetch(self, task: Callable, *args, timeout: int = 30, **kwargs): + task_id = self.submit(task, *args, **kwargs) + return self.fetch(task_id, timeout) + def _rebalance(self): self._pool = [p for p in self._pool if p.native_process.is_alive()] if self.workers == len(self._pool): diff --git a/tests/test_simple_functions.py b/tests/test_simple_functions.py index cbe1256..5e15fd3 100644 --- a/tests/test_simple_functions.py +++ b/tests/test_simple_functions.py @@ -6,9 +6,7 @@ def simple_function(): def test_simple_function(pool): - task_id = pool.new_task(simple_function) - - assert pool.get_result(task_id) == 5 + assert pool.submit_and_fetch(simple_function) == 5 def add(a: int, b: int): @@ -16,15 +14,11 @@ def add(a: int, b: int): def test_simple_function_with_args(pool): - task_id = pool.new_task(add, 6, 7) - - assert pool.get_result(task_id) == 13 + assert pool.submit_and_fetch(add, 6, 7) == 13 def test_simple_function_with_kwargs(pool): - task_id = pool.new_task(add, a=5, b=5) - - assert pool.get_result(task_id) == 10 + assert pool.submit_and_fetch(add, a=5, b=5) == 10 def default_a(a: int = 0): @@ -32,11 +26,8 @@ def default_a(a: int = 0): def test_simple_function_with_default(pool): - task_id = pool.new_task(default_a) - assert pool.get_result(task_id) == 0 - - task_id = pool.new_task(default_a, 80) - assert pool.get_result(task_id) == 80 + assert pool.submit_and_fetch(default_a) == 0 + assert pool.submit_and_fetch(default_a, 80) == 80 def recurse(a: int = 0): @@ -46,8 +37,7 @@ def recurse(a: int = 0): def test_recursive(pool): - task_id = pool.new_task(recurse) - assert pool.get_result(task_id) == 10 + assert pool.submit_and_fetch(recurse) == 10 def dependency_subtract(a: int, b: int): @@ -55,5 +45,4 @@ def dependency_subtract(a: int, b: int): def test_function_with_dependency(pool): - task_id = pool.new_task(dependency_subtract, 7, 2) - assert pool.get_result(task_id) == 5 + assert pool.submit_and_fetch(dependency_subtract, 7, 2) == 5