diff --git a/docs/conf.py b/docs/conf.py index 8172d57f55..402e263b2d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -300,6 +300,7 @@ (_py_class_role, 'unittest.case.TestCase'), (_py_class_role, 'unittest.result.TestResult'), (_py_class_role, 'UUID'), + (_py_class_role, 'UpstreamConnectionPool'), (_py_class_role, 'Url'), (_py_class_role, 'WebsocketFrame'), (_py_class_role, 'Work'), diff --git a/proxy/core/acceptor/executors.py b/proxy/core/acceptor/executors.py index 9512762671..d0c5a912c3 100644 --- a/proxy/core/acceptor/executors.py +++ b/proxy/core/acceptor/executors.py @@ -131,6 +131,7 @@ def start_threaded_work( TcpClientConnection(conn, addr), flags=flags, event_queue=event_queue, + upstream_conn_pool=None, ) # TODO: Keep reference to threads and join during shutdown. # This will ensure connections are not abruptly closed on shutdown diff --git a/proxy/core/acceptor/threadless.py b/proxy/core/acceptor/threadless.py index 5140c9345b..9858712adc 100644 --- a/proxy/core/acceptor/threadless.py +++ b/proxy/core/acceptor/threadless.py @@ -25,7 +25,7 @@ from ...common.constants import DEFAULT_INACTIVE_CONN_CLEANUP_TIMEOUT, DEFAULT_SELECTOR_SELECT_TIMEOUT from ...common.constants import DEFAULT_WAIT_FOR_TASKS_TIMEOUT -from ..connection import TcpClientConnection +from ..connection import TcpClientConnection, UpstreamConnectionPool from ..event import eventNames, EventQueue from .work import Work @@ -87,6 +87,9 @@ def __init__( self.wait_timeout: float = DEFAULT_WAIT_FOR_TASKS_TIMEOUT self.cleanup_inactive_timeout: float = DEFAULT_INACTIVE_CONN_CLEANUP_TIMEOUT self._total: int = 0 + self._upstream_conn_pool: Optional[UpstreamConnectionPool] = None + if self.flags.enable_conn_pool: + self._upstream_conn_pool = UpstreamConnectionPool() @property @abstractmethod @@ -134,6 +137,7 @@ def work_on_tcp_conn( flags=self.flags, event_queue=self.event_queue, uid=uid, + upstream_conn_pool=self._upstream_conn_pool, ) self.works[fileno].publish_event( event_name=eventNames.WORK_STARTED, diff --git a/proxy/core/acceptor/work.py b/proxy/core/acceptor/work.py index 5a7ba0723d..37f0bf591b 100644 --- a/proxy/core/acceptor/work.py +++ b/proxy/core/acceptor/work.py @@ -16,11 +16,14 @@ from abc import ABC, abstractmethod from uuid import uuid4 -from typing import Optional, Dict, Any, TypeVar, Generic +from typing import Optional, Dict, Any, TypeVar, Generic, TYPE_CHECKING from ..event import eventNames, EventQueue from ...common.types import Readables, Writables +if TYPE_CHECKING: + from ..connection import UpstreamConnectionPool + T = TypeVar('T') @@ -33,6 +36,7 @@ def __init__( flags: argparse.Namespace, event_queue: Optional[EventQueue] = None, uid: Optional[str] = None, + upstream_conn_pool: Optional['UpstreamConnectionPool'] = None, ) -> None: # Work uuid self.uid: str = uid if uid is not None else uuid4().hex @@ -41,6 +45,7 @@ def __init__( self.event_queue = event_queue # Accept work self.work = work + self.upstream_conn_pool = upstream_conn_pool @abstractmethod async def get_events(self) -> Dict[int, int]: diff --git a/proxy/core/connection/__init__.py b/proxy/core/connection/__init__.py index 952ee08f9e..58d100a81b 100644 --- a/proxy/core/connection/__init__.py +++ b/proxy/core/connection/__init__.py @@ -16,7 +16,7 @@ from .connection import TcpConnection, TcpConnectionUninitializedException from .client import TcpClientConnection from .server import TcpServerConnection -from .pool import ConnectionPool +from .pool import UpstreamConnectionPool from .types import tcpConnectionTypes __all__ = [ @@ -25,5 +25,5 @@ 'TcpServerConnection', 'TcpClientConnection', 'tcpConnectionTypes', - 'ConnectionPool', + 'UpstreamConnectionPool', ] diff --git a/proxy/core/connection/pool.py b/proxy/core/connection/pool.py index 16cd5096b1..5f92066b9d 100644 --- a/proxy/core/connection/pool.py +++ b/proxy/core/connection/pool.py @@ -17,6 +17,9 @@ from typing import Set, Dict, Tuple from ...common.flag import flags +from ...common.types import Readables, Writables + +from ..acceptor.work import Work from .server import TcpServerConnection @@ -31,10 +34,10 @@ ) -class ConnectionPool: +class UpstreamConnectionPool(Work[TcpServerConnection]): """Manages connection pool to upstream servers. - `ConnectionPool` avoids need to reconnect with the upstream + `UpstreamConnectionPool` avoids need to reconnect with the upstream servers repeatedly when a reusable connection is available in the pool. @@ -47,16 +50,16 @@ class ConnectionPool: the pool users. Example, if acquired connection is stale, reacquire. - TODO: Ideally, ConnectionPool must be shared across + TODO: Ideally, `UpstreamConnectionPool` must be shared across all cores to make SSL session cache to also work without additional out-of-bound synchronizations. - TODO: ConnectionPool currently WON'T work for + TODO: `UpstreamConnectionPool` currently WON'T work for HTTPS connection. This is because of missing support for session cache, session ticket, abbr TLS handshake and other necessary features to make it work. - NOTE: However, for all HTTP only connections, ConnectionPool + NOTE: However, for all HTTP only connections, `UpstreamConnectionPool` can be used to save upon connection setup time and speed-up performance of requests. """ @@ -113,3 +116,9 @@ def release(self, conn: TcpServerConnection) -> None: assert not conn.is_reusable() # Reset for reusability conn.reset() + + async def get_events(self) -> Dict[int, int]: + return await super().get_events() + + async def handle_events(self, readables: Readables, writables: Writables) -> bool: + return await super().handle_events(readables, writables) diff --git a/proxy/http/handler.py b/proxy/http/handler.py index 38429df7de..ae6c0d66ca 100644 --- a/proxy/http/handler.py +++ b/proxy/http/handler.py @@ -100,6 +100,7 @@ def initialize(self) -> None: self.work, self.request, self.event_queue, + self.upstream_conn_pool, ) self.plugins[instance.name()] = instance logger.debug('Handling connection %r' % self.work.connection) diff --git a/proxy/http/plugin.py b/proxy/http/plugin.py index eafcd0539d..0180e5f9d9 100644 --- a/proxy/http/plugin.py +++ b/proxy/http/plugin.py @@ -12,7 +12,7 @@ import argparse from abc import ABC, abstractmethod -from typing import Tuple, List, Union, Optional +from typing import Tuple, List, Union, Optional, TYPE_CHECKING from .parser import HttpParser @@ -20,6 +20,9 @@ from ..core.event import EventQueue from ..core.connection import TcpClientConnection +if TYPE_CHECKING: + from ..core.connection import UpstreamConnectionPool + class HttpProtocolHandlerPlugin(ABC): """Base HttpProtocolHandler Plugin class. @@ -50,12 +53,14 @@ def __init__( client: TcpClientConnection, request: HttpParser, event_queue: EventQueue, + upstream_conn_pool: Optional['UpstreamConnectionPool'] = None, ): self.uid: str = uid self.flags: argparse.Namespace = flags self.client: TcpClientConnection = client self.request: HttpParser = request self.event_queue = event_queue + self.upstream_conn_pool = upstream_conn_pool super().__init__() def name(self) -> str: diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index 55604abf63..228d8ecc34 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -44,7 +44,7 @@ from ...common.pki import gen_public_key, gen_csr, sign_csr from ...core.event import eventNames -from ...core.connection import TcpServerConnection, ConnectionPool +from ...core.connection import TcpServerConnection from ...core.connection import TcpConnectionUninitializedException from ...common.flag import flags @@ -140,9 +140,6 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin): # connection pool operations. lock = threading.Lock() - # Shared connection pool - pool = ConnectionPool() - def __init__( self, *args: Any, **kwargs: Any, @@ -200,10 +197,10 @@ def get_descriptors(self) -> Tuple[List[int], List[int]]: def _close_and_release(self) -> bool: if self.flags.enable_conn_pool: - assert self.upstream and not self.upstream.closed + assert self.upstream and not self.upstream.closed and self.upstream_conn_pool self.upstream.closed = True with self.lock: - self.pool.release(self.upstream) + self.upstream_conn_pool.release(self.upstream) self.upstream = None return True @@ -391,9 +388,10 @@ def on_client_connection_close(self) -> None: return if self.flags.enable_conn_pool: + assert self.upstream_conn_pool # Release the connection for reusability with self.lock: - self.pool.release(self.upstream) + self.upstream_conn_pool.release(self.upstream) return try: @@ -589,8 +587,9 @@ def connect_upstream(self) -> None: host, port = self.request.host, self.request.port if host and port: if self.flags.enable_conn_pool: + assert self.upstream_conn_pool with self.lock: - created, self.upstream = self.pool.acquire( + created, self.upstream = self.upstream_conn_pool.acquire( text_(host), port, ) else: @@ -642,8 +641,9 @@ def connect_upstream(self) -> None: ), ) if self.flags.enable_conn_pool: + assert self.upstream_conn_pool with self.lock: - self.pool.release(self.upstream) + self.upstream_conn_pool.release(self.upstream) raise ProxyConnectionFailed( text_(host), port, repr(e), ) from e diff --git a/proxy/plugin/proxy_pool.py b/proxy/plugin/proxy_pool.py index cfc8017820..641d95d622 100644 --- a/proxy/plugin/proxy_pool.py +++ b/proxy/plugin/proxy_pool.py @@ -88,7 +88,7 @@ def before_upstream_connection( must be bootstrapped within it's own re-usable and garbage collected pool, to avoid establishing a new upstream proxy connection for each client request. - See :class:`~proxy.core.connection.pool.ConnectionPool` which is a work + See :class:`~proxy.core.connection.pool.UpstreamConnectionPool` which is a work in progress for SSL cache handling. """ # We don't want to send private IP requests to remote proxies diff --git a/tests/core/test_acceptor.py b/tests/core/test_acceptor.py index 2a4d089898..89bbce46aa 100644 --- a/tests/core/test_acceptor.py +++ b/tests/core/test_acceptor.py @@ -101,6 +101,7 @@ def test_accepts_client_from_server_socket( mock_client.return_value, flags=self.flags, event_queue=None, + upstream_conn_pool=None, ) mock_thread.assert_called_with( target=self.flags.work_klass.return_value.run, diff --git a/tests/core/test_conn_pool.py b/tests/core/test_conn_pool.py index 3eaad052f3..db3de3d7c0 100644 --- a/tests/core/test_conn_pool.py +++ b/tests/core/test_conn_pool.py @@ -12,14 +12,14 @@ from unittest import mock -from proxy.core.connection import ConnectionPool +from proxy.core.connection import UpstreamConnectionPool class TestConnectionPool(unittest.TestCase): @mock.patch('proxy.core.connection.pool.TcpServerConnection') def test_acquire_and_release_and_reacquire(self, mock_tcp_server_connection: mock.Mock) -> None: - pool = ConnectionPool() + pool = UpstreamConnectionPool() addr = ('localhost', 1234) # Mock mock_conn = mock_tcp_server_connection.return_value @@ -50,7 +50,7 @@ def test_acquire_and_release_and_reacquire(self, mock_tcp_server_connection: moc def test_closed_connections_are_removed_on_release( self, mock_tcp_server_connection: mock.Mock, ) -> None: - pool = ConnectionPool() + pool = UpstreamConnectionPool() addr = ('localhost', 1234) # Mock mock_conn = mock_tcp_server_connection.return_value