Skip to content

Commit

Permalink
Merge pull request #1293 from qstokkink/add_get_task
Browse files Browse the repository at this point in the history
Added methods for task retrieval to `TaskManager`
  • Loading branch information
qstokkink committed Apr 11, 2024
2 parents c17400e + 0c1ec0f commit 3169dbd
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 14 deletions.
26 changes: 21 additions & 5 deletions ipv8/taskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self) -> None:
self._counter = 0
self._logger = logging.getLogger(self.__class__.__name__)

self._checker = self.register_task('_check_tasks', self._check_tasks,
self._checker = self.register_task("_check_tasks", self._check_tasks,
interval=MAX_TASK_AGE, delay=MAX_TASK_AGE * 1.5)

def _check_tasks(self) -> None:
Expand Down Expand Up @@ -139,8 +139,8 @@ def register_task(self, name: Hashable, task: Callable | Coroutine | Future, #
task.start_time = time.time() # type: ignore[attr-defined]
task.interval = interval # type: ignore[attr-defined]
# The set_name function is only available in Python 3.8+
task_name = f'{self.__class__.__name__}:{name}'
if hasattr(task, 'set_name'):
task_name = f"{self.__class__.__name__}:{name}"
if hasattr(task, "set_name"):
task.set_name(task_name)
else:
task.name = task_name # type: ignore[attr-defined]
Expand All @@ -154,7 +154,7 @@ def done_cb(future: Future) -> None:
except CancelledError:
pass
except ignore as e: # type: ignore[misc]
self._logger.exception('Task resulted in error: %s\n%s', e, ''.join(traceback.format_exc()))
self._logger.exception("Task resulted in error: %s\n%s", e, "".join(traceback.format_exc()))

self._pending_tasks[name] = task
task.add_done_callback(done_cb)
Expand All @@ -165,7 +165,7 @@ def register_anonymous_task(self, basename: str, task: Callable | Coroutine | Fu
Wrapper for register_task to derive a unique name from the basename.
"""
self._counter += 1
return self.register_task(basename + ' ' + str(self._counter), task, *args, **kwargs)
return self.register_task(basename + " " + str(self._counter), task, *args, **kwargs)

def register_executor_task(self, name: str, func: Callable, *args: Any, # noqa: ANN401
executor: ThreadPoolExecutor | None = None, anon: bool = False, **kwargs) -> Future:
Expand Down Expand Up @@ -212,13 +212,29 @@ def is_pending_task_active(self, name: Hashable) -> bool:
task = self._pending_tasks.get(name, None)
return not task.done() if task else False

def get_task(self, name: Hashable) -> Future | None:
"""
Return a task if it exists. Otherwise, return None.
"""
with self._task_lock:
return self._pending_tasks.get(name, None)

def get_tasks(self) -> list[Future]:
"""
Returns a list of all registered tasks, excluding tasks the are created by the TaskManager itself.
"""
with self._task_lock:
return [t for t in self._pending_tasks.values() if t != self._checker]

def get_anonymous_tasks(self, base_name: str) -> list[Future]:
"""
Return all tasks with a given base name.
Note that this method will return ALL tasks that start with the given base name, including non-anonymous ones.
"""
with self._task_lock:
return [t[1] for t in self._pending_tasks.items() if isinstance(t[0], str) and t[0].startswith(base_name)]

async def wait_for_tasks(self) -> None:
"""
Waits until all registered tasks are done.
Expand Down
82 changes: 73 additions & 9 deletions ipv8/test/test_taskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@ async def tearDown(self) -> None:
await self.tm.shutdown_task_manager()
return await super().tearDown()

def count(self) -> None:
"""
A function used to increment the local counter.
"""
self.counter += 1

def set_counter(self, value: int) -> None:
"""
Set the testing counter to a fixed value.
"""
self.counter = value

def test_call_later(self) -> None:
"""
Check that tasks can be sheduled for the future.
Expand Down Expand Up @@ -147,7 +159,7 @@ async def test_shutdown(self) -> None:
"""
await self.tm.shutdown_task_manager()
task = self.tm.register_anonymous_task("test", lambda: None)
self.assertFalse(self.tm.is_pending_task_active('test'))
self.assertFalse(self.tm.is_pending_task_active("test"))
self.assertFalse(task.cancelled())

async def test_cleanup(self) -> None:
Expand Down Expand Up @@ -176,7 +188,7 @@ def exception_handler(_: AbstractEventLoop, __: dict[str, Any]) -> None:

get_running_loop().set_exception_handler(exception_handler)
with suppress(ZeroDivisionError):
await self.tm.register_task('test', lambda: 1 / 0)
await self.tm.register_task("test", lambda: 1 / 0)
self.assertTrue(exception_handler.called)

async def test_task_with_exception_ignore(self) -> None:
Expand All @@ -191,7 +203,7 @@ def exception_handler(_: AbstractEventLoop, __: dict[str, Any]) -> None:

get_running_loop().set_exception_handler(exception_handler)
with suppress(ZeroDivisionError):
await self.tm.register_task('test', lambda: 1 / 0, ignore=(ZeroDivisionError,))
await self.tm.register_task("test", lambda: 1 / 0, ignore=(ZeroDivisionError,))
self.assertFalse(exception_handler.called)

async def test_task_decorator_coro(self) -> None:
Expand Down Expand Up @@ -279,14 +291,66 @@ async def test_register_executor_task_anon(self) -> None:
_ = self.tm.register_executor_task("test", test, anon=True)
self.assertEqual(2, len(self.tm.get_tasks()))

def count(self) -> None:
async def test_get_task_existing_pending(self) -> None:
"""
A function used to increment the local counter.
Check if an existing pending task can be retrieved.
"""
self.counter += 1
registered = self.tm.register_task("test", lambda: None, delay=10.0)
await sleep(0)

def set_counter(self, value: int) -> None:
retrieved = self.tm.get_task("test")

self.assertEqual(registered, retrieved)
self.assertFalse(retrieved.done())

async def test_get_task_existing_finished(self) -> None:
"""
Set the testing counter to a fixed value.
Check if an existing finished task can be retrieved.
"""
self.counter = value
registered = self.tm.register_task("test", lambda: None)
await sleep(0)

retrieved = self.tm.get_task("test")

self.assertEqual(registered, retrieved)
self.assertTrue(retrieved.done())

def test_get_task_non_existent(self) -> None:
"""
Check if retrieving an unknown task returns None.
"""
retrieved = self.tm.get_task("test")

self.assertIsNone(retrieved)

async def test_get_anon_tasks_existing_pending(self) -> None:
"""
Check if existing pending anonymous tasks can be retrieved.
"""
registered = self.tm.register_anonymous_task("test", lambda: None, delay=10.0)
await sleep(0)

retrieved = self.tm.get_anonymous_tasks("test")

self.assertListEqual([registered], retrieved)
self.assertFalse(retrieved[0].done())

async def test_get_anon_tasks_existing_finished(self) -> None:
"""
Check if existing finished anonymous tasks can be retrieved.
"""
registered = self.tm.register_anonymous_task("test", lambda: None)
await sleep(0)

retrieved = self.tm.get_anonymous_tasks("test")

self.assertListEqual([registered], retrieved)
self.assertTrue(retrieved[0].done())

def test_get_anon_tasks_non_existent(self) -> None:
"""
Check if retrieving anonymous tasks with an unknown base name returns an empty list.
"""
retrieved = self.tm.get_anonymous_tasks("test")

self.assertListEqual([], retrieved)

0 comments on commit 3169dbd

Please sign in to comment.