diff --git a/Makefile b/Makefile index 17152b4e8f..2fac9c297d 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,7 @@ devtools: pushd dashboard && npm run devtools && popd autopep8: + autopep8 --recursive --in-place --aggressive examples autopep8 --recursive --in-place --aggressive proxy autopep8 --recursive --in-place --aggressive tests autopep8 --recursive --in-place --aggressive setup.py @@ -73,8 +74,8 @@ lib-clean: rm -rf .hypothesis lib-lint: - flake8 --ignore=W504 --max-line-length=127 --max-complexity=19 proxy/ tests/ setup.py - mypy --strict --ignore-missing-imports proxy/ tests/ setup.py + flake8 --ignore=W504 --max-line-length=127 --max-complexity=19 examples/ proxy/ tests/ setup.py + mypy --strict --ignore-missing-imports examples/ proxy/ tests/ setup.py lib-test: lib-clean lib-version lib-lint pytest -v tests/ diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000000..82765d9588 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,5 @@ +# Proxy.py Library Examples + +This directory contains examples that demonstrate `proxy.py` core library capabilities. + +Looking for `proxy.py` plugin examples? Check [proxy/plugin](https://github.com/abhinavsingh/proxy.py/tree/develop/proxy/plugin) directory. diff --git a/examples/tcp_echo_server.py b/examples/tcp_echo_server.py new file mode 100644 index 0000000000..02d7ca3bd0 --- /dev/null +++ b/examples/tcp_echo_server.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import time +import socket +import selectors + +from typing import Dict + +from proxy.core.acceptor import AcceptorPool, Work +from proxy.common.flags import Flags +from proxy.common.types import Readables, Writables + + +class EchoServerHandler(Work): + """EchoServerHandler implements Work interface. + + An instance of EchoServerHandler is created for each client + connection. EchoServerHandler lifecycle is controlled by + Threadless core using asyncio. Implementation must provide + get_events and handle_events method. Optionally, also implement + intialize, is_inactive and shutdown method. + """ + + def get_events(self) -> Dict[socket.socket, int]: + # We always want to read from client + # Register for EVENT_READ events + events = {self.client.connection: selectors.EVENT_READ} + # If there is pending buffer for client + # also register for EVENT_WRITE events + if self.client.has_buffer(): + events[self.client.connection] |= selectors.EVENT_WRITE + return events + + def handle_events( + self, + readables: Readables, + writables: Writables) -> bool: + """Return True to shutdown work.""" + if self.client.connection in readables: + data = self.client.recv() + if data is None: + # Client closed connection, signal shutdown + return True + # Queue data back to client + self.client.queue(data) + + if self.client.connection in writables: + self.client.flush() + + return False + + +def main() -> None: + # This example requires `threadless=True` + pool = AcceptorPool( + flags=Flags(num_workers=1, threadless=True), + work_klass=EchoServerHandler) + try: + pool.setup() + while True: + time.sleep(1) + finally: + pool.shutdown() + + +if __name__ == '__main__': + main() diff --git a/examples/websocket_client.py b/examples/websocket_client.py new file mode 100644 index 0000000000..5509d10dac --- /dev/null +++ b/examples/websocket_client.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import time +from proxy.http.websocket import WebsocketClient, WebsocketFrame, websocketOpcodes + + +# globals +client: WebsocketClient +last_dispatch_time: float +static_frame = memoryview(WebsocketFrame.text(b'hello')) +num_echos = 10 + + +def on_message(frame: WebsocketFrame) -> None: + """WebsocketClient on_message callback.""" + global client, num_echos, last_dispatch_time + print('Received %r after %d millisec' % (frame.data, (time.time() - last_dispatch_time) * 1000)) + assert(frame.data == b'hello' and frame.opcode == websocketOpcodes.TEXT_FRAME) + if num_echos > 0: + client.queue(static_frame) + last_dispatch_time = time.time() + num_echos -= 1 + else: + client.close() + + +if __name__ == '__main__': + # Constructor establishes socket connection + client = WebsocketClient(b'echo.websocket.org', 80, b'/', on_message=on_message) + # Perform handshake + client.handshake() + # Queue some data for client + client.queue(static_frame) + last_dispatch_time = time.time() + # Start event loop + client.run() diff --git a/proxy/common/flags.py b/proxy/common/flags.py index 26b4470971..e22ae5d038 100644 --- a/proxy/common/flags.py +++ b/proxy/common/flags.py @@ -21,8 +21,9 @@ import sys import inspect -from typing import Optional, Union, Dict, List, TypeVar, Type, cast, Any, Tuple +from typing import Optional, Dict, List, TypeVar, Type, cast, Any, Tuple +from .types import IpAddress from .utils import text_, bytes_ from .constants import DEFAULT_LOG_LEVEL, DEFAULT_LOG_FILE, DEFAULT_LOG_FORMAT, DEFAULT_BACKLOG, DEFAULT_BASIC_AUTH from .constants import DEFAULT_TIMEOUT, DEFAULT_DEVTOOLS_WS_PATH, DEFAULT_DISABLE_HTTP_PROXY, DEFAULT_DISABLE_HEADERS @@ -67,8 +68,7 @@ def __init__( ca_signing_key_file: Optional[str] = None, ca_file: Optional[str] = None, num_workers: int = 0, - hostname: Union[ipaddress.IPv4Address, - ipaddress.IPv6Address] = DEFAULT_IPV6_HOSTNAME, + hostname: IpAddress = DEFAULT_IPV6_HOSTNAME, port: int = DEFAULT_PORT, backlog: int = DEFAULT_BACKLOG, static_server_dir: str = DEFAULT_STATIC_SERVER_DIR, @@ -99,8 +99,7 @@ def __init__( self.ca_signing_key_file: Optional[str] = ca_signing_key_file self.ca_file = ca_file self.num_workers: int = num_workers if num_workers > 0 else multiprocessing.cpu_count() - self.hostname: Union[ipaddress.IPv4Address, - ipaddress.IPv6Address] = hostname + self.hostname: IpAddress = hostname self.family: socket.AddressFamily = socket.AF_INET6 if hostname.version == 6 else socket.AF_INET self.port: int = port self.backlog: int = backlog @@ -161,7 +160,8 @@ def initialize( # Setup limits Flags.set_open_file_limit(args.open_file_limit) - # Prepare list of plugins to load based upon --enable-* and --disable-* flags + # Prepare list of plugins to load based upon --enable-* and --disable-* + # flags default_plugins: List[Tuple[str, bool]] = [] if args.enable_dashboard: default_plugins.append((PLUGIN_WEB_SERVER, True)) @@ -249,8 +249,7 @@ def initialize( opts.get( 'ca_file', args.ca_file)), - hostname=cast(Union[ipaddress.IPv4Address, - ipaddress.IPv6Address], + hostname=cast(IpAddress, opts.get('hostname', ipaddress.ip_address(args.hostname))), port=cast(int, opts.get('port', args.port)), backlog=cast(int, opts.get('backlog', args.backlog)), diff --git a/proxy/common/types.py b/proxy/common/types.py index c411048443..279211e422 100644 --- a/proxy/common/types.py +++ b/proxy/common/types.py @@ -9,8 +9,9 @@ :license: BSD, see LICENSE for more details. """ import queue +import ipaddress -from typing import TYPE_CHECKING, Dict, Any +from typing import TYPE_CHECKING, Dict, Any, List, Union from typing_extensions import Protocol @@ -23,3 +24,8 @@ class HasFileno(Protocol): def fileno(self) -> int: ... # pragma: no cover + + +Readables = List[Union[int, HasFileno]] +Writables = List[Union[int, HasFileno]] +IpAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] diff --git a/proxy/common/utils.py b/proxy/common/utils.py index ecdc4e9fe2..078bbbcbe7 100644 --- a/proxy/common/utils.py +++ b/proxy/common/utils.py @@ -101,7 +101,8 @@ def build_http_pkt(line: List[bytes], def build_websocket_handshake_request( key: bytes, method: bytes = b'GET', - url: bytes = b'/') -> bytes: + url: bytes = b'/', + host: bytes = b'localhost') -> bytes: """ Build and returns a Websocket handshake request packet. @@ -112,6 +113,7 @@ def build_websocket_handshake_request( return build_http_request( method, url, headers={ + b'Host': host, b'Connection': b'upgrade', b'Upgrade': b'websocket', b'Sec-WebSocket-Key': key, diff --git a/proxy/core/acceptor/__init__.py b/proxy/core/acceptor/__init__.py index 9c0a97b332..cca3bcdb4d 100644 --- a/proxy/core/acceptor/__init__.py +++ b/proxy/core/acceptor/__init__.py @@ -10,8 +10,12 @@ """ from .acceptor import Acceptor from .pool import AcceptorPool +from .work import Work +from .threadless import Threadless __all__ = [ 'Acceptor', 'AcceptorPool', + 'Work', + 'Threadless', ] diff --git a/proxy/core/acceptor/acceptor.py b/proxy/core/acceptor/acceptor.py index eeb2955269..d5134b4e13 100644 --- a/proxy/core/acceptor/acceptor.py +++ b/proxy/core/acceptor/acceptor.py @@ -14,13 +14,15 @@ import selectors import socket import threading -# import time + from multiprocessing import connection from multiprocessing.reduction import send_handle, recv_handle from typing import Optional, Type, Tuple +from .work import Work +from .threadless import Threadless + from ..connection import TcpClientConnection -from ..threadless import ThreadlessWork, Threadless from ..event import EventQueue, eventNames from ...common.flags import Flags @@ -28,10 +30,12 @@ class Acceptor(multiprocessing.Process): - """Socket client acceptor. + """Socket server acceptor process. - Accepts client connection over received server socket handle and - starts a new work thread. + Accepts client connection over received server socket handle at startup. Spawns a separate + thread to handle each client request. However, when `--threadless` is enabled, Acceptor also + pre-spawns a `Threadless` process at startup. Accepted client connections are passed to + `Threadless` process which internally uses asyncio event loop to handle client connections. """ def __init__( @@ -39,7 +43,7 @@ def __init__( idd: int, work_queue: connection.Connection, flags: Flags, - work_klass: Type[ThreadlessWork], + work_klass: Type[Work], lock: multiprocessing.synchronize.Lock, event_queue: Optional[EventQueue] = None) -> None: super().__init__() @@ -108,11 +112,7 @@ def run_once(self) -> None: if len(events) == 0: return conn, addr = self.sock.accept() - - # now = time.time() - # fileno: int = conn.fileno() self.start_work(conn, addr) - # logger.info('Work started for fd %d in %f seconds', fileno, time.time() - now) def run(self) -> None: self.selector = selectors.DefaultSelector() diff --git a/proxy/core/acceptor/pool.py b/proxy/core/acceptor/pool.py index 48cedaf96f..259c5d3049 100644 --- a/proxy/core/acceptor/pool.py +++ b/proxy/core/acceptor/pool.py @@ -18,7 +18,8 @@ from typing import List, Optional, Type from .acceptor import Acceptor -from ..threadless import ThreadlessWork +from .work import Work + from ..event import EventQueue, EventDispatcher from ...common.flags import Flags @@ -31,11 +32,20 @@ class AcceptorPool: """AcceptorPool. Pre-spawns worker processes to utilize all cores available on the system. Server socket connection is - dispatched over a pipe to workers. Each worker accepts incoming client request and spawns a - separate thread to handle the client request. + dispatched over a pipe to workers. Each Acceptor instance accepts for new client connection. + + Example usage: + + pool = AcceptorPool(flags=..., work_klass=...) + try: + pool.setup() + while True: + time.sleep(1) + finally: + pool.shutdown() """ - def __init__(self, flags: Flags, work_klass: Type[ThreadlessWork]) -> None: + def __init__(self, flags: Flags, work_klass: Type[Work]) -> None: self.flags = flags self.socket: Optional[socket.socket] = None self.acceptors: List[Acceptor] = [] diff --git a/proxy/core/threadless.py b/proxy/core/acceptor/threadless.py similarity index 69% rename from proxy/core/threadless.py rename to proxy/core/acceptor/threadless.py index 87be7e5b86..78a59cc9f0 100644 --- a/proxy/core/threadless.py +++ b/proxy/core/acceptor/threadless.py @@ -18,81 +18,18 @@ from multiprocessing import connection from multiprocessing.reduction import recv_handle -from abc import ABC, abstractmethod -from typing import Dict, Optional, Tuple, List, Union, Generator, Any, Type -from uuid import uuid4, UUID +from typing import Dict, Optional, Tuple, List, Generator, Any, Type -from .connection import TcpClientConnection -from .event import EventQueue, eventNames +from .work import Work -from ..common.flags import Flags -from ..common.types import HasFileno -from ..common.constants import DEFAULT_TIMEOUT +from ..connection import TcpClientConnection +from ..event import EventQueue, eventNames -logger = logging.getLogger(__name__) - - -class ThreadlessWork(ABC): - """Implement ThreadlessWork to hook into the event loop provided by Threadless process.""" - - @abstractmethod - def __init__( - self, - client: TcpClientConnection, - flags: Optional[Flags], - event_queue: Optional[EventQueue] = None, - uid: Optional[UUID] = None) -> None: - self.client = client - self.flags = flags if flags else Flags() - self.event_queue = event_queue - self.uid: UUID = uid if uid is not None else uuid4() - - @abstractmethod - def initialize(self) -> None: - pass # pragma: no cover - - @abstractmethod - def is_inactive(self) -> bool: - return False # pragma: no cover - - @abstractmethod - def get_events(self) -> Dict[socket.socket, int]: - return {} # pragma: no cover +from ...common.flags import Flags +from ...common.types import Readables, Writables +from ...common.constants import DEFAULT_TIMEOUT - @abstractmethod - def handle_events( - self, - readables: List[Union[int, HasFileno]], - writables: List[Union[int, HasFileno]]) -> bool: - """Return True to shutdown work.""" - return False # pragma: no cover - - @abstractmethod - def run(self) -> None: - pass - - def publish_event( - self, - event_name: int, - event_payload: Dict[str, Any], - publisher_id: Optional[str] = None) -> None: - if not self.flags.enable_events: - return - assert self.event_queue - self.event_queue.publish( - self.uid.hex, - event_name, - event_payload, - publisher_id - ) - - def shutdown(self) -> None: - """Must close any opened resources and call super().shutdown().""" - self.publish_event( - event_name=eventNames.WORK_FINISHED, - event_payload={}, - publisher_id=self.__class__.__name__ - ) +logger = logging.getLogger(__name__) class Threadless(multiprocessing.Process): @@ -103,15 +40,15 @@ class Threadless(multiprocessing.Process): for each accepted client connection, Acceptor process sends accepted client connection to Threadless process over a pipe. - HttpProtocolHandler implements ThreadlessWork class and hooks into the - event loop provided by Threadless. + Example, HttpProtocolHandler implements Work class to hooks into the + event loop provided by Threadless process. """ def __init__( self, client_queue: connection.Connection, flags: Flags, - work_klass: Type[ThreadlessWork], + work_klass: Type[Work], event_queue: Optional[EventQueue] = None) -> None: super().__init__() self.client_queue = client_queue @@ -120,13 +57,12 @@ def __init__( self.event_queue = event_queue self.running = multiprocessing.Event() - self.works: Dict[int, ThreadlessWork] = {} + self.works: Dict[int, Work] = {} self.selector: Optional[selectors.DefaultSelector] = None self.loop: Optional[asyncio.AbstractEventLoop] = None @contextlib.contextmanager - def selected_events(self) -> Generator[Tuple[List[Union[int, HasFileno]], - List[Union[int, HasFileno]]], + def selected_events(self) -> Generator[Tuple[Readables, Writables], None, None]: events: Dict[socket.socket, int] = {} for work in self.works.values(): @@ -148,8 +84,8 @@ def selected_events(self) -> Generator[Tuple[List[Union[int, HasFileno]], async def handle_events( self, fileno: int, - readables: List[Union[int, HasFileno]], - writables: List[Union[int, HasFileno]]) -> bool: + readables: Readables, + writables: Writables) -> bool: return self.works[fileno].handle_events(readables, writables) # TODO: Use correct future typing annotations diff --git a/proxy/core/acceptor/work.py b/proxy/core/acceptor/work.py new file mode 100644 index 0000000000..bcd251f7b0 --- /dev/null +++ b/proxy/core/acceptor/work.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import socket + +from abc import ABC, abstractmethod +from uuid import uuid4, UUID +from typing import Optional, Dict, Any + +from ..event import eventNames, EventQueue +from ..connection import TcpClientConnection +from ...common.flags import Flags +from ...common.types import Readables, Writables + + +class Work(ABC): + """Implement Work to hook into the event loop provided by Threadless process.""" + + def __init__( + self, + client: TcpClientConnection, + flags: Optional[Flags], + event_queue: Optional[EventQueue] = None, + uid: Optional[UUID] = None) -> None: + self.client = client + self.flags = flags if flags else Flags() + self.event_queue = event_queue + self.uid: UUID = uid if uid is not None else uuid4() + + @abstractmethod + def get_events(self) -> Dict[socket.socket, int]: + """Return sockets and events (read or write) that we are interested in.""" + return {} # pragma: no cover + + @abstractmethod + def handle_events( + self, + readables: Readables, + writables: Writables) -> bool: + """Handle readable and writable sockets. + + Return True to shutdown work.""" + return False # pragma: no cover + + def initialize(self) -> None: + """Perform any resource initialization.""" + pass # pragma: no cover + + def is_inactive(self) -> bool: + """Return True if connection should be considered inactive.""" + return False # pragma: no cover + + def shutdown(self) -> None: + """Implementation must close any opened resources here + and call super().shutdown().""" + self.publish_event( + event_name=eventNames.WORK_FINISHED, + event_payload={}, + publisher_id=self.__class__.__name__ + ) + + def run(self) -> None: + """run() method is not used by Threadless. It's here for backward + compatibility with threaded mode where work class is started as + a separate thread. + """ + pass + + def publish_event( + self, + event_name: int, + event_payload: Dict[str, Any], + publisher_id: Optional[str] = None) -> None: + """Convenience method provided to publish events into the global event queue.""" + if not self.flags.enable_events: + return + assert self.event_queue + self.event_queue.publish( + self.uid.hex, + event_name, + event_payload, + publisher_id + ) diff --git a/proxy/http/handler.py b/proxy/http/handler.py index d5538a2672..4a1b67be9e 100644 --- a/proxy/http/handler.py +++ b/proxy/http/handler.py @@ -24,15 +24,15 @@ from .exception import HttpProtocolException from ..common.flags import Flags -from ..common.types import HasFileno -from ..core.threadless import ThreadlessWork +from ..common.types import Readables, Writables +from ..core.acceptor.work import Work from ..core.event import EventQueue from ..core.connection import TcpClientConnection logger = logging.getLogger(__name__) -class HttpProtocolHandler(ThreadlessWork): +class HttpProtocolHandler(Work): """HTTP, HTTPS, HTTP2, WebSockets protocol handler. Accepts `Client` connection object and manages HttpProtocolHandlerPlugin invocations. @@ -100,8 +100,8 @@ def get_events(self) -> Dict[socket.socket, int]: def handle_events( self, - readables: List[Union[int, HasFileno]], - writables: List[Union[int, HasFileno]]) -> bool: + readables: Readables, + writables: Writables) -> bool: """Returns True if proxy must teardown.""" # Flush buffer for ready to write sockets teardown = self.handle_writables(writables) @@ -197,7 +197,7 @@ def flush(self) -> None: finally: self.selector.unregister(self.client.connection) - def handle_writables(self, writables: List[Union[int, HasFileno]]) -> bool: + def handle_writables(self, writables: Writables) -> bool: if self.client.has_buffer() and self.client.connection in writables: logger.debug('Client is ready for writes, flushing buffer') self.last_activity = time.time() @@ -222,7 +222,7 @@ def handle_writables(self, writables: List[Union[int, HasFileno]]) -> bool: return True return False - def handle_readables(self, readables: List[Union[int, HasFileno]]) -> bool: + def handle_readables(self, readables: Readables) -> bool: if self.client.connection in readables: logger.debug('Client is ready for reads, reading') self.last_activity = time.time() @@ -292,8 +292,7 @@ def handle_readables(self, readables: List[Union[int, HasFileno]]) -> bool: @contextlib.contextmanager def selected_events(self) -> \ - Generator[Tuple[List[Union[int, HasFileno]], - List[Union[int, HasFileno]]], + Generator[Tuple[Readables, Writables], None, None]: events = self.get_events() for fd in events: diff --git a/proxy/http/parser.py b/proxy/http/parser.py index e3e23d2a00..d40885ca7c 100644 --- a/proxy/http/parser.py +++ b/proxy/http/parser.py @@ -171,7 +171,8 @@ def parse(self, raw: bytes) -> None: self.state = httpParserStates.COMPLETE more = False else: - raise NotImplementedError('Parser shouldn\'t have reached here') + raise NotImplementedError( + 'Parser shouldn\'t have reached here') else: more, raw = self.process(raw) self.buffer = raw @@ -258,7 +259,8 @@ def build_response(self) -> bytes: status_code=int(self.code), protocol_version=self.version, reason=self.reason, - headers={} if not self.headers else {self.headers[k][0]: self.headers[k][1] for k in self.headers}, + headers={} if not self.headers else { + self.headers[k][0]: self.headers[k][1] for k in self.headers}, body=self.body if not self.is_chunked_encoded() else ChunkParser.to_chunks(self.body)) def has_upstream_server(self) -> bool: diff --git a/proxy/http/plugin.py b/proxy/http/plugin.py index 86232ca514..b8d95b2863 100644 --- a/proxy/http/plugin.py +++ b/proxy/http/plugin.py @@ -17,7 +17,7 @@ from .parser import HttpParser from ..common.flags import Flags -from ..common.types import HasFileno +from ..common.types import Readables, Writables from ..core.event import EventQueue from ..core.connection import TcpClientConnection @@ -71,11 +71,11 @@ def get_descriptors( return [], [] # pragma: no cover @abstractmethod - def write_to_descriptors(self, w: List[Union[int, HasFileno]]) -> bool: + def write_to_descriptors(self, w: Writables) -> bool: return False # pragma: no cover @abstractmethod - def read_from_descriptors(self, r: List[Union[int, HasFileno]]) -> bool: + def read_from_descriptors(self, r: Readables) -> bool: return False # pragma: no cover @abstractmethod diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index 0396bf0432..3fad47e7ff 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -25,7 +25,7 @@ from ..parser import HttpParser, httpParserStates, httpParserTypes from ..methods import httpMethods -from ...common.types import HasFileno +from ...common.types import Readables, Writables from ...common.constants import PROXY_AGENT_HEADER_VALUE from ...common.utils import build_http_response, text_ from ...common.pki import gen_public_key, gen_csr, sign_csr @@ -82,7 +82,7 @@ def get_descriptors( w.append(self.server.connection) return r, w - def write_to_descriptors(self, w: List[Union[int, HasFileno]]) -> bool: + def write_to_descriptors(self, w: Writables) -> bool: if self.request.has_upstream_server() and \ self.server and not self.server.closed and \ self.server.has_buffer() and \ @@ -91,18 +91,20 @@ def write_to_descriptors(self, w: List[Union[int, HasFileno]]) -> bool: try: self.server.flush() except ssl.SSLWantWriteError: - logger.warning('SSLWantWriteError while trying to flush to server, will retry') + logger.warning( + 'SSLWantWriteError while trying to flush to server, will retry') return False except BrokenPipeError: logger.error( 'BrokenPipeError when flushing buffer for server') return True except OSError as e: - logger.exception('OSError when flushing buffer to server', exc_info=e) + logger.exception( + 'OSError when flushing buffer to server', exc_info=e) return True return False - def read_from_descriptors(self, r: List[Union[int, HasFileno]]) -> bool: + def read_from_descriptors(self, r: Readables) -> bool: if self.request.has_upstream_server( ) and self.server and not self.server.closed and self.server.connection in r: logger.debug('Server is ready for reads, reading...') @@ -289,7 +291,8 @@ def on_request_complete(self) -> Union[socket.socket, bool]: # sending to client can raise, handle expected exceptions self.wrap_client() except subprocess.TimeoutExpired as e: # Popen communicate timeout - logger.exception('TimeoutExpired during certificate generation', exc_info=e) + logger.exception( + 'TimeoutExpired during certificate generation', exc_info=e) return True except BrokenPipeError: logger.error( @@ -360,7 +363,8 @@ def access_log(self) -> None: self.response.total_size, connection_time_ms)) - def gen_ca_signed_certificate(self, cert_file_path: str, certificate: Dict[str, Any]) -> None: + def gen_ca_signed_certificate( + self, cert_file_path: str, certificate: Dict[str, Any]) -> None: '''CA signing key (default) is used for generating a public key for common_name, if one already doesn't exist. Using generated public key a CSR request is generated, which is then signed by @@ -388,7 +392,8 @@ def gen_ca_signed_certificate(self, cert_file_path: str, certificate: Dict[str, subject = '' for key in keys: if upstream_subject.get(keys[key], None): - subject += '/{0}={1}'.format(key, upstream_subject.get(keys[key])) + subject += '/{0}={1}'.format(key, + upstream_subject.get(keys[key])) alt_subj_names = [text_(self.request.host), ] validity_in_days = 365 * 2 timeout = 10 diff --git a/proxy/http/server/web.py b/proxy/http/server/web.py index ea53756d78..54be0ab65b 100644 --- a/proxy/http/server/web.py +++ b/proxy/http/server/web.py @@ -27,7 +27,7 @@ from ...common.utils import bytes_, text_, build_http_response, build_websocket_handshake_response from ...common.constants import PROXY_AGENT_HEADER_VALUE -from ...common.types import HasFileno +from ...common.types import Readables, Writables logger = logging.getLogger(__name__) @@ -166,10 +166,10 @@ def on_request_complete(self) -> Union[socket.socket, bool]: self.client.queue(self.DEFAULT_404_RESPONSE) return True - def write_to_descriptors(self, w: List[Union[int, HasFileno]]) -> bool: + def write_to_descriptors(self, w: Writables) -> bool: pass - def read_from_descriptors(self, r: List[Union[int, HasFileno]]) -> bool: + def read_from_descriptors(self, r: Readables) -> bool: pass def on_client_data(self, raw: memoryview) -> Optional[memoryview]: diff --git a/proxy/http/websocket/__init__.py b/proxy/http/websocket/__init__.py new file mode 100644 index 0000000000..2870e3b26f --- /dev/null +++ b/proxy/http/websocket/__init__.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from .frame import WebsocketFrame, websocketOpcodes +from .client import WebsocketClient + +__all__ = [ + 'websocketOpcodes', + 'WebsocketFrame', + 'WebsocketClient', +] diff --git a/proxy/http/websocket/client.py b/proxy/http/websocket/client.py new file mode 100644 index 0000000000..16e440dec4 --- /dev/null +++ b/proxy/http/websocket/client.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import base64 +import selectors +import socket +import secrets +import ssl + +from typing import Optional, Union, Callable + +from .frame import WebsocketFrame + +from ..parser import httpParserTypes, HttpParser + +from ...common.constants import DEFAULT_BUFFER_SIZE +from ...common.utils import new_socket_connection, build_websocket_handshake_request, text_ +from ...core.connection import tcpConnectionTypes, TcpConnection + + +class WebsocketClient(TcpConnection): + + def __init__(self, + hostname: bytes, + port: int, + path: bytes = b'/', + on_message: Optional[Callable[[WebsocketFrame], None]] = None) -> None: + super().__init__(tcpConnectionTypes.CLIENT) + self.hostname: bytes = hostname + self.port: int = port + self.path: bytes = path + self.sock: socket.socket = new_socket_connection( + (socket.gethostbyname(text_(self.hostname)), self.port)) + self.on_message: Optional[Callable[[ + WebsocketFrame], None]] = on_message + self.selector: selectors.DefaultSelector = selectors.DefaultSelector() + + @property + def connection(self) -> Union[ssl.SSLSocket, socket.socket]: + return self.sock + + def handshake(self) -> None: + self.upgrade() + self.sock.setblocking(False) + + def upgrade(self) -> None: + key = base64.b64encode(secrets.token_bytes(16)) + self.sock.send(build_websocket_handshake_request(key, url=self.path, host=self.hostname)) + response = HttpParser(httpParserTypes.RESPONSE_PARSER) + response.parse(self.sock.recv(DEFAULT_BUFFER_SIZE)) + accept = response.header(b'Sec-Websocket-Accept') + assert WebsocketFrame.key_to_accept(key) == accept + + def ping(self, data: Optional[bytes] = None) -> None: + pass + + def pong(self, data: Optional[bytes] = None) -> None: + pass + + def shutdown(self, _data: Optional[bytes] = None) -> None: + """Closes connection with the server.""" + super().close() + + def run_once(self) -> bool: + ev = selectors.EVENT_READ + if self.has_buffer(): + ev |= selectors.EVENT_WRITE + self.selector.register(self.sock.fileno(), ev) + events = self.selector.select(timeout=1) + self.selector.unregister(self.sock) + for _, mask in events: + if mask & selectors.EVENT_READ and self.on_message: + raw = self.recv() + if raw is None or raw.tobytes() == b'': + self.closed = True + return True + frame = WebsocketFrame() + # TODO(abhinavsingh): Remove .tobytes after parser is + # memoryview compliant + frame.parse(raw.tobytes()) + self.on_message(frame) + elif mask & selectors.EVENT_WRITE: + self.flush() + return False + + def run(self) -> None: + try: + while not self.closed: + teardown = self.run_once() + if teardown: + break + except KeyboardInterrupt: + pass + finally: + if not self.closed: + self.selector.unregister(self.sock) + self.sock.shutdown(socket.SHUT_WR) + self.sock.close() diff --git a/proxy/http/websocket.py b/proxy/http/websocket/frame.py similarity index 58% rename from proxy/http/websocket.py rename to proxy/http/websocket/frame.py index a6eb5a3372..55f9d91b16 100644 --- a/proxy/http/websocket.py +++ b/proxy/http/websocket/frame.py @@ -10,22 +10,12 @@ """ import hashlib import base64 -import selectors import struct -import socket import secrets -import ssl -import ipaddress import logging import io -from typing import TypeVar, Type, Optional, NamedTuple, Union, Callable - -from .parser import httpParserTypes, HttpParser - -from ..common.constants import DEFAULT_BUFFER_SIZE -from ..common.utils import new_socket_connection, build_websocket_handshake_request -from ..core.connection import tcpConnectionTypes, TcpConnection +from typing import TypeVar, Type, Optional, NamedTuple WebsocketOpcodes = NamedTuple('WebsocketOpcodes', [ @@ -180,89 +170,3 @@ def key_to_accept(key: bytes) -> bytes: sha1 = hashlib.sha1() sha1.update(key + WebsocketFrame.GUID) return base64.b64encode(sha1.digest()) - - -class WebsocketClient(TcpConnection): - - def __init__(self, - hostname: Union[ipaddress.IPv4Address, ipaddress.IPv6Address], - port: int, - path: bytes = b'/', - on_message: Optional[Callable[[WebsocketFrame], None]] = None) -> None: - super().__init__(tcpConnectionTypes.CLIENT) - self.hostname: Union[ipaddress.IPv4Address, - ipaddress.IPv6Address] = hostname - self.port: int = port - self.path: bytes = path - self.sock: socket.socket = new_socket_connection( - (str(self.hostname), self.port)) - self.on_message: Optional[Callable[[ - WebsocketFrame], None]] = on_message - self.upgrade() - self.sock.setblocking(False) - self.selector: selectors.DefaultSelector = selectors.DefaultSelector() - - @property - def connection(self) -> Union[ssl.SSLSocket, socket.socket]: - return self.sock - - def upgrade(self) -> None: - key = base64.b64encode(secrets.token_bytes(16)) - self.sock.send(build_websocket_handshake_request(key, url=self.path)) - response = HttpParser(httpParserTypes.RESPONSE_PARSER) - response.parse(self.sock.recv(DEFAULT_BUFFER_SIZE)) - accept = response.header(b'Sec-Websocket-Accept') - assert WebsocketFrame.key_to_accept(key) == accept - - def ping(self, data: Optional[bytes] = None) -> None: - pass - - def pong(self, data: Optional[bytes] = None) -> None: - pass - - def shutdown(self, _data: Optional[bytes] = None) -> None: - """Closes connection with the server.""" - super().close() - - def run_once(self) -> bool: - ev = selectors.EVENT_READ - if self.has_buffer(): - ev |= selectors.EVENT_WRITE - self.selector.register(self.sock.fileno(), ev) - events = self.selector.select(timeout=1) - self.selector.unregister(self.sock) - for _, mask in events: - if mask & selectors.EVENT_READ and self.on_message: - raw = self.recv() - if raw is None or raw.tobytes() == b'': - self.closed = True - logger.debug('Websocket connection closed by server') - return True - frame = WebsocketFrame() - # TODO(abhinavsingh): Remove .tobytes after parser is - # memoryview compliant - frame.parse(raw.tobytes()) - self.on_message(frame) - elif mask & selectors.EVENT_WRITE: - logger.debug(self.buffer) - self.flush() - return False - - def run(self) -> None: - logger.debug('running') - try: - while not self.closed: - teardown = self.run_once() - if teardown: - break - except KeyboardInterrupt: - pass - finally: - try: - self.selector.unregister(self.sock) - self.sock.shutdown(socket.SHUT_WR) - except Exception as e: - logging.exception( - 'Exception while shutdown of websocket client', exc_info=e) - self.sock.close() - logger.info('done') diff --git a/proxy/proxy.py b/proxy/proxy.py index 5479032689..643e85dcb9 100644 --- a/proxy/proxy.py +++ b/proxy/proxy.py @@ -26,6 +26,11 @@ class Proxy: + """Context manager for controlling core AcceptorPool server lifecycle. + + By default this context manager starts AcceptorPool with HttpProtocolHandler + worker class. + """ def __init__(self, input_args: Optional[List[str]], **opts: Any) -> None: self.flags = Flags.initialize(input_args, **opts) @@ -78,7 +83,8 @@ def main( **opts: Any) -> None: try: with Proxy(input_args=input_args, **opts): - # TODO: Introduce cron feature instead of mindless sleep + # TODO: Introduce cron feature + # https://github.com/abhinavsingh/proxy.py/issues/392 while True: time.sleep(1) except KeyboardInterrupt: diff --git a/setup.py b/setup.py index 8570fb2952..63df325b7d 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,8 @@ author_email=__author_email__, url=__homepage__, description=__description__, - long_description=open('README.md', 'r', encoding='utf-8').read().strip(), + long_description=open( + 'README.md', 'r', encoding='utf-8').read().strip(), long_description_content_type='text/markdown', download_url=__download_url__, license=__license__, diff --git a/tests/common/test_pki.py b/tests/common/test_pki.py index ebeb767dba..e55c063795 100644 --- a/tests/common/test_pki.py +++ b/tests/common/test_pki.py @@ -112,7 +112,9 @@ def _gen_public_private_key(self) -> Tuple[str, str, str]: def _gen_private_key(self) -> Tuple[str, str]: key_path = os.path.join(tempfile.gettempdir(), 'test_gen_private.key') - nopass_key_path = os.path.join(tempfile.gettempdir(), 'test_gen_private_nopass.key') + nopass_key_path = os.path.join( + tempfile.gettempdir(), + 'test_gen_private_nopass.key') pki.gen_private_key(key_path, 'password') pki.remove_passphrase(key_path, 'password', nopass_key_path) return (key_path, nopass_key_path) diff --git a/tests/http/test_http_parser.py b/tests/http/test_http_parser.py index c98e6381ec..e8d78cbea9 100644 --- a/tests/http/test_http_parser.py +++ b/tests/http/test_http_parser.py @@ -139,7 +139,8 @@ def test_get_full_parse(self) -> None: self.assertEqual(self.parser.url.port, None) self.assertEqual(self.parser.version, b'HTTP/1.1') self.assertEqual(self.parser.state, httpParserStates.COMPLETE) - self.assertEqual(self.parser.headers[b'host'], (b'Host', b'example.com')) + self.assertEqual( + self.parser.headers[b'host'], (b'Host', b'example.com')) self.parser.del_headers([b'host']) self.parser.add_headers([(b'Host', b'example.com')]) self.assertEqual( @@ -193,7 +194,10 @@ def test_get_partial_parse1(self) -> None: self.parser.parse(CRLF * 2) self.assertEqual(self.parser.total_size, len(pkt) + (3 * len(CRLF)) + len(host_hdr)) - self.assertEqual(self.parser.headers[b'host'], (b'Host', b'localhost:8080')) + self.assertEqual( + self.parser.headers[b'host'], + (b'Host', + b'localhost:8080')) self.assertEqual(self.parser.state, httpParserStates.COMPLETE) def test_get_partial_parse2(self) -> None: @@ -210,7 +214,10 @@ def test_get_partial_parse2(self) -> None: self.assertEqual(self.parser.state, httpParserStates.LINE_RCVD) self.parser.parse(b'localhost:8080' + CRLF) - self.assertEqual(self.parser.headers[b'host'], (b'Host', b'localhost:8080')) + self.assertEqual( + self.parser.headers[b'host'], + (b'Host', + b'localhost:8080')) self.assertEqual(self.parser.buffer, b'') self.assertEqual( self.parser.state, diff --git a/tests/http/test_http_proxy_tls_interception.py b/tests/http/test_http_proxy_tls_interception.py index dac97cd65e..2790f8e530 100644 --- a/tests/http/test_http_proxy_tls_interception.py +++ b/tests/http/test_http_proxy_tls_interception.py @@ -146,7 +146,8 @@ def mock_connection() -> Any: self.mock_server_conn.return_value.connection.setblocking.assert_called_with( False) - self.mock_ssl_context.assert_called_with(ssl.Purpose.SERVER_AUTH, cafile=None) + self.mock_ssl_context.assert_called_with( + ssl.Purpose.SERVER_AUTH, cafile=None) # self.assertEqual(self.mock_ssl_context.return_value.options, # ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | # ssl.OP_NO_TLSv1_1) diff --git a/tests/http/test_websocket_client.py b/tests/http/test_websocket_client.py index 060fba1068..faf18b2ac6 100644 --- a/tests/http/test_websocket_client.py +++ b/tests/http/test_websocket_client.py @@ -13,21 +13,26 @@ from proxy.common.utils import build_websocket_handshake_response, build_websocket_handshake_request from proxy.http.websocket import WebsocketClient, WebsocketFrame -from proxy.common.constants import DEFAULT_IPV4_HOSTNAME, DEFAULT_PORT +from proxy.common.constants import DEFAULT_PORT class TestWebsocketClient(unittest.TestCase): + @mock.patch('proxy.http.websocket.client.socket.gethostbyname') @mock.patch('base64.b64encode') - @mock.patch('proxy.http.websocket.new_socket_connection') + @mock.patch('proxy.http.websocket.client.new_socket_connection') def test_handshake(self, mock_connect: mock.Mock, - mock_b64encode: mock.Mock) -> None: + mock_b64encode: mock.Mock, + mock_gethostbyname: mock.Mock) -> None: key = b'MySecretKey' mock_b64encode.return_value = key + mock_gethostbyname.return_value = '127.0.0.1' mock_connect.return_value.recv.return_value = \ build_websocket_handshake_response( WebsocketFrame.key_to_accept(key)) - _ = WebsocketClient(DEFAULT_IPV4_HOSTNAME, DEFAULT_PORT) + client = WebsocketClient(b'localhost', DEFAULT_PORT) + mock_connect.return_value.send.assert_not_called() + client.handshake() mock_connect.return_value.send.assert_called_with( build_websocket_handshake_request(key) ) diff --git a/tests/testing/test_embed.py b/tests/testing/test_embed.py index 6a55af2694..e69c5ffe56 100644 --- a/tests/testing/test_embed.py +++ b/tests/testing/test_embed.py @@ -21,7 +21,8 @@ from proxy.http.methods import httpMethods -@unittest.skipIf(os.name == 'nt', 'Disabled for Windows due to weird permission issues.') +@unittest.skipIf( + os.name == 'nt', 'Disabled for Windows due to weird permission issues.') class TestProxyPyEmbedded(TestCase): """This test case is a demonstration of proxy.TestCase and also serves as integration test suite for proxy.py."""