From ed0fc7e2dea85e43d3f06cf4c71f31f4778d6d90 Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Mon, 24 Jan 2022 16:13:46 +0530 Subject: [PATCH 1/3] `selector.close` when done --- proxy/core/acceptor/acceptor.py | 4 +++- proxy/core/work/threadless.py | 1 + proxy/http/handler.py | 2 ++ proxy/http/websocket/client.py | 1 + 4 files changed, 7 insertions(+), 1 deletion(-) diff --git a/proxy/core/acceptor/acceptor.py b/proxy/core/acceptor/acceptor.py index 0df4c2a555..299c0efbd0 100644 --- a/proxy/core/acceptor/acceptor.py +++ b/proxy/core/acceptor/acceptor.py @@ -178,6 +178,7 @@ def run(self) -> None: for fileno in self.socks: self.socks[fileno].close() self.socks.clear() + self.selector.close() logger.debug('Acceptor#%d shutdown', self.idd) def _recv_and_setup_socks(self) -> None: @@ -207,7 +208,8 @@ def _start_local(self) -> None: self._lthread.start() def _stop_local(self) -> None: - if self._lthread is not None and self._local_work_queue is not None: + if self._lthread is not None and \ + self._local_work_queue is not None: self._local_work_queue.put(False) self._lthread.join() diff --git a/proxy/core/work/threadless.py b/proxy/core/work/threadless.py index 1ad84ac645..f43c0a4732 100644 --- a/proxy/core/work/threadless.py +++ b/proxy/core/work/threadless.py @@ -419,6 +419,7 @@ def run(self) -> None: if wqfileno is not None: self.selector.unregister(wqfileno) self.close_work_queue() + self.selector.close() assert self.loop is not None self.loop.run_until_complete(self.loop.shutdown_asyncgens()) self.loop.close() diff --git a/proxy/http/handler.py b/proxy/http/handler.py index dc2d7cfbf1..bea158afba 100644 --- a/proxy/http/handler.py +++ b/proxy/http/handler.py @@ -351,6 +351,8 @@ def run(self) -> None: ) finally: self.shutdown() + if self.selector: + self.selector.close() loop.close() async def _run_once(self) -> bool: diff --git a/proxy/http/websocket/client.py b/proxy/http/websocket/client.py index c8c53ebee8..742eab8556 100644 --- a/proxy/http/websocket/client.py +++ b/proxy/http/websocket/client.py @@ -121,3 +121,4 @@ def run(self) -> None: except OSError: pass self.sock.close() + self.selector.close() From 9686a4c2a1e68e1adca6d776d36dcda5717761ef Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Mon, 24 Jan 2022 19:08:49 +0530 Subject: [PATCH 2/3] Websocket client test --- proxy/http/parser/parser.py | 2 +- proxy/http/websocket/client.py | 13 ++-- tests/http/websocket/test_websocket_client.py | 77 +++++++++++++++++-- 3 files changed, 79 insertions(+), 13 deletions(-) diff --git a/proxy/http/parser/parser.py b/proxy/http/parser/parser.py index 6573cabaf6..00790c32d8 100644 --- a/proxy/http/parser/parser.py +++ b/proxy/http/parser/parser.py @@ -114,7 +114,7 @@ def response(cls: Type[T], raw: bytes) -> T: def header(self, key: bytes) -> bytes: """Convenient method to return original header value from internal data structure.""" if self.headers is None or key.lower() not in self.headers: - raise KeyError('%s not found in headers', text_(key)) + raise KeyError('%s not found in headers' % text_(key)) return self.headers[key.lower()][1] def has_header(self, key: bytes) -> bool: diff --git a/proxy/http/websocket/client.py b/proxy/http/websocket/client.py index 742eab8556..13e70e7ab9 100644 --- a/proxy/http/websocket/client.py +++ b/proxy/http/websocket/client.py @@ -27,6 +27,9 @@ class WebsocketClient(TcpConnection): + """Websocket client connection. + + TODO: Make me compatible with the work framework.""" def __init__( self, @@ -57,10 +60,14 @@ def connection(self) -> TcpOrTlsSocket: return self.sock def handshake(self) -> None: + """Start websocket upgrade & handshake protocol""" self.upgrade() self.sock.setblocking(False) def upgrade(self) -> None: + """Creates a key and sends websocket handshake packet to upstream. + Receives response from the server and asserts that websocket + accept header is valid in the response.""" key = base64.b64encode(secrets.token_bytes(16)) self.sock.send( build_websocket_handshake_request( @@ -74,12 +81,6 @@ def upgrade(self) -> None: accept = response.header(b'Sec-Websocket-Accept') assert WebsocketFrame.key_to_accept(key) == accept - def ping(self, data: Optional[bytes] = None) -> None: - pass # pragma: no cover - - def pong(self, data: Optional[bytes] = None) -> None: - pass # pragma: no cover - def shutdown(self, _data: Optional[bytes] = None) -> None: """Closes connection with the server.""" super().close() diff --git a/tests/http/websocket/test_websocket_client.py b/tests/http/websocket/test_websocket_client.py index 4e31e48e30..981392200f 100644 --- a/tests/http/websocket/test_websocket_client.py +++ b/tests/http/websocket/test_websocket_client.py @@ -8,6 +8,8 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ +import selectors + import unittest from unittest import mock @@ -15,18 +17,19 @@ build_websocket_handshake_request, build_websocket_handshake_response, ) from proxy.http.websocket import WebsocketFrame, WebsocketClient -from proxy.common.constants import DEFAULT_PORT +from proxy.common.constants import DEFAULT_PORT, DEFAULT_BUFFER_SIZE class TestWebsocketClient(unittest.TestCase): - @mock.patch('proxy.http.websocket.client.socket.gethostbyname') @mock.patch('base64.b64encode') + @mock.patch('proxy.http.websocket.client.socket.gethostbyname') @mock.patch('proxy.http.websocket.client.new_socket_connection') - def test_handshake( - self, mock_connect: mock.Mock, - mock_b64encode: mock.Mock, - mock_gethostbyname: mock.Mock, + def test_handshake_success( + self, + mock_connect: mock.Mock, + mock_gethostbyname: mock.Mock, + mock_b64encode: mock.Mock, ) -> None: key = b'MySecretKey' mock_b64encode.return_value = key @@ -35,9 +38,71 @@ def test_handshake( build_websocket_handshake_response( WebsocketFrame.key_to_accept(key), ) + mock_connect.assert_not_called() client = WebsocketClient(b'localhost', DEFAULT_PORT) + mock_connect.assert_called_once() mock_connect.return_value.send.assert_not_called() client.handshake() mock_connect.return_value.send.assert_called_with( build_websocket_handshake_request(key), ) + mock_connect.return_value.recv.assert_called_once_with( + DEFAULT_BUFFER_SIZE, + ) + + @mock.patch('base64.b64encode') + @mock.patch('selectors.DefaultSelector') + @mock.patch('proxy.http.websocket.client.new_socket_connection') + def test_send_recv_frames_success( + self, + mock_connect: mock.Mock, + mock_selector: mock.Mock, + mock_b64encode: mock.Mock, + ): + key = b'MySecretKey' + mock_b64encode.return_value = key + mock_connect.return_value.recv.side_effect = [ + build_websocket_handshake_response( + WebsocketFrame.key_to_accept(key), + ), + WebsocketFrame.text(b'world'), + ] + + def on_message(frame: WebsocketFrame): + assert frame.build() == WebsocketFrame.text(b'world') + + client = WebsocketClient( + b'localhost', DEFAULT_PORT, on_message=on_message, + ) + mock_selector.assert_called_once() + client.handshake() + client.queue(memoryview(WebsocketFrame.text(b'hello'))) + mock_connect.return_value.send.assert_called_once() + mock_selector.return_value.select.side_effect = [ + [ + (mock.Mock(), selectors.EVENT_WRITE), + ], + ] + client.run_once() + self.assertEqual(mock_connect.return_value.send.call_count, 2) + mock_selector.return_value.select.side_effect = [ + [ + (mock.Mock(), selectors.EVENT_READ), + ], + ] + client.run_once() + + @mock.patch('selectors.DefaultSelector') + @mock.patch('proxy.http.websocket.client.new_socket_connection') + def test_run( + self, + mock_connect: mock.Mock, + mock_selector: mock.Mock, + ) -> None: + mock_selector.return_value.select.side_effect = KeyboardInterrupt + client = WebsocketClient(b'localhost', DEFAULT_PORT) + client.run() + mock_connect.return_value.shutdown.assert_called_once() + mock_connect.return_value.close.assert_called_once() + mock_selector.return_value.unregister.assert_called_once_with(mock_connect.return_value) + mock_selector.return_value.close.assert_called_once() From defbaa2b7101b4da7e1945ceea6a36b6fa37d790 Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Mon, 24 Jan 2022 19:10:59 +0530 Subject: [PATCH 3/3] type info --- tests/http/websocket/test_websocket_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/http/websocket/test_websocket_client.py b/tests/http/websocket/test_websocket_client.py index 981392200f..92df52eb93 100644 --- a/tests/http/websocket/test_websocket_client.py +++ b/tests/http/websocket/test_websocket_client.py @@ -58,7 +58,7 @@ def test_send_recv_frames_success( mock_connect: mock.Mock, mock_selector: mock.Mock, mock_b64encode: mock.Mock, - ): + ) -> None: key = b'MySecretKey' mock_b64encode.return_value = key mock_connect.return_value.recv.side_effect = [ @@ -68,7 +68,7 @@ def test_send_recv_frames_success( WebsocketFrame.text(b'world'), ] - def on_message(frame: WebsocketFrame): + def on_message(frame: WebsocketFrame) -> None: assert frame.build() == WebsocketFrame.text(b'world') client = WebsocketClient(