From 5999888688251c4b50dc27782d9baf634b6ec110 Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Sat, 13 Jun 2020 15:54:10 +0530 Subject: [PATCH] Add DEFAULT_MAX_SEND_SIZE and handle SSLWantWriteError errors when dispatching to upstream servers --- proxy/common/constants.py | 1 + proxy/core/connection/connection.py | 8 ++++---- proxy/http/proxy/server.py | 9 ++++++--- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/proxy/common/constants.py b/proxy/common/constants.py index a4a32bf19e..fc9dac9372 100644 --- a/proxy/common/constants.py +++ b/proxy/common/constants.py @@ -74,3 +74,4 @@ DEFAULT_TIMEOUT = 10 DEFAULT_VERSION = False DEFAULT_HTTP_PORT = 80 +DEFAULT_MAX_SEND_SIZE = 16 * 1024 diff --git a/proxy/core/connection/connection.py b/proxy/core/connection/connection.py index 29c15c1e34..3aa72eebca 100644 --- a/proxy/core/connection/connection.py +++ b/proxy/core/connection/connection.py @@ -14,7 +14,7 @@ from abc import ABC, abstractmethod from typing import NamedTuple, Optional, Union, List -from ...common.constants import DEFAULT_BUFFER_SIZE +from ...common.constants import DEFAULT_BUFFER_SIZE, DEFAULT_MAX_SEND_SIZE logger = logging.getLogger(__name__) @@ -82,11 +82,11 @@ def flush(self) -> int: """Users must handle BrokenPipeError exceptions""" if not self.has_buffer(): return 0 - mv = self.buffer[0] - sent: int = self.send(mv.tobytes()) + mv = self.buffer[0].tobytes() + sent: int = self.send(mv[:DEFAULT_MAX_SEND_SIZE]) if sent == len(mv): self.buffer.pop(0) else: - self.buffer[0] = memoryview(mv.tobytes()[sent:]) + self.buffer[0] = memoryview(mv[sent:]) logger.debug('flushed %d bytes to %s' % (sent, self.tag)) return sent diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index e028c5be43..4b885d80ad 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -89,12 +89,15 @@ def write_to_descriptors(self, w: List[Union[int, HasFileno]]) -> bool: logger.debug('Server is write ready, flushing buffer') try: self.server.flush() + except ssl.SSLWantWriteError: + logger.warning('SSLWantWriteError while trying to flush to server, will retry') + return False except BrokenPipeError: logger.error( 'BrokenPipeError when flushing buffer for server') return True - except OSError: - logger.error('OSError when flushing buffer to server') + except OSError as e: + logger.exception('OSError when flushing buffer to server', exc_info=e) return True return False @@ -207,7 +210,6 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: if self.request.state == httpParserStates.COMPLETE and ( self.request.method != httpMethods.CONNECT or self.flags.tls_interception_enabled()): - if self.pipeline_request is not None and \ self.pipeline_request.is_connection_upgrade(): # Previous pipelined request was a WebSocket @@ -219,6 +221,7 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]: if self.pipeline_request is None: self.pipeline_request = HttpParser( httpParserTypes.REQUEST_PARSER) + # TODO(abhinavsingh): Remove .tobytes after parser is # memoryview compliant self.pipeline_request.parse(raw.tobytes())