Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion proxy/core/acceptor/acceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def run(self) -> None:
for fileno in self.socks:
self.socks[fileno].close()
self.socks.clear()
self.selector.close()
logger.debug('Acceptor#%d shutdown', self.idd)

def _recv_and_setup_socks(self) -> None:
Expand Down Expand Up @@ -207,7 +208,8 @@ def _start_local(self) -> None:
self._lthread.start()

def _stop_local(self) -> None:
if self._lthread is not None and self._local_work_queue is not None:
if self._lthread is not None and \
self._local_work_queue is not None:
self._local_work_queue.put(False)
self._lthread.join()

Expand Down
1 change: 1 addition & 0 deletions proxy/core/work/threadless.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def run(self) -> None:
if wqfileno is not None:
self.selector.unregister(wqfileno)
self.close_work_queue()
self.selector.close()
assert self.loop is not None
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
self.loop.close()
2 changes: 2 additions & 0 deletions proxy/http/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ def run(self) -> None:
)
finally:
self.shutdown()
if self.selector:
self.selector.close()
loop.close()

async def _run_once(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion proxy/http/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def response(cls: Type[T], raw: bytes) -> T:
def header(self, key: bytes) -> bytes:
"""Convenient method to return original header value from internal data structure."""
if self.headers is None or key.lower() not in self.headers:
raise KeyError('%s not found in headers', text_(key))
raise KeyError('%s not found in headers' % text_(key))
return self.headers[key.lower()][1]

def has_header(self, key: bytes) -> bool:
Expand Down
14 changes: 8 additions & 6 deletions proxy/http/websocket/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@


class WebsocketClient(TcpConnection):
"""Websocket client connection.

TODO: Make me compatible with the work framework."""

def __init__(
self,
Expand Down Expand Up @@ -57,10 +60,14 @@ def connection(self) -> TcpOrTlsSocket:
return self.sock

def handshake(self) -> None:
"""Start websocket upgrade & handshake protocol"""
self.upgrade()
self.sock.setblocking(False)

def upgrade(self) -> None:
"""Creates a key and sends websocket handshake packet to upstream.
Receives response from the server and asserts that websocket
accept header is valid in the response."""
key = base64.b64encode(secrets.token_bytes(16))
self.sock.send(
build_websocket_handshake_request(
Expand All @@ -74,12 +81,6 @@ def upgrade(self) -> None:
accept = response.header(b'Sec-Websocket-Accept')
assert WebsocketFrame.key_to_accept(key) == accept

def ping(self, data: Optional[bytes] = None) -> None:
pass # pragma: no cover

def pong(self, data: Optional[bytes] = None) -> None:
pass # pragma: no cover

def shutdown(self, _data: Optional[bytes] = None) -> None:
"""Closes connection with the server."""
super().close()
Expand Down Expand Up @@ -121,3 +122,4 @@ def run(self) -> None:
except OSError:
pass
self.sock.close()
self.selector.close()
77 changes: 71 additions & 6 deletions tests/http/websocket/test_websocket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,28 @@
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import selectors

import unittest
from unittest import mock

from proxy.common.utils import (
build_websocket_handshake_request, build_websocket_handshake_response,
)
from proxy.http.websocket import WebsocketFrame, WebsocketClient
from proxy.common.constants import DEFAULT_PORT
from proxy.common.constants import DEFAULT_PORT, DEFAULT_BUFFER_SIZE


class TestWebsocketClient(unittest.TestCase):

@mock.patch('proxy.http.websocket.client.socket.gethostbyname')
@mock.patch('base64.b64encode')
@mock.patch('proxy.http.websocket.client.socket.gethostbyname')
@mock.patch('proxy.http.websocket.client.new_socket_connection')
def test_handshake(
self, mock_connect: mock.Mock,
mock_b64encode: mock.Mock,
mock_gethostbyname: mock.Mock,
def test_handshake_success(
self,
mock_connect: mock.Mock,
mock_gethostbyname: mock.Mock,
mock_b64encode: mock.Mock,
) -> None:
key = b'MySecretKey'
mock_b64encode.return_value = key
Expand All @@ -35,9 +38,71 @@ def test_handshake(
build_websocket_handshake_response(
WebsocketFrame.key_to_accept(key),
)
mock_connect.assert_not_called()
client = WebsocketClient(b'localhost', DEFAULT_PORT)
mock_connect.assert_called_once()
mock_connect.return_value.send.assert_not_called()
client.handshake()
mock_connect.return_value.send.assert_called_with(
build_websocket_handshake_request(key),
)
mock_connect.return_value.recv.assert_called_once_with(
DEFAULT_BUFFER_SIZE,
)

@mock.patch('base64.b64encode')
@mock.patch('selectors.DefaultSelector')
@mock.patch('proxy.http.websocket.client.new_socket_connection')
def test_send_recv_frames_success(
self,
mock_connect: mock.Mock,
mock_selector: mock.Mock,
mock_b64encode: mock.Mock,
) -> None:
key = b'MySecretKey'
mock_b64encode.return_value = key
mock_connect.return_value.recv.side_effect = [
build_websocket_handshake_response(
WebsocketFrame.key_to_accept(key),
),
WebsocketFrame.text(b'world'),
]

def on_message(frame: WebsocketFrame) -> None:
assert frame.build() == WebsocketFrame.text(b'world')

client = WebsocketClient(
b'localhost', DEFAULT_PORT, on_message=on_message,
)
mock_selector.assert_called_once()
client.handshake()
client.queue(memoryview(WebsocketFrame.text(b'hello')))
mock_connect.return_value.send.assert_called_once()
mock_selector.return_value.select.side_effect = [
[
(mock.Mock(), selectors.EVENT_WRITE),
],
]
client.run_once()
self.assertEqual(mock_connect.return_value.send.call_count, 2)
mock_selector.return_value.select.side_effect = [
[
(mock.Mock(), selectors.EVENT_READ),
],
]
client.run_once()

@mock.patch('selectors.DefaultSelector')
@mock.patch('proxy.http.websocket.client.new_socket_connection')
def test_run(
self,
mock_connect: mock.Mock,
mock_selector: mock.Mock,
) -> None:
mock_selector.return_value.select.side_effect = KeyboardInterrupt
client = WebsocketClient(b'localhost', DEFAULT_PORT)
client.run()
mock_connect.return_value.shutdown.assert_called_once()
mock_connect.return_value.close.assert_called_once()
mock_selector.return_value.unregister.assert_called_once_with(mock_connect.return_value)
mock_selector.return_value.close.assert_called_once()