Skip to content

Commit

Permalink
better api
Browse files Browse the repository at this point in the history
  • Loading branch information
BinarSkugga committed Oct 11, 2023
1 parent 6f2f72f commit 7f6bc50
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 24 deletions.
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,26 @@ pool.start()
def task(arg: int):
return 2 + arg

task_id = pool.new_task(task, 3)
result = pool.get_result(task_id) # 5
task_id = pool.submit(task, 3)
result = pool.fetch(task_id) # 5

pool.close()
```

You can also do this in a single step:
```python
pool = ProcessPool(workers=3)
pool.start()

def task(arg: int):
return 2 + arg

result = pool.submit_and_fetch(task, 5) # 7

pool.close()
```


## Restrictions
- Anything [dill](https://github.com/uqfoundation/dill) can't pickle, can't be used as a task.
- Arguments are pickled using the regular pickling for processes.
8 changes: 6 additions & 2 deletions src/tentacule/i_process_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,13 @@ def close(self, force: bool):
pass

@abstractmethod
def new_task(self, task: Callable, *args, **kwargs) -> str:
def submit(self, task: Callable, *args, **kwargs) -> str:
pass

@abstractmethod
def get_result(self, task_id: str, timeout: int = 30) -> Any:
def fetch(self, task_id: str, timeout: int = 30) -> Any:
pass

@abstractmethod
def submit_and_fetch(self, task: Callable, *args, timeout: int = 30, **kwargs):
pass
8 changes: 6 additions & 2 deletions src/tentacule/process_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,24 @@ def close(self, force: bool = False):
for process in self._pool:
terminate_process_with_timeout(process.native_process, self.terminate_timeout)

def new_task(self, task: Callable, *args, **kwargs) -> str:
def submit(self, task: Callable, *args, **kwargs) -> str:
task_id = generate_unique_id()
self._task_queue.put((task_id, dill.dumps(task), args, kwargs))
self._result_events[task_id] = Event()

return task_id

def get_result(self, task_id: str, timeout: int = 30) -> Any:
def fetch(self, task_id: str, timeout: int = 30) -> Any:
try:
self._result_events[task_id].wait(timeout)
return self._results[task_id][0]
finally:
self._result_events.pop(task_id)

def submit_and_fetch(self, task: Callable, *args, timeout: int = 30, **kwargs):
task_id = self.submit(task, *args, **kwargs)
return self.fetch(task_id, timeout)

def _rebalance(self):
self._pool = [p for p in self._pool if p.native_process.is_alive()]
if self.workers == len(self._pool):
Expand Down
25 changes: 7 additions & 18 deletions tests/test_simple_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,28 @@ def simple_function():


def test_simple_function(pool):
task_id = pool.new_task(simple_function)

assert pool.get_result(task_id) == 5
assert pool.submit_and_fetch(simple_function) == 5


def add(a: int, b: int):
return a + b


def test_simple_function_with_args(pool):
task_id = pool.new_task(add, 6, 7)

assert pool.get_result(task_id) == 13
assert pool.submit_and_fetch(add, 6, 7) == 13


def test_simple_function_with_kwargs(pool):
task_id = pool.new_task(add, a=5, b=5)

assert pool.get_result(task_id) == 10
assert pool.submit_and_fetch(add, a=5, b=5) == 10


def default_a(a: int = 0):
return a


def test_simple_function_with_default(pool):
task_id = pool.new_task(default_a)
assert pool.get_result(task_id) == 0

task_id = pool.new_task(default_a, 80)
assert pool.get_result(task_id) == 80
assert pool.submit_and_fetch(default_a) == 0
assert pool.submit_and_fetch(default_a, 80) == 80


def recurse(a: int = 0):
Expand All @@ -46,14 +37,12 @@ def recurse(a: int = 0):


def test_recursive(pool):
task_id = pool.new_task(recurse)
assert pool.get_result(task_id) == 10
assert pool.submit_and_fetch(recurse) == 10


def dependency_subtract(a: int, b: int):
return subtract(a, b)


def test_function_with_dependency(pool):
task_id = pool.new_task(dependency_subtract, 7, 2)
assert pool.get_result(task_id) == 5
assert pool.submit_and_fetch(dependency_subtract, 7, 2) == 5

0 comments on commit 7f6bc50

Please sign in to comment.