diff --git a/README.md b/README.md index c8e812e0f0..8bfd444bab 100644 --- a/README.md +++ b/README.md @@ -208,6 +208,12 @@ - Enable builtin [Reverse Proxy Server](#reverse-proxy-plugins). Example: - `--enable-reverse-proxy --plugins proxy.plugin.ReverseProxyPlugin` - Plugin API is currently in *development phase*. Expect breaking changes. See [Deploying proxy.py in production](#deploying-proxypy-in-production) on how to ensure reliability across code changes. + +- Can listen on multiple ports + - Use `--ports` flag to provide additional ports + - Optionally, use `--port` flag to override default port `8899` + - Capable of serving multiple protocols over the same port + - Real-time Dashboard - Optionally, enable [proxy.py dashboard](#run-dashboard). - Use `--enable-dashboard` @@ -216,34 +222,45 @@ - [Chrome DevTools Protocol](#chrome-devtools-protocol) support - Extend dashboard frontend using `typescript` based [plugins](https://github.com/abhinavsingh/proxy.py/tree/develop/dashboard/src/plugins) - Dashboard is currently in *development phase* Expect breaking changes. + - Secure - Enable end-to-end encryption between clients and `proxy.py` - See [End-to-End Encryption](#end-to-end-encryption) + - Private - Protection against DNS based traffic blockers - Browse with malware and adult content protection enabled - See [DNS-over-HTTPS](#cloudflarednsresolverplugin) + - Man-In-The-Middle - Can decrypt TLS traffic between clients and upstream servers - See [TLS Interception](#tls-interception) + - Supported http protocols for proxy requests - `http(s)` - `http1` - `http1.1` with pipeline - `http2` - `websockets` + - Support for `HAProxy Protocol` - See `--enable-proxy-protocol` flag + - Static file server support - See `--enable-static-server` and `--static-server-dir` flags + - Optimized for large file uploads and downloads - See `--client-recvbuf-size` and `--server-recvbuf-size` flag + - `IPv4` and `IPv6` support - See `--hostname` flag + - Unix domain socket support - See `--unix-socket-path` flag + - Basic authentication support - See `--basic-auth` flag + - PAC (Proxy Auto-configuration) support - See `--pac-file` and `--pac-file-url-path` flags diff --git a/docs/conf.py b/docs/conf.py index 05d968a832..452a09a8c2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -283,6 +283,7 @@ (_any_role, 'work_klass'), (_py_class_role, '_asyncio.Task'), (_py_class_role, 'asyncio.events.AbstractEventLoop'), + (_py_class_role, 'BaseListener'), (_py_class_role, 'CacheStore'), (_py_class_role, 'Channel'), (_py_class_role, 'HttpParser'), diff --git a/proxy/common/flag.py b/proxy/common/flag.py index ec846ee037..279752f6dc 100644 --- a/proxy/common/flag.py +++ b/proxy/common/flag.py @@ -14,6 +14,7 @@ import socket import argparse import ipaddress +import itertools import collections import multiprocessing from typing import Any, List, Optional, cast @@ -302,7 +303,12 @@ def initialize( # assert args.unix_socket_path is None args.family = socket.AF_INET6 if args.hostname.version == 6 else socket.AF_INET args.port = cast(int, opts.get('port', args.port)) - args.ports = cast(Optional[List[int]], opts.get('ports', args.ports)) + ports: List[List[int]] = opts.get('ports', args.ports) + args.ports = [ + int(port) for port in list( + itertools.chain.from_iterable([] if ports is None else ports), + ) + ] args.backlog = cast(int, opts.get('backlog', args.backlog)) num_workers = opts.get('num_workers', args.num_workers) args.num_workers = cast( diff --git a/proxy/core/acceptor/acceptor.py b/proxy/core/acceptor/acceptor.py index 9572c1b5b4..74721deb9d 100644 --- a/proxy/core/acceptor/acceptor.py +++ b/proxy/core/acceptor/acceptor.py @@ -19,7 +19,7 @@ import threading import multiprocessing import multiprocessing.synchronize -from typing import List, Tuple, Optional +from typing import Dict, List, Tuple, Optional from multiprocessing import connection from multiprocessing.reduction import recv_handle @@ -93,9 +93,8 @@ def __init__( # Selector self.running = multiprocessing.Event() self.selector: Optional[selectors.DefaultSelector] = None - # File descriptor used to accept new work - # Currently, a socket fd is assumed. - self.sock: Optional[socket.socket] = None + # File descriptors used to accept new work + self.socks: Dict[int, socket.socket] = {} # Internals self._total: Optional[int] = None self._local_work_queue: Optional['NonBlockingQueue'] = None @@ -107,11 +106,10 @@ def accept( events: List[Tuple[selectors.SelectorKey, int]], ) -> List[Tuple[socket.socket, Optional[Tuple[str, int]]]]: works = [] - for _, mask in events: - if mask & selectors.EVENT_READ and \ - self.sock is not None: + for key, mask in events: + if mask & selectors.EVENT_READ: try: - conn, addr = self.sock.accept() + conn, addr = self.socks[key.data].accept() logging.debug( 'Accepting new work#{0}'.format(conn.fileno()), ) @@ -158,33 +156,43 @@ def run(self) -> None: self.flags.log_format, ) self.selector = selectors.DefaultSelector() - # TODO: Use selector on fd_queue so that we can - # dynamically accept from new fds. - fileno = recv_handle(self.fd_queue) - self.fd_queue.close() - # TODO: Convert to socks i.e. list of fds - self.sock = socket.fromfd( - fileno, - family=self.flags.family, - type=socket.SOCK_STREAM, - ) + self._recv_and_setup_socks() try: if self.flags.threadless and self.flags.local_executor: self._start_local() - self.selector.register(self.sock, selectors.EVENT_READ) + for fileno in self.socks: + self.selector.register( + fileno, selectors.EVENT_READ, fileno, + ) while not self.running.is_set(): self.run_once() except KeyboardInterrupt: pass finally: - self.selector.unregister(self.sock) + for fileno in self.socks: + self.selector.unregister(fileno) if self.flags.threadless and self.flags.local_executor: self._stop_local() - self.sock.close() + for fileno in self.socks: + self.socks[fileno].close() + self.socks.clear() logger.debug('Acceptor#%d shutdown', self.idd) + def _recv_and_setup_socks(self) -> None: + # TODO: Use selector on fd_queue so that we can + # dynamically accept from new fds. + for _ in range(self.fd_queue.recv()): + fileno = recv_handle(self.fd_queue) + # TODO: Convert to socks i.e. list of fds + self.socks[fileno] = socket.fromfd( + fileno, + family=self.flags.family, + type=socket.SOCK_STREAM, + ) + self.fd_queue.close() + def _start_local(self) -> None: - assert self.sock + assert self.socks self._local_work_queue = NonBlockingQueue() self._local = LocalExecutor( iid=self.idd, diff --git a/proxy/core/acceptor/pool.py b/proxy/core/acceptor/pool.py index e7d16f99fc..09fb9f447f 100644 --- a/proxy/core/acceptor/pool.py +++ b/proxy/core/acceptor/pool.py @@ -109,16 +109,17 @@ def setup(self) -> None: ), ) # Send file descriptor to all acceptor processes. - for listener in self.listeners.pool: - fd = listener.fileno() - assert fd is not None - for index in range(self.flags.num_acceptors): + for index in range(self.flags.num_acceptors): + self.fd_queues[index].send(len(self.listeners.pool)) + for listener in self.listeners.pool: + fd = listener.fileno() + assert fd is not None send_handle( self.fd_queues[index], fd, self.acceptors[index].pid, ) - self.fd_queues[index].close() + self.fd_queues[index].close() def shutdown(self) -> None: logger.info('Shutting down %d acceptors' % self.flags.num_acceptors) diff --git a/proxy/core/listener/base.py b/proxy/core/listener/base.py index f3fab4a52b..899538b927 100644 --- a/proxy/core/listener/base.py +++ b/proxy/core/listener/base.py @@ -34,7 +34,7 @@ class BaseListener(ABC): For usage provide a listen method implementation.""" - def __init__(self, flags: argparse.Namespace) -> None: + def __init__(self, *args: Any, flags: argparse.Namespace, **kwargs: Any) -> None: self.flags = flags self._socket: Optional[socket.socket] = None diff --git a/proxy/core/listener/pool.py b/proxy/core/listener/pool.py index 25835fa984..5c743bc940 100644 --- a/proxy/core/listener/pool.py +++ b/proxy/core/listener/pool.py @@ -9,7 +9,7 @@ :license: BSD, see LICENSE for more details. """ import argparse -from typing import TYPE_CHECKING, Any, List +from typing import TYPE_CHECKING, Any, List, Type from .tcp import TcpSocketListener from .unix import UnixSocketListener @@ -36,15 +36,18 @@ def __exit__(self, *args: Any) -> None: def setup(self) -> None: if self.flags.unix_socket_path: - ulistener = UnixSocketListener(flags=self.flags) - ulistener.setup() - self.pool.append(ulistener) + self.add(UnixSocketListener) else: - listener = TcpSocketListener(flags=self.flags) - listener.setup() - self.pool.append(listener) + self.add(TcpSocketListener) + for port in self.flags.ports: + self.add(TcpSocketListener, port=port) def shutdown(self) -> None: for listener in self.pool: listener.shutdown() self.pool.clear() + + def add(self, klass: Type['BaseListener'], **kwargs: Any) -> None: + listener = klass(flags=self.flags, **kwargs) + listener.setup() + self.pool.append(listener) diff --git a/proxy/core/listener/tcp.py b/proxy/core/listener/tcp.py index 68df2f5546..3000ae3dd7 100644 --- a/proxy/core/listener/tcp.py +++ b/proxy/core/listener/tcp.py @@ -54,7 +54,10 @@ class TcpSocketListener(BaseListener): """Tcp listener.""" - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__(self, *args: Any, port: Optional[int] = None, **kwargs: Any) -> None: + # Port if passed will be used, otherwise + # flag port value will be used. + self.port = port # Set after binding to a port. # # Stored here separately for ephemeral port discovery. @@ -66,7 +69,8 @@ def listen(self) -> socket.socket: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) # s.setsockopt(socket.SOL_TCP, socket.TCP_FASTOPEN, 5) - sock.bind((str(self.flags.hostname), self.flags.port)) + port = self.port if self.port is not None else self.flags.port + sock.bind((str(self.flags.hostname), port)) sock.listen(self.flags.backlog) sock.setblocking(False) self._port = sock.getsockname()[1] diff --git a/tests/core/test_acceptor.py b/tests/core/test_acceptor.py index 376186c539..acfe0eab21 100644 --- a/tests/core/test_acceptor.py +++ b/tests/core/test_acceptor.py @@ -23,7 +23,7 @@ class TestAcceptor(unittest.TestCase): def setUp(self) -> None: self.acceptor_id = 1 - self.pipe = multiprocessing.Pipe() + self.pipe = mock.MagicMock() self.work_klass = mock.MagicMock() self.flags = FlagParser.initialize( threaded=True, @@ -82,15 +82,23 @@ def test_accepts_client_from_server_socket( mock_fromfd.return_value.accept.return_value = (conn, addr) mock_recv_handle.return_value = fileno + self.pipe[1].recv.return_value = 1 + mock_thread.return_value.start.side_effect = KeyboardInterrupt() + mock_key = mock.MagicMock() + type(mock_key).data = mock.PropertyMock(return_value=fileno) + selector = mock_selector.return_value - selector.select.return_value = [(None, selectors.EVENT_READ)] + selector.select.return_value = [(mock_key, selectors.EVENT_READ)] self.acceptor.run() - selector.register.assert_called_with(sock, selectors.EVENT_READ) - selector.unregister.assert_called_with(sock) + self.pipe[1].recv.assert_called_once() + selector.register.assert_called_with( + fileno, selectors.EVENT_READ, fileno, + ) + selector.unregister.assert_called_with(fileno) mock_recv_handle.assert_called_with(self.pipe[1]) mock_fromfd.assert_called_with( fileno, diff --git a/tests/core/test_acceptor_pool.py b/tests/core/test_acceptor_pool.py index d550901523..9965c56f3a 100644 --- a/tests/core/test_acceptor_pool.py +++ b/tests/core/test_acceptor_pool.py @@ -41,9 +41,11 @@ def test_setup_and_shutdown( ) self.assertEqual(flags.num_acceptors, num_acceptors) - mock_listener_pool.return_value.pool = [ - mock_tcp_socket_listener.return_value, - ] + type(mock_listener_pool.return_value).pool = mock.PropertyMock( + return_value=[ + mock_tcp_socket_listener.return_value, + ], + ) pool = AcceptorPool( flags=flags, listeners=mock_listener_pool.return_value, executor_queues=[], executor_pids=[], executor_locks=[], @@ -52,30 +54,49 @@ def test_setup_and_shutdown( self.assertEqual(mock_pipe.call_count, num_acceptors) self.assertEqual(mock_acceptor.call_count, num_acceptors) - mock_send_handle.assert_called() self.assertEqual(mock_send_handle.call_count, num_acceptors) self.assertEqual( - mock_acceptor.call_args_list[0][1]['idd'], 0, + mock_acceptor.call_args_list[0][1]['idd'], + 0, ) self.assertEqual( - mock_acceptor.call_args_list[0][1]['fd_queue'], mock_pipe.return_value[1], + mock_acceptor.call_args_list[0][1]['fd_queue'], + mock_pipe.return_value[1], ) self.assertEqual( - mock_acceptor.call_args_list[0][1]['flags'], flags, + mock_acceptor.call_args_list[0][1]['flags'], + flags, ) self.assertEqual( - mock_acceptor.call_args_list[0][1]['event_queue'], None, + mock_acceptor.call_args_list[0][1]['event_queue'], + None, ) # executor_queues=[], # executor_pids=[] self.assertEqual( mock_acceptor.call_args_list[1][1]['idd'], 1, ) + self.assertEqual( + mock_acceptor.call_args_list[1][1]['fd_queue'], + mock_pipe.return_value[2], + ) + self.assertEqual( + mock_acceptor.call_args_list[1][1]['flags'], + flags, + ) + self.assertEqual( + mock_acceptor.call_args_list[1][1]['event_queue'], + None, + ) acceptor1.start.assert_called_once() acceptor2.start.assert_called_once() - mock_tcp_socket_listener.return_value.fileno.assert_called_once() + + self.assertEqual( + mock_tcp_socket_listener.return_value.fileno.call_count, + num_acceptors, + ) acceptor1.join.assert_not_called() acceptor2.join.assert_not_called() diff --git a/tests/core/test_listener.py b/tests/core/test_listener.py index 6af6593812..3468060a0b 100644 --- a/tests/core/test_listener.py +++ b/tests/core/test_listener.py @@ -27,7 +27,7 @@ class TestListener(unittest.TestCase): def test_setup_and_teardown(self, mock_socket: mock.Mock) -> None: sock = mock_socket.return_value flags = FlagParser.initialize(port=0) - listener = TcpSocketListener(flags) + listener = TcpSocketListener(flags=flags) listener.setup() mock_socket.assert_called_with( socket.AF_INET6 if flags.hostname.version == 6 else socket.AF_INET, @@ -66,7 +66,7 @@ def test_unix_path_listener(self, mock_socket: mock.Mock, mock_remove: mock.Mock sock = mock_socket.return_value sock_path = os.path.join(tempfile.gettempdir(), 'proxy.sock') flags = FlagParser.initialize(unix_socket_path=sock_path) - listener = UnixSocketListener(flags) + listener = UnixSocketListener(flags=flags) listener.setup() mock_socket.assert_called_with(