Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions proxy/core/connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""
import logging
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import List, Union, Optional

from .types import tcpConnectionTypes
from ...common.types import TcpOrTlsSocket
Expand Down Expand Up @@ -47,7 +47,7 @@ def connection(self) -> TcpOrTlsSocket:
"""Must return the socket connection to use in this class."""
raise TcpConnectionUninitializedException() # pragma: no cover

def send(self, data: bytes) -> int:
def send(self, data: Union[memoryview, bytes]) -> int:
"""Users must handle BrokenPipeError exceptions"""
# logger.info(data)
return self.connection.send(data)
Expand Down Expand Up @@ -83,16 +83,16 @@ def flush(self, max_send_size: Optional[int] = None) -> int:
"""Users must handle BrokenPipeError exceptions"""
if not self.has_buffer():
return 0
mv = self.buffer[0].tobytes()
max_send_size = max_send_size or DEFAULT_MAX_SEND_SIZE
mv = self.buffer[0]
# TODO: Assemble multiple packets if total
# size remains below max send size.
max_send_size = max_send_size or DEFAULT_MAX_SEND_SIZE
sent: int = self.send(mv[:max_send_size])
if sent == len(mv):
self.buffer.pop(0)
self._num_buffer -= 1
else:
self.buffer[0] = memoryview(mv[sent:])
self.buffer[0] = mv[sent:]
del mv
logger.debug('flushed %d bytes to %s' % (sent, self.tag))
return sent
Expand Down
5 changes: 2 additions & 3 deletions proxy/http/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,9 @@ def _discover_plugin_klass(self, protocol: int) -> Optional[Type['HttpProtocolHa

def _parse_first_request(self, data: memoryview) -> bool:
# Parse http request
#
# TODO(abhinavsingh): Remove .tobytes after parser is
# memoryview compliant
try:
# TODO(abhinavsingh): Remove .tobytes after parser is
# memoryview compliant
self.request.parse(data.tobytes())
except HttpProtocolException as e: # noqa: WPS329
self.work.queue(BAD_REQUEST_RESPONSE_PKT)
Expand Down
6 changes: 4 additions & 2 deletions proxy/http/server/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,10 @@ async def read_from_descriptors(self, r: Readables) -> bool:

def on_client_data(self, raw: memoryview) -> None:
if self.switched_protocol == httpProtocolTypes.WEBSOCKET:
# TODO(abhinavsingh): Remove .tobytes after websocket frame parser
# is memoryview compliant
# TODO(abhinavsingh): Do we really tobytes() here?
# Websocket parser currently doesn't depend on internal
# buffers, due to which it can directly parse out of
# memory views. But how about large payloads scenarios?
remaining = raw.tobytes()
frame = WebsocketFrame()
while remaining != b'':
Expand Down
2 changes: 1 addition & 1 deletion proxy/http/websocket/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def run_once(self) -> bool:
if mask & selectors.EVENT_READ and self.on_message:
# TODO: client recvbuf size flag currently not used here
raw = self.recv()
if raw is None or raw.tobytes() == b'':
if raw is None or raw == b'':
self.closed = True
return True
frame = WebsocketFrame()
Expand Down
2 changes: 1 addition & 1 deletion tests/http/test_protocol_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ async def assert_data_queued(
CRLF,
])
server.queue.assert_called_once()
self.assertEqual(server.queue.call_args_list[0][0][0].tobytes(), pkt)
self.assertEqual(server.queue.call_args_list[0][0][0], pkt)
server.buffer_size.return_value = len(pkt)

async def assert_data_queued_to_server(self, server: mock.Mock) -> None:
Expand Down