diff --git a/docs/changelog-fragments.d/1324.feature b/docs/changelog-fragments.d/1324.feature new file mode 100644 index 0000000000..408dad86e2 --- /dev/null +++ b/docs/changelog-fragments.d/1324.feature @@ -0,0 +1 @@ +Add `--listener-pool-klass`, `--acceptor-pool-klass`, `--threadless-pool-klass` diff --git a/proxy/common/constants.py b/proxy/common/constants.py index bd0a40e785..f09e2aaa7c 100644 --- a/proxy/common/constants.py +++ b/proxy/common/constants.py @@ -114,6 +114,8 @@ def _env_threadless_compliant() -> bool: '{response_bytes} bytes - {connection_time_ms}ms' DEFAULT_REVERSE_PROXY_ACCESS_LOG_FORMAT = '{client_ip}:{client_port} - ' + \ '{request_method} {request_path} -> {upstream_proxy_pass} - {connection_time_ms}ms' +DEFAULT_LISTENER_POOL_KLASS = 'proxy.core.listener.ListenerPool' +DEFAULT_ACCEPTOR_POOL_KLASS = 'proxy.core.acceptor.AcceptorPool' DEFAULT_NUM_ACCEPTORS = 0 DEFAULT_NUM_WORKERS = 0 DEFAULT_OPEN_FILE_LIMIT = 1024 @@ -127,6 +129,7 @@ def _env_threadless_compliant() -> bool: DEFAULT_STATIC_SERVER_DIR = os.path.join(PROXY_PY_DIR, "public") DEFAULT_MIN_COMPRESSION_LENGTH = 20 # In bytes DEFAULT_THREADLESS = _env_threadless_compliant() +DEFAULT_THREADLESS_POOL_KLASS = 'proxy.core.work.ThreadlessPool' DEFAULT_LOCAL_EXECUTOR = True DEFAULT_TIMEOUT = 10.0 DEFAULT_VERSION = False diff --git a/proxy/common/flag.py b/proxy/common/flag.py index f8395a6f62..5d3c986a38 100644 --- a/proxy/common/flag.py +++ b/proxy/common/flag.py @@ -138,6 +138,24 @@ def initialize( if isinstance(work_klass, str) \ else work_klass + # Load acceptor_pool_klass + acceptor_pool_klass = opts.get('acceptor_pool_klass', args.acceptor_pool_klass) + acceptor_pool_klass = Plugins.importer(bytes_(acceptor_pool_klass))[0] \ + if isinstance(acceptor_pool_klass, str) \ + else acceptor_pool_klass + + # Load listener_pool_klass + listener_pool_klass = opts.get('listener_pool_klass', args.listener_pool_klass) + listener_pool_klass = Plugins.importer(bytes_(listener_pool_klass))[0] \ + if isinstance(listener_pool_klass, str) \ + else listener_pool_klass + + # Load threadless_pool_klass + threadless_pool_klass = opts.get('threadless_pool_klass', args.threadless_pool_klass) + threadless_pool_klass = Plugins.importer(bytes_(threadless_pool_klass))[0] \ + if isinstance(threadless_pool_klass, str) \ + else threadless_pool_klass + # TODO: Plugin flag initialization logic must be moved within plugins. # # Generate auth_code required for basic authentication if enabled @@ -201,6 +219,8 @@ def initialize( # def option(t: object, key: str, default: Any) -> Any: # return cast(t, opts.get(key, default)) args.work_klass = work_klass + args.acceptor_pool_klass = acceptor_pool_klass + args.listener_pool_klass = listener_pool_klass args.plugins = plugins args.auth_code = cast( Optional[bytes], @@ -376,6 +396,7 @@ def initialize( # evaluates to False. args.threadless = cast(bool, opts.get('threadless', args.threadless)) args.threadless = is_threadless(args.threadless, args.threaded) + args.threadless_pool_klass = threadless_pool_klass args.pid_file = cast( Optional[str], opts.get( diff --git a/proxy/common/plugins.py b/proxy/common/plugins.py index f92193ee8c..96dba3e619 100644 --- a/proxy/common/plugins.py +++ b/proxy/common/plugins.py @@ -111,7 +111,7 @@ def locate_klass(klass_module_name: str, klass_path: List[str]) -> Union[type, N klass_container = getattr(klass_container, klass_path_part) except AttributeError: return None - if not isinstance(klass_container, type) or not inspect.isclass(klass_container): + if not callable(klass_container): return None return klass_container diff --git a/proxy/core/acceptor/pool.py b/proxy/core/acceptor/pool.py index 09fb9f447f..19871cef39 100644 --- a/proxy/core/acceptor/pool.py +++ b/proxy/core/acceptor/pool.py @@ -24,7 +24,9 @@ from .acceptor import Acceptor from ..listener import ListenerPool from ...common.flag import flags -from ...common.constants import DEFAULT_NUM_ACCEPTORS +from ...common.constants import ( + DEFAULT_NUM_ACCEPTORS, DEFAULT_ACCEPTOR_POOL_KLASS, +) if TYPE_CHECKING: # pragma: no cover @@ -33,6 +35,14 @@ logger = logging.getLogger(__name__) +flags.add_argument( + '--acceptor-pool-klass', + type=str, + default=DEFAULT_ACCEPTOR_POOL_KLASS, + help='Default: ' + DEFAULT_ACCEPTOR_POOL_KLASS + + '. Acceptor pool klass.', +) + flags.add_argument( '--num-acceptors', type=int, diff --git a/proxy/core/listener/pool.py b/proxy/core/listener/pool.py index b362ae558c..aef0b724fd 100644 --- a/proxy/core/listener/pool.py +++ b/proxy/core/listener/pool.py @@ -13,12 +13,23 @@ from .tcp import TcpSocketListener from .unix import UnixSocketListener +from ...common.flag import flags +from ...common.constants import DEFAULT_LISTENER_POOL_KLASS if TYPE_CHECKING: # pragma: no cover from .base import BaseListener +flags.add_argument( + '--listener-pool-klass', + type=str, + default=DEFAULT_LISTENER_POOL_KLASS, + help='Default: ' + DEFAULT_LISTENER_POOL_KLASS + + '. Listener pool klass.', +) + + class ListenerPool: """Provides abstraction around starting multiple listeners based upon flags.""" diff --git a/proxy/core/work/pool.py b/proxy/core/work/pool.py index 5458f0a89d..12d738f2c0 100644 --- a/proxy/core/work/pool.py +++ b/proxy/core/work/pool.py @@ -15,7 +15,9 @@ from multiprocessing import connection from ...common.flag import flags -from ...common.constants import DEFAULT_THREADLESS, DEFAULT_NUM_WORKERS +from ...common.constants import ( + DEFAULT_THREADLESS, DEFAULT_NUM_WORKERS, DEFAULT_THREADLESS_POOL_KLASS, +) if TYPE_CHECKING: # pragma: no cover @@ -54,6 +56,14 @@ help='Defaults to number of CPU cores.', ) +flags.add_argument( + '--threadless-pool-klass', + type=str, + default=DEFAULT_THREADLESS_POOL_KLASS, + help='Default: ' + DEFAULT_THREADLESS_POOL_KLASS + + '. Threadless pool klass.', +) + class ThreadlessPool: """Manages lifecycle of threadless pool and delegates work to them diff --git a/proxy/proxy.py b/proxy/proxy.py index d9d9f89798..3b4e23fa0c 100644 --- a/proxy/proxy.py +++ b/proxy/proxy.py @@ -199,7 +199,10 @@ def setup(self) -> None: self._write_pid_file() # We setup listeners first because of flags.port override # in case of ephemeral port being used - self.listeners = ListenerPool(flags=self.flags) + self.listeners = cast( + 'ListenerPool', + self.flags.listener_pool_klass(flags=self.flags), + ) self.listeners.setup() # Override flags.port to match the actual port # we are listening upon. This is necessary to preserve @@ -234,20 +237,26 @@ def setup(self) -> None: # Setup remote executors only if # --local-executor mode isn't enabled. if self.remote_executors_enabled: - self.executors = ThreadlessPool( - flags=self.flags, - event_queue=event_queue, - executor_klass=RemoteFdExecutor, + self.executors = cast( + 'ThreadlessPool', + self.flags.threadless_pool_klass( + flags=self.flags, + event_queue=event_queue, + executor_klass=RemoteFdExecutor, + ), ) self.executors.setup() # Setup acceptors - self.acceptors = AcceptorPool( - flags=self.flags, - listeners=self.listeners, - executor_queues=self.executors.work_queues if self.executors else [], - executor_pids=self.executors.work_pids if self.executors else [], - executor_locks=self.executors.work_locks if self.executors else [], - event_queue=event_queue, + self.acceptors = cast( + 'AcceptorPool', + self.flags.acceptor_pool_klass( + flags=self.flags, + listeners=self.listeners, + executor_queues=self.executors.work_queues if self.executors else [], + executor_pids=self.executors.work_pids if self.executors else [], + executor_locks=self.executors.work_locks if self.executors else [], + event_queue=event_queue, + ), ) self.acceptors.setup() # Start SSH tunnel acceptor if enabled diff --git a/tests/test_main.py b/tests/test_main.py index a328f6919b..4f3d3b84b2 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -32,10 +32,11 @@ DEFAULT_ENABLE_DASHBOARD, PLUGIN_DEVTOOLS_PROTOCOL, DEFAULT_ENABLE_SSH_TUNNEL, DEFAULT_ENABLE_WEB_SERVER, DEFAULT_DISABLE_HTTP_PROXY, PLUGIN_WEBSOCKET_TRANSPORT, - DEFAULT_CA_SIGNING_KEY_FILE, DEFAULT_CLIENT_RECVBUF_SIZE, + DEFAULT_ACCEPTOR_POOL_KLASS, DEFAULT_CA_SIGNING_KEY_FILE, + DEFAULT_CLIENT_RECVBUF_SIZE, DEFAULT_LISTENER_POOL_KLASS, DEFAULT_SERVER_RECVBUF_SIZE, DEFAULT_CACHE_DIRECTORY_PATH, DEFAULT_ENABLE_REVERSE_PROXY, DEFAULT_ENABLE_STATIC_SERVER, - _env_threadless_compliant, + DEFAULT_THREADLESS_POOL_KLASS, _env_threadless_compliant, ) @@ -58,6 +59,8 @@ def mock_default_args(mock_args: mock.Mock) -> None: mock_args.basic_auth = DEFAULT_BASIC_AUTH mock_args.hostname = DEFAULT_IPV6_HOSTNAME mock_args.port = DEFAULT_PORT + mock_args.listener_pool_klass = DEFAULT_LISTENER_POOL_KLASS + mock_args.acceptor_pool_klass = DEFAULT_ACCEPTOR_POOL_KLASS mock_args.num_acceptors = DEFAULT_NUM_ACCEPTORS mock_args.num_workers = DEFAULT_NUM_WORKERS mock_args.disable_http_proxy = DEFAULT_DISABLE_HTTP_PROXY @@ -71,6 +74,7 @@ def mock_default_args(mock_args: mock.Mock) -> None: mock_args.devtools_ws_path = DEFAULT_DEVTOOLS_WS_PATH mock_args.timeout = DEFAULT_TIMEOUT mock_args.threadless = DEFAULT_THREADLESS + mock_args.threadless_pool_klass = DEFAULT_THREADLESS_POOL_KLASS mock_args.threaded = not DEFAULT_THREADLESS mock_args.enable_web_server = DEFAULT_ENABLE_WEB_SERVER mock_args.enable_static_server = DEFAULT_ENABLE_STATIC_SERVER @@ -91,9 +95,9 @@ def mock_default_args(mock_args: mock.Mock) -> None: @mock.patch('time.sleep') @mock.patch('proxy.proxy.FlagParser.initialize') @mock.patch('proxy.proxy.EventManager') - @mock.patch('proxy.proxy.AcceptorPool') - @mock.patch('proxy.proxy.ThreadlessPool') - @mock.patch('proxy.proxy.ListenerPool') + @mock.patch(DEFAULT_ACCEPTOR_POOL_KLASS) + @mock.patch(DEFAULT_THREADLESS_POOL_KLASS) + @mock.patch(DEFAULT_LISTENER_POOL_KLASS) def test_entry_point( self, mock_listener_pool: mock.Mock, @@ -147,9 +151,9 @@ def test_entry_point( @mock.patch('time.sleep') @mock.patch('proxy.proxy.FlagParser.initialize') @mock.patch('proxy.proxy.EventManager') - @mock.patch('proxy.proxy.AcceptorPool') - @mock.patch('proxy.proxy.ThreadlessPool') - @mock.patch('proxy.proxy.ListenerPool') + @mock.patch(DEFAULT_ACCEPTOR_POOL_KLASS) + @mock.patch(DEFAULT_THREADLESS_POOL_KLASS) + @mock.patch(DEFAULT_LISTENER_POOL_KLASS) def test_main_with_no_flags( self, mock_listener_pool: mock.Mock, @@ -191,9 +195,9 @@ def test_main_with_no_flags( @mock.patch('time.sleep') @mock.patch('proxy.proxy.FlagParser.initialize') @mock.patch('proxy.proxy.EventManager') - @mock.patch('proxy.proxy.AcceptorPool') - @mock.patch('proxy.proxy.ThreadlessPool') - @mock.patch('proxy.proxy.ListenerPool') + @mock.patch(DEFAULT_ACCEPTOR_POOL_KLASS) + @mock.patch(DEFAULT_THREADLESS_POOL_KLASS) + @mock.patch(DEFAULT_LISTENER_POOL_KLASS) def test_enable_events( self, mock_listener_pool: mock.Mock, @@ -238,9 +242,9 @@ def test_enable_events( @mock.patch('proxy.common.plugins.Plugins.load') @mock.patch('proxy.common.flag.FlagParser.parse_args') @mock.patch('proxy.proxy.EventManager') - @mock.patch('proxy.proxy.AcceptorPool') - @mock.patch('proxy.proxy.ThreadlessPool') - @mock.patch('proxy.proxy.ListenerPool') + @mock.patch(DEFAULT_ACCEPTOR_POOL_KLASS) + @mock.patch(DEFAULT_THREADLESS_POOL_KLASS) + @mock.patch(DEFAULT_LISTENER_POOL_KLASS) def test_enable_dashboard( self, mock_listener_pool: mock.Mock, @@ -285,9 +289,9 @@ def test_enable_dashboard( @mock.patch('proxy.common.plugins.Plugins.load') @mock.patch('proxy.common.flag.FlagParser.parse_args') @mock.patch('proxy.proxy.EventManager') - @mock.patch('proxy.proxy.AcceptorPool') - @mock.patch('proxy.proxy.ThreadlessPool') - @mock.patch('proxy.proxy.ListenerPool') + @mock.patch(DEFAULT_ACCEPTOR_POOL_KLASS) + @mock.patch(DEFAULT_THREADLESS_POOL_KLASS) + @mock.patch(DEFAULT_LISTENER_POOL_KLASS) def test_enable_devtools( self, mock_listener_pool: mock.Mock, @@ -326,9 +330,9 @@ def test_enable_devtools( @mock.patch('proxy.common.plugins.Plugins.load') @mock.patch('proxy.common.flag.FlagParser.parse_args') @mock.patch('proxy.proxy.EventManager') - @mock.patch('proxy.proxy.AcceptorPool') - @mock.patch('proxy.proxy.ThreadlessPool') - @mock.patch('proxy.proxy.ListenerPool') + @mock.patch(DEFAULT_ACCEPTOR_POOL_KLASS) + @mock.patch(DEFAULT_THREADLESS_POOL_KLASS) + @mock.patch(DEFAULT_LISTENER_POOL_KLASS) @mock.patch('proxy.proxy.SshHttpProtocolHandler') @mock.patch('proxy.proxy.SshTunnelListener') def test_enable_ssh_tunnel(