Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@
(_py_class_role, 'unittest.case.TestCase'),
(_py_class_role, 'unittest.result.TestResult'),
(_py_class_role, 'UUID'),
(_py_class_role, 'UpstreamConnectionPool'),
(_py_class_role, 'Url'),
(_py_class_role, 'WebsocketFrame'),
(_py_class_role, 'Work'),
Expand Down
1 change: 1 addition & 0 deletions proxy/core/acceptor/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def start_threaded_work(
TcpClientConnection(conn, addr),
flags=flags,
event_queue=event_queue,
upstream_conn_pool=None,
)
# TODO: Keep reference to threads and join during shutdown.
# This will ensure connections are not abruptly closed on shutdown
Expand Down
6 changes: 5 additions & 1 deletion proxy/core/acceptor/threadless.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ...common.constants import DEFAULT_INACTIVE_CONN_CLEANUP_TIMEOUT, DEFAULT_SELECTOR_SELECT_TIMEOUT
from ...common.constants import DEFAULT_WAIT_FOR_TASKS_TIMEOUT

from ..connection import TcpClientConnection
from ..connection import TcpClientConnection, UpstreamConnectionPool
from ..event import eventNames, EventQueue

from .work import Work
Expand Down Expand Up @@ -87,6 +87,9 @@ def __init__(
self.wait_timeout: float = DEFAULT_WAIT_FOR_TASKS_TIMEOUT
self.cleanup_inactive_timeout: float = DEFAULT_INACTIVE_CONN_CLEANUP_TIMEOUT
self._total: int = 0
self._upstream_conn_pool: Optional[UpstreamConnectionPool] = None
if self.flags.enable_conn_pool:
self._upstream_conn_pool = UpstreamConnectionPool()

@property
@abstractmethod
Expand Down Expand Up @@ -134,6 +137,7 @@ def work_on_tcp_conn(
flags=self.flags,
event_queue=self.event_queue,
uid=uid,
upstream_conn_pool=self._upstream_conn_pool,
)
self.works[fileno].publish_event(
event_name=eventNames.WORK_STARTED,
Expand Down
7 changes: 6 additions & 1 deletion proxy/core/acceptor/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@

from abc import ABC, abstractmethod
from uuid import uuid4
from typing import Optional, Dict, Any, TypeVar, Generic
from typing import Optional, Dict, Any, TypeVar, Generic, TYPE_CHECKING

from ..event import eventNames, EventQueue
from ...common.types import Readables, Writables

if TYPE_CHECKING:
from ..connection import UpstreamConnectionPool

T = TypeVar('T')


Expand All @@ -33,6 +36,7 @@ def __init__(
flags: argparse.Namespace,
event_queue: Optional[EventQueue] = None,
uid: Optional[str] = None,
upstream_conn_pool: Optional['UpstreamConnectionPool'] = None,
) -> None:
# Work uuid
self.uid: str = uid if uid is not None else uuid4().hex
Expand All @@ -41,6 +45,7 @@ def __init__(
self.event_queue = event_queue
# Accept work
self.work = work
self.upstream_conn_pool = upstream_conn_pool

@abstractmethod
async def get_events(self) -> Dict[int, int]:
Expand Down
4 changes: 2 additions & 2 deletions proxy/core/connection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .connection import TcpConnection, TcpConnectionUninitializedException
from .client import TcpClientConnection
from .server import TcpServerConnection
from .pool import ConnectionPool
from .pool import UpstreamConnectionPool
from .types import tcpConnectionTypes

__all__ = [
Expand All @@ -25,5 +25,5 @@
'TcpServerConnection',
'TcpClientConnection',
'tcpConnectionTypes',
'ConnectionPool',
'UpstreamConnectionPool',
]
19 changes: 14 additions & 5 deletions proxy/core/connection/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from typing import Set, Dict, Tuple

from ...common.flag import flags
from ...common.types import Readables, Writables

from ..acceptor.work import Work

from .server import TcpServerConnection

Expand All @@ -31,10 +34,10 @@
)


class ConnectionPool:
class UpstreamConnectionPool(Work[TcpServerConnection]):
"""Manages connection pool to upstream servers.

`ConnectionPool` avoids need to reconnect with the upstream
`UpstreamConnectionPool` avoids need to reconnect with the upstream
servers repeatedly when a reusable connection is available
in the pool.

Expand All @@ -47,16 +50,16 @@ class ConnectionPool:
the pool users. Example, if acquired connection
is stale, reacquire.

TODO: Ideally, ConnectionPool must be shared across
TODO: Ideally, `UpstreamConnectionPool` must be shared across
all cores to make SSL session cache to also work
without additional out-of-bound synchronizations.

TODO: ConnectionPool currently WON'T work for
TODO: `UpstreamConnectionPool` currently WON'T work for
HTTPS connection. This is because of missing support for
session cache, session ticket, abbr TLS handshake
and other necessary features to make it work.

NOTE: However, for all HTTP only connections, ConnectionPool
NOTE: However, for all HTTP only connections, `UpstreamConnectionPool`
can be used to save upon connection setup time and
speed-up performance of requests.
"""
Expand Down Expand Up @@ -113,3 +116,9 @@ def release(self, conn: TcpServerConnection) -> None:
assert not conn.is_reusable()
# Reset for reusability
conn.reset()

async def get_events(self) -> Dict[int, int]:
return await super().get_events()

async def handle_events(self, readables: Readables, writables: Writables) -> bool:
return await super().handle_events(readables, writables)
1 change: 1 addition & 0 deletions proxy/http/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def initialize(self) -> None:
self.work,
self.request,
self.event_queue,
self.upstream_conn_pool,
)
self.plugins[instance.name()] = instance
logger.debug('Handling connection %r' % self.work.connection)
Expand Down
7 changes: 6 additions & 1 deletion proxy/http/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@
import argparse

from abc import ABC, abstractmethod
from typing import Tuple, List, Union, Optional
from typing import Tuple, List, Union, Optional, TYPE_CHECKING

from .parser import HttpParser

from ..common.types import Readables, Writables
from ..core.event import EventQueue
from ..core.connection import TcpClientConnection

if TYPE_CHECKING:
from ..core.connection import UpstreamConnectionPool


class HttpProtocolHandlerPlugin(ABC):
"""Base HttpProtocolHandler Plugin class.
Expand Down Expand Up @@ -50,12 +53,14 @@ def __init__(
client: TcpClientConnection,
request: HttpParser,
event_queue: EventQueue,
upstream_conn_pool: Optional['UpstreamConnectionPool'] = None,
):
self.uid: str = uid
self.flags: argparse.Namespace = flags
self.client: TcpClientConnection = client
self.request: HttpParser = request
self.event_queue = event_queue
self.upstream_conn_pool = upstream_conn_pool
super().__init__()

def name(self) -> str:
Expand Down
18 changes: 9 additions & 9 deletions proxy/http/proxy/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from ...common.pki import gen_public_key, gen_csr, sign_csr

from ...core.event import eventNames
from ...core.connection import TcpServerConnection, ConnectionPool
from ...core.connection import TcpServerConnection
from ...core.connection import TcpConnectionUninitializedException
from ...common.flag import flags

Expand Down Expand Up @@ -140,9 +140,6 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
# connection pool operations.
lock = threading.Lock()

# Shared connection pool
pool = ConnectionPool()

def __init__(
self,
*args: Any, **kwargs: Any,
Expand Down Expand Up @@ -200,10 +197,10 @@ def get_descriptors(self) -> Tuple[List[int], List[int]]:

def _close_and_release(self) -> bool:
if self.flags.enable_conn_pool:
assert self.upstream and not self.upstream.closed
assert self.upstream and not self.upstream.closed and self.upstream_conn_pool
self.upstream.closed = True
with self.lock:
self.pool.release(self.upstream)
self.upstream_conn_pool.release(self.upstream)
self.upstream = None
return True

Expand Down Expand Up @@ -391,9 +388,10 @@ def on_client_connection_close(self) -> None:
return

if self.flags.enable_conn_pool:
assert self.upstream_conn_pool
# Release the connection for reusability
with self.lock:
self.pool.release(self.upstream)
self.upstream_conn_pool.release(self.upstream)
return

try:
Expand Down Expand Up @@ -589,8 +587,9 @@ def connect_upstream(self) -> None:
host, port = self.request.host, self.request.port
if host and port:
if self.flags.enable_conn_pool:
assert self.upstream_conn_pool
with self.lock:
created, self.upstream = self.pool.acquire(
created, self.upstream = self.upstream_conn_pool.acquire(
text_(host), port,
)
else:
Expand Down Expand Up @@ -642,8 +641,9 @@ def connect_upstream(self) -> None:
),
)
if self.flags.enable_conn_pool:
assert self.upstream_conn_pool
with self.lock:
self.pool.release(self.upstream)
self.upstream_conn_pool.release(self.upstream)
raise ProxyConnectionFailed(
text_(host), port, repr(e),
) from e
Expand Down
2 changes: 1 addition & 1 deletion proxy/plugin/proxy_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def before_upstream_connection(
must be bootstrapped within it's own re-usable and garbage collected pool,
to avoid establishing a new upstream proxy connection for each client request.

See :class:`~proxy.core.connection.pool.ConnectionPool` which is a work
See :class:`~proxy.core.connection.pool.UpstreamConnectionPool` which is a work
in progress for SSL cache handling.
"""
# We don't want to send private IP requests to remote proxies
Expand Down
1 change: 1 addition & 0 deletions tests/core/test_acceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_accepts_client_from_server_socket(
mock_client.return_value,
flags=self.flags,
event_queue=None,
upstream_conn_pool=None,
)
mock_thread.assert_called_with(
target=self.flags.work_klass.return_value.run,
Expand Down
6 changes: 3 additions & 3 deletions tests/core/test_conn_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@

from unittest import mock

from proxy.core.connection import ConnectionPool
from proxy.core.connection import UpstreamConnectionPool


class TestConnectionPool(unittest.TestCase):

@mock.patch('proxy.core.connection.pool.TcpServerConnection')
def test_acquire_and_release_and_reacquire(self, mock_tcp_server_connection: mock.Mock) -> None:
pool = ConnectionPool()
pool = UpstreamConnectionPool()
addr = ('localhost', 1234)
# Mock
mock_conn = mock_tcp_server_connection.return_value
Expand Down Expand Up @@ -50,7 +50,7 @@ def test_acquire_and_release_and_reacquire(self, mock_tcp_server_connection: moc
def test_closed_connections_are_removed_on_release(
self, mock_tcp_server_connection: mock.Mock,
) -> None:
pool = ConnectionPool()
pool = UpstreamConnectionPool()
addr = ('localhost', 1234)
# Mock
mock_conn = mock_tcp_server_connection.return_value
Expand Down