diff --git a/README.md b/README.md index 7346c35482..dd13814a05 100644 --- a/README.md +++ b/README.md @@ -3,9 +3,9 @@ [![License](https://img.shields.io/github/license/abhinavsingh/proxy.py.svg)](https://opensource.org/licenses/BSD-3-Clause) [![Build Status](https://travis-ci.org/abhinavsingh/proxy.py.svg?branch=develop)](https://travis-ci.org/abhinavsingh/proxy.py/) [![No Dependencies](https://img.shields.io/static/v1?label=dependencies&message=none&color=green)](https://github.com/abhinavsingh/proxy.py) -[![Coverage](https://codecov.io/gh/abhinavsingh/proxy.py/branch/develop/graph/badge.svg)](https://codecov.io/gh/abhinavsingh/proxy.py) [![PyPi Monthly](https://img.shields.io/pypi/dm/proxy.py.svg?color=green)](https://pypi.org/project/proxy.py/) [![Docker Pulls](https://img.shields.io/docker/pulls/abhinavsingh/proxy.py?color=green)](https://hub.docker.com/r/abhinavsingh/proxy.py) +[![Coverage](https://codecov.io/gh/abhinavsingh/proxy.py/branch/develop/graph/badge.svg)](https://codecov.io/gh/abhinavsingh/proxy.py) [![Tested With MacOS, Ubuntu, Windows, Android, Android Emulator, iOS, iOS Simulator](https://img.shields.io/static/v1?label=tested%20with&message=mac%20OS%20%F0%9F%92%BB%20%7C%20Ubuntu%20%F0%9F%96%A5%20%7C%20Windows%20%F0%9F%92%BB&color=brightgreen)](https://abhinavsingh.com/proxy-py-a-lightweight-single-file-http-proxy-server-in-python/) [![Android, Android Emulator](https://img.shields.io/static/v1?label=tested%20with&message=Android%20%F0%9F%93%B1%20%7C%20Android%20Emulator%20%F0%9F%93%B1&color=brightgreen)](https://abhinavsingh.com/proxy-py-a-lightweight-single-file-http-proxy-server-in-python/) @@ -58,6 +58,9 @@ Table of Contents * [Plugin Ordering](#plugin-ordering) * [End-to-End Encryption](#end-to-end-encryption) * [TLS Interception](#tls-interception) +* [Proxy Over SSH Tunnel](#proxy-over-ssh-tunnel) + * [Proxy Remote Requests Locally](#proxy-remote-requests-locally) + * [Proxy Local Requests Remotely](#proxy-local-requests-remotely) * [Embed proxy.py](#embed-proxypy) * [Blocking Mode](#blocking-mode) * [Non-blocking Mode](#non-blocking-mode) @@ -798,6 +801,92 @@ cached file instead of plain text. Now use CA flags with other [plugin examples](#plugin-examples) to see them work with `https` traffic. +Proxy Over SSH Tunnel +===================== + +Requires `paramiko` to work. See [requirements-tunnel.txt](https://github.com/abhinavsingh/proxy.py/blob/develop/requirements-tunnel.txt) + +## Proxy Remote Requests Locally + + | + +------------+ | +----------+ + | LOCAL | | | REMOTE | + | HOST | <== SSH ==== :8900 == | SERVER | + +------------+ | +----------+ + :8899 proxy.py | + | + FIREWALL + (allow tcp/22) + +## What + +Proxy HTTP(s) requests made on a `remote` server through `proxy.py` server +running on `localhost`. + +### How + +* Requested `remote` port is forwarded over the SSH connection. +* `proxy.py` running on the `localhost` handles and responds to + `remote` proxy requests. + +### Requirements + +1. `localhost` MUST have SSH access to the `remote` server +2. `remote` server MUST be configured to proxy HTTP(s) requests + through the forwarded port number e.g. `:8900`. + - `remote` and `localhost` ports CAN be same e.g. `:8899`. + - `:8900` is chosen in ascii art for differentiation purposes. + +### Try it + +Start `proxy.py` as: + +``` +$ # On localhost +$ proxy --enable-tunnel \ + --tunnel-username username \ + --tunnel-hostname ip.address.or.domain.name \ + --tunnel-port 22 \ + --tunnel-remote-host 127.0.0.1 + --tunnel-remote-port 8899 +``` + +Make a HTTP proxy request on `remote` server and +verify that response contains public IP address of `localhost` as origin: + +``` +$ # On remote +$ curl -x 127.0.0.1:8899 http://httpbin.org/get +{ + "args": {}, + "headers": { + "Accept": "*/*", + "Host": "httpbin.org", + "User-Agent": "curl/7.54.0" + }, + "origin": "x.x.x.x, y.y.y.y", + "url": "https://httpbin.org/get" +} +``` + +Also, verify that `proxy.py` logs on `localhost` contains `remote` IP as client IP. + +``` +access_log:328 - remote:52067 - GET httpbin.org:80 +``` + +## Proxy Local Requests Remotely + + | + +------------+ | +----------+ + | LOCAL | | | REMOTE | + | HOST | === SSH =====> | SERVER | + +------------+ | +----------+ + | :8899 proxy.py + | + FIREWALL + (allow tcp/22) + Embed proxy.py ============== diff --git a/proxy/core/acceptor/__init__.py b/proxy/core/acceptor/__init__.py new file mode 100644 index 0000000000..9c0a97b332 --- /dev/null +++ b/proxy/core/acceptor/__init__.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from .acceptor import Acceptor +from .pool import AcceptorPool + +__all__ = [ + 'Acceptor', + 'AcceptorPool', +] diff --git a/proxy/core/acceptor/acceptor.py b/proxy/core/acceptor/acceptor.py new file mode 100644 index 0000000000..648f51edc5 --- /dev/null +++ b/proxy/core/acceptor/acceptor.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import logging +import multiprocessing +import selectors +import socket +import threading +# import time +from multiprocessing import connection +from multiprocessing.reduction import send_handle, recv_handle +from typing import Optional, Type, Tuple + +from ..connection import TcpClientConnection +from ..threadless import ThreadlessWork, Threadless +from ..event import EventQueue, eventNames +from ...common.flags import Flags + +logger = logging.getLogger(__name__) + + +class Acceptor(multiprocessing.Process): + """Socket client acceptor. + + Accepts client connection over received server socket handle and + starts a new work thread. + """ + + lock = multiprocessing.Lock() + + def __init__( + self, + idd: int, + work_queue: connection.Connection, + flags: Flags, + work_klass: Type[ThreadlessWork], + event_queue: Optional[EventQueue] = None) -> None: + super().__init__() + self.idd = idd + self.work_queue: connection.Connection = work_queue + self.flags = flags + self.work_klass = work_klass + self.event_queue = event_queue + + self.running = multiprocessing.Event() + self.selector: Optional[selectors.DefaultSelector] = None + self.sock: Optional[socket.socket] = None + self.threadless_process: Optional[Threadless] = None + self.threadless_client_queue: Optional[connection.Connection] = None + + def start_threadless_process(self) -> None: + pipe = multiprocessing.Pipe() + self.threadless_client_queue = pipe[0] + self.threadless_process = Threadless( + client_queue=pipe[1], + flags=self.flags, + work_klass=self.work_klass, + event_queue=self.event_queue + ) + self.threadless_process.start() + logger.debug('Started process %d', self.threadless_process.pid) + + def shutdown_threadless_process(self) -> None: + assert self.threadless_process and self.threadless_client_queue + logger.debug('Stopped process %d', self.threadless_process.pid) + self.threadless_process.running.set() + self.threadless_process.join() + self.threadless_client_queue.close() + + def start_work(self, conn: socket.socket, addr: Tuple[str, int]) -> None: + if self.flags.threadless and \ + self.threadless_client_queue and \ + self.threadless_process: + self.threadless_client_queue.send(addr) + send_handle( + self.threadless_client_queue, + conn.fileno(), + self.threadless_process.pid + ) + conn.close() + else: + work = self.work_klass( + TcpClientConnection(conn, addr), + flags=self.flags, + event_queue=self.event_queue + ) + work_thread = threading.Thread(target=work.run) + work_thread.daemon = True + work.publish_event( + event_name=eventNames.WORK_STARTED, + event_payload={'fileno': conn.fileno(), 'addr': addr}, + publisher_id=self.__class__.__name__ + ) + work_thread.start() + + def run_once(self) -> None: + assert self.selector and self.sock + with self.lock: + events = self.selector.select(timeout=1) + if len(events) == 0: + return + conn, addr = self.sock.accept() + # now = time.time() + # fileno: int = conn.fileno() + self.start_work(conn, addr) + # logger.info('Work started for fd %d in %f seconds', fileno, time.time() - now) + + def run(self) -> None: + self.selector = selectors.DefaultSelector() + fileno = recv_handle(self.work_queue) + self.work_queue.close() + self.sock = socket.fromfd( + fileno, + family=self.flags.family, + type=socket.SOCK_STREAM + ) + try: + self.selector.register(self.sock, selectors.EVENT_READ) + if self.flags.threadless: + self.start_threadless_process() + while not self.running.is_set(): + self.run_once() + except KeyboardInterrupt: + pass + finally: + self.selector.unregister(self.sock) + if self.flags.threadless: + self.shutdown_threadless_process() + self.sock.close() + logger.debug('Acceptor#%d shutdown', self.idd) diff --git a/proxy/core/acceptor.py b/proxy/core/acceptor/pool.py similarity index 51% rename from proxy/core/acceptor.py rename to proxy/core/acceptor/pool.py index 2c376dbbce..4b83b9ae95 100644 --- a/proxy/core/acceptor.py +++ b/proxy/core/acceptor/pool.py @@ -10,17 +10,17 @@ """ import logging import multiprocessing -import selectors import socket import threading # import time from multiprocessing import connection -from multiprocessing.reduction import send_handle, recv_handle -from typing import List, Optional, Type, Tuple +from multiprocessing.reduction import send_handle +from typing import List, Optional, Type -from .threadless import ThreadlessWork, Threadless -from .event import EventQueue, EventDispatcher, eventNames -from ..common.flags import Flags +from .acceptor import Acceptor +from ..threadless import ThreadlessWork +from ..event import EventQueue, EventDispatcher +from ...common.flags import Flags logger = logging.getLogger(__name__) @@ -125,115 +125,3 @@ def setup(self) -> None: ) self.work_queues[index].close() self.socket.close() - - -class Acceptor(multiprocessing.Process): - """Socket client acceptor. - - Accepts client connection over received server socket handle and - starts a new work thread. - """ - - lock = multiprocessing.Lock() - - def __init__( - self, - idd: int, - work_queue: connection.Connection, - flags: Flags, - work_klass: Type[ThreadlessWork], - event_queue: Optional[EventQueue] = None) -> None: - super().__init__() - self.idd = idd - self.work_queue: connection.Connection = work_queue - self.flags = flags - self.work_klass = work_klass - self.event_queue = event_queue - - self.running = multiprocessing.Event() - self.selector: Optional[selectors.DefaultSelector] = None - self.sock: Optional[socket.socket] = None - self.threadless_process: Optional[Threadless] = None - self.threadless_client_queue: Optional[connection.Connection] = None - - def start_threadless_process(self) -> None: - pipe = multiprocessing.Pipe() - self.threadless_client_queue = pipe[0] - self.threadless_process = Threadless( - client_queue=pipe[1], - flags=self.flags, - work_klass=self.work_klass, - event_queue=self.event_queue - ) - self.threadless_process.start() - logger.debug('Started process %d', self.threadless_process.pid) - - def shutdown_threadless_process(self) -> None: - assert self.threadless_process and self.threadless_client_queue - logger.debug('Stopped process %d', self.threadless_process.pid) - self.threadless_process.running.set() - self.threadless_process.join() - self.threadless_client_queue.close() - - def start_work(self, conn: socket.socket, addr: Tuple[str, int]) -> None: - if self.flags.threadless and \ - self.threadless_client_queue and \ - self.threadless_process: - self.threadless_client_queue.send(addr) - send_handle( - self.threadless_client_queue, - conn.fileno(), - self.threadless_process.pid - ) - conn.close() - else: - work = self.work_klass( - fileno=conn.fileno(), - addr=addr, - flags=self.flags, - event_queue=self.event_queue - ) - work_thread = threading.Thread(target=work.run) - work_thread.daemon = True - work.publish_event( - event_name=eventNames.WORK_STARTED, - event_payload={'fileno': conn.fileno(), 'addr': addr}, - publisher_id=self.__class__.__name__ - ) - work_thread.start() - - def run_once(self) -> None: - assert self.selector and self.sock - with self.lock: - events = self.selector.select(timeout=1) - if len(events) == 0: - return - conn, addr = self.sock.accept() - # now = time.time() - # fileno: int = conn.fileno() - self.start_work(conn, addr) - # logger.info('Work started for fd %d in %f seconds', fileno, time.time() - now) - - def run(self) -> None: - self.selector = selectors.DefaultSelector() - fileno = recv_handle(self.work_queue) - self.work_queue.close() - self.sock = socket.fromfd( - fileno, - family=self.flags.family, - type=socket.SOCK_STREAM - ) - try: - self.selector.register(self.sock, selectors.EVENT_READ) - if self.flags.threadless: - self.start_threadless_process() - while not self.running.is_set(): - self.run_once() - except KeyboardInterrupt: - pass - finally: - self.selector.unregister(self.sock) - if self.flags.threadless: - self.shutdown_threadless_process() - self.sock.close() - logger.debug('Acceptor#%d shutdown', self.idd) diff --git a/proxy/core/threadless.py b/proxy/core/threadless.py index 750588906f..d13f973aa8 100644 --- a/proxy/core/threadless.py +++ b/proxy/core/threadless.py @@ -22,6 +22,7 @@ from abc import ABC, abstractmethod from typing import Dict, Optional, Tuple, List, Union, Generator, Any, Type +from .connection import TcpClientConnection from .event import EventQueue, eventNames from ..common.flags import Flags @@ -37,15 +38,12 @@ class ThreadlessWork(ABC): @abstractmethod def __init__( self, - fileno: int, - addr: Tuple[str, int], + client: TcpClientConnection, flags: Optional[Flags], event_queue: Optional[EventQueue] = None, uid: Optional[str] = None) -> None: - self.fileno = fileno - self.addr = addr + self.client = client self.flags = flags if flags else Flags() - self.event_queue = event_queue self.uid: str = uid if uid is not None else uuid.uuid4().hex @@ -167,12 +165,16 @@ async def wait_for_tasks( except asyncio.TimeoutError: self.cleanup(work_id) + def fromfd(self, fileno: int) -> socket.socket: + return socket.fromfd( + fileno, family=socket.AF_INET if self.flags.hostname.version == 4 else socket.AF_INET6, + type=socket.SOCK_STREAM) + def accept_client(self) -> None: addr = self.client_queue.recv() fileno = recv_handle(self.client_queue) self.works[fileno] = self.work_klass( - fileno=fileno, - addr=addr, + TcpClientConnection(conn=self.fromfd(fileno), addr=addr), flags=self.flags, event_queue=self.event_queue ) diff --git a/proxy/http/handler.py b/proxy/http/handler.py index 569a23c929..a19a2b65d2 100644 --- a/proxy/http/handler.py +++ b/proxy/http/handler.py @@ -113,20 +113,18 @@ class HttpProtocolHandler(ThreadlessWork): Accepts `Client` connection object and manages HttpProtocolHandlerPlugin invocations. """ - def __init__(self, fileno: int, addr: Tuple[str, int], + def __init__(self, client: TcpClientConnection, flags: Optional[Flags] = None, event_queue: Optional[EventQueue] = None, uid: Optional[str] = None): - super().__init__(fileno, addr, flags, event_queue, uid) + super().__init__(client, flags, event_queue, uid) self.start_time: float = time.time() self.last_activity: float = self.start_time self.request: HttpParser = HttpParser(httpParserTypes.REQUEST_PARSER) self.response: HttpParser = HttpParser(httpParserTypes.RESPONSE_PARSER) self.selector = selectors.DefaultSelector() - self.client: TcpClientConnection = TcpClientConnection( - self.fromfd(self.fileno), self.addr - ) + self.client: TcpClientConnection = client self.plugins: Dict[str, HttpProtocolHandlerPlugin] = {} def initialize(self) -> None: @@ -134,7 +132,7 @@ def initialize(self) -> None: conn = self.optionally_wrap_socket(self.client.connection) conn.setblocking(False) if self.flags.encryption_enabled(): - self.client = TcpClientConnection(conn=conn, addr=self.addr) + self.client = TcpClientConnection(conn=conn, addr=self.client.addr) if b'HttpProtocolHandlerPlugin' in self.flags.plugins: for klass in self.flags.plugins[b'HttpProtocolHandlerPlugin']: instance = klass( @@ -232,12 +230,6 @@ def shutdown(self) -> None: logger.debug('Client connection closed') super().shutdown() - def fromfd(self, fileno: int) -> socket.socket: - conn = socket.fromfd( - fileno, family=socket.AF_INET if self.flags.hostname.version == 4 else socket.AF_INET6, - type=socket.SOCK_STREAM) - return conn - def optionally_wrap_socket( self, conn: socket.socket) -> Union[ssl.SSLSocket, socket.socket]: """Attempts to wrap accepted client connection using provided certificates. diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index 14b475f818..229187cd0c 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -106,7 +106,9 @@ def read_from_descriptors(self, r: List[Union[int, HasFileno]]) -> bool: raw = self.server.recv(self.flags.server_recvbuf_size) except TimeoutError as e: if e.errno == errno.ETIMEDOUT: - logger.warning('%s:%d timed out on recv' % self.server.addr) + logger.warning( + '%s:%d timed out on recv' % + self.server.addr) return True else: raise e @@ -115,7 +117,9 @@ def read_from_descriptors(self, r: List[Union[int, HasFileno]]) -> bool: return False except OSError as e: if e.errno == errno.EHOSTUNREACH: - logger.warning('%s:%d unreachable on recv' % self.server.addr) + logger.warning( + '%s:%d unreachable on recv' % + self.server.addr) return True elif e.errno == errno.ECONNRESET: logger.warning('Connection reset by upstream: %r' % e) diff --git a/tests/core/test_acceptor.py b/tests/core/test_acceptor.py index 92ae2fb66e..537339d537 100644 --- a/tests/core/test_acceptor.py +++ b/tests/core/test_acceptor.py @@ -33,7 +33,7 @@ def setUp(self) -> None: @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') - @mock.patch('proxy.core.acceptor.recv_handle') + @mock.patch('proxy.core.acceptor.acceptor.recv_handle') def test_continues_when_no_events( self, mock_recv_handle: mock.Mock, @@ -54,16 +54,18 @@ def test_continues_when_no_events( sock.accept.assert_not_called() self.mock_protocol_handler.assert_not_called() + @mock.patch('proxy.core.acceptor.acceptor.TcpClientConnection') @mock.patch('threading.Thread') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') - @mock.patch('proxy.core.acceptor.recv_handle') + @mock.patch('proxy.core.acceptor.acceptor.recv_handle') def test_accepts_client_from_server_socket( self, mock_recv_handle: mock.Mock, mock_fromfd: mock.Mock, mock_selector: mock.Mock, - mock_thread: mock.Mock) -> None: + mock_thread: mock.Mock, + mock_client: mock.Mock) -> None: fileno = 10 conn = mock.MagicMock() addr = mock.MagicMock() @@ -87,8 +89,7 @@ def test_accepts_client_from_server_socket( type=socket.SOCK_STREAM ) self.mock_protocol_handler.assert_called_with( - fileno=conn.fileno(), - addr=addr, + mock_client.return_value, flags=self.flags, event_queue=None, ) diff --git a/tests/core/test_acceptor_pool.py b/tests/core/test_acceptor_pool.py index 51f10d6095..e8192495ed 100644 --- a/tests/core/test_acceptor_pool.py +++ b/tests/core/test_acceptor_pool.py @@ -18,49 +18,50 @@ class TestAcceptorPool(unittest.TestCase): - @mock.patch('proxy.core.acceptor.send_handle') + @mock.patch('proxy.core.acceptor.pool.send_handle') @mock.patch('multiprocessing.Pipe') @mock.patch('socket.socket') - @mock.patch('proxy.core.acceptor.Acceptor') + @mock.patch('proxy.core.acceptor.pool.Acceptor') def test_setup_and_shutdown( self, - mock_worker: mock.Mock, + mock_acceptor: mock.Mock, mock_socket: mock.Mock, mock_pipe: mock.Mock, - _mock_send_handle: mock.Mock) -> None: - mock_worker1 = mock.MagicMock() - mock_worker2 = mock.MagicMock() - mock_worker.side_effect = [mock_worker1, mock_worker2] + mock_send_handle: mock.Mock) -> None: + acceptor1 = mock.MagicMock() + acceptor2 = mock.MagicMock() + mock_acceptor.side_effect = [acceptor1, acceptor2] num_workers = 2 sock = mock_socket.return_value work_klass = mock.MagicMock() flags = Flags(num_workers=2) - acceptor = AcceptorPool(flags=flags, work_klass=work_klass) - acceptor.setup() + pool = AcceptorPool(flags=flags, work_klass=work_klass) + pool.setup() + mock_send_handle.assert_called() work_klass.assert_not_called() mock_socket.assert_called_with( - socket.AF_INET6 if acceptor.flags.hostname.version == 6 else socket.AF_INET, + socket.AF_INET6 if pool.flags.hostname.version == 6 else socket.AF_INET, socket.SOCK_STREAM ) sock.setsockopt.assert_called_with( socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind.assert_called_with( - (str(acceptor.flags.hostname), acceptor.flags.port)) - sock.listen.assert_called_with(acceptor.flags.backlog) + (str(pool.flags.hostname), pool.flags.port)) + sock.listen.assert_called_with(pool.flags.backlog) sock.setblocking.assert_called_with(False) self.assertTrue(mock_pipe.call_count, num_workers) - self.assertTrue(mock_worker.call_count, num_workers) - mock_worker1.start.assert_called() - mock_worker1.join.assert_not_called() - mock_worker2.start.assert_called() - mock_worker2.join.assert_not_called() + self.assertTrue(mock_acceptor.call_count, num_workers) + acceptor1.start.assert_called() + acceptor2.start.assert_called() + acceptor1.join.assert_not_called() + acceptor2.join.assert_not_called() sock.close.assert_called() - acceptor.shutdown() - mock_worker1.join.assert_called() - mock_worker2.join.assert_called() + pool.shutdown() + acceptor1.join.assert_called() + acceptor2.join.assert_called() diff --git a/tests/http/test_http_proxy.py b/tests/http/test_http_proxy.py index 3bb2648ae4..60024d3f00 100644 --- a/tests/http/test_http_proxy.py +++ b/tests/http/test_http_proxy.py @@ -14,6 +14,7 @@ from proxy.common.constants import DEFAULT_HTTP_PORT from proxy.common.flags import Flags +from proxy.core.connection import TcpClientConnection from proxy.http.proxy import HttpProxyPlugin from proxy.http.handler import HttpProtocolHandler from proxy.http.exception import HttpProtocolException @@ -40,7 +41,8 @@ def setUp(self, } self._conn = mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( - self.fileno, self._addr, flags=self.flags) + TcpClientConnection(self._conn, self._addr), + flags=self.flags) self.protocol_handler.initialize() def test_proxy_plugin_initialized(self) -> None: diff --git a/tests/http/test_http_proxy_tls_interception.py b/tests/http/test_http_proxy_tls_interception.py index 87c3eed780..ff01a7c85f 100644 --- a/tests/http/test_http_proxy_tls_interception.py +++ b/tests/http/test_http_proxy_tls_interception.py @@ -17,6 +17,7 @@ from typing import Any from unittest import mock +from proxy.core.connection import TcpClientConnection from proxy.http.handler import HttpProtocolHandler from proxy.http.proxy import HttpProxyPlugin from proxy.http.methods import httpMethods @@ -78,7 +79,8 @@ def mock_connection() -> Any: } self._conn = mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( - self.fileno, self._addr, flags=self.flags) + TcpClientConnection(self._conn, self._addr), + flags=self.flags) self.protocol_handler.initialize() self.plugin.assert_called() diff --git a/tests/http/test_protocol_handler.py b/tests/http/test_protocol_handler.py index f4140fe998..ca4dac037d 100644 --- a/tests/http/test_protocol_handler.py +++ b/tests/http/test_protocol_handler.py @@ -15,15 +15,16 @@ from typing import cast from unittest import mock +from proxy.common.version import __version__ from proxy.common.flags import Flags from proxy.common.utils import bytes_ from proxy.common.constants import CRLF +from proxy.core.connection import TcpClientConnection from proxy.http.parser import HttpParser from proxy.http.proxy import HttpProxyPlugin from proxy.http.parser import httpParserStates, httpParserTypes from proxy.http.exception import ProxyAuthenticationFailed, ProxyConnectionFailed from proxy.http.handler import HttpProtocolHandler -from proxy.common.version import __version__ class TestHttpProtocolHandler(unittest.TestCase): @@ -44,7 +45,7 @@ def setUp(self, self.mock_selector = mock_selector self.protocol_handler = HttpProtocolHandler( - self.fileno, self._addr, flags=self.flags) + TcpClientConnection(self._conn, self._addr), flags=self.flags) self.protocol_handler.initialize() @mock.patch('proxy.http.proxy.server.TcpServerConnection') @@ -175,7 +176,7 @@ def test_proxy_authentication_failed( flags.plugins = Flags.load_plugins( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin') self.protocol_handler = HttpProtocolHandler( - self.fileno, self._addr, flags=flags) + TcpClientConnection(self._conn, self._addr), flags=flags) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET http://abhinavsingh.com HTTP/1.1', @@ -208,7 +209,7 @@ def test_authenticated_proxy_http_get( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin') self.protocol_handler = HttpProtocolHandler( - self.fileno, addr=self._addr, flags=flags) + TcpClientConnection(self._conn, self._addr), flags=flags) self.protocol_handler.initialize() assert self.http_server_port is not None @@ -256,7 +257,7 @@ def test_authenticated_proxy_http_tunnel( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin') self.protocol_handler = HttpProtocolHandler( - self.fileno, self._addr, flags=flags) + TcpClientConnection(self._conn, self._addr), flags=flags) self.protocol_handler.initialize() assert self.http_server_port is not None diff --git a/tests/http/test_web_server.py b/tests/http/test_web_server.py index 72042b40a9..3a21fc9e1f 100644 --- a/tests/http/test_web_server.py +++ b/tests/http/test_web_server.py @@ -16,6 +16,7 @@ from unittest import mock from proxy.common.flags import Flags +from proxy.core.connection import TcpClientConnection from proxy.http.handler import HttpProtocolHandler from proxy.http.parser import httpParserStates from proxy.common.utils import build_http_response, build_http_request, bytes_, text_ @@ -36,7 +37,8 @@ def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self.flags.plugins = Flags.load_plugins( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin') self.protocol_handler = HttpProtocolHandler( - self.fileno, self._addr, flags=self.flags) + TcpClientConnection(self._conn, self._addr), + flags=self.flags) self.protocol_handler.initialize() @mock.patch('selectors.DefaultSelector') @@ -96,7 +98,8 @@ def test_default_web_server_returns_404( flags.plugins = Flags.load_plugins( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin') self.protocol_handler = HttpProtocolHandler( - self.fileno, self._addr, flags=flags) + TcpClientConnection(self._conn, self._addr), + flags=flags) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET /hello HTTP/1.1', @@ -147,7 +150,8 @@ def test_static_web_server_serves( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin') self.protocol_handler = HttpProtocolHandler( - self.fileno, self._addr, flags=flags) + TcpClientConnection(self._conn, self._addr), + flags=flags) self.protocol_handler.initialize() self.protocol_handler.run_once() @@ -194,7 +198,8 @@ def test_static_web_server_serves_404( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin') self.protocol_handler = HttpProtocolHandler( - self.fileno, self._addr, flags=flags) + TcpClientConnection(self._conn, self._addr), + flags=flags) self.protocol_handler.initialize() self.protocol_handler.run_once() @@ -213,7 +218,8 @@ def test_on_client_connection_called_on_teardown( flags.plugins = {b'HttpProtocolHandlerPlugin': [plugin]} self._conn = mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( - self.fileno, self._addr, flags=flags) + TcpClientConnection(self._conn, self._addr), + flags=flags) self.protocol_handler.initialize() plugin.assert_called() with mock.patch.object(self.protocol_handler, 'run_once') as mock_run_once: @@ -228,7 +234,8 @@ def init_and_make_pac_file_request(self, pac_file: str) -> None: b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin,' b'proxy.http.server.HttpWebServerPacFilePlugin') self.protocol_handler = HttpProtocolHandler( - self.fileno, self._addr, flags=flags) + TcpClientConnection(self._conn, self._addr), + flags=flags) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET / HTTP/1.1', diff --git a/tests/plugin/test_http_proxy_plugins.py b/tests/plugin/test_http_proxy_plugins.py index a768c1013e..84ca5a6970 100644 --- a/tests/plugin/test_http_proxy_plugins.py +++ b/tests/plugin/test_http_proxy_plugins.py @@ -17,6 +17,7 @@ from typing import cast from proxy.common.flags import Flags +from proxy.core.connection import TcpClientConnection from proxy.http.handler import HttpProtocolHandler from proxy.http.proxy import HttpProxyPlugin from proxy.common.utils import build_http_request, bytes_, build_http_response @@ -51,7 +52,8 @@ def setUp(self, } self._conn = mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( - self.fileno, self._addr, flags=self.flags) + TcpClientConnection(self._conn, self._addr), + flags=self.flags) self.protocol_handler.initialize() @mock.patch('proxy.http.proxy.server.TcpServerConnection') diff --git a/tests/plugin/test_http_proxy_plugins_with_tls_interception.py b/tests/plugin/test_http_proxy_plugins_with_tls_interception.py index ad05b2b1a3..2976869d97 100644 --- a/tests/plugin/test_http_proxy_plugins_with_tls_interception.py +++ b/tests/plugin/test_http_proxy_plugins_with_tls_interception.py @@ -19,6 +19,7 @@ from proxy.common.utils import bytes_ from proxy.common.flags import Flags from proxy.common.utils import build_http_request, build_http_response +from proxy.core.connection import TcpClientConnection from proxy.http.codes import httpStatusCodes from proxy.http.methods import httpMethods from proxy.http.handler import HttpProtocolHandler @@ -66,7 +67,7 @@ def setUp(self, self._conn = mock.MagicMock(spec=socket.socket) mock_fromfd.return_value = self._conn self.protocol_handler = HttpProtocolHandler( - self.fileno, self._addr, flags=self.flags) + TcpClientConnection(self._conn, self._addr), flags=self.flags) self.protocol_handler.initialize() self.server = self.mock_server_conn.return_value