diff --git a/proxy/core/connection/pool.py b/proxy/core/connection/pool.py index 5f92066b9d..a9b0585a01 100644 --- a/proxy/core/connection/pool.py +++ b/proxy/core/connection/pool.py @@ -13,8 +13,9 @@ reusability """ import logging +import selectors -from typing import Set, Dict, Tuple +from typing import TYPE_CHECKING, Set, Dict, Tuple from ...common.flag import flags from ...common.types import Readables, Writables @@ -66,11 +67,21 @@ class UpstreamConnectionPool(Work[TcpServerConnection]): def __init__(self) -> None: # Pools of connection per upstream server + self.connections: Dict[int, TcpServerConnection] = {} self.pools: Dict[Tuple[str, int], Set[TcpServerConnection]] = {} - def acquire(self, host: str, port: int) -> Tuple[bool, TcpServerConnection]: + def add(self, addr: Tuple[str, int]) -> TcpServerConnection: + # Create new connection + new_conn = TcpServerConnection(addr[0], addr[1]) + new_conn.connect() + if addr not in self.pools: + self.pools[addr] = set() + self.pools[addr].add(new_conn) + self.connections[new_conn.connection.fileno()] = new_conn + return new_conn + + def acquire(self, addr: Tuple[str, int]) -> Tuple[bool, TcpServerConnection]: """Returns a connection for use with the server.""" - addr = (host, port) # Return a reusable connection if available if addr in self.pools: for old_conn in self.pools[addr]: @@ -78,18 +89,14 @@ def acquire(self, host: str, port: int) -> Tuple[bool, TcpServerConnection]: old_conn.mark_inuse() logger.debug( 'Reusing connection#{2} for upstream {0}:{1}'.format( - host, port, id(old_conn), + addr[0], addr[1], id(old_conn), ), ) return False, old_conn - # Create new connection - new_conn = TcpServerConnection(*addr) - if addr not in self.pools: - self.pools[addr] = set() - self.pools[addr].add(new_conn) + new_conn = self.add(addr) logger.debug( 'Created new connection#{2} for upstream {0}:{1}'.format( - host, port, id(new_conn), + addr[0], addr[1], id(new_conn), ), ) return True, new_conn @@ -118,7 +125,17 @@ def release(self, conn: TcpServerConnection) -> None: conn.reset() async def get_events(self) -> Dict[int, int]: - return await super().get_events() - - async def handle_events(self, readables: Readables, writables: Writables) -> bool: - return await super().handle_events(readables, writables) + events = {} + for connections in self.pools.values(): + for conn in connections: + events[conn.connection.fileno()] = selectors.EVENT_READ + return events + + async def handle_events(self, readables: Readables, _writables: Writables) -> bool: + for r in readables: + if TYPE_CHECKING: + assert isinstance(r, int) + conn = self.connections[r] + self.pools[conn.addr].remove(conn) + del self.connections[r] + return False diff --git a/proxy/core/connection/server.py b/proxy/core/connection/server.py index 7aae5371cc..2c2af73f99 100644 --- a/proxy/core/connection/server.py +++ b/proxy/core/connection/server.py @@ -25,7 +25,7 @@ class TcpServerConnection(TcpConnection): def __init__(self, host: str, port: int) -> None: super().__init__(tcpConnectionTypes.SERVER) self._conn: Optional[Union[ssl.SSLSocket, socket.socket]] = None - self.addr: Tuple[str, int] = (host, int(port)) + self.addr: Tuple[str, int] = (host, port) self.closed = True @property diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index 228d8ecc34..fdb74d7176 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -586,29 +586,6 @@ def handle_pipeline_response(self, raw: memoryview) -> None: def connect_upstream(self) -> None: host, port = self.request.host, self.request.port if host and port: - if self.flags.enable_conn_pool: - assert self.upstream_conn_pool - with self.lock: - created, self.upstream = self.upstream_conn_pool.acquire( - text_(host), port, - ) - else: - created, self.upstream = True, TcpServerConnection( - text_(host), port, - ) - if not created: - # NOTE: Acquired connection might be in an unusable state. - # - # This can only be confirmed by reading from connection. - # For stale connections, we will receive None, indicating - # to drop the connection. - # - # If that happen, we must acquire a fresh connection. - logger.info( - 'Reusing connection to upstream %s:%d' % - (text_(host), port), - ) - return try: logger.debug( 'Connecting to upstream %s:%d' % @@ -622,14 +599,37 @@ def connect_upstream(self) -> None: ) if upstream_ip or source_addr: break - # Connect with overridden upstream IP and source address - # if any of the plugin returned a non-null value. - self.upstream.connect( - addr=None if not upstream_ip else ( - upstream_ip, port, - ), source_address=source_addr, - ) - self.upstream.connection.setblocking(False) + if self.flags.enable_conn_pool: + assert self.upstream_conn_pool + with self.lock: + created, self.upstream = self.upstream_conn_pool.acquire( + (text_(host), port), + ) + else: + created, self.upstream = True, TcpServerConnection( + text_(host), port, + ) + # Connect with overridden upstream IP and source address + # if any of the plugin returned a non-null value. + self.upstream.connect( + addr=None if not upstream_ip else ( + upstream_ip, port, + ), source_address=source_addr, + ) + self.upstream.connection.setblocking(False) + if not created: + # NOTE: Acquired connection might be in an unusable state. + # + # This can only be confirmed by reading from connection. + # For stale connections, we will receive None, indicating + # to drop the connection. + # + # If that happen, we must acquire a fresh connection. + logger.info( + 'Reusing connection to upstream %s:%d' % + (text_(host), port), + ) + return logger.debug( 'Connected to upstream %s:%s' % (text_(host), port), @@ -640,7 +640,7 @@ def connect_upstream(self) -> None: text_(host), port, str(e), ), ) - if self.flags.enable_conn_pool: + if self.flags.enable_conn_pool and self.upstream: assert self.upstream_conn_pool with self.lock: self.upstream_conn_pool.release(self.upstream) diff --git a/tests/core/test_conn_pool.py b/tests/core/test_conn_pool.py index db3de3d7c0..e00436cd2f 100644 --- a/tests/core/test_conn_pool.py +++ b/tests/core/test_conn_pool.py @@ -8,9 +8,12 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ +import pytest import unittest +import selectors from unittest import mock +from pytest_mock import MockerFixture from proxy.core.connection import UpstreamConnectionPool @@ -28,7 +31,7 @@ def test_acquire_and_release_and_reacquire(self, mock_tcp_server_connection: moc ] mock_conn.closed = False # Acquire - created, conn = pool.acquire(*addr) + created, conn = pool.acquire(addr) self.assertTrue(created) mock_tcp_server_connection.assert_called_once_with(*addr) self.assertEqual(conn, mock_conn) @@ -39,7 +42,7 @@ def test_acquire_and_release_and_reacquire(self, mock_tcp_server_connection: moc self.assertEqual(len(pool.pools[addr]), 1) self.assertTrue(conn in pool.pools[addr]) # Reacquire - created, conn = pool.acquire(*addr) + created, conn = pool.acquire(addr) self.assertFalse(created) mock_conn.reset.assert_called_once() self.assertEqual(conn, mock_conn) @@ -57,7 +60,7 @@ def test_closed_connections_are_removed_on_release( mock_conn.closed = True mock_conn.addr = addr # Acquire - created, conn = pool.acquire(*addr) + created, conn = pool.acquire(addr) self.assertTrue(created) mock_tcp_server_connection.assert_called_once_with(*addr) self.assertEqual(conn, mock_conn) @@ -67,7 +70,45 @@ def test_closed_connections_are_removed_on_release( pool.release(conn) self.assertEqual(len(pool.pools[addr]), 0) # Acquire - created, conn = pool.acquire(*addr) + created, conn = pool.acquire(addr) self.assertTrue(created) self.assertEqual(mock_tcp_server_connection.call_count, 2) mock_conn.is_reusable.assert_not_called() + + +class TestConnectionPoolAsync: + + @pytest.mark.asyncio # type: ignore[misc] + async def test_get_events(self, mocker: MockerFixture) -> None: + mock_tcp_server_connection = mocker.patch( + 'proxy.core.connection.pool.TcpServerConnection', + ) + pool = UpstreamConnectionPool() + addr = ('localhost', 1234) + mock_conn = mock_tcp_server_connection.return_value + pool.add(addr) + mock_tcp_server_connection.assert_called_once_with(*addr) + mock_conn.connect.assert_called_once() + events = await pool.get_events() + print(events) + assert events == { + mock_conn.connection.fileno.return_value: selectors.EVENT_READ, + } + assert pool.pools[addr].pop() == mock_conn + assert len(pool.pools[addr]) == 0 + assert pool.connections[mock_conn.connection.fileno.return_value] == mock_conn + + @pytest.mark.asyncio # type: ignore[misc] + async def test_handle_events(self, mocker: MockerFixture) -> None: + mock_tcp_server_connection = mocker.patch( + 'proxy.core.connection.pool.TcpServerConnection', + ) + pool = UpstreamConnectionPool() + mock_conn = mock_tcp_server_connection.return_value + addr = mock_conn.addr + pool.add(addr) + assert len(pool.pools[addr]) == 1 + assert len(pool.connections) == 1 + await pool.handle_events([mock_conn.connection.fileno.return_value], []) + assert len(pool.pools[addr]) == 0 + assert len(pool.connections) == 0