diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 7815f0afc8..11cca3215a 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -27,8 +27,8 @@ jobs: run: | # The GitHub editor is 127 chars wide # W504 screams for line break after binary operators - flake8 --ignore=W504 --max-line-length=127 proxy.py plugin_examples.py tests.py setup.py + flake8 --ignore=W504 --max-line-length=127 proxy.py plugin_examples.py tests.py setup.py benchmark.py # mypy compliance check - mypy --strict --ignore-missing-imports proxy.py plugin_examples.py tests.py setup.py + mypy --strict --ignore-missing-imports proxy.py plugin_examples.py tests.py setup.py benchmark.py - name: Run Tests run: pytest tests.py diff --git a/.gitignore b/.gitignore index 9a8ba1c0db..eaabbb272a 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,3 @@ proxy.py.iml *.pyc ca-*.pem https-*.pem -benchmark.py diff --git a/Makefile b/Makefile index bb44f08a1d..981bc25b75 100644 --- a/Makefile +++ b/Makefile @@ -44,8 +44,8 @@ coverage: open htmlcov/index.html lint: - flake8 --ignore=W504 --max-line-length=127 proxy.py plugin_examples.py tests.py setup.py - mypy --strict --ignore-missing-imports proxy.py plugin_examples.py tests.py setup.py + flake8 --ignore=W504 --max-line-length=127 proxy.py plugin_examples.py tests.py setup.py benchmark.py + mypy --strict --ignore-missing-imports proxy.py plugin_examples.py tests.py setup.py benchmark.py autopep8: autopep8 --recursive --in-place --aggressive proxy.py diff --git a/README.md b/README.md index ea57228e7a..3c23e58647 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ Table of Contents * [Stable version](#stable-version-from-docker-hub) * [Development version](#build-development-version-locally) * [Plugin Examples](#plugin-examples) + * [ShortLinkPlugin](#shortlinkplugin) * [ModifyPostDataPlugin](#modifypostdataplugin) * [ProposedRestApiPlugin](#proposedrestapiplugin) * [RedirectToCustomServerPlugin](#redirecttocustomserverplugin) @@ -49,9 +50,16 @@ Table of Contents * [End-to-End Encryption](#end-to-end-encryption) * [TLS Interception](#tls-interception) * [import proxy.py](#import-proxypy) - * [proxy.new_socket_connection](#proxynew_socket_connection) - * [proxy.socket_connection](#proxysocket_connection) - * [proxy.build_http_request](#proxybuild_http_request) + * [TCP Sockets](#tcp-sockets) + * [proxy.new_socket_connection](#proxynew_socket_connection) + * [proxy.socket_connection](#proxysocket_connection) + * [Http Client](#http-client) + * [proxy.build_http_request](#proxybuild_http_request) + * [proxy.build_http_response](#proxybuild_http_response) + * [Websocket Client](#websocket-client) + * [proxy.WebsocketFrame](#proxywebsocketframe) + * [proxy.WebsocketClient](#proxywebsocketclient) + * [Embed proxy.py](#embed-proxypy) * [Plugin Developer and Contributor Guide](#plugin-developer-and-contributor-guide) * [Everything is a plugin](#everything-is-a-plugin) * [Internal Architecture](#internal-architecture) @@ -68,6 +76,29 @@ Table of Contents Features ======== +- Fast & Scalable + - Scales by using all available cores on the system + - Threadless executions using coroutine + - Made to handle `tens-of-thousands` connections / sec + ``` + # On Macbook Pro 2015 / 2.8 GHz Intel Core i7 + $ hey -n 10000 -c 100 http://localhost:8899/ + + Summary: + Total: 0.6157 secs + Slowest: 0.1049 secs + Fastest: 0.0007 secs + Average: 0.0055 secs + Requests/sec: 16240.5444 + + Total data: 800000 bytes + Size/request: 80 bytes + + Response time histogram: + 0.001 [1] | + 0.011 [9565] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ + 0.022 [332] |■ + ``` - Lightweight - Distributed as a single file module `~100KB` - Uses only `~5-20MB` RAM @@ -204,6 +235,35 @@ See [plugin_examples.py](https://github.com/abhinavsingh/proxy.py/blob/develop/p All the examples below also works with `https` traffic but require additional flags and certificate generation. See [TLS Interception](#tls-interception). +## ShortLinkPlugin + +Add support for short links in your favorite browsers / applications. + +Start `proxy.py` as: + +``` +$ proxy.py \ + --plugins plugin_examples.ShortLinkPlugin +``` + +Now you can speed up your daily browsing experience by visiting your +favorite website using single character domain names :). This works +across all browsers. + +Following short links are enabled by default: + +Short Link | Destination URL +:--------: | :---------------: +a/ | amazon.com +i/ | instagram.com +l/ | linkedin.com +f/ | facebook.com +g/ | google.com +t/ | twitter.com +w/ | web.whatsapp.com +y/ | youtube.com +proxy/ | localhost:8899 + ## ModifyPostDataPlugin Modifies POST request body before sending request to upstream server. @@ -599,7 +659,9 @@ $ python >>> ``` -## proxy.new_socket_connection +## TCP Sockets + +### proxy.new_socket_connection Attempts to create an IPv4 connection, then IPv6 and finally a dual stack connection to provided address. @@ -610,7 +672,7 @@ finally a dual stack connection to provided address. >>> conn.close() ``` -## proxy.socket_connection +### proxy.socket_connection `socket_connection` is a convenient decorator + context manager around `new_socket_connection` which ensures `conn.close` is implicit. @@ -630,9 +692,11 @@ As a decorator: >>> ... [ use connection ] ... ``` -## proxy.build_http_request +## Http Client -#### Generate HTTP GET request +### proxy.build_http_request + +##### Generate HTTP GET request ``` >>> proxy.build_http_request(b'GET', b'/') @@ -640,7 +704,7 @@ b'GET / HTTP/1.1\r\n\r\n' >>> ``` -#### Generate HTTP GET request with headers +##### Generate HTTP GET request with headers ``` >>> proxy.build_http_request(b'GET', b'/', @@ -649,7 +713,7 @@ b'GET / HTTP/1.1\r\nConnection: close\r\n\r\n' >>> ``` -#### Generate HTTP POST request with headers and body +##### Generate HTTP POST request with headers and body ``` >>> import json @@ -659,6 +723,22 @@ b'GET / HTTP/1.1\r\nConnection: close\r\n\r\n' b'POST /form HTTP/1.1\r\nContent-type: application/json\r\n\r\n{"email": "hello@world.com"}' ``` +### proxy.build_http_response + +TODO + +## Websocket Client + +### proxy.WebsocketFrame + +TODO + +### proxy.WebsocketClient + +TODO + +## Embed proxy.py + To start `proxy.py` server from imported `proxy.py` module, simply do: ``` @@ -710,14 +790,14 @@ mechanism. Its responsibility is to establish connection between client and upstream [TcpServerConnection](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L204-L227) and invoke `HttpProxyBasePlugin` lifecycle hooks. -- `ProtocolHandler` threads are started by [Worker](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L424-L472) +- `ProtocolHandler` threads are started by [Acceptor](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L424-L472) processes. -- `--num-workers` `Worker` processes are started by +- `--num-workers` `Acceptor` processes are started by [AcceptorPool](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L368-L421) on start-up. -- `AcceptorPool` listens on server socket and pass the handler to `Worker` processes. +- `AcceptorPool` listens on server socket and pass the handler to `Acceptor` processes. Workers are responsible for accepting new client connections and starting `ProtocolHandler` thread. @@ -748,33 +828,23 @@ Example: ``` $ pydoc3 proxy -Help on module proxy: - -NAME - proxy - -DESCRIPTION - proxy.py - ~~~~~~~~ - Lightweight, Programmable, TLS interceptor Proxy for HTTP(S), HTTP2, WebSockets protocols in a single Python file. - - :copyright: (c) 2013-present by Abhinav Singh and contributors. - :license: BSD, see LICENSE for more details. CLASSES abc.ABC(builtins.object) HttpProxyBasePlugin HttpWebServerBasePlugin - DevtoolsFrontendPlugin + DevtoolsWebsocketPlugin HttpWebServerPacFilePlugin ProtocolHandlerPlugin - DevtoolsEventGeneratorPlugin + DevtoolsProtocolPlugin HttpProxyPlugin HttpWebServerPlugin TcpConnection TcpClientConnection TcpServerConnection WebsocketClient + ThreadlessWork + ProtocolHandler(threading.Thread, ThreadlessWork) builtins.Exception(builtins.BaseException) ProtocolException HttpRequestRejected @@ -789,17 +859,20 @@ CLASSES WebsocketFrame builtins.tuple(builtins.object) ChunkParserStates + HttpMethods HttpParserStates HttpParserTypes HttpProtocolTypes + HttpStatusCodes TcpConnectionTypes WebsocketOpcodes contextlib.ContextDecorator(builtins.object) socket_connection multiprocessing.context.Process(multiprocessing.process.BaseProcess) - Worker + Acceptor + Threadless threading.Thread(builtins.object) - ProtocolHandler + ProtocolHandler(threading.Thread, ThreadlessWork) ``` Frequently Asked Questions @@ -905,8 +978,8 @@ usage: proxy.py [-h] [--backlog BACKLOG] [--basic-auth BASIC_AUTH] [--pac-file-url-path PAC_FILE_URL_PATH] [--pid-file PID_FILE] [--plugins PLUGINS] [--port PORT] [--server-recvbuf-size SERVER_RECVBUF_SIZE] - [--static-server-dir STATIC_SERVER_DIR] [--timeout TIMEOUT] - [--version] + [--static-server-dir STATIC_SERVER_DIR] [--threadless] + [--timeout TIMEOUT] [--version] proxy.py v1.2.0 @@ -991,10 +1064,11 @@ optional arguments: value for faster downloads at the expense of increased RAM. --static-server-dir STATIC_SERVER_DIR - Default: /Users/abhinav/Dev/proxy.py/public. Static - server root directory. This option is only applicable - when static server is also enabled. See --enable- - static-server. + Default: "public" folder in directory where proxy.py + is placed. This option is only applicable when static + server is also enabled. See --enable-static-server. + --threadless Default: False. When disabled a new thread is spawned + to handle each client connection. --timeout TIMEOUT Default: 10. Number of seconds after which an inactive connection must be dropped. Inactivity is defined by no data sent or received by the client. diff --git a/benchmark.py b/benchmark.py new file mode 100755 index 0000000000..296658ce96 --- /dev/null +++ b/benchmark.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Programmable Proxy Server in a single Python file. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import argparse +import asyncio +import sys +from typing import List, Tuple + +import proxy + +DEFAULT_N = 10 + + +def init_parser() -> argparse.ArgumentParser: + """Initializes and returns argument parser.""" + parser = argparse.ArgumentParser( + description='Benchmark opens N concurrent connections ' + 'to proxy.py web server. Currently, HTTP/1.1 ' + 'keep-alive connections are opened. Over each opened ' + 'connection multiple pipelined request / response ' + 'packets are exchanged with proxy.py web server.', + epilog='Proxy.py not working? Report at: %s/issues/new' % proxy.__homepage__ + ) + parser.add_argument( + '--n', '-n', + type=int, + default=DEFAULT_N, + help='Default: ' + str(DEFAULT_N) + '. See description above for meaning of N.' + ) + return parser + + +class Benchmark: + + def __init__(self, n: int = DEFAULT_N) -> None: + self.n = n + self.clients: List[Tuple[asyncio.StreamReader, asyncio.StreamWriter]] = [] + + async def open_connections(self) -> None: + for _ in range(self.n): + self.clients.append(await asyncio.open_connection('::', 8899)) + print('Opened ' + str(self.n) + ' connections') + + def send_requests(self) -> None: + for _, writer in self.clients: + writer.write(proxy.build_http_request( + proxy.httpMethods.GET, b'/' + )) + + async def recv_responses(self) -> None: + for reader, _ in self.clients: + response = proxy.HttpParser(proxy.httpParserTypes.RESPONSE_PARSER) + while response.state != proxy.httpParserStates.COMPLETE: + response.parse(await reader.read(proxy.DEFAULT_BUFFER_SIZE)) + + async def close_connections(self) -> None: + for reader, writer in self.clients: + writer.close() + await writer.wait_closed() + print('Closed ' + str(self.n) + ' connections') + + async def run(self) -> None: + num_completed_requests_per_connection: int = 0 + try: + await self.open_connections() + print('Exchanging request / response packets') + while True: + self.send_requests() + await self.recv_responses() + num_completed_requests_per_connection += 1 + await asyncio.sleep(1) + finally: + await self.close_connections() + print('Exchanged ' + str(num_completed_requests_per_connection) + + ' request / response per connection') + + +def main(input_args: List[str]) -> None: + args = init_parser().parse_args(input_args) + benchmark = Benchmark(n=args.n) + try: + asyncio.run(benchmark.run()) + except KeyboardInterrupt: + pass + + +if __name__ == '__main__': + main(sys.argv[1:]) # pragma: no cover diff --git a/chrome_with_proxy.sh b/chrome_with_proxy.sh new file mode 100755 index 0000000000..66e00cb87d --- /dev/null +++ b/chrome_with_proxy.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# proxy.py +# ~~~~~~~~ +# ⚡⚡⚡ Fast, Lightweight, Programmable Proxy Server in a single Python file. +# +# :copyright: (c) 2013-present by Abhinav Singh and contributors. +# :license: BSD, see LICENSE for more details. +# +# Usage +# ./chrome_with_proxy + +PROXY_PY_ADDR=$1 +if [ -z "$PROXY_PY_ADDR" ]; then + PROXY_PY_ADDR="localhost:8899" +fi + +/Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome \ + --no-first-run \ + --no-default-browser-check \ + --user-data-dir="$(mktemp -d -t 'chrome-remote_data_dir')" \ + --proxy-server=$PROXY_PY_ADDR \ + --ignore-urlfetcher-cert-requests \ + --ignore-certificate-errors diff --git a/monitor_open_files.sh b/monitor_open_files.sh new file mode 100755 index 0000000000..a5001d002b --- /dev/null +++ b/monitor_open_files.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# proxy.py +# ~~~~~~~~ +# ⚡⚡⚡ Fast, Lightweight, Programmable Proxy Server in a single Python file. +# +# :copyright: (c) 2013-present by Abhinav Singh and contributors. +# :license: BSD, see LICENSE for more details. +# +# Usage +# ./monitor +# +# Alternately, just run: +# watch -n 1 'lsof -i TCP:8899 | grep -v LISTEN' + +PROXY_PY_PID=$1 +if [ -z "$PROXY_PY_PID" ]; then + echo "PROXY_PY_PID required as argument." + exit 1 +fi + +OPEN_FILES_BY_MAIN=$(lsof -p "$PROXY_PY_PID" | wc -l) +echo "[$PROXY_PY_PID] Main process: $OPEN_FILES_BY_MAIN" + +pgrep -P "$PROXY_PY_PID" | while read -r acceptorPid; do + OPEN_FILES_BY_ACCEPTOR=$(lsof -p "$acceptorPid" | wc -l) + echo "[$acceptorPid] Acceptor process: $OPEN_FILES_BY_ACCEPTOR" + + pgrep -P "$acceptorPid" | while read -r threadlessPid; do + OPEN_FILES_BY_THREADLESS=$(lsof -p "$threadlessPid" | wc -l) + echo " [$threadlessPid] Threadless process: $OPEN_FILES_BY_THREADLESS" + done +done diff --git a/plugin_examples.py b/plugin_examples.py index 7ae2d73d31..cdd1f66946 100644 --- a/plugin_examples.py +++ b/plugin_examples.py @@ -14,6 +14,72 @@ from urllib import parse as urlparse import proxy +from proxy import HttpParser + + +class ShortLinkPlugin(proxy.HttpProxyBasePlugin): + """Add support for short links in your favorite browsers / applications. + + Enable ShortLinkPlugin and speed up your daily browsing experience. + + Example: + * f/ for facebook.com + * g/ for google.com + * t/ for twitter.com + * y/ for youtube.com + * proxy/ for proxy.py internal web servers. + Customize map below for your taste and need. + + Paths are also preserved. E.g. t/imoracle will + resolve to http://twitter.com/imoracle. + """ + + SHORT_LINKS = { + b'a': b'amazon.com', + b'i': b'instagram.com', + b'l': b'linkedin.com', + b'f': b'facebook.com', + b'g': b'google.com', + b't': b'twitter.com', + b'w': b'web.whatsapp.com', + b'y': b'youtube.com', + b'proxy': b'localhost:8899', + } + + def before_upstream_connection(self, request: HttpParser) -> Optional[HttpParser]: + if request.host and proxy.DOT not in request.host: + # Avoid connecting to upstream + return None + return request + + def handle_client_request(self, request: HttpParser) -> Optional[HttpParser]: + if request.host and proxy.DOT not in request.host: + if request.host in self.SHORT_LINKS: + path = proxy.SLASH if not request.path else request.path + self.client.queue(proxy.build_http_response( + proxy.httpStatusCodes.SEE_OTHER, reason=b'See Other', + headers={ + b'Location': b'http://' + self.SHORT_LINKS[request.host] + path, + b'Content-Length': b'0', + b'Connection': b'close', + } + )) + else: + self.client.queue(proxy.build_http_response( + proxy.httpStatusCodes.NOT_FOUND, reason=b'NOT FOUND', + headers={ + b'Content-Length': b'0', + b'Connection': b'close', + } + )) + return None + return request + + def handle_upstream_chunk(self, chunk: bytes) -> bytes: + return chunk + + def on_upstream_connection_close(self) -> None: + pass class ModifyPostDataPlugin(proxy.HttpProxyBasePlugin): diff --git a/proxy.py b/proxy.py index b0e2a05a13..4324505605 100755 --- a/proxy.py +++ b/proxy.py @@ -9,9 +9,9 @@ :license: BSD, see LICENSE for more details. """ import argparse +import asyncio import base64 import contextlib -import datetime import errno import functools import hashlib @@ -39,7 +39,8 @@ from multiprocessing import connection from multiprocessing.reduction import send_handle, recv_handle from types import TracebackType -from typing import Any, Dict, List, Tuple, Optional, Union, NamedTuple, Callable, TYPE_CHECKING, Type, cast +from typing import Any, Dict, List, Tuple, Optional, Union, NamedTuple, Callable, Type, TypeVar +from typing import cast, Generator, TYPE_CHECKING from urllib import parse as urlparse from typing_extensions import Protocol @@ -90,6 +91,7 @@ DEFAULT_PORT = 8899 DEFAULT_SERVER_RECVBUF_SIZE = DEFAULT_BUFFER_SIZE DEFAULT_STATIC_SERVER_DIR = os.path.join(PROXY_PY_DIR, 'public') +DEFAULT_THREADLESS = False DEFAULT_TIMEOUT = 10 DEFAULT_VERSION = False UNDER_TEST = False # Set to True if under test @@ -122,7 +124,7 @@ def bytes_(s: Any, encoding: str = 'utf-8', errors: str = 'strict') -> Any: version = bytes_(__version__) -CRLF, COLON, WHITESPACE, COMMA, DOT, HTTP_1_1 = b'\r\n', b':', b' ', b',', b'.', b'HTTP/1.1' +CRLF, COLON, WHITESPACE, COMMA, DOT, SLASH, HTTP_1_1 = b'\r\n', b':', b' ', b',', b'.', b'/', b'HTTP/1.1' PROXY_AGENT_HEADER_KEY = b'Proxy-agent' PROXY_AGENT_HEADER_VALUE = b'proxy.py v' + version PROXY_AGENT_HEADER = PROXY_AGENT_HEADER_KEY + \ @@ -147,6 +149,11 @@ def bytes_(s: Any, encoding: str = 'utf-8', errors: str = 'strict') -> Any: ('SWITCHING_PROTOCOLS', int), # 2xx ('OK', int), + # 3xx + ('MOVED_PERMANENTLY', int), + ('SEE_OTHER', int), + ('TEMPORARY_REDIRECT', int), + ('PERMANENT_REDIRECT', int), # 4xx ('BAD_REQUEST', int), ('UNAUTHORIZED', int), @@ -166,6 +173,7 @@ def bytes_(s: Any, encoding: str = 'utf-8', errors: str = 'strict') -> Any: httpStatusCodes = HttpStatusCodes( 100, 101, 200, + 301, 303, 307, 308, 400, 401, 403, 404, 407, 408, 418, 500, 501, 502, 504, 598, 599 ) @@ -325,6 +333,7 @@ def find_http_line(raw: bytes) -> Tuple[Optional[bytes], bytes]: def new_socket_connection(addr: Tuple[str, int]) -> socket.socket: + conn = None try: ip = ipaddress.ip_address(addr[0]) if ip.version == 4: @@ -336,10 +345,13 @@ def new_socket_connection(addr: Tuple[str, int]) -> socket.socket: socket.AF_INET6, socket.SOCK_STREAM, 0) conn.connect((addr[0], addr[1], 0, 0)) except ValueError: - # Not a valid IP address, most likely its a domain name, - # try to establish dual stack IPv4/IPv6 connection. - conn = socket.create_connection(addr) - return conn + pass # does not appear to be an IPv4 or IPv6 address + + if conn is not None: + return conn + + # try to establish dual stack IPv4/IPv6 connection. + return socket.create_connection(addr) class socket_connection(contextlib.ContextDecorator): @@ -405,21 +417,15 @@ def send(self, data: bytes) -> int: return self.connection.send(data) def recv(self, buffer_size: int = DEFAULT_BUFFER_SIZE) -> Optional[bytes]: - try: - data: bytes = self.connection.recv(buffer_size) - if len(data) > 0: - logger.debug( - 'received %d bytes from %s' % - (len(data), self.tag)) - return data - except socket.error as e: - if e.errno == errno.ECONNRESET: - logger.debug('%r' % e) - else: - logger.exception( - 'Exception while receiving from connection %s %r with reason %r' % - (self.tag, self.connection, e)) - return None + """Users must handle socket.error exceptions""" + data: bytes = self.connection.recv(buffer_size) + if len(data) == 0: + return None + logger.debug( + 'received %d bytes from %s' % + (len(data), self.tag)) + # logger.info(data) + return data def close(self) -> bool: if not self.closed: @@ -438,11 +444,11 @@ def queue(self, data: bytes) -> int: return len(data) def flush(self) -> int: + """Users must handle BrokenPipeError exceptions""" if self.buffer_size() == 0: return 0 - if self.closed: - raise BrokenPipeError() sent: int = self.send(self.buffer) + # logger.info(self.buffer[:sent]) self.buffer = self.buffer[sent:] logger.debug('flushed %d bytes to %s' % (sent, self.tag)) return sent @@ -543,6 +549,9 @@ def to_chunks(raw: bytes, chunk_size: int = DEFAULT_BUFFER_SIZE) -> bytes: return CRLF.join(chunks) + CRLF +T = TypeVar('T', bound='HttpParser') + + class HttpParser: """HTTP request/response parser.""" @@ -575,6 +584,18 @@ def __init__(self, parser_type: int) -> None: self.port: Optional[int] = None self.path: Optional[bytes] = None + @classmethod + def request(cls: Type[T], raw: bytes) -> T: + parser = cls(httpParserTypes.REQUEST_PARSER) + parser.parse(raw) + return parser + + @classmethod + def response(cls: Type[T], raw: bytes) -> T: + parser = cls(httpParserTypes.RESPONSE_PARSER) + parser.parse(raw) + return parser + def header(self, key: bytes) -> bytes: if key.lower() not in self.headers: raise KeyError('%s not found in headers', text_(key)) @@ -765,7 +786,10 @@ def __init__(self, hostname: Union[ipaddress.IPv4Address, ipaddress.IPv6Address], port: int, backlog: int, num_workers: int, - work_klass: type, **kwargs: Any) -> None: + threadless: bool, + work_klass: type, + **kwargs: Any) -> None: + self.threadless = threadless self.running: bool = False self.hostname: Union[ipaddress.IPv4Address, @@ -775,11 +799,9 @@ def __init__(self, self.backlog: int = backlog self.socket: Optional[socket.socket] = None - self.current_worker_id = 0 - self.num_workers = num_workers - self.workers: List[Worker] = [] - self.work_queues: List[Tuple[connection.Connection, - connection.Connection]] = [] + self.num_acceptors = num_workers + self.acceptors: List[Acceptor] = [] + self.work_queues: List[connection.Connection] = [] self.work_klass = work_klass self.kwargs = kwargs @@ -790,26 +812,31 @@ def listen(self) -> None: self.socket.bind((str(self.hostname), self.port)) self.socket.listen(self.backlog) self.socket.setblocking(False) - self.socket.settimeout(0) logger.info('Listening on %s:%d' % (self.hostname, self.port)) def start_workers(self) -> None: """Start worker processes.""" - for worker_id in range(self.num_workers): + for _ in range(self.num_acceptors): work_queue = multiprocessing.Pipe() - - worker = Worker(work_queue[1], self.work_klass, **self.kwargs) - worker.daemon = True - worker.start() - - self.workers.append(worker) - self.work_queues.append(work_queue) - logger.info('Started %d workers' % self.num_workers) + acceptor = Acceptor( + self.family, + self.threadless, + work_queue[1], + self.work_klass, + **self.kwargs + ) + # acceptor.daemon = True + acceptor.start() + self.acceptors.append(acceptor) + self.work_queues.append(work_queue[0]) + logger.info('Started %d workers' % self.num_acceptors) def shutdown(self) -> None: - logger.info('Shutting down %d workers' % self.num_workers) - for worker in self.workers: - worker.join() + logger.info('Shutting down %d workers' % self.num_acceptors) + for acceptor in self.acceptors: + acceptor.join() + for work_queue in self.work_queues: + work_queue.close() def setup(self) -> None: """Listen on port, setup workers and pass server socket to workers.""" @@ -817,16 +844,181 @@ def setup(self) -> None: self.listen() self.start_workers() - # Send server socket to workers. + # Send server socket to all acceptor processes. assert self.socket is not None - for work_queue in self.work_queues: - work_queue[0].send(self.family) - send_handle(work_queue[0], self.socket.fileno(), - self.workers[self.current_worker_id].pid) + for index in range(self.num_acceptors): + send_handle( + self.work_queues[index], + self.socket.fileno(), + self.acceptors[index].pid + ) self.socket.close() -class Worker(multiprocessing.Process): +class ThreadlessWork(ABC): + """Implement ThreadlessWork to hook into the event loop provided by Threadless process.""" + + @abstractmethod + def initialize(self) -> None: + pass # pragma: no cover + + @abstractmethod + def is_inactive(self) -> bool: + return False # pragma: no cover + + @abstractmethod + def get_events(self) -> Dict[socket.socket, int]: + return {} # pragma: no cover + + @abstractmethod + def handle_events(self, + readables: List[Union[int, _HasFileno]], + writables: List[Union[int, _HasFileno]]) -> bool: + """Return True to shutdown work.""" + return False # pragma: no cover + + @abstractmethod + def shutdown(self) -> None: + """Must close any opened resources.""" + pass # pragma: no cover + + +class Threadless(multiprocessing.Process): + """Threadless provides an event loop. Use it by implementing Threadless class. + + When --threadless option is enabled, each Acceptor process also + spawns one Threadless process. And instead of spawning new thread + for each accepted client connection, Acceptor process sends + accepted client connection to Threadless process over a pipe. + + ProtocolHandler implements ThreadlessWork class and hooks into the + event loop provided by Threadless. + """ + + def __init__( + self, + client_queue: connection.Connection, + work_klass: type, + **kwargs: Any) -> None: + super().__init__() + self.client_queue = client_queue + self.work_klass = work_klass + self.kwargs = kwargs + + self.works: Dict[int, ThreadlessWork] = {} + self.selector: Optional[selectors.DefaultSelector] = None + self.loop: Optional[asyncio.AbstractEventLoop] = None + + @contextlib.contextmanager + def selected_events(self) -> Generator[Tuple[List[Union[int, _HasFileno]], + List[Union[int, _HasFileno]]], + None, None]: + events: Dict[socket.socket, int] = {} + for work in self.works.values(): + events.update(work.get_events()) + assert self.selector is not None + for fd in events: + self.selector.register(fd, events[fd]) + ev = self.selector.select(timeout=1) + readables = [] + writables = [] + for key, mask in ev: + if mask & selectors.EVENT_READ: + readables.append(key.fileobj) + if mask & selectors.EVENT_WRITE: + writables.append(key.fileobj) + yield (readables, writables) + for fd in events.keys(): + self.selector.unregister(fd) + + async def handle_events( + self, fileno: int, + readables: List[Union[int, _HasFileno]], + writables: List[Union[int, _HasFileno]]) -> bool: + return self.works[fileno].handle_events(readables, writables) + + # TODO: Use correct future typing annotations + async def wait_for_tasks( + self, tasks: Dict[int, Any]) -> None: + for work_id in tasks: + # TODO: Resolving one handle_events here can block resolution of other tasks + try: + teardown = await asyncio.wait_for(tasks[work_id], DEFAULT_TIMEOUT) + if teardown: + self.cleanup(work_id) + except asyncio.TimeoutError: + self.cleanup(work_id) + + def accept_client(self) -> None: + addr = self.client_queue.recv() + fileno = recv_handle(self.client_queue) + self.works[fileno] = self.work_klass( + fileno=fileno, + addr=addr, + **self.kwargs) + try: + self.works[fileno].initialize() + except ssl.SSLError as e: + logger.exception('ssl.SSLError', exc_info=e) + self.cleanup(fileno) + + def cleanup_inactive(self) -> None: + inactive_works: List[int] = [] + for work_id in self.works: + if self.works[work_id].is_inactive(): + inactive_works.append(work_id) + for work_id in inactive_works: + self.cleanup(work_id) + + def cleanup(self, work_id: int) -> None: + # TODO: ProtocolHandler.shutdown can call flush which may block + self.works[work_id].shutdown() + del self.works[work_id] + + def run_once(self) -> None: + assert self.loop is not None + readables: List[Union[int, _HasFileno]] = [] + writables: List[Union[int, _HasFileno]] = [] + with self.selected_events() as (readables, writables): + if len(readables) == 0 and len(writables) == 0: + # Remove and shutdown inactive connections + self.cleanup_inactive() + return + # Note that selector from now on is idle, + # until all the logic below completes. + # + # Invoke Threadless.handle_events + # TODO: Only send readable / writables that client originally registered. + tasks = {} + for fileno in self.works: + tasks[fileno] = self.loop.create_task( + self.handle_events(fileno, readables, writables)) + # Accepted client connection from Acceptor + if self.client_queue in readables: + self.accept_client() + # Wait for Threadless.handle_events to complete + self.loop.run_until_complete(self.wait_for_tasks(tasks)) + # Remove and shutdown inactive connections + self.cleanup_inactive() + + def run(self) -> None: + try: + self.selector = selectors.DefaultSelector() + self.selector.register(self.client_queue, selectors.EVENT_READ) + self.loop = asyncio.get_event_loop() + while True: + self.run_once() + except KeyboardInterrupt: + pass + finally: + assert self.selector is not None + self.selector.unregister(self.client_queue) + self.client_queue.close() + assert self.loop is not None + self.loop.close() + + +class Acceptor(multiprocessing.Process): """Socket client acceptor. Accepts client connection over received server socket handle and @@ -837,46 +1029,98 @@ class Worker(multiprocessing.Process): def __init__( self, + family: socket.AddressFamily, + threadless: bool, work_queue: connection.Connection, work_klass: type, - **kwargs: Any): + **kwargs: Any) -> None: super().__init__() + self.family: socket.AddressFamily = family + self.threadless: bool = threadless self.work_queue: connection.Connection = work_queue self.work_klass = work_klass self.kwargs = kwargs - self.running = True + + self.running = False + self.selector: Optional[selectors.DefaultSelector] = None + self.sock: Optional[socket.socket] = None + self.threadless_process: Optional[multiprocessing.Process] = None + self.threadless_client_queue: Optional[connection.Connection] = None + + def start_threadless_process(self) -> None: + if not self.threadless: + return + pipe = multiprocessing.Pipe() + self.threadless_client_queue = pipe[0] + self.threadless_process = Threadless( + pipe[1], self.work_klass, **self.kwargs + ) + # self.threadless_process.daemon = True + self.threadless_process.start() + + def shutdown_threadless_process(self) -> None: + if not self.threadless: + return + assert self.threadless_process and self.threadless_client_queue + self.threadless_process.join() + self.threadless_client_queue.close() + + def run_once(self) -> None: + assert self.selector + with self.lock: + events = self.selector.select(timeout=1) + if len(events) == 0: + return + try: + assert self.sock + conn, addr = self.sock.accept() + except BlockingIOError: + return + if self.threadless and \ + self.threadless_client_queue and \ + self.threadless_process: + self.threadless_client_queue.send(addr) + send_handle( + self.threadless_client_queue, + conn.fileno(), + self.threadless_process.pid + ) + conn.close() + else: + # Starting a new thread per client request simply means + # we need 1 million threads to handle a million concurrent + # connections. Since most of the client requests are short + # lived (even with keep-alive), starting threads is excessive. + work = self.work_klass( + fileno=conn.fileno(), + addr=addr, + **self.kwargs) + # work.setDaemon(True) + work.start() def run(self) -> None: - family = self.work_queue.recv() - sock = socket.fromfd( - recv_handle(self.work_queue), - family=family, + self.running = True + self.selector = selectors.DefaultSelector() + fileno = recv_handle(self.work_queue) + self.sock = socket.fromfd( + fileno, + family=self.family, type=socket.SOCK_STREAM ) - selector = selectors.DefaultSelector() + os.close(fileno) try: + self.selector.register(self.sock, selectors.EVENT_READ) + self.start_threadless_process() while self.running: - with self.lock: - selector.register(sock, selectors.EVENT_READ) - events = selector.select(timeout=1) - selector.unregister(sock) - if len(events) == 0: - continue - try: - conn, addr = sock.accept() - except BlockingIOError: # as e: - # logger.exception('BlockingIOError', exc_info=e) - continue - work = self.work_klass( - fileno=conn.fileno(), - addr=addr, - **self.kwargs) - work.setDaemon(True) - work.start() + self.run_once() except KeyboardInterrupt: pass finally: - sock.close() + self.selector.unregister(self.sock) + self.shutdown_threadless_process() + self.sock.close() + self.work_queue.close() + self.running = False class ProtocolException(Exception): @@ -1001,7 +1245,9 @@ def __init__( enable_static_server: bool = DEFAULT_ENABLE_STATIC_SERVER, devtools_event_queue: Optional[DevtoolsEventQueueType] = None, devtools_ws_path: bytes = DEFAULT_DEVTOOLS_WS_PATH, - timeout: int = DEFAULT_TIMEOUT) -> None: + timeout: int = DEFAULT_TIMEOUT, + threadless: bool = DEFAULT_THREADLESS) -> None: + self.threadless = threadless self.timeout = timeout self.auth_code = auth_code self.server_recvbuf_size = server_recvbuf_size @@ -1205,6 +1451,7 @@ def __init__( self.server: Optional[TcpServerConnection] = None self.response: HttpParser = HttpParser(httpParserTypes.RESPONSE_PARSER) self.pipeline_request: Optional[HttpParser] = None + self.pipeline_response: Optional[HttpParser] = None self.plugins: Dict[str, HttpProxyBasePlugin] = {} if b'HttpProxyBasePlugin' in self.config.plugins: @@ -1234,6 +1481,9 @@ def write_to_descriptors(self, w: List[Union[int, _HasFileno]]) -> bool: logger.debug('Server is write ready, flushing buffer') try: self.server.flush() + except OSError: + logger.error('OSError when flushing buffer to server') + return True except BrokenPipeError: logger.error( 'BrokenPipeError when flushing buffer for server') @@ -1243,8 +1493,23 @@ def write_to_descriptors(self, w: List[Union[int, _HasFileno]]) -> bool: def read_from_descriptors(self, r: List[Union[int, _HasFileno]]) -> bool: if self.request.has_upstream_server( ) and self.server and not self.server.closed and self.server.connection in r: - logger.debug('Server is ready for reads, reading') - raw = self.server.recv(self.config.server_recvbuf_size) + logger.debug('Server is ready for reads, reading...') + raw: Optional[bytes] = None + + try: + raw = self.server.recv(self.config.server_recvbuf_size) + except ssl.SSLWantReadError: # Try again later + # logger.warning('SSLWantReadError encountered while reading from server, will retry ...') + return False + except socket.error as e: + if e.errno == errno.ECONNRESET: + logger.warning('Connection reset by upstream: %r' % e) + else: + logger.exception( + 'Exception while receiving from %s connection %r with reason %r' % + (self.server.tag, self.server.connection, e)) + return True + if not raw: logger.debug('Server closed connection, tearing down...') return True @@ -1255,9 +1520,19 @@ def read_from_descriptors(self, r: List[Union[int, _HasFileno]]) -> bool: # parse incoming response packet # only for non-https requests and when # tls interception is enabled - if self.request.method != httpMethods.CONNECT or \ - self.config.tls_interception_enabled(): - self.response.parse(raw) + if self.request.method != httpMethods.CONNECT: + # See https://github.com/abhinavsingh/proxy.py/issues/127 for why + # currently response parsing is disabled when TLS interception is enabled. + # + # or self.config.tls_interception_enabled(): + if self.response.state == httpParserStates.COMPLETE: + if self.pipeline_response is None: + self.pipeline_response = HttpParser(httpParserTypes.RESPONSE_PARSER) + self.pipeline_response.parse(raw) + if self.pipeline_response.state == httpParserStates.COMPLETE: + self.pipeline_response = None + else: + self.response.parse(raw) else: self.response.total_size += len(raw) # queue raw data for client @@ -1293,12 +1568,30 @@ def access_log(self) -> None: def on_client_connection_close(self) -> None: if not self.request.has_upstream_server(): return + self.access_log() + + # If server was never initialized, return + if self.server is None: + return + + # Note that, server instance was initialized + # but not necessarily the connection object exists. # Invoke plugin.on_upstream_connection_close - if self.server and not self.server.closed: - for plugin in self.plugins.values(): - plugin.on_upstream_connection_close() - self.server.close() + for plugin in self.plugins.values(): + plugin.on_upstream_connection_close() + + try: + try: + self.server.connection.shutdown(socket.SHUT_WR) + except OSError: + pass + finally: + # TODO: Unwrap if wrapped before close? + self.server.connection.close() + except TcpConnectionUninitializedException: + pass + finally: logger.debug( 'Closed server connection with pending server buffer size %d bytes' % self.server.buffer_size()) @@ -1319,9 +1612,6 @@ def on_client_data(self, raw: bytes) -> Optional[bytes]: return raw if self.server and not self.server.closed: - # If 1st request did reach completion stage - # and 1st request was not a CONNECT request - # or if TLS interception was enabled if self.request.state == httpParserStates.COMPLETE and ( self.request.method != httpMethods.CONNECT or self.config.tls_interception_enabled()): @@ -1377,6 +1667,34 @@ def generate_upstream_certificate(self, _certificate: Optional[Dict[str, Any]]) sign_cert.communicate(timeout=10) return cert_file_path + def wrap_server(self) -> None: + assert self.server is not None + assert isinstance(self.server.connection, socket.socket) + ctx = ssl.create_default_context( + ssl.Purpose.SERVER_AUTH) + ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 + self.server.connection.setblocking(True) + self.server._conn = ctx.wrap_socket( + self.server.connection, + server_hostname=text_(self.request.host)) + self.server.connection.setblocking(False) + + def wrap_client(self) -> None: + assert self.server is not None + assert isinstance(self.server.connection, ssl.SSLSocket) + generated_cert = self.generate_upstream_certificate( + cast(Dict[str, Any], self.server.connection.getpeercert())) + self.client.connection.setblocking(True) + self.client.flush() + self.client._conn = ssl.wrap_socket( + self.client.connection, + server_side=True, + keyfile=self.config.ca_signing_key_file, + certfile=generated_cert) + self.client.connection.setblocking(False) + logger.debug( + 'TLS interception using %s', generated_cert) + def on_request_complete(self) -> Union[socket.socket, bool]: if not self.request.has_upstream_server(): return False @@ -1409,31 +1727,20 @@ def on_request_complete(self) -> Union[socket.socket, bool]: HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) # If interception is enabled if self.config.tls_interception_enabled(): - assert self.server is not None - assert isinstance(self.server.connection, socket.socket) # Perform SSL/TLS handshake with upstream - ctx = ssl.create_default_context( - ssl.Purpose.SERVER_AUTH) - ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 - self.server.connection.setblocking(True) - self.server._conn = ctx.wrap_socket( - self.server.connection, - server_hostname=text_(self.request.host)) - self.server.connection.setblocking(False) - assert isinstance(self.server.connection, ssl.SSLSocket) + self.wrap_server() # Generate certificate and perform handshake with client - generated_cert = self.generate_upstream_certificate( - cast(Dict[str, Any], self.server.connection.getpeercert())) - self.client.flush() - self.client.connection.setblocking(True) - self.client._conn = ssl.wrap_socket( - self.client.connection, - server_side=True, - keyfile=self.config.ca_signing_key_file, - certfile=generated_cert) - self.client.connection.setblocking(False) - logger.info( - 'TLS interception using %s', generated_cert) + try: + # wrap_client also flushes client data before wrapping + # sending to client can raise, handle expected exceptions + self.wrap_client() + except OSError: + logger.error('OSError when wrapping client') + return True + except BrokenPipeError: + logger.error( + 'BrokenPipeError when wrapping client') + return True # Update all plugin connection reference for plugin in self.plugins.values(): plugin.client._conn = self.client.connection @@ -1688,7 +1995,7 @@ def run(self) -> None: finally: try: self.selector.unregister(self.sock) - self.sock.shutdown(socket.SHUT_RDWR) + self.sock.shutdown(socket.SHUT_WR) except Exception as e: logging.exception('Exception while shutdown of websocket client', exc_info=e) self.sock.close() @@ -1755,7 +2062,7 @@ def start_dispatcher(self) -> None: args=(self.event_dispatcher_shutdown, self.config.devtools_event_queue, self.client)) - self.event_dispatcher_thread.setDaemon(True) + # self.event_dispatcher_thread.setDaemon(True) self.event_dispatcher_thread.start() def stop_dispatcher(self) -> None: @@ -1881,8 +2188,10 @@ def cache_pac_file_response(self) -> None: def routes(self) -> List[Tuple[int, bytes]]: if self.config.pac_file_url_path: - return [(httpProtocolTypes.HTTP, bytes_( - self.config.pac_file_url_path))] + return [ + (httpProtocolTypes.HTTP, bytes_(self.config.pac_file_url_path)), + (httpProtocolTypes.HTTPS, bytes_(self.config.pac_file_url_path)), + ] return [] # pragma: no cover def handle_request(self, request: HttpParser) -> None: @@ -1939,6 +2248,11 @@ def __init__( self.routes[protocol][path] = instance def serve_file_or_404(self, path: str) -> bool: + """Read and serves a file from disk. + + Queues 404 Not Found for IOError. + Shouldn't this be server error? + """ try: with open(path, 'rb') as f: content = f.read() @@ -2030,15 +2344,17 @@ def on_client_data(self, raw: bytes) -> Optional[bytes]: frame.reset() return None # If 1st valid request was completed and it's a HTTP/1.1 keep-alive + # And only if we have a route, parse pipeline requests elif self.request.state == httpParserStates.COMPLETE and \ - self.request.is_http_1_1_keep_alive(): + self.request.is_http_1_1_keep_alive() and \ + self.route is not None: if self.pipeline_request is None: self.pipeline_request = HttpParser(httpParserTypes.REQUEST_PARSER) self.pipeline_request.parse(raw) if self.pipeline_request.state == httpParserStates.COMPLETE: - assert self.route is not None self.route.handle_request(self.pipeline_request) if not self.pipeline_request.is_http_1_1_keep_alive(): + logger.error('Pipelined request is not keep-alive, will teardown request...') raise ProtocolException() self.pipeline_request = None return raw @@ -2058,18 +2374,19 @@ def on_client_connection_close(self) -> None: def access_log(self) -> None: logger.info( - '%s:%s - %s %s' % + '%s:%s - %s %s - %.2f ms' % (self.client.addr[0], self.client.addr[1], text_(self.request.method), - text_(self.request.path))) + text_(self.request.path), + (time.time() - self.start_time) * 1000)) def get_descriptors( self) -> Tuple[List[socket.socket], List[socket.socket]]: return [], [] -class ProtocolHandler(threading.Thread): +class ProtocolHandler(threading.Thread, ThreadlessWork): """HTTP, HTTPS, HTTP2, WebSockets protocol handler. Accepts `Client` connection object and manages ProtocolHandlerPlugin invocations. @@ -2081,8 +2398,8 @@ def __init__(self, fileno: int, addr: Tuple[str, int], self.fileno: int = fileno self.addr: Tuple[str, int] = addr - self.start_time: datetime.datetime = self.now() - self.last_activity: datetime.datetime = self.start_time + self.start_time: float = time.time() + self.last_activity: float = self.start_time self.config: ProtocolConfig = config if config else ProtocolConfig() self.request: HttpParser = HttpParser(httpParserTypes.REQUEST_PARSER) @@ -2094,10 +2411,6 @@ def __init__(self, fileno: int, addr: Tuple[str, int], ) self.plugins: Dict[str, ProtocolHandlerPlugin] = {} - @staticmethod - def now() -> datetime.datetime: - return datetime.datetime.utcnow() - def initialize(self) -> None: """Optionally upgrades connection to HTTPS, set conn in non-blocking mode and initializes plugins.""" conn = self.optionally_wrap_socket(self.client.connection) @@ -2108,11 +2421,99 @@ def initialize(self) -> None: for klass in self.config.plugins[b'ProtocolHandlerPlugin']: instance = klass(self.config, self.client, self.request) self.plugins[instance.name()] = instance + logger.debug('Handling connection %r' % self.client.connection) + + def is_inactive(self) -> bool: + if not self.client.has_buffer() and \ + self.connection_inactive_for() > self.config.timeout: + return True + return False + + def get_events(self) -> Dict[socket.socket, int]: + events: Dict[socket.socket, int] = { + self.client.connection: selectors.EVENT_READ + } + if self.client.has_buffer(): + events[self.client.connection] |= selectors.EVENT_WRITE + + # ProtocolHandlerPlugin.get_descriptors + for plugin in self.plugins.values(): + plugin_read_desc, plugin_write_desc = plugin.get_descriptors() + for r in plugin_read_desc: + if r not in events: + events[r] = selectors.EVENT_READ + else: + events[r] |= selectors.EVENT_READ + for w in plugin_write_desc: + if w not in events: + events[w] = selectors.EVENT_WRITE + else: + events[w] |= selectors.EVENT_WRITE + + return events + + def handle_events( + self, + readables: List[Union[int, _HasFileno]], + writables: List[Union[int, _HasFileno]]) -> bool: + """Returns True if proxy must teardown.""" + # Flush buffer for ready to write sockets + teardown = self.handle_writables(writables) + if teardown: + return True + + # Invoke plugin.write_to_descriptors + for plugin in self.plugins.values(): + teardown = plugin.write_to_descriptors(writables) + if teardown: + return True + + # Read from ready to read sockets + teardown = self.handle_readables(readables) + if teardown: + return True + + # Invoke plugin.read_from_descriptors + for plugin in self.plugins.values(): + teardown = plugin.read_from_descriptors(readables) + if teardown: + return True + + return False + + def shutdown(self) -> None: + # Flush pending buffer if any + self.flush() + + # Invoke plugin.on_client_connection_close + for plugin in self.plugins.values(): + plugin.on_client_connection_close() + + logger.debug( + 'Closing client connection %r ' + 'at address %r with pending client buffer size %d bytes' % + (self.client.connection, self.client.addr, self.client.buffer_size())) + + conn = self.client.connection + try: + # Unwrap if wrapped before shutdown. + if self.config.encryption_enabled() and \ + isinstance(self.client.connection, ssl.SSLSocket): + conn = self.client.connection.unwrap() + conn.shutdown(socket.SHUT_WR) + logger.debug('Client connection shutdown successful') + except OSError: + pass + finally: + conn.close() + logger.debug('Client connection closed') def fromfd(self, fileno: int) -> socket.socket: - return socket.fromfd( + conn = socket.fromfd( fileno, family=socket.AF_INET if self.config.hostname.version == 4 else socket.AF_INET6, type=socket.SOCK_STREAM) + os.close(fileno) + return conn def optionally_wrap_socket( self, conn: socket.socket) -> Union[ssl.SSLSocket, socket.socket]: @@ -2132,16 +2533,28 @@ def optionally_wrap_socket( conn = ctx.wrap_socket(conn, server_side=True) return conn - def connection_inactive_for(self) -> int: - return (self.now() - self.last_activity).seconds + def connection_inactive_for(self) -> float: + return time.time() - self.last_activity - def is_connection_inactive(self) -> bool: - return self.connection_inactive_for() > self.config.timeout + def flush(self) -> None: + if not self.client.has_buffer(): + return + try: + self.selector.register(self.client.connection, selectors.EVENT_WRITE) + while self.client.has_buffer(): + ev: List[Tuple[selectors.SelectorKey, int]] = self.selector.select(timeout=1) + if len(ev) == 0: + continue + self.client.flush() + except BrokenPipeError: + pass + finally: + self.selector.unregister(self.client.connection) def handle_writables(self, writables: List[Union[int, _HasFileno]]) -> bool: if self.client.buffer_size() > 0 and self.client.connection in writables: logger.debug('Client is ready for writes, flushing buffer') - self.last_activity = self.now() + self.last_activity = time.time() # Invoke plugin.on_response_chunk chunk = self.client.buffer @@ -2152,6 +2565,9 @@ def handle_writables(self, writables: List[Union[int, _HasFileno]]) -> bool: try: self.client.flush() + except OSError: + logger.error('OSError when flushing buffer to client') + return True except BrokenPipeError: logger.error( 'BrokenPipeError when flushing buffer for client') @@ -2161,9 +2577,23 @@ def handle_writables(self, writables: List[Union[int, _HasFileno]]) -> bool: def handle_readables(self, readables: List[Union[int, _HasFileno]]) -> bool: if self.client.connection in readables: logger.debug('Client is ready for reads, reading') - self.last_activity = self.now() + self.last_activity = time.time() + client_data: Optional[bytes] = None + + try: + client_data = self.client.recv(self.config.client_recvbuf_size) + except ssl.SSLWantReadError: # Try again later + logger.warning('SSLWantReadError encountered while reading from client, will retry ...') + return False + except socket.error as e: + if e.errno == errno.ECONNRESET: + logger.warning('%r' % e) + else: + logger.exception( + 'Exception while receiving from %s connection %r with reason %r' % + (self.client.tag, self.client.connection, e)) + return True - client_data = self.client.recv(self.config.client_recvbuf_size) if not client_data: logger.debug('Client closed connection, tearing down...') self.client.closed = True @@ -2198,8 +2628,6 @@ def handle_readables(self, readables: List[Union[int, _HasFileno]]) -> bool: for plugin_ in self.plugins.values(): if plugin_ != plugin: plugin_.client._conn = upgraded_sock - logger.debug( - 'Upgraded client conn for plugin %s', str(plugin_)) elif isinstance(upgraded_sock, bool) and upgraded_sock is True: return True except ProtocolException as e: @@ -2211,116 +2639,43 @@ def handle_readables(self, readables: List[Union[int, _HasFileno]]) -> bool: return True return False - def get_events(self) -> Dict[socket.socket, int]: - events: Dict[socket.socket, int] = { - self.client.connection: selectors.EVENT_READ - } - if self.client.has_buffer(): - events[self.client.connection] |= selectors.EVENT_WRITE - - # ProtocolHandlerPlugin.get_descriptors - for plugin in self.plugins.values(): - plugin_read_desc, plugin_write_desc = plugin.get_descriptors() - for r in plugin_read_desc: - if r not in events: - events[r] = selectors.EVENT_READ - else: - events[r] |= selectors.EVENT_READ - for w in plugin_write_desc: - if w not in events: - events[w] = selectors.EVENT_WRITE - else: - events[w] |= selectors.EVENT_WRITE - - return events - - def handle_events(self, readables: List[Union[int, _HasFileno]], writables: List[Union[int, _HasFileno]]) -> bool: - """Returns True if proxy must teardown.""" - # Flush buffer for ready to write sockets - teardown = self.handle_writables(writables) - if teardown: - return True - - # Invoke plugin.write_to_descriptors - for plugin in self.plugins.values(): - teardown = plugin.write_to_descriptors(writables) - if teardown: - return True - - # Read from ready to read sockets - teardown = self.handle_readables(readables) - if teardown: - return True - - # Invoke plugin.read_from_descriptors - for plugin in self.plugins.values(): - teardown = plugin.read_from_descriptors(readables) - if teardown: - return True - - # Teardown if client buffer is empty and connection is inactive - if not self.client.has_buffer() and \ - self.is_connection_inactive(): - self.client.queue(build_http_response( - httpStatusCodes.REQUEST_TIMEOUT, reason=b'Request Timeout', - headers={ - b'Server': PROXY_AGENT_HEADER_VALUE, - b'Connection': b'close', - } - )) - logger.debug( - 'Client buffer is empty and maximum inactivity has reached ' - 'between client and server connection, tearing down...') - return True - - return False - - def run_once(self) -> bool: + @contextlib.contextmanager + def selected_events(self) -> \ + Generator[Tuple[List[Union[int, _HasFileno]], + List[Union[int, _HasFileno]]], + None, None]: events = self.get_events() for fd in events: self.selector.register(fd, events[fd]) - - # Select - e: List[Tuple[selectors.SelectorKey, int]] = self.selector.select(timeout=1) + ev = self.selector.select(timeout=1) readables = [] writables = [] - for key, mask in e: + for key, mask in ev: if mask & selectors.EVENT_READ: readables.append(key.fileobj) if mask & selectors.EVENT_WRITE: writables.append(key.fileobj) - - teardown = self.handle_events(readables, writables) - - # Unregister + yield (readables, writables) for fd in events.keys(): self.selector.unregister(fd) - if teardown: - return True - return False - - def flush(self) -> None: - if not self.client.has_buffer(): - return - try: - self.selector.register(self.client.connection, selectors.EVENT_WRITE) - while self.client.has_buffer(): - ev: List[Tuple[selectors.SelectorKey, int]] = self.selector.select(timeout=1) - if len(ev) == 0: - continue - self.client.flush() - except BrokenPipeError: - pass - finally: - self.selector.unregister(self.client.connection) + def run_once(self) -> bool: + with self.selected_events() as (readables, writables): + teardown = self.handle_events(readables, writables) + if teardown: + return True + return False def run(self) -> None: try: self.initialize() - logger.debug('Handling connection %r' % self.client.connection) - while True: + # Teardown if client buffer is empty and connection is inactive + if self.is_inactive(): + logger.debug( + 'Client buffer is empty and maximum inactivity has reached ' + 'between client and server connection, tearing down...') + break teardown = self.run_once() if teardown: break @@ -2333,27 +2688,7 @@ def run(self) -> None: 'Exception while handling connection %r' % self.client.connection, exc_info=e) finally: - # Flush pending buffer if any - self.flush() - - # Invoke plugin.on_client_connection_close - for plugin in self.plugins.values(): - plugin.on_client_connection_close() - - logger.debug( - 'Closing proxy for connection %r ' - 'at address %r with pending client buffer size %d bytes' % - (self.client.connection, self.client.addr, self.client.buffer_size())) - - if not self.client.closed: - try: - self.client.connection.shutdown(socket.SHUT_WR) - logger.debug('Client connection shutdown successful') - except OSError: - pass - finally: - self.client.connection.close() - logger.debug('Client connection closed') + self.shutdown() class DevtoolsProtocolPlugin(ProtocolHandlerPlugin): @@ -2755,10 +3090,17 @@ def init_parser() -> argparse.ArgumentParser: '--static-server-dir', type=str, default=DEFAULT_STATIC_SERVER_DIR, - help='Default: ' + DEFAULT_STATIC_SERVER_DIR + '. Static server root directory. ' + help='Default: "public" folder in directory where proxy.py is placed. ' 'This option is only applicable when static server is also enabled. ' 'See --enable-static-server.' ) + parser.add_argument( + '--threadless', + action='store_true', + default=DEFAULT_THREADLESS, + help='Default: False. When disabled a new thread is spawned ' + 'to handle each client connection.' + ) parser.add_argument( '--timeout', type=int, @@ -2796,7 +3138,8 @@ def main(input_args: List[str]) -> None: if (args.cert_file and args.key_file) and \ (args.ca_key_file and args.ca_cert_file and args.ca_signing_key_file): - print('HTTPS interception not supported when proxy.py is serving over HTTPS') + print('You can either enable end-to-end encryption OR TLS interception,' + 'not both together.') sys.exit(0) try: @@ -2848,7 +3191,8 @@ def main(input_args: List[str]) -> None: enable_static_server=args.enable_static_server, devtools_event_queue=devtools_event_queue, devtools_ws_path=args.devtools_ws_path, - timeout=args.timeout) + timeout=args.timeout, + threadless=args.threadless) config.plugins = load_plugins( bytes_( @@ -2860,6 +3204,7 @@ def main(input_args: List[str]) -> None: port=config.port, backlog=config.backlog, num_workers=config.num_workers, + threadless=config.threadless, work_klass=ProtocolHandler, config=config) if args.pid_file: diff --git a/tests.py b/tests.py index 8111f0f718..dddcd2a2f9 100644 --- a/tests.py +++ b/tests.py @@ -8,7 +8,6 @@ :license: BSD, see LICENSE for more details. """ import base64 -import errno import ipaddress import json import logging @@ -99,13 +98,6 @@ def connection(self) -> Union[ssl.SSLSocket, socket.socket]: raise proxy.TcpConnectionUninitializedException() return self._conn - def testFlushThrowsBrokenPipeIfClosed(self) -> None: - self.conn = TestTcpConnection.TcpConnectionToTest() - self.conn.queue(b'some data') - self.conn.closed = True - with self.assertRaises(BrokenPipeError): - self.conn.flush() - def testThrowsKeyErrorIfNoConn(self) -> None: self.conn = TestTcpConnection.TcpConnectionToTest() with self.assertRaises(proxy.TcpConnectionUninitializedException): @@ -115,28 +107,6 @@ def testThrowsKeyErrorIfNoConn(self) -> None: with self.assertRaises(proxy.TcpConnectionUninitializedException): self.conn.close() - def testHandlesIOError(self) -> None: - _conn = mock.MagicMock() - _conn.recv.side_effect = IOError() - self.conn = TestTcpConnection.TcpConnectionToTest(_conn) - with mock.patch('proxy.logger') as mock_logger: - self.conn.recv() - mock_logger.exception.assert_called() - logging.info(mock_logger.exception.call_args[0][0].startswith( - 'Exception while receiving from connection')) - - def testHandlesConnReset(self) -> None: - _conn = mock.MagicMock() - e = IOError() - e.errno = errno.ECONNRESET - _conn.recv.side_effect = e - self.conn = TestTcpConnection.TcpConnectionToTest(_conn) - with mock.patch('proxy.logger') as mock_logger: - self.conn.recv() - mock_logger.exception.assert_not_called() - mock_logger.debug.assert_called() - self.assertEqual(mock_logger.debug.call_args[0][0], '%r' % e) - def testClosesIfNotClosed(self) -> None: _conn = mock.MagicMock() self.conn = TestTcpConnection.TcpConnectionToTest(_conn) @@ -258,7 +228,7 @@ class TestAcceptorPool(unittest.TestCase): @mock.patch('proxy.send_handle') @mock.patch('multiprocessing.Pipe') @mock.patch('socket.socket') - @mock.patch('proxy.Worker') + @mock.patch('proxy.Acceptor') def test_setup_and_shutdown( self, mock_worker: mock.Mock, @@ -278,6 +248,7 @@ def test_setup_and_shutdown( proxy.DEFAULT_PORT, proxy.DEFAULT_BACKLOG, num_workers, + threadless=proxy.DEFAULT_THREADLESS, work_klass=work_klass, **kwargs ) @@ -292,7 +263,6 @@ def test_setup_and_shutdown( sock.bind.assert_called_with((str(acceptor.hostname), acceptor.port)) sock.listen.assert_called_with(acceptor.backlog) sock.setblocking.assert_called_with(False) - sock.settimeout.assert_called_with(0) self.assertTrue(mock_pipe.call_count, num_workers) self.assertTrue(mock_worker.call_count, num_workers) @@ -311,15 +281,20 @@ def test_setup_and_shutdown( class TestWorker(unittest.TestCase): @mock.patch('proxy.ProtocolHandler') - def setUp(self, mock_protocol_handler: mock.Mock) -> None: + def setUp( + self, + mock_protocol_handler: mock.Mock) -> None: + self.mock_protocol_handler = mock_protocol_handler self.pipe = multiprocessing.Pipe() self.protocol_config = proxy.ProtocolConfig() - self.worker = proxy.Worker( + self.worker = proxy.Acceptor( + socket.AF_INET6, + proxy.DEFAULT_THREADLESS, self.pipe[1], mock_protocol_handler, config=self.protocol_config) - self.mock_protocol_handler = mock_protocol_handler + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') @mock.patch('proxy.recv_handle') @@ -327,7 +302,8 @@ def test_continues_when_no_events( self, mock_recv_handle: mock.Mock, mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: fileno = 10 conn = mock.MagicMock() addr = mock.MagicMock() @@ -338,12 +314,12 @@ def test_continues_when_no_events( selector = mock_selector.return_value selector.select.side_effect = [[], KeyboardInterrupt()] - self.pipe[0].send(socket.AF_INET6) self.worker.run() sock.accept.assert_not_called() self.mock_protocol_handler.assert_not_called() + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') @mock.patch('proxy.recv_handle') @@ -351,7 +327,8 @@ def test_worker_doesnt_teardown_on_blocking_io_error( self, mock_recv_handle: mock.Mock, mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: fileno = 10 conn = mock.MagicMock() addr = mock.MagicMock() @@ -363,11 +340,11 @@ def test_worker_doesnt_teardown_on_blocking_io_error( selector.select.side_effect = [(None, None), KeyboardInterrupt()] sock.accept.side_effect = BlockingIOError() - self.pipe[0].send(socket.AF_INET6) self.worker.run() self.mock_protocol_handler.assert_not_called() + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') @mock.patch('proxy.recv_handle') @@ -375,7 +352,8 @@ def test_accepts_client_from_server_socket( self, mock_recv_handle: mock.Mock, mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: fileno = 10 conn = mock.MagicMock() addr = mock.MagicMock() @@ -388,7 +366,6 @@ def test_accepts_client_from_server_socket( selector = mock_selector.return_value selector.select.return_value = [(None, None)] - self.pipe[0].send(socket.AF_INET6) self.worker.run() selector.register.assert_called_with(sock, selectors.EVENT_READ) @@ -404,7 +381,7 @@ def test_accepts_client_from_server_socket( addr=addr, **{'config': self.protocol_config} ) - self.mock_protocol_handler.return_value.setDaemon.assert_called() + # self.mock_protocol_handler.return_value.setDaemon.assert_called() self.mock_protocol_handler.return_value.start.assert_called() sock.close.assert_called() @@ -996,9 +973,13 @@ def test_handshake(self, mock_connect: mock.Mock, mock_b64encode: mock.Mock) -> class TestHttpProtocolHandler(unittest.TestCase): + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') - def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: + def setUp(self, + mock_fromfd: mock.Mock, + mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: self.fileno = 10 self._addr = ('127.0.0.1', 54382) self._conn = mock_fromfd.return_value @@ -1011,6 +992,7 @@ def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self.mock_selector = mock_selector self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=self.config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() @mock.patch('proxy.TcpServerConnection') @@ -1124,10 +1106,14 @@ def test_proxy_connection_failed(self) -> None: self.proxy.run_once() self.assertEqual(self.proxy.client.buffer, proxy.ProxyConnectionFailed.RESPONSE_PKT) + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_proxy_authentication_failed( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: + self, + mock_fromfd: mock.Mock, + mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) config = proxy.ProtocolConfig( @@ -1137,6 +1123,7 @@ def test_proxy_authentication_failed( b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() self._conn.recv.return_value = proxy.CRLF.join([ b'GET http://abhinavsingh.com HTTP/1.1', @@ -1148,13 +1135,15 @@ def test_proxy_authentication_failed( self.proxy.client.buffer, proxy.ProxyAuthenticationFailed.RESPONSE_PKT) + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') @mock.patch('proxy.TcpServerConnection') def test_authenticated_proxy_http_get( self, mock_server_connection: mock.Mock, mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) @@ -1170,6 +1159,7 @@ def test_authenticated_proxy_http_get( self.proxy = proxy.ProtocolHandler( self.fileno, addr=self._addr, config=config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() assert self.http_server_port is not None @@ -1196,13 +1186,15 @@ def test_authenticated_proxy_http_get( ]) self.assert_data_queued(mock_server_connection, server) + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') @mock.patch('proxy.TcpServerConnection') def test_authenticated_proxy_http_tunnel( self, mock_server_connection: mock.Mock, mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: server = mock_server_connection.return_value server.connect.return_value = True server.buffer_size.return_value = 0 @@ -1217,6 +1209,7 @@ def test_authenticated_proxy_http_tunnel( self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() assert self.http_server_port is not None @@ -1310,9 +1303,10 @@ def mock_selector_for_client_read(self, mock_selector: mock.Mock) -> None: class TestWebServerPlugin(unittest.TestCase): + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') - def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: + def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, mock_os_close: mock.Mock) -> None: self.fileno = 10 self._addr = ('127.0.0.1', 54382) self._conn = mock_fromfd.return_value @@ -1322,16 +1316,20 @@ def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=self.config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_pac_file_served_from_disk( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: + self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: pac_file = 'proxy.pac' self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) self.init_and_make_pac_file_request(pac_file) + mock_os_close.assert_called_with(self.fileno) self.proxy.run_once() self.assertEqual( self.proxy.request.state, @@ -1344,14 +1342,17 @@ def test_pac_file_served_from_disk( }, body=f.read() )) + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_pac_file_served_from_buffer( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: + self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) pac_file_content = b'function FindProxyForURL(url, host) { return "PROXY localhost:8899; DIRECT"; }' self.init_and_make_pac_file_request(proxy.text_(pac_file_content)) + mock_os_close.assert_called_with(self.fileno) self.proxy.run_once() self.assertEqual( self.proxy.request.state, @@ -1363,10 +1364,12 @@ def test_pac_file_served_from_buffer( }, body=pac_file_content )) + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_default_web_server_returns_404( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: + self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: self._conn = mock_fromfd.return_value mock_selector.return_value.select.return_value = [( selectors.SelectorKey( @@ -1379,6 +1382,7 @@ def test_default_web_server_returns_404( b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() self._conn.recv.return_value = proxy.CRLF.join([ b'GET /hello HTTP/1.1', @@ -1392,10 +1396,12 @@ def test_default_web_server_returns_404( self.proxy.client.buffer, proxy.HttpWebServerPlugin.DEFAULT_404_RESPONSE) + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_static_web_server_serves( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: + self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: # Setup a static directory static_server_dir = os.path.join(tempfile.gettempdir(), 'static') index_file_path = os.path.join(static_server_dir, 'index.html') @@ -1447,10 +1453,14 @@ def test_static_web_server_serves( body=html_file_content )) + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_static_web_server_serves_404( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: + self, + mock_fromfd: mock.Mock, + mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: self._conn = mock_fromfd.return_value self._conn.recv.return_value = proxy.build_http_request(b'GET', b'/not-found.html') @@ -1472,6 +1482,7 @@ def test_static_web_server_serves_404( self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() self.proxy.run_once() @@ -1482,15 +1493,17 @@ def test_static_web_server_serves_404( self.assertEqual(self._conn.send.call_args[0][0], proxy.HttpWebServerPlugin.DEFAULT_404_RESPONSE) + @mock.patch('os.close') @mock.patch('socket.fromfd') def test_on_client_connection_called_on_teardown( - self, mock_fromfd: mock.Mock) -> None: + self, mock_fromfd: mock.Mock, mock_os_close: mock.Mock) -> None: config = proxy.ProtocolConfig() plugin = mock.MagicMock() config.plugins = {b'ProtocolHandlerPlugin': [plugin]} self._conn = mock_fromfd.return_value self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() plugin.assert_called() with mock.patch.object(self.proxy, 'run_once') as mock_run_once: @@ -1522,11 +1535,13 @@ def mock_selector_for_client_read(self, mock_selector: mock.Mock) -> None: class TestHttpProxyPlugin(unittest.TestCase): + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def setUp(self, mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: self.mock_fromfd = mock_fromfd self.mock_selector = mock_selector @@ -1541,6 +1556,7 @@ def setUp(self, self._conn = mock_fromfd.return_value self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=self.config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() def test_proxy_plugin_initialized(self) -> None: @@ -1595,11 +1611,13 @@ def test_proxy_plugin_before_upstream_connection_can_teardown( class TestHttpProxyPluginExamples(unittest.TestCase): + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def setUp(self, mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.config = proxy.ProtocolConfig() @@ -1617,6 +1635,7 @@ def setUp(self, self._conn = mock_fromfd.return_value self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=self.config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() @mock.patch('proxy.TcpServerConnection') @@ -1743,11 +1762,6 @@ def test_filter_by_upstream_host_plugin( ) ) - @mock.patch('proxy.TcpServerConnection') - def test_cache_responses_plugin( - self, mock_server_conn: mock.Mock) -> None: - pass - @mock.patch('proxy.TcpServerConnection') def test_man_in_the_middle_plugin( self, mock_server_conn: mock.Mock) -> None: @@ -1822,6 +1836,7 @@ def closed() -> bool: class TestHttpProxyTlsInterception(unittest.TestCase): + @mock.patch('os.close') @mock.patch('ssl.wrap_socket') @mock.patch('ssl.create_default_context') @mock.patch('proxy.TcpServerConnection') @@ -1835,7 +1850,8 @@ def test_e2e( mock_popen: mock.Mock, mock_server_conn: mock.Mock, mock_ssl_context: mock.Mock, - mock_ssl_wrap: mock.Mock) -> None: + mock_ssl_wrap: mock.Mock, + mock_os_close: mock.Mock) -> None: host, port = uuid.uuid4().hex, 443 netloc = '{0}:{1}'.format(host, port) @@ -1875,6 +1891,7 @@ def mock_connection() -> Any: self._conn = mock_fromfd.return_value self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=self.config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() self.plugin.assert_called() @@ -1953,6 +1970,7 @@ def mock_connection() -> Any: class TestHttpProxyPluginExamplesWithTlsInterception(unittest.TestCase): + @mock.patch('os.close') @mock.patch('ssl.wrap_socket') @mock.patch('ssl.create_default_context') @mock.patch('proxy.TcpServerConnection') @@ -1965,7 +1983,8 @@ def setUp(self, mock_popen: mock.Mock, mock_server_conn: mock.Mock, mock_ssl_context: mock.Mock, - mock_ssl_wrap: mock.Mock) -> None: + mock_ssl_wrap: mock.Mock, + mock_os_close: mock.Mock) -> None: self.mock_fromfd = mock_fromfd self.mock_selector = mock_selector self.mock_popen = mock_popen @@ -1991,6 +2010,7 @@ def setUp(self, mock_fromfd.return_value = self._conn self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=self.config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() self.server = self.mock_server_conn.return_value @@ -2082,11 +2102,6 @@ def test_modify_post_data_plugin(self) -> None: ) ) - @mock.patch('proxy.TcpServerConnection') - def test_cache_responses_plugin( - self, mock_server_conn: mock.Mock) -> None: - pass - @mock.patch('proxy.TcpServerConnection') def test_man_in_the_middle_plugin( self, mock_server_conn: mock.Mock) -> None: @@ -2180,6 +2195,7 @@ def mock_default_args(mock_args: mock.Mock) -> None: mock_args.devtools_event_queue = None mock_args.devtools_ws_path = proxy.DEFAULT_DEVTOOLS_WS_PATH mock_args.timeout = proxy.DEFAULT_TIMEOUT + mock_args.threadless = proxy.DEFAULT_THREADLESS @mock.patch('time.sleep') @mock.patch('proxy.load_plugins') @@ -2236,7 +2252,8 @@ def test_init_with_no_arguments( enable_static_server=mock_args.enable_static_server, devtools_event_queue=None, devtools_ws_path=proxy.DEFAULT_DEVTOOLS_WS_PATH, - timeout=proxy.DEFAULT_TIMEOUT + timeout=proxy.DEFAULT_TIMEOUT, + threadless=proxy.DEFAULT_THREADLESS, ) mock_acceptor_pool.assert_called_with( @@ -2245,6 +2262,7 @@ def test_init_with_no_arguments( backlog=mock_protocol_config.return_value.backlog, num_workers=mock_protocol_config.return_value.num_workers, work_klass=proxy.ProtocolHandler, + threadless=mock_protocol_config.return_value.threadless, config=mock_protocol_config.return_value, ) mock_acceptor_pool.return_value.setup.assert_called() @@ -2297,6 +2315,7 @@ def test_basic_auth( backlog=config.backlog, num_workers=config.num_workers, work_klass=proxy.ProtocolHandler, + threadless=config.threadless, config=config) self.assertEqual(mock_protocol_config.call_args[1]['auth_code'], b'Basic dXNlcjpwYXNz')