Skip to content

Commit

Permalink
Customize the input buffer size (#5065)
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov committed Oct 17, 2020
1 parent 61cc4ff commit aa4ef4b
Show file tree
Hide file tree
Showing 18 changed files with 235 additions and 67 deletions.
1 change: 1 addition & 0 deletions CHANGES/4453.feature
@@ -0,0 +1 @@
Allow configuring the sbuffer size of input stream by passing ``read_bufsize`` argument.
22 changes: 14 additions & 8 deletions aiohttp/_http_parser.pyx
Expand Up @@ -303,6 +303,7 @@ cdef class HttpParser:
object _payload_exception
object _last_error
bint _auto_decompress
int _limit

str _content_encoding

Expand All @@ -324,7 +325,8 @@ cdef class HttpParser:
PyMem_Free(self._csettings)

cdef _init(self, cparser.http_parser_type mode,
object protocol, object loop, object timer=None,
object protocol, object loop, int limit,
object timer=None,
size_t max_line_size=8190, size_t max_headers=32768,
size_t max_field_size=8190, payload_exception=None,
bint response_with_body=True, bint read_until_eof=False,
Expand Down Expand Up @@ -370,6 +372,7 @@ cdef class HttpParser:
self._csettings.on_chunk_complete = cb_on_chunk_complete

self._last_error = None
self._limit = limit

cdef _process_header(self):
if self._raw_name:
Expand Down Expand Up @@ -454,7 +457,8 @@ cdef class HttpParser:
self._read_until_eof)
):
payload = StreamReader(
self._protocol, timer=self._timer, loop=self._loop)
self._protocol, timer=self._timer, loop=self._loop,
limit=self._limit)
else:
payload = EMPTY_PAYLOAD

Expand Down Expand Up @@ -563,11 +567,12 @@ cdef class HttpParser:

cdef class HttpRequestParser(HttpParser):

def __init__(self, protocol, loop, timer=None,
def __init__(self, protocol, loop, int limit, timer=None,
size_t max_line_size=8190, size_t max_headers=32768,
size_t max_field_size=8190, payload_exception=None,
bint response_with_body=True, bint read_until_eof=False):
self._init(cparser.HTTP_REQUEST, protocol, loop, timer,
bint response_with_body=True, bint read_until_eof=False,
):
self._init(cparser.HTTP_REQUEST, protocol, loop, limit, timer,
max_line_size, max_headers, max_field_size,
payload_exception, response_with_body, read_until_eof)

Expand All @@ -590,12 +595,13 @@ cdef class HttpRequestParser(HttpParser):

cdef class HttpResponseParser(HttpParser):

def __init__(self, protocol, loop, timer=None,
def __init__(self, protocol, loop, int limit, timer=None,
size_t max_line_size=8190, size_t max_headers=32768,
size_t max_field_size=8190, payload_exception=None,
bint response_with_body=True, bint read_until_eof=False,
bint auto_decompress=True):
self._init(cparser.HTTP_RESPONSE, protocol, loop, timer,
bint auto_decompress=True
):
self._init(cparser.HTTP_RESPONSE, protocol, loop, limit, timer,
max_line_size, max_headers, max_field_size,
payload_exception, response_with_body, read_until_eof,
auto_decompress)
Expand Down
25 changes: 18 additions & 7 deletions aiohttp/client.py
Expand Up @@ -180,7 +180,8 @@ class ClientSession:
'_timeout', '_raise_for_status', '_auto_decompress',
'_trust_env', '_default_headers', '_skip_auto_headers',
'_request_class', '_response_class',
'_ws_response_class', '_trace_configs'])
'_ws_response_class', '_trace_configs',
'_read_bufsize'])

_source_traceback = None

Expand All @@ -204,7 +205,8 @@ def __init__(self, *, connector: Optional[BaseConnector]=None,
auto_decompress: bool=True,
trust_env: bool=False,
requote_redirect_url: bool=True,
trace_configs: Optional[List[TraceConfig]]=None) -> None:
trace_configs: Optional[List[TraceConfig]]=None,
read_bufsize: int=2**16) -> None:

if loop is None:
if connector is not None:
Expand Down Expand Up @@ -265,6 +267,7 @@ def __init__(self, *, connector: Optional[BaseConnector]=None,
self._auto_decompress = auto_decompress
self._trust_env = trust_env
self._requote_redirect_url = requote_redirect_url
self._read_bufsize = read_bufsize

# Convert to list of tuples
if headers:
Expand Down Expand Up @@ -349,7 +352,8 @@ async def _request(
ssl_context: Optional[SSLContext]=None,
ssl: Optional[Union[SSLContext, bool, Fingerprint]]=None,
proxy_headers: Optional[LooseHeaders]=None,
trace_request_ctx: Optional[SimpleNamespace]=None
trace_request_ctx: Optional[SimpleNamespace]=None,
read_bufsize: Optional[int] = None
) -> ClientResponse:

# NOTE: timeout clamps existing connect and read timeouts. We cannot
Expand Down Expand Up @@ -407,6 +411,9 @@ async def _request(
tm = TimeoutHandle(self._loop, real_timeout.total)
handle = tm.start()

if read_bufsize is None:
read_bufsize = self._read_bufsize

traces = [
Trace(
self,
Expand Down Expand Up @@ -498,7 +505,8 @@ async def _request(
skip_payload=method.upper() == 'HEAD',
read_until_eof=read_until_eof,
auto_decompress=self._auto_decompress,
read_timeout=real_timeout.sock_read)
read_timeout=real_timeout.sock_read,
read_bufsize=read_bufsize)

try:
try:
Expand Down Expand Up @@ -805,7 +813,7 @@ async def _ws_connect(
transport = conn.transport
assert transport is not None
reader = FlowControlDataQueue(
conn_proto, limit=2 ** 16, loop=self._loop) # type: FlowControlDataQueue[WSMessage] # noqa
conn_proto, 2 ** 16, loop=self._loop) # type: FlowControlDataQueue[WSMessage] # noqa
conn_proto.set_parser(
WebSocketReader(reader, max_msg_size), reader)
writer = WebSocketWriter(
Expand Down Expand Up @@ -1149,6 +1157,7 @@ def request(
cookies: Optional[LooseCookies]=None,
version: HttpVersion=http.HttpVersion11,
connector: Optional[BaseConnector]=None,
read_bufsize: Optional[int] = None,
loop: Optional[asyncio.AbstractEventLoop]=None
) -> _SessionRequestContextManager:
"""Constructs and sends a request. Returns response object.
Expand Down Expand Up @@ -1210,5 +1219,7 @@ def request(
raise_for_status=raise_for_status,
read_until_eof=read_until_eof,
proxy=proxy,
proxy_auth=proxy_auth,),
session)
proxy_auth=proxy_auth,
read_bufsize=read_bufsize),
session
)
5 changes: 3 additions & 2 deletions aiohttp/client_proto.py
Expand Up @@ -137,14 +137,15 @@ def set_response_params(self, *, timer: BaseTimerContext=None,
skip_payload: bool=False,
read_until_eof: bool=False,
auto_decompress: bool=True,
read_timeout: Optional[float]=None) -> None:
read_timeout: Optional[float]=None,
read_bufsize: int = 2 ** 16) -> None:
self._skip_payload = skip_payload

self._read_timeout = read_timeout
self._reschedule_timeout()

self._parser = HttpResponseParser(
self, self._loop, timer=timer,
self, self._loop, read_bufsize, timer=timer,
payload_exception=ClientPayloadError,
response_with_body=not skip_payload,
read_until_eof=read_until_eof,
Expand Down
11 changes: 8 additions & 3 deletions aiohttp/http_parser.py
Expand Up @@ -168,6 +168,7 @@ class HttpParser(abc.ABC):

def __init__(self, protocol: Optional[BaseProtocol]=None,
loop: Optional[asyncio.AbstractEventLoop]=None,
limit: int=2**16,
max_line_size: int=8190,
max_headers: int=32768,
max_field_size: int=8190,
Expand Down Expand Up @@ -198,6 +199,7 @@ def __init__(self, protocol: Optional[BaseProtocol]=None,
self._payload = None
self._payload_parser = None # type: Optional[HttpPayloadParser]
self._auto_decompress = auto_decompress
self._limit = limit
self._headers_parser = HeadersParser(max_line_size,
max_headers,
max_field_size)
Expand Down Expand Up @@ -288,7 +290,8 @@ def feed_data(
if ((length is not None and length > 0) or
msg.chunked and not msg.upgrade):
payload = StreamReader(
self.protocol, timer=self.timer, loop=loop)
self.protocol, timer=self.timer, loop=loop,
limit=self._limit)
payload_parser = HttpPayloadParser(
payload, length=length,
chunked=msg.chunked, method=method,
Expand All @@ -300,7 +303,8 @@ def feed_data(
self._payload_parser = payload_parser
elif method == METH_CONNECT:
payload = StreamReader(
self.protocol, timer=self.timer, loop=loop)
self.protocol, timer=self.timer, loop=loop,
limit=self._limit)
self._upgraded = True
self._payload_parser = HttpPayloadParser(
payload, method=msg.method,
Expand All @@ -310,7 +314,8 @@ def feed_data(
if (getattr(msg, 'code', 100) >= 199 and
length is None and self.read_until_eof):
payload = StreamReader(
self.protocol, timer=self.timer, loop=loop)
self.protocol, timer=self.timer, loop=loop,
limit=self._limit)
payload_parser = HttpPayloadParser(
payload, length=length,
chunked=msg.chunked, method=method,
Expand Down
10 changes: 5 additions & 5 deletions aiohttp/payload.py
Expand Up @@ -33,7 +33,7 @@
parse_mimetype,
sentinel,
)
from .streams import DEFAULT_LIMIT, StreamReader
from .streams import StreamReader
from .typedefs import JSONEncoder, _CIMultiDict

__all__ = ('PAYLOAD_REGISTRY', 'get_payload', 'payload_type', 'Payload',
Expand Down Expand Up @@ -295,12 +295,12 @@ async def write(self, writer: AbstractStreamWriter) -> None:
loop = asyncio.get_event_loop()
try:
chunk = await loop.run_in_executor(
None, self._value.read, DEFAULT_LIMIT
None, self._value.read, 2**16
)
while chunk:
await writer.write(chunk)
chunk = await loop.run_in_executor(
None, self._value.read, DEFAULT_LIMIT
None, self._value.read, 2**16
)
finally:
await loop.run_in_executor(None, self._value.close)
Expand Down Expand Up @@ -345,12 +345,12 @@ async def write(self, writer: AbstractStreamWriter) -> None:
loop = asyncio.get_event_loop()
try:
chunk = await loop.run_in_executor(
None, self._value.read, DEFAULT_LIMIT
None, self._value.read, 2**16
)
while chunk:
await writer.write(chunk.encode(self._encoding))
chunk = await loop.run_in_executor(
None, self._value.read, DEFAULT_LIMIT
None, self._value.read, 2**16
)
finally:
await loop.run_in_executor(None, self._value.close)
Expand Down
14 changes: 7 additions & 7 deletions aiohttp/streams.py
Expand Up @@ -17,8 +17,6 @@
'EMPTY_PAYLOAD', 'EofStream', 'StreamReader', 'DataQueue',
'FlowControlDataQueue')

DEFAULT_LIMIT = 2 ** 16

_T = TypeVar('_T')


Expand Down Expand Up @@ -105,8 +103,7 @@ class StreamReader(AsyncStreamReaderMixin):

total_bytes = 0

def __init__(self, protocol: BaseProtocol,
*, limit: int=DEFAULT_LIMIT,
def __init__(self, protocol: BaseProtocol, limit: int, *,
timer: Optional[BaseTimerContext]=None,
loop: Optional[asyncio.AbstractEventLoop]=None) -> None:
self._protocol = protocol
Expand All @@ -133,14 +130,17 @@ def __repr__(self) -> str:
info.append('%d bytes' % self._size)
if self._eof:
info.append('eof')
if self._low_water != DEFAULT_LIMIT:
if self._low_water != 2 ** 16: # default limit
info.append('low=%d high=%d' % (self._low_water, self._high_water))
if self._waiter:
info.append('w=%r' % self._waiter)
if self._exception:
info.append('e=%r' % self._exception)
return '<%s>' % ' '.join(info)

def get_read_buffer_limits(self) -> Tuple[int, int]:
return (self._low_water, self._high_water)

def exception(self) -> Optional[BaseException]:
return self._exception

Expand Down Expand Up @@ -612,8 +612,8 @@ class FlowControlDataQueue(DataQueue[_T]):
It is a destination for parsed data."""

def __init__(self, protocol: BaseProtocol, *,
limit: int=DEFAULT_LIMIT,
def __init__(self, protocol: BaseProtocol,
limit: int, *,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(loop=loop)

Expand Down
5 changes: 3 additions & 2 deletions aiohttp/web_protocol.py
Expand Up @@ -129,7 +129,8 @@ def __init__(self, manager: 'Server', *,
max_line_size: int=8190,
max_headers: int=32768,
max_field_size: int=8190,
lingering_time: float=10.0):
lingering_time: float=10.0,
read_bufsize: int=2 ** 16):

super().__init__(loop)

Expand All @@ -156,7 +157,7 @@ def __init__(self, manager: 'Server', *,
self._upgrade = False
self._payload_parser = None # type: Any
self._request_parser = HttpRequestParser(
self, loop,
self, loop, read_bufsize,
max_line_size=max_line_size,
max_field_size=max_field_size,
max_headers=max_headers,
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/web_ws.py
Expand Up @@ -228,7 +228,7 @@ def _post_start(self, request: BaseRequest,
loop = self._loop
assert loop is not None
self._reader = FlowControlDataQueue(
request._protocol, limit=2 ** 16, loop=loop)
request._protocol, 2 ** 16, loop=loop)
request.protocol.set_parser(WebSocketReader(
self._reader, self._max_msg_size, compress=self._compress))
# disable HTTP keepalive for WebSocket
Expand Down

0 comments on commit aa4ef4b

Please sign in to comment.