From 2c7ce957ec44ee7291bd6b6d4ffa41346a1d71ad Mon Sep 17 00:00:00 2001 From: Maciej Urbanski Date: Sat, 24 Feb 2024 23:13:21 +0100 Subject: [PATCH 1/2] allow set_thread_pool_size to be set after pool has been once used already --- b2sdk/utils/thread_pool.py | 67 ++++++++++++++++++++++-- b2sdk/v2/transfer.py | 37 ++----------- changelog.d/+set_threads.added.md | 1 + test/unit/{v2 => v_all}/test_transfer.py | 13 +++-- 4 files changed, 76 insertions(+), 42 deletions(-) create mode 100644 changelog.d/+set_threads.added.md rename test/unit/{v2 => v_all}/test_transfer.py (61%) diff --git a/b2sdk/utils/thread_pool.py b/b2sdk/utils/thread_pool.py index f07c11ccb..7726cdcb5 100644 --- a/b2sdk/utils/thread_pool.py +++ b/b2sdk/utils/thread_pool.py @@ -9,27 +9,84 @@ ###################################################################### from __future__ import annotations -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor +from typing import Callable from b2sdk.utils import B2TraceMetaAbstract +class LazyThreadPool: + """ + Lazily initialized thread pool. + """ + + _THREAD_POOL_FACTORY = ThreadPoolExecutor + + def __init__(self, max_workers: int | None = None, **kwargs): + self._max_workers = max_workers + self._thread_pool: ThreadPoolExecutor | None = None + super().__init__(**kwargs) + + def submit(self, fn: Callable, *args, **kwargs) -> Future: + if self._thread_pool is None: + self._thread_pool = self._THREAD_POOL_FACTORY(self._max_workers) + return self._thread_pool.submit(fn, *args, **kwargs) + + def set_size(self, max_workers: int) -> None: + """ + Set the size of the thread pool. + + This operation will block until all tasks in the current thread pool are completed. + + :param max_workers: New size of the thread pool + :return: None + """ + if self._max_workers == max_workers: + return + old_thread_pool = self._thread_pool + self._thread_pool = self._THREAD_POOL_FACTORY(max_workers=max_workers) + if old_thread_pool is not None: + old_thread_pool.shutdown(wait=True) + self._max_workers = max_workers + + def get_size(self) -> int | None: + return self._max_workers + + class ThreadPoolMixin(metaclass=B2TraceMetaAbstract): """ Mixin class with ThreadPoolExecutor. """ - DEFAULT_THREAD_POOL_CLASS = staticmethod(ThreadPoolExecutor) + + DEFAULT_THREAD_POOL_CLASS = LazyThreadPool def __init__( self, thread_pool: ThreadPoolExecutor | None = None, max_workers: int | None = None, - **kwargs + **kwargs, ): """ :param thread_pool: thread pool to be used :param max_workers: maximum number of worker threads (ignored if thread_pool is not None) """ - self._thread_pool = thread_pool if thread_pool is not None \ - else self.DEFAULT_THREAD_POOL_CLASS(max_workers=max_workers) + self._thread_pool = ( + thread_pool + if thread_pool is not None else self.DEFAULT_THREAD_POOL_CLASS(max_workers=max_workers) + ) + self._max_workers = max_workers super().__init__(**kwargs) + + def set_thread_pool_size(self, max_workers: int) -> None: + """ + Set the size of the thread pool. + + This operation will block until all tasks in the current thread pool are completed. + + :param max_workers: New size of the thread pool + :return: None + """ + return self._thread_pool.set_size(max_workers) + + def get_thread_pool_size(self) -> int | None: + return self._thread_pool.get_size() diff --git a/b2sdk/v2/transfer.py b/b2sdk/v2/transfer.py index da4d28b68..151f5ff1c 100644 --- a/b2sdk/v2/transfer.py +++ b/b2sdk/v2/transfer.py @@ -9,46 +9,17 @@ ###################################################################### from __future__ import annotations -from concurrent.futures import Future, ThreadPoolExecutor -from typing import Callable - from b2sdk import _v3 as v3 - - -class LazyThreadPool: - """ - Lazily initialized thread pool. - """ - - def __init__(self, max_workers: int | None = None, **kwargs): - self._max_workers = max_workers - self._thread_pool = None # type: 'Optional[ThreadPoolExecutor]' - super().__init__(**kwargs) - - def submit(self, fn: Callable, *args, **kwargs) -> Future: - if self._thread_pool is None: - self._thread_pool = ThreadPoolExecutor(self._max_workers) - return self._thread_pool.submit(fn, *args, **kwargs) - - def set_size(self, max_workers: int) -> None: - if self._max_workers == max_workers: - return - if self._thread_pool is not None: - raise RuntimeError('Thread pool already created') - self._max_workers = max_workers +from b2sdk.utils.thread_pool import LazyThreadPool # noqa: F401 class ThreadPoolMixin(v3.ThreadPoolMixin): - DEFAULT_THREAD_POOL_CLASS = staticmethod(LazyThreadPool) - - # This method is used in CLI even though it doesn't belong to the public API - def set_thread_pool_size(self, max_workers: int) -> None: - self._thread_pool.set_size(max_workers) + pass -class DownloadManager(v3.DownloadManager, ThreadPoolMixin): +class DownloadManager(v3.DownloadManager): pass -class UploadManager(v3.UploadManager, ThreadPoolMixin): +class UploadManager(v3.UploadManager): pass diff --git a/changelog.d/+set_threads.added.md b/changelog.d/+set_threads.added.md new file mode 100644 index 000000000..45ec116cf --- /dev/null +++ b/changelog.d/+set_threads.added.md @@ -0,0 +1 @@ +Add `set_thread_pool_size`, `get_thread_pool_size` to *Manger classes. diff --git a/test/unit/v2/test_transfer.py b/test/unit/v_all/test_transfer.py similarity index 61% rename from test/unit/v2/test_transfer.py rename to test/unit/v_all/test_transfer.py index 9636b199e..1492cac08 100644 --- a/test/unit/v2/test_transfer.py +++ b/test/unit/v_all/test_transfer.py @@ -1,6 +1,6 @@ ###################################################################### # -# File: test/unit/v2/test_transfer.py +# File: test/unit/v_all/test_transfer.py # # Copyright 2022 Backblaze Inc. All Rights Reserved. # @@ -11,19 +11,24 @@ from unittest.mock import Mock +from apiver_deps import DownloadManager, UploadManager + from ..test_base import TestBase -from .apiver.apiver_deps import DownloadManager, UploadManager class TestDownloadManager(TestBase): def test_set_thread_pool_size(self) -> None: download_manager = DownloadManager(services=Mock()) + assert download_manager.get_thread_pool_size() is None download_manager.set_thread_pool_size(21) - self.assertEqual(download_manager._thread_pool._max_workers, 21) + assert download_manager._thread_pool._max_workers == 21 + assert download_manager.get_thread_pool_size() == 21 class TestUploadManager(TestBase): def test_set_thread_pool_size(self) -> None: upload_manager = UploadManager(services=Mock()) + assert upload_manager.get_thread_pool_size() is None upload_manager.set_thread_pool_size(37) - self.assertEqual(upload_manager._thread_pool._max_workers, 37) + assert upload_manager._thread_pool._max_workers == 37 + assert upload_manager.get_thread_pool_size() == 37 From ed48c424290428d3acbb6965c3ee1709056b040b Mon Sep 17 00:00:00 2001 From: Maciej Urbanski Date: Mon, 26 Feb 2024 11:37:10 +0100 Subject: [PATCH 2/2] fix typing on thread_pool and add tests --- b2sdk/utils/thread_pool.py | 28 +++++++++++++++++-- test/unit/utils/test_thread_pool.py | 43 +++++++++++++++++++++++++++++ test/unit/v_all/test_transfer.py | 18 ++++++------ 3 files changed, 78 insertions(+), 11 deletions(-) create mode 100644 test/unit/utils/test_thread_pool.py diff --git a/b2sdk/utils/thread_pool.py b/b2sdk/utils/thread_pool.py index 7726cdcb5..b9a190c89 100644 --- a/b2sdk/utils/thread_pool.py +++ b/b2sdk/utils/thread_pool.py @@ -9,12 +9,29 @@ ###################################################################### from __future__ import annotations +import os from concurrent.futures import Future, ThreadPoolExecutor from typing import Callable +try: + from typing_extensions import Protocol +except ImportError: + from typing import Protocol + from b2sdk.utils import B2TraceMetaAbstract +class DynamicThreadPoolExecutorProtocol(Protocol): + def submit(self, fn: Callable, *args, **kwargs) -> Future: + ... + + def set_size(self, max_workers: int) -> None: + """Set the size of the thread pool.""" + + def get_size(self) -> int: + """Return the current size of the thread pool.""" + + class LazyThreadPool: """ Lazily initialized thread pool. @@ -23,6 +40,10 @@ class LazyThreadPool: _THREAD_POOL_FACTORY = ThreadPoolExecutor def __init__(self, max_workers: int | None = None, **kwargs): + if max_workers is None: + max_workers = min( + 32, (os.cpu_count() or 1) + 4 + ) # same default as in ThreadPoolExecutor self._max_workers = max_workers self._thread_pool: ThreadPoolExecutor | None = None super().__init__(**kwargs) @@ -49,7 +70,8 @@ def set_size(self, max_workers: int) -> None: old_thread_pool.shutdown(wait=True) self._max_workers = max_workers - def get_size(self) -> int | None: + def get_size(self) -> int: + """Return the current size of the thread pool.""" return self._max_workers @@ -62,7 +84,7 @@ class ThreadPoolMixin(metaclass=B2TraceMetaAbstract): def __init__( self, - thread_pool: ThreadPoolExecutor | None = None, + thread_pool: DynamicThreadPoolExecutorProtocol | None = None, max_workers: int | None = None, **kwargs, ): @@ -88,5 +110,5 @@ def set_thread_pool_size(self, max_workers: int) -> None: """ return self._thread_pool.set_size(max_workers) - def get_thread_pool_size(self) -> int | None: + def get_thread_pool_size(self) -> int: return self._thread_pool.get_size() diff --git a/test/unit/utils/test_thread_pool.py b/test/unit/utils/test_thread_pool.py new file mode 100644 index 000000000..5742d3e35 --- /dev/null +++ b/test/unit/utils/test_thread_pool.py @@ -0,0 +1,43 @@ +###################################################################### +# +# File: test/unit/utils/test_thread_pool.py +# +# Copyright 2024 Backblaze Inc. All Rights Reserved. +# +# License https://www.backblaze.com/using_b2_code.html +# +###################################################################### +from concurrent.futures import Future + +import pytest + +from b2sdk.utils.thread_pool import LazyThreadPool + + +class TestLazyThreadPool: + @pytest.fixture + def thread_pool(self): + return LazyThreadPool() + + def test_submit(self, thread_pool): + + future = thread_pool.submit(sum, (1, 2)) + assert isinstance(future, Future) + assert future.result() == 3 + + def test_set_size(self, thread_pool): + thread_pool.set_size(10) + assert thread_pool.get_size() == 10 + + def test_get_size(self, thread_pool): + assert thread_pool.get_size() > 0 + + def test_set_size__after_submit(self, thread_pool): + future = thread_pool.submit(sum, (1, 2)) + + thread_pool.set_size(7) + assert thread_pool.get_size() == 7 + + assert future.result() == 3 + + assert thread_pool.submit(sum, (1,)).result() == 1 diff --git a/test/unit/v_all/test_transfer.py b/test/unit/v_all/test_transfer.py index 1492cac08..584bf6d68 100644 --- a/test/unit/v_all/test_transfer.py +++ b/test/unit/v_all/test_transfer.py @@ -19,16 +19,18 @@ class TestDownloadManager(TestBase): def test_set_thread_pool_size(self) -> None: download_manager = DownloadManager(services=Mock()) - assert download_manager.get_thread_pool_size() is None - download_manager.set_thread_pool_size(21) - assert download_manager._thread_pool._max_workers == 21 - assert download_manager.get_thread_pool_size() == 21 + assert download_manager.get_thread_pool_size() > 0 + + pool_size = 21 + download_manager.set_thread_pool_size(pool_size) + assert download_manager.get_thread_pool_size() == pool_size class TestUploadManager(TestBase): def test_set_thread_pool_size(self) -> None: upload_manager = UploadManager(services=Mock()) - assert upload_manager.get_thread_pool_size() is None - upload_manager.set_thread_pool_size(37) - assert upload_manager._thread_pool._max_workers == 37 - assert upload_manager.get_thread_pool_size() == 37 + assert upload_manager.get_thread_pool_size() > 0 + + pool_size = 37 + upload_manager.set_thread_pool_size(pool_size) + assert upload_manager.get_thread_pool_size() == pool_size