Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions proxy/core/connection/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
reusability
"""
import logging
import selectors

from typing import Set, Dict, Tuple
from typing import TYPE_CHECKING, Set, Dict, Tuple

from ...common.flag import flags
from ...common.types import Readables, Writables
Expand Down Expand Up @@ -66,30 +67,36 @@ class UpstreamConnectionPool(Work[TcpServerConnection]):

def __init__(self) -> None:
# Pools of connection per upstream server
self.connections: Dict[int, TcpServerConnection] = {}
self.pools: Dict[Tuple[str, int], Set[TcpServerConnection]] = {}

def acquire(self, host: str, port: int) -> Tuple[bool, TcpServerConnection]:
def add(self, addr: Tuple[str, int]) -> TcpServerConnection:
# Create new connection
new_conn = TcpServerConnection(addr[0], addr[1])
new_conn.connect()
if addr not in self.pools:
self.pools[addr] = set()
self.pools[addr].add(new_conn)
self.connections[new_conn.connection.fileno()] = new_conn
return new_conn

def acquire(self, addr: Tuple[str, int]) -> Tuple[bool, TcpServerConnection]:
"""Returns a connection for use with the server."""
addr = (host, port)
# Return a reusable connection if available
if addr in self.pools:
for old_conn in self.pools[addr]:
if old_conn.is_reusable():
old_conn.mark_inuse()
logger.debug(
'Reusing connection#{2} for upstream {0}:{1}'.format(
host, port, id(old_conn),
addr[0], addr[1], id(old_conn),
),
)
return False, old_conn
# Create new connection
new_conn = TcpServerConnection(*addr)
if addr not in self.pools:
self.pools[addr] = set()
self.pools[addr].add(new_conn)
new_conn = self.add(addr)
logger.debug(
'Created new connection#{2} for upstream {0}:{1}'.format(
host, port, id(new_conn),
addr[0], addr[1], id(new_conn),
),
)
return True, new_conn
Expand Down Expand Up @@ -118,7 +125,17 @@ def release(self, conn: TcpServerConnection) -> None:
conn.reset()

async def get_events(self) -> Dict[int, int]:
return await super().get_events()

async def handle_events(self, readables: Readables, writables: Writables) -> bool:
return await super().handle_events(readables, writables)
events = {}
for connections in self.pools.values():
for conn in connections:
events[conn.connection.fileno()] = selectors.EVENT_READ
return events

async def handle_events(self, readables: Readables, _writables: Writables) -> bool:
for r in readables:
if TYPE_CHECKING:
assert isinstance(r, int)
conn = self.connections[r]
self.pools[conn.addr].remove(conn)
del self.connections[r]
return False
2 changes: 1 addition & 1 deletion proxy/core/connection/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class TcpServerConnection(TcpConnection):
def __init__(self, host: str, port: int) -> None:
super().__init__(tcpConnectionTypes.SERVER)
self._conn: Optional[Union[ssl.SSLSocket, socket.socket]] = None
self.addr: Tuple[str, int] = (host, int(port))
self.addr: Tuple[str, int] = (host, port)
self.closed = True

@property
Expand Down
64 changes: 32 additions & 32 deletions proxy/http/proxy/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,29 +586,6 @@ def handle_pipeline_response(self, raw: memoryview) -> None:
def connect_upstream(self) -> None:
host, port = self.request.host, self.request.port
if host and port:
if self.flags.enable_conn_pool:
assert self.upstream_conn_pool
with self.lock:
created, self.upstream = self.upstream_conn_pool.acquire(
text_(host), port,
)
else:
created, self.upstream = True, TcpServerConnection(
text_(host), port,
)
if not created:
# NOTE: Acquired connection might be in an unusable state.
#
# This can only be confirmed by reading from connection.
# For stale connections, we will receive None, indicating
# to drop the connection.
#
# If that happen, we must acquire a fresh connection.
logger.info(
'Reusing connection to upstream %s:%d' %
(text_(host), port),
)
return
try:
logger.debug(
'Connecting to upstream %s:%d' %
Expand All @@ -622,14 +599,37 @@ def connect_upstream(self) -> None:
)
if upstream_ip or source_addr:
break
# Connect with overridden upstream IP and source address
# if any of the plugin returned a non-null value.
self.upstream.connect(
addr=None if not upstream_ip else (
upstream_ip, port,
), source_address=source_addr,
)
self.upstream.connection.setblocking(False)
if self.flags.enable_conn_pool:
assert self.upstream_conn_pool
with self.lock:
created, self.upstream = self.upstream_conn_pool.acquire(
(text_(host), port),
)
else:
created, self.upstream = True, TcpServerConnection(
text_(host), port,
)
# Connect with overridden upstream IP and source address
# if any of the plugin returned a non-null value.
self.upstream.connect(
addr=None if not upstream_ip else (
upstream_ip, port,
), source_address=source_addr,
)
self.upstream.connection.setblocking(False)
if not created:
# NOTE: Acquired connection might be in an unusable state.
#
# This can only be confirmed by reading from connection.
# For stale connections, we will receive None, indicating
# to drop the connection.
#
# If that happen, we must acquire a fresh connection.
logger.info(
'Reusing connection to upstream %s:%d' %
(text_(host), port),
)
return
logger.debug(
'Connected to upstream %s:%s' %
(text_(host), port),
Expand All @@ -640,7 +640,7 @@ def connect_upstream(self) -> None:
text_(host), port, str(e),
),
)
if self.flags.enable_conn_pool:
if self.flags.enable_conn_pool and self.upstream:
assert self.upstream_conn_pool
with self.lock:
self.upstream_conn_pool.release(self.upstream)
Expand Down
49 changes: 45 additions & 4 deletions tests/core/test_conn_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import pytest
import unittest
import selectors

from unittest import mock
from pytest_mock import MockerFixture

from proxy.core.connection import UpstreamConnectionPool

Expand All @@ -28,7 +31,7 @@ def test_acquire_and_release_and_reacquire(self, mock_tcp_server_connection: moc
]
mock_conn.closed = False
# Acquire
created, conn = pool.acquire(*addr)
created, conn = pool.acquire(addr)
self.assertTrue(created)
mock_tcp_server_connection.assert_called_once_with(*addr)
self.assertEqual(conn, mock_conn)
Expand All @@ -39,7 +42,7 @@ def test_acquire_and_release_and_reacquire(self, mock_tcp_server_connection: moc
self.assertEqual(len(pool.pools[addr]), 1)
self.assertTrue(conn in pool.pools[addr])
# Reacquire
created, conn = pool.acquire(*addr)
created, conn = pool.acquire(addr)
self.assertFalse(created)
mock_conn.reset.assert_called_once()
self.assertEqual(conn, mock_conn)
Expand All @@ -57,7 +60,7 @@ def test_closed_connections_are_removed_on_release(
mock_conn.closed = True
mock_conn.addr = addr
# Acquire
created, conn = pool.acquire(*addr)
created, conn = pool.acquire(addr)
self.assertTrue(created)
mock_tcp_server_connection.assert_called_once_with(*addr)
self.assertEqual(conn, mock_conn)
Expand All @@ -67,7 +70,45 @@ def test_closed_connections_are_removed_on_release(
pool.release(conn)
self.assertEqual(len(pool.pools[addr]), 0)
# Acquire
created, conn = pool.acquire(*addr)
created, conn = pool.acquire(addr)
self.assertTrue(created)
self.assertEqual(mock_tcp_server_connection.call_count, 2)
mock_conn.is_reusable.assert_not_called()


class TestConnectionPoolAsync:

@pytest.mark.asyncio # type: ignore[misc]
async def test_get_events(self, mocker: MockerFixture) -> None:
mock_tcp_server_connection = mocker.patch(
'proxy.core.connection.pool.TcpServerConnection',
)
pool = UpstreamConnectionPool()
addr = ('localhost', 1234)
mock_conn = mock_tcp_server_connection.return_value
pool.add(addr)
mock_tcp_server_connection.assert_called_once_with(*addr)
mock_conn.connect.assert_called_once()
events = await pool.get_events()
print(events)
assert events == {
mock_conn.connection.fileno.return_value: selectors.EVENT_READ,
}
assert pool.pools[addr].pop() == mock_conn
assert len(pool.pools[addr]) == 0
assert pool.connections[mock_conn.connection.fileno.return_value] == mock_conn

@pytest.mark.asyncio # type: ignore[misc]
async def test_handle_events(self, mocker: MockerFixture) -> None:
mock_tcp_server_connection = mocker.patch(
'proxy.core.connection.pool.TcpServerConnection',
)
pool = UpstreamConnectionPool()
mock_conn = mock_tcp_server_connection.return_value
addr = mock_conn.addr
pool.add(addr)
assert len(pool.pools[addr]) == 1
assert len(pool.connections) == 1
await pool.handle_events([mock_conn.connection.fileno.return_value], [])
assert len(pool.pools[addr]) == 0
assert len(pool.connections) == 0