diff --git a/README.md b/README.md index 273bb02418..fcd25b1818 100644 --- a/README.md +++ b/README.md @@ -2284,7 +2284,7 @@ usage: -m [-h] [--tunnel-hostname TUNNEL_HOSTNAME] [--tunnel-port TUNNEL_PORT] [--ca-cert-file CA_CERT_FILE] [--ca-file CA_FILE] [--ca-signing-key-file CA_SIGNING_KEY_FILE] [--auth-plugin AUTH_PLUGIN] [--cache-dir CACHE_DIR] - [--proxy-pool PROXY_POOL] [--enable-web-server] + [--cache-requests] [--proxy-pool PROXY_POOL] [--enable-web-server] [--enable-static-server] [--static-server-dir STATIC_SERVER_DIR] [--min-compression-length MIN_COMPRESSION_LENGTH] [--enable-reverse-proxy] [--pac-file PAC_FILE] @@ -2425,6 +2425,7 @@ options: Default: /Users/abhinavsingh/.proxy/cache. Flag only applicable when cache plugin is used with on-disk storage. + --cache-requests Default: False. Whether to also cache request packets. --proxy-pool PROXY_POOL List of upstream proxies to use in the pool --enable-web-server Default: False. Whether to enable diff --git a/proxy/common/constants.py b/proxy/common/constants.py index 3ec1acb758..dfc630588e 100644 --- a/proxy/common/constants.py +++ b/proxy/common/constants.py @@ -148,6 +148,7 @@ def _env_threadless_compliant() -> bool: DEFAULT_CACHE_DIRECTORY_PATH = os.path.join( DEFAULT_DATA_DIRECTORY_PATH, 'cache', ) +DEFAULT_CACHE_REQUESTS = False # Cor plugins enabled by default or via flags DEFAULT_ABC_PLUGINS = [ diff --git a/proxy/http/handler.py b/proxy/http/handler.py index bea158afba..158bc5a1ed 100644 --- a/proxy/http/handler.py +++ b/proxy/http/handler.py @@ -163,9 +163,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]: self.work.closed = True return True try: - # Don't parse incoming data any further after 1st request has completed. - # - # This specially does happen for pipeline requests. + # We don't parse incoming data any further after 1st HTTP request packet. # # Plugins can utilize on_client_data for such cases and # apply custom logic to handle request data sent after 1st @@ -173,11 +171,10 @@ def handle_data(self, data: memoryview) -> Optional[bool]: if self.request.state != httpParserStates.COMPLETE: if self._parse_first_request(data): return True - else: - # HttpProtocolHandlerPlugin.on_client_data - # Can raise HttpProtocolException to tear down the connection - if self.plugin: - data = self.plugin.on_client_data(data) or data + # HttpProtocolHandlerPlugin.on_client_data + # Can raise HttpProtocolException to tear down the connection + elif self.plugin: + self.plugin.on_client_data(data) except HttpProtocolException as e: logger.info('HttpProtocolException: %s' % e) response: Optional[memoryview] = e.response(self.request) diff --git a/proxy/http/plugin.py b/proxy/http/plugin.py index ff37173773..754fb28024 100644 --- a/proxy/http/plugin.py +++ b/proxy/http/plugin.py @@ -71,9 +71,9 @@ def protocols() -> List[int]: raise NotImplementedError() @abstractmethod - def on_client_data(self, raw: memoryview) -> Optional[memoryview]: + def on_client_data(self, raw: memoryview) -> None: """Called only after original request has been completely received.""" - return raw # pragma: no cover + pass # pragma: no cover @abstractmethod def on_request_complete(self) -> Union[socket.socket, bool]: diff --git a/proxy/http/proxy/plugin.py b/proxy/http/proxy/plugin.py index a71e553b33..7768e3c59d 100644 --- a/proxy/http/proxy/plugin.py +++ b/proxy/http/proxy/plugin.py @@ -121,7 +121,6 @@ def handle_client_request( Return None to drop the request data, e.g. in case a response has already been queued. Raise HttpRequestRejected or HttpProtocolException directly to tear down the connection with client. - """ return request # pragma: no cover diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index 10932fde9e..dc9b2ed493 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -398,7 +398,7 @@ def on_response_chunk(self, chunk: List[memoryview]) -> List[memoryview]: return chunk # Can return None to tear down connection - def on_client_data(self, raw: memoryview) -> Optional[memoryview]: + def on_client_data(self, raw: memoryview) -> None: # For scenarios when an upstream connection was never established, # let plugin do whatever they wish to. These are special scenarios # where plugins are trying to do something magical. Within the core @@ -413,7 +413,7 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: for plugin in self.plugins.values(): o = plugin.handle_client_data(raw) if o is None: - return None + return raw = o elif self.upstream and not self.upstream.closed: # For http proxy requests, handle pipeline case. @@ -429,12 +429,17 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: # upgrade request. Incoming client data now # must be treated as WebSocket protocol packets. self.upstream.queue(raw) - return None + return if self.pipeline_request is None: # For pipeline requests, we never # want to use --enable-proxy-protocol flag # as proxy protocol header will not be present + # + # TODO: HTTP parser must be smart about detecting + # HA proxy protocol or we must always explicitly pass + # the flag when we are expecting HA proxy protocol + # request line before HTTP request lines. self.pipeline_request = HttpParser( httpParserTypes.REQUEST_PARSER, ) @@ -447,7 +452,7 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: assert self.pipeline_request is not None r = plugin.handle_client_request(self.pipeline_request) if r is None: - return None + return self.pipeline_request = r assert self.pipeline_request is not None # TODO(abhinavsingh): Remove memoryview wrapping here after @@ -463,8 +468,6 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: # simply queue for upstream server. else: self.upstream.queue(raw) - return None - return raw def on_request_complete(self) -> Union[socket.socket, bool]: self.emit_request_complete() diff --git a/proxy/http/server/web.py b/proxy/http/server/web.py index f6b0a12fc3..34ab4d3fe4 100644 --- a/proxy/http/server/web.py +++ b/proxy/http/server/web.py @@ -194,7 +194,7 @@ async def read_from_descriptors(self, r: Readables) -> bool: return True return False - def on_client_data(self, raw: memoryview) -> Optional[memoryview]: + def on_client_data(self, raw: memoryview) -> None: if self.switched_protocol == httpProtocolTypes.WEBSOCKET: # TODO(abhinavsingh): Remove .tobytes after websocket frame parser # is memoryview compliant @@ -211,7 +211,7 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: assert self.route self.route.on_websocket_message(frame) frame.reset() - return None + return # If 1st valid request was completed and it's a HTTP/1.1 keep-alive # And only if we have a route, parse pipeline requests if self.request.is_complete and \ @@ -231,7 +231,6 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: 'Pipelined request is not keep-alive, will tear down request...', ) self.pipeline_request = None - return raw def on_response_chunk(self, chunk: List[memoryview]) -> List[memoryview]: return chunk diff --git a/proxy/plugin/cache/cache_responses.py b/proxy/plugin/cache/cache_responses.py index cbc57ebc16..45bd3174de 100644 --- a/proxy/plugin/cache/cache_responses.py +++ b/proxy/plugin/cache/cache_responses.py @@ -30,6 +30,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.flags.cache_dir, 'responses', ), + cache_requests=self.flags.cache_requests, ) self.set_store(self.disk_store) diff --git a/proxy/plugin/cache/store/disk.py b/proxy/plugin/cache/store/disk.py index a9ac027ba4..da849b07d9 100644 --- a/proxy/plugin/cache/store/disk.py +++ b/proxy/plugin/cache/store/disk.py @@ -16,7 +16,9 @@ from ....common.flag import flags from ....http.parser import HttpParser from ....common.utils import text_ -from ....common.constants import DEFAULT_CACHE_DIRECTORY_PATH +from ....common.constants import ( + DEFAULT_CACHE_REQUESTS, DEFAULT_CACHE_DIRECTORY_PATH, +) logger = logging.getLogger(__name__) @@ -30,12 +32,21 @@ 'Flag only applicable when cache plugin is used with on-disk storage.', ) +flags.add_argument( + '--cache-requests', + action='store_true', + default=DEFAULT_CACHE_REQUESTS, + help='Default: False. ' + + 'Whether to also cache request packets.', +) + class OnDiskCacheStore(CacheStore): - def __init__(self, uid: str, cache_dir: str) -> None: + def __init__(self, uid: str, cache_dir: str, cache_requests: bool) -> None: super().__init__(uid) self.cache_dir = cache_dir + self.cache_requests = cache_requests self.cache_file_path: Optional[str] = None self.cache_file: Optional[BinaryIO] = None @@ -47,6 +58,8 @@ def open(self, request: HttpParser) -> None: self.cache_file = open(self.cache_file_path, "wb") def cache_request(self, request: HttpParser) -> Optional[HttpParser]: + if self.cache_file and self.cache_requests: + self.cache_file.write(request.build()) return request def cache_response_chunk(self, chunk: memoryview) -> memoryview: diff --git a/tests/http/proxy/test_http_proxy_tls_interception.py b/tests/http/proxy/test_http_proxy_tls_interception.py index e7e166f372..263265c77f 100644 --- a/tests/http/proxy/test_http_proxy_tls_interception.py +++ b/tests/http/proxy/test_http_proxy_tls_interception.py @@ -12,7 +12,7 @@ import uuid import socket import selectors -from typing import Any +from typing import Any, TypeVar import pytest from unittest import mock @@ -22,7 +22,10 @@ from proxy.http import HttpProtocolHandler, HttpClientConnection, httpMethods from proxy.http.proxy import HttpProxyPlugin from proxy.common.flag import FlagParser -from proxy.common.utils import bytes_, build_http_request +from proxy.http.parser import HttpParser +from proxy.common.utils import ( + bytes_, build_http_request, tls_interception_enabled, +) from proxy.http.responses import PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT from proxy.core.connection import TcpServerConnection from proxy.common.constants import DEFAULT_CA_FILE @@ -46,21 +49,25 @@ async def test_e2e(self, mocker: MockerFixture) -> None: self.mock_server_conn = mocker.patch( 'proxy.http.proxy.server.TcpServerConnection', ) - self.mock_ssl_context = mocker.patch('ssl.create_default_context') - self.mock_ssl_wrap = mocker.patch('ssl.wrap_socket') - self.mock_sign_csr.return_value = True self.mock_gen_csr.return_value = True self.mock_gen_public_key.return_value = True - ssl_connection = mock.MagicMock(spec=ssl.SSLSocket) - self.mock_ssl_context.return_value.wrap_socket.return_value = ssl_connection - self.mock_ssl_wrap.return_value = mock.MagicMock(spec=ssl.SSLSocket) + # Used for server side wrapping + self.mock_ssl_context = mocker.patch('ssl.create_default_context') + upstream_tls_sock = mock.MagicMock(spec=ssl.SSLSocket) + self.mock_ssl_context.return_value.wrap_socket.return_value = upstream_tls_sock + + # Used for client wrapping + self.mock_ssl_wrap = mocker.patch('ssl.wrap_socket') + client_tls_sock = mock.MagicMock(spec=ssl.SSLSocket) + self.mock_ssl_wrap.return_value = client_tls_sock + plain_connection = mock.MagicMock(spec=socket.socket) def mock_connection() -> Any: if self.mock_ssl_context.return_value.wrap_socket.called: - return ssl_connection + return upstream_tls_sock return plain_connection # Do not mock the original wrap method @@ -72,6 +79,9 @@ def mock_connection() -> Any: type(self.mock_server_conn.return_value).connection = \ mock.PropertyMock(side_effect=mock_connection) + type(self.mock_server_conn.return_value).closed = \ + mock.PropertyMock(return_value=False) + self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.flags = FlagParser.initialize( @@ -80,6 +90,11 @@ def mock_connection() -> Any: ca_signing_key_file='ca-signing-key.pem', threaded=True, ) + self.assertTrue(tls_interception_enabled(self.flags)) + # In this test we enable a mock http protocol handler plugin + # and a mock http proxy plugin. Internally, http protocol + # handler will only initialize proxy plugin as we'll never + # make any other request. self.plugin = mock.MagicMock() self.proxy_plugin = mock.MagicMock() self.flags.plugins = { @@ -96,36 +111,51 @@ def mock_connection() -> Any: self.plugin.assert_not_called() self.proxy_plugin.assert_not_called() + # Mock a CONNECT request followed by a GET request + # from client connection + headers = { + b'Host': bytes_(netloc), + } connect_request = build_http_request( httpMethods.CONNECT, bytes_(netloc), - headers={ - b'Host': bytes_(netloc), - }, + headers=headers, ) self._conn.recv.return_value = connect_request + get_request = build_http_request( + httpMethods.GET, b'/', + headers=headers, + ) + client_tls_sock.recv.return_value = get_request - async def asyncReturnBool(val: bool) -> bool: - return val + T = TypeVar('T') # noqa: N806 - # Prepare mocked HttpProtocolHandlerPlugin - # self.plugin.return_value.get_descriptors.return_value = ([], []) - # self.plugin.return_value.write_to_descriptors.return_value = asyncReturnBool(False) - # self.plugin.return_value.read_from_descriptors.return_value = asyncReturnBool(False) - # self.plugin.return_value.on_client_data.side_effect = lambda raw: raw - # self.plugin.return_value.on_request_complete.return_value = False - # self.plugin.return_value.on_response_chunk.side_effect = lambda chunk: chunk - # self.plugin.return_value.on_client_connection_close.return_value = None + async def asyncReturn(val: T) -> T: + return val # Prepare mocked HttpProxyBasePlugin - self.proxy_plugin.return_value.write_to_descriptors.return_value = \ - asyncReturnBool(False) - self.proxy_plugin.return_value.read_from_descriptors.return_value = \ - asyncReturnBool(False) + # 1. Mock descriptor mixin methods + # + # NOTE: We need multiple async result otherwise + # we will end up with cannot await on already + # awaited coroutine. + self.proxy_plugin.return_value.get_descriptors.side_effect = \ + [asyncReturn(([], [])), asyncReturn(([], []))] + self.proxy_plugin.return_value.write_to_descriptors.side_effect = \ + [asyncReturn(False), asyncReturn(False)] + self.proxy_plugin.return_value.read_from_descriptors.side_effect = \ + [asyncReturn(False), asyncReturn(False)] + # 2. Mock plugin lifecycle methods + self.proxy_plugin.return_value.resolve_dns.return_value = None, None self.proxy_plugin.return_value.before_upstream_connection.side_effect = lambda r: r + self.proxy_plugin.return_value.handle_client_data.side_effect = lambda r: r self.proxy_plugin.return_value.handle_client_request.side_effect = lambda r: r - self.proxy_plugin.return_value.resolve_dns.return_value = None, None + self.proxy_plugin.return_value.handle_upstream_chunk.side_effect = lambda r: r + self.proxy_plugin.return_value.on_upstream_connection_close.return_value = None + self.proxy_plugin.return_value.on_access_log.side_effect = lambda r: r + self.proxy_plugin.return_value.do_intercept.return_value = True self.mock_selector.return_value.select.side_effect = [ + # Trigger read on plain socket [( selectors.SelectorKey( fileobj=self._conn.fileno(), @@ -135,6 +165,16 @@ async def asyncReturnBool(val: bool) -> bool: ), selectors.EVENT_READ, )], + # Trigger read on encrypted socket + [( + selectors.SelectorKey( + fileobj=client_tls_sock.fileno(), + fd=client_tls_sock.fileno(), + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], ] await self.protocol_handler._run_once() @@ -150,20 +190,38 @@ async def asyncReturnBool(val: bool) -> bool: # for interception self.assertEqual( self.proxy_plugin.call_args[0][2].connection, - self.mock_ssl_wrap.return_value, - ) - - # Assert our mocked plugins invocations - # self.plugin.return_value.get_descriptors.assert_called() - # self.plugin.return_value.write_to_descriptors.assert_called_with([]) - # # on_client_data is only called after initial request has completed - # self.plugin.return_value.on_client_data.assert_not_called() - # self.plugin.return_value.on_request_complete.assert_called() - # self.plugin.return_value.read_from_descriptors.assert_called_with([ - # self._conn.fileno(), - # ]) + client_tls_sock, + ) + + # Invoked lifecycle callbacks + self.proxy_plugin.return_value.resolve_dns.assert_called_once_with( + host, port, + ) self.proxy_plugin.return_value.before_upstream_connection.assert_called() - self.proxy_plugin.return_value.handle_client_request.assert_called() + self.proxy_plugin.return_value.handle_client_request.assert_called_once() + self.proxy_plugin.return_value.do_intercept.assert_called_once() + # All the invoked lifecycle callbacks will receive the CONNECT request + # packet with / as the path + callback_request: HttpParser = \ + self.proxy_plugin.return_value.before_upstream_connection.call_args_list[0][0][0] + callback_request1: HttpParser = \ + self.proxy_plugin.return_value.handle_client_request.call_args_list[0][0][0] + callback_request2: HttpParser = \ + self.proxy_plugin.return_value.do_intercept.call_args_list[0][0][0] + self.assertEqual(callback_request.host, bytes_(host)) + self.assertEqual(callback_request.port, 443) + self.assertEqual(callback_request.header(b'Host'), headers[b'Host']) + assert callback_request._url + self.assertEqual(callback_request._url.remainder, None) + self.assertEqual(callback_request.method, httpMethods.CONNECT) + self.assertEqual(callback_request.is_https_tunnel, True) + self.assertEqual(callback_request.build(), callback_request1.build()) + self.assertEqual(callback_request.build(), callback_request2.build()) + # Lifecycle callbacks not invoked + self.proxy_plugin.return_value.handle_client_data.assert_not_called() + self.proxy_plugin.return_value.handle_upstream_chunk.assert_not_called() + self.proxy_plugin.return_value.on_upstream_connection_close.assert_not_called() + self.proxy_plugin.return_value.on_access_log.assert_not_called() self.mock_server_conn.assert_called_with(host, port) self.mock_server_conn.return_value.connection.setblocking.assert_called_with( @@ -173,9 +231,6 @@ async def asyncReturnBool(val: bool) -> bool: self.mock_ssl_context.assert_called_with( ssl.Purpose.SERVER_AUTH, cafile=str(DEFAULT_CA_FILE), ) - # self.assertEqual(self.mock_ssl_context.return_value.options, - # ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | - # ssl.OP_NO_TLSv1_1) self.assertEqual(plain_connection.setblocking.call_count, 2) self.mock_ssl_context.return_value.wrap_socket.assert_called_with( plain_connection, server_hostname=host, @@ -183,10 +238,10 @@ async def asyncReturnBool(val: bool) -> bool: self.assertEqual(self.mock_sign_csr.call_count, 1) self.assertEqual(self.mock_gen_csr.call_count, 1) self.assertEqual(self.mock_gen_public_key.call_count, 1) - self.assertEqual(ssl_connection.setblocking.call_count, 1) + self.assertEqual(upstream_tls_sock.setblocking.call_count, 1) self.assertEqual( self.mock_server_conn.return_value._conn, - ssl_connection, + upstream_tls_sock, ) self._conn.send.assert_called_with( PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, @@ -204,15 +259,42 @@ async def asyncReturnBool(val: bool) -> bool: self.assertEqual(self._conn.setblocking.call_count, 2) self.assertEqual( self.protocol_handler.work.connection, - self.mock_ssl_wrap.return_value, + client_tls_sock, ) # Assert connection references for all other plugins is updated - # self.assertEqual( - # self.plugin.return_value.client._conn, - # self.mock_ssl_wrap.return_value, - # ) self.assertEqual( self.proxy_plugin.return_value.client._conn, - self.mock_ssl_wrap.return_value, + client_tls_sock, ) + + # Now process the GET request + await self.protocol_handler._run_once() + self.plugin.assert_not_called() + self.proxy_plugin.assert_called_once() + + # Lifecycle callbacks still not invoked + self.proxy_plugin.return_value.handle_client_data.assert_not_called() + self.proxy_plugin.return_value.handle_upstream_chunk.assert_not_called() + self.proxy_plugin.return_value.on_upstream_connection_close.assert_not_called() + self.proxy_plugin.return_value.on_access_log.assert_not_called() + # Only handle client request lifecycle must be called again + self.proxy_plugin.return_value.resolve_dns.assert_called_once_with( + host, port, + ) + self.proxy_plugin.return_value.before_upstream_connection.assert_called() + self.assertEqual( + self.proxy_plugin.return_value.handle_client_request.call_count, + 2, + ) + self.proxy_plugin.return_value.do_intercept.assert_called_once() + + callback_request = \ + self.proxy_plugin.return_value.handle_client_request.call_args_list[1][0][0] + self.assertEqual(callback_request.host, None) + self.assertEqual(callback_request.port, 80) + self.assertEqual(callback_request.header(b'Host'), headers[b'Host']) + assert callback_request._url + self.assertEqual(callback_request._url.remainder, b'/') + self.assertEqual(callback_request.method, httpMethods.GET) + self.assertEqual(callback_request.is_https_tunnel, False)