From 21cfd85a070978137da04abbff48b993ef707ae4 Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Tue, 28 Dec 2021 23:33:36 +0530 Subject: [PATCH 1/5] Define work lifecycle events for pool --- proxy/core/connection/pool.py | 45 +++++++++++++++-------- proxy/core/connection/server.py | 2 +- proxy/http/proxy/server.py | 64 ++++++++++++++++----------------- tests/core/test_conn_pool.py | 44 ++++++++++++++++++++--- 4 files changed, 104 insertions(+), 51 deletions(-) diff --git a/proxy/core/connection/pool.py b/proxy/core/connection/pool.py index 5f92066b9d..3a9067b27c 100644 --- a/proxy/core/connection/pool.py +++ b/proxy/core/connection/pool.py @@ -13,8 +13,9 @@ reusability """ import logging +import selectors -from typing import Set, Dict, Tuple +from typing import TYPE_CHECKING, Set, Dict, Tuple from ...common.flag import flags from ...common.types import Readables, Writables @@ -66,11 +67,21 @@ class UpstreamConnectionPool(Work[TcpServerConnection]): def __init__(self) -> None: # Pools of connection per upstream server + self.connections: Dict[int, TcpServerConnection] = {} self.pools: Dict[Tuple[str, int], Set[TcpServerConnection]] = {} - def acquire(self, host: str, port: int) -> Tuple[bool, TcpServerConnection]: + def add(self, addr: Tuple[str, int]) -> TcpServerConnection: + # Create new connection + new_conn = TcpServerConnection(*addr) + new_conn.connect() + if addr not in self.pools: + self.pools[addr] = set() + self.pools[addr].add(new_conn) + self.connections[new_conn.connection.fileno()] = new_conn + return new_conn + + def acquire(self, addr: Tuple[str, int]) -> Tuple[bool, TcpServerConnection]: """Returns a connection for use with the server.""" - addr = (host, port) # Return a reusable connection if available if addr in self.pools: for old_conn in self.pools[addr]: @@ -78,18 +89,14 @@ def acquire(self, host: str, port: int) -> Tuple[bool, TcpServerConnection]: old_conn.mark_inuse() logger.debug( 'Reusing connection#{2} for upstream {0}:{1}'.format( - host, port, id(old_conn), + addr[0], addr[1], id(old_conn), ), ) return False, old_conn - # Create new connection - new_conn = TcpServerConnection(*addr) - if addr not in self.pools: - self.pools[addr] = set() - self.pools[addr].add(new_conn) + new_conn = self.add(addr) logger.debug( 'Created new connection#{2} for upstream {0}:{1}'.format( - host, port, id(new_conn), + addr[0], addr[1], id(new_conn), ), ) return True, new_conn @@ -118,7 +125,17 @@ def release(self, conn: TcpServerConnection) -> None: conn.reset() async def get_events(self) -> Dict[int, int]: - return await super().get_events() - - async def handle_events(self, readables: Readables, writables: Writables) -> bool: - return await super().handle_events(readables, writables) + events = {} + for connections in self.pools.values(): + for conn in connections: + events[conn.connection.fileno()] = selectors.EVENT_READ + return events + + async def handle_events(self, readables: Readables, _writables: Writables) -> bool: + for r in readables: + if TYPE_CHECKING: + assert type(r) == int + conn = self.connections[r] + self.pools[conn.addr].remove(conn) + del self.connections[r] + return False diff --git a/proxy/core/connection/server.py b/proxy/core/connection/server.py index 7aae5371cc..2c2af73f99 100644 --- a/proxy/core/connection/server.py +++ b/proxy/core/connection/server.py @@ -25,7 +25,7 @@ class TcpServerConnection(TcpConnection): def __init__(self, host: str, port: int) -> None: super().__init__(tcpConnectionTypes.SERVER) self._conn: Optional[Union[ssl.SSLSocket, socket.socket]] = None - self.addr: Tuple[str, int] = (host, int(port)) + self.addr: Tuple[str, int] = (host, port) self.closed = True @property diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index 228d8ecc34..fdb74d7176 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -586,29 +586,6 @@ def handle_pipeline_response(self, raw: memoryview) -> None: def connect_upstream(self) -> None: host, port = self.request.host, self.request.port if host and port: - if self.flags.enable_conn_pool: - assert self.upstream_conn_pool - with self.lock: - created, self.upstream = self.upstream_conn_pool.acquire( - text_(host), port, - ) - else: - created, self.upstream = True, TcpServerConnection( - text_(host), port, - ) - if not created: - # NOTE: Acquired connection might be in an unusable state. - # - # This can only be confirmed by reading from connection. - # For stale connections, we will receive None, indicating - # to drop the connection. - # - # If that happen, we must acquire a fresh connection. - logger.info( - 'Reusing connection to upstream %s:%d' % - (text_(host), port), - ) - return try: logger.debug( 'Connecting to upstream %s:%d' % @@ -622,14 +599,37 @@ def connect_upstream(self) -> None: ) if upstream_ip or source_addr: break - # Connect with overridden upstream IP and source address - # if any of the plugin returned a non-null value. - self.upstream.connect( - addr=None if not upstream_ip else ( - upstream_ip, port, - ), source_address=source_addr, - ) - self.upstream.connection.setblocking(False) + if self.flags.enable_conn_pool: + assert self.upstream_conn_pool + with self.lock: + created, self.upstream = self.upstream_conn_pool.acquire( + (text_(host), port), + ) + else: + created, self.upstream = True, TcpServerConnection( + text_(host), port, + ) + # Connect with overridden upstream IP and source address + # if any of the plugin returned a non-null value. + self.upstream.connect( + addr=None if not upstream_ip else ( + upstream_ip, port, + ), source_address=source_addr, + ) + self.upstream.connection.setblocking(False) + if not created: + # NOTE: Acquired connection might be in an unusable state. + # + # This can only be confirmed by reading from connection. + # For stale connections, we will receive None, indicating + # to drop the connection. + # + # If that happen, we must acquire a fresh connection. + logger.info( + 'Reusing connection to upstream %s:%d' % + (text_(host), port), + ) + return logger.debug( 'Connected to upstream %s:%s' % (text_(host), port), @@ -640,7 +640,7 @@ def connect_upstream(self) -> None: text_(host), port, str(e), ), ) - if self.flags.enable_conn_pool: + if self.flags.enable_conn_pool and self.upstream: assert self.upstream_conn_pool with self.lock: self.upstream_conn_pool.release(self.upstream) diff --git a/tests/core/test_conn_pool.py b/tests/core/test_conn_pool.py index db3de3d7c0..e0eb87172b 100644 --- a/tests/core/test_conn_pool.py +++ b/tests/core/test_conn_pool.py @@ -8,7 +8,9 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ +import pytest import unittest +import selectors from unittest import mock @@ -28,7 +30,7 @@ def test_acquire_and_release_and_reacquire(self, mock_tcp_server_connection: moc ] mock_conn.closed = False # Acquire - created, conn = pool.acquire(*addr) + created, conn = pool.acquire(addr) self.assertTrue(created) mock_tcp_server_connection.assert_called_once_with(*addr) self.assertEqual(conn, mock_conn) @@ -39,7 +41,7 @@ def test_acquire_and_release_and_reacquire(self, mock_tcp_server_connection: moc self.assertEqual(len(pool.pools[addr]), 1) self.assertTrue(conn in pool.pools[addr]) # Reacquire - created, conn = pool.acquire(*addr) + created, conn = pool.acquire(addr) self.assertFalse(created) mock_conn.reset.assert_called_once() self.assertEqual(conn, mock_conn) @@ -57,7 +59,7 @@ def test_closed_connections_are_removed_on_release( mock_conn.closed = True mock_conn.addr = addr # Acquire - created, conn = pool.acquire(*addr) + created, conn = pool.acquire(addr) self.assertTrue(created) mock_tcp_server_connection.assert_called_once_with(*addr) self.assertEqual(conn, mock_conn) @@ -67,7 +69,41 @@ def test_closed_connections_are_removed_on_release( pool.release(conn) self.assertEqual(len(pool.pools[addr]), 0) # Acquire - created, conn = pool.acquire(*addr) + created, conn = pool.acquire(addr) self.assertTrue(created) self.assertEqual(mock_tcp_server_connection.call_count, 2) mock_conn.is_reusable.assert_not_called() + + +class TestConnectionPoolAsync: + + @pytest.mark.asyncio # type: ignore[misc] + @mock.patch('proxy.core.connection.pool.TcpServerConnection') + async def test_get_events(self, mock_tcp_server_connection: mock.Mock) -> None: + pool = UpstreamConnectionPool() + addr = ('localhost', 1234) + mock_conn = mock_tcp_server_connection.return_value + pool.add(addr) + mock_tcp_server_connection.assert_called_once_with(*addr) + mock_conn.connect.assert_called_once() + events = await pool.get_events() + print(events) + assert events == { + mock_conn.connection.fileno.return_value: selectors.EVENT_READ + } + assert pool.pools[addr].pop() == mock_conn + assert len(pool.pools[addr]) == 0 + assert pool.connections[mock_conn.connection.fileno.return_value] == mock_conn + + @pytest.mark.asyncio # type: ignore[misc] + @mock.patch('proxy.core.connection.pool.TcpServerConnection') + async def test_handle_events(self, mock_tcp_server_connection: mock.Mock) -> None: + pool = UpstreamConnectionPool() + mock_conn = mock_tcp_server_connection.return_value + addr = mock_conn.addr + pool.add(addr) + assert len(pool.pools[addr]) == 1 + assert len(pool.connections) == 1 + await pool.handle_events([mock_conn.connection.fileno.return_value], []) + assert len(pool.pools[addr]) == 0 + assert len(pool.connections) == 0 From c7cd4666e6068c9a184dc7bacb0e69991df45c42 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Dec 2021 18:04:43 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/core/test_conn_pool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_conn_pool.py b/tests/core/test_conn_pool.py index e0eb87172b..25ca118fbf 100644 --- a/tests/core/test_conn_pool.py +++ b/tests/core/test_conn_pool.py @@ -89,7 +89,7 @@ async def test_get_events(self, mock_tcp_server_connection: mock.Mock) -> None: events = await pool.get_events() print(events) assert events == { - mock_conn.connection.fileno.return_value: selectors.EVENT_READ + mock_conn.connection.fileno.return_value: selectors.EVENT_READ, } assert pool.pools[addr].pop() == mock_conn assert len(pool.pools[addr]) == 0 From 2d70372dbba01d163858810602eb92bc0fe83ae5 Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Tue, 28 Dec 2021 23:51:18 +0530 Subject: [PATCH 3/5] Use isinstance --- proxy/core/connection/pool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/proxy/core/connection/pool.py b/proxy/core/connection/pool.py index 3a9067b27c..a9b0585a01 100644 --- a/proxy/core/connection/pool.py +++ b/proxy/core/connection/pool.py @@ -72,7 +72,7 @@ def __init__(self) -> None: def add(self, addr: Tuple[str, int]) -> TcpServerConnection: # Create new connection - new_conn = TcpServerConnection(*addr) + new_conn = TcpServerConnection(addr[0], addr[1]) new_conn.connect() if addr not in self.pools: self.pools[addr] = set() @@ -134,7 +134,7 @@ async def get_events(self) -> Dict[int, int]: async def handle_events(self, readables: Readables, _writables: Writables) -> bool: for r in readables: if TYPE_CHECKING: - assert type(r) == int + assert isinstance(r, int) conn = self.connections[r] self.pools[conn.addr].remove(conn) del self.connections[r] From 32dde0097f34475841769e3a0bea0b023993ffab Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Wed, 29 Dec 2021 00:14:02 +0530 Subject: [PATCH 4/5] Use mocker fixture to pass CI on 3.6 and 3.7 --- tests/core/test_conn_pool.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/core/test_conn_pool.py b/tests/core/test_conn_pool.py index 25ca118fbf..ef18b09c22 100644 --- a/tests/core/test_conn_pool.py +++ b/tests/core/test_conn_pool.py @@ -13,6 +13,7 @@ import selectors from unittest import mock +from pytest_mock import MockerFixture from proxy.core.connection import UpstreamConnectionPool @@ -78,8 +79,9 @@ def test_closed_connections_are_removed_on_release( class TestConnectionPoolAsync: @pytest.mark.asyncio # type: ignore[misc] - @mock.patch('proxy.core.connection.pool.TcpServerConnection') - async def test_get_events(self, mock_tcp_server_connection: mock.Mock) -> None: + async def test_get_events(self, mocker: MockerFixture) -> None: + mock_tcp_server_connection = mocker.patch( + 'proxy.core.connection.pool.TcpServerConnection') pool = UpstreamConnectionPool() addr = ('localhost', 1234) mock_conn = mock_tcp_server_connection.return_value @@ -96,8 +98,9 @@ async def test_get_events(self, mock_tcp_server_connection: mock.Mock) -> None: assert pool.connections[mock_conn.connection.fileno.return_value] == mock_conn @pytest.mark.asyncio # type: ignore[misc] - @mock.patch('proxy.core.connection.pool.TcpServerConnection') - async def test_handle_events(self, mock_tcp_server_connection: mock.Mock) -> None: + async def test_handle_events(self, mocker: MockerFixture) -> None: + mock_tcp_server_connection = mocker.patch( + 'proxy.core.connection.pool.TcpServerConnection') pool = UpstreamConnectionPool() mock_conn = mock_tcp_server_connection.return_value addr = mock_conn.addr From 4e58db0bc459c5d4fd2b200dd1b2196b8b7e1ebe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Dec 2021 18:44:51 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/core/test_conn_pool.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/core/test_conn_pool.py b/tests/core/test_conn_pool.py index ef18b09c22..e00436cd2f 100644 --- a/tests/core/test_conn_pool.py +++ b/tests/core/test_conn_pool.py @@ -81,7 +81,8 @@ class TestConnectionPoolAsync: @pytest.mark.asyncio # type: ignore[misc] async def test_get_events(self, mocker: MockerFixture) -> None: mock_tcp_server_connection = mocker.patch( - 'proxy.core.connection.pool.TcpServerConnection') + 'proxy.core.connection.pool.TcpServerConnection', + ) pool = UpstreamConnectionPool() addr = ('localhost', 1234) mock_conn = mock_tcp_server_connection.return_value @@ -100,7 +101,8 @@ async def test_get_events(self, mocker: MockerFixture) -> None: @pytest.mark.asyncio # type: ignore[misc] async def test_handle_events(self, mocker: MockerFixture) -> None: mock_tcp_server_connection = mocker.patch( - 'proxy.core.connection.pool.TcpServerConnection') + 'proxy.core.connection.pool.TcpServerConnection', + ) pool = UpstreamConnectionPool() mock_conn = mock_tcp_server_connection.return_value addr = mock_conn.addr