Skip to content

Commit

Permalink
Client timeouts (#2972)
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov committed May 11, 2018
1 parent 8cbe05c commit 7d136fb
Show file tree
Hide file tree
Showing 14 changed files with 372 additions and 210 deletions.
1 change: 1 addition & 0 deletions CHANGES/2768.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement ``ClientTimeout`` class and support socket read timeout.
84 changes: 69 additions & 15 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import warnings
from collections.abc import Coroutine

import attr
from multidict import CIMultiDict, MultiDict, MultiDictProxy, istr
from yarl import URL

Expand Down Expand Up @@ -38,11 +39,33 @@
__all__ = (client_exceptions.__all__ + # noqa
client_reqrep.__all__ + # noqa
connector_mod.__all__ + # noqa
('ClientSession', 'ClientWebSocketResponse', 'request'))
('ClientSession', 'ClientTimeout',
'ClientWebSocketResponse', 'request'))


# 5 Minute default read and connect timeout
DEFAULT_TIMEOUT = 5 * 60
@attr.s(frozen=True, slots=True)
class ClientTimeout:
total = attr.ib(type=float, default=None)
connect = attr.ib(type=float, default=None)
sock_read = attr.ib(type=float, default=None)
sock_connect = attr.ib(type=float, default=None)

# pool_queue_timeout = attr.ib(type=float, default=None)
# dns_resolution_timeout = attr.ib(type=float, default=None)
# socket_connect_timeout = attr.ib(type=float, default=None)
# connection_acquiring_timeout = attr.ib(type=float, default=None)
# new_connection_timeout = attr.ib(type=float, default=None)
# http_header_timeout = attr.ib(type=float, default=None)
# response_body_timeout = attr.ib(type=float, default=None)

# to create a timeout specific for a single request, either
# - create a completely new one to overwrite the default
# - or use http://www.attrs.org/en/stable/api.html#attr.evolve
# to overwrite the defaults


# 5 Minute default read timeout
DEFAULT_TIMEOUT = ClientTimeout(total=5*60)


class ClientSession:
Expand All @@ -52,8 +75,8 @@ class ClientSession:
'_source_traceback', '_connector',
'requote_redirect_url', '_loop', '_cookie_jar',
'_connector_owner', '_default_auth',
'_version', '_json_serialize', '_read_timeout',
'_conn_timeout', '_raise_for_status', '_auto_decompress',
'_version', '_json_serialize',
'_timeout', '_raise_for_status', '_auto_decompress',
'_trust_env', '_default_headers', '_skip_auto_headers',
'_request_class', '_response_class',
'_ws_response_class', '_trace_configs'])
Expand All @@ -71,6 +94,7 @@ def __init__(self, *, connector=None, loop=None, cookies=None,
version=http.HttpVersion11,
cookie_jar=None, connector_owner=True, raise_for_status=False,
read_timeout=sentinel, conn_timeout=None,
timeout=sentinel,
auto_decompress=True, trust_env=False,
trace_configs=None):

Expand Down Expand Up @@ -117,9 +141,26 @@ def __init__(self, *, connector=None, loop=None, cookies=None,
self._default_auth = auth
self._version = version
self._json_serialize = json_serialize
self._read_timeout = (read_timeout if read_timeout is not sentinel
else DEFAULT_TIMEOUT)
self._conn_timeout = conn_timeout
if timeout is not sentinel:
self._timeout = timeout
else:
self._timeout = DEFAULT_TIMEOUT
if read_timeout is not sentinel:
if timeout is not sentinel:
raise ValueError("read_timeout and timeout parameters "
"conflict, please setup "
"timeout.read")
else:
self._timeout = attr.evolve(self._timeout,
total=read_timeout)
if conn_timeout is not None:
if timeout is not sentinel:
raise ValueError("conn_timeout and timeout parameters "
"conflict, please setup "
"timeout.connect")
else:
self._timeout = attr.evolve(self._timeout,
connect=conn_timeout)
self._raise_for_status = raise_for_status
self._auto_decompress = auto_decompress
self._trust_env = trust_env
Expand Down Expand Up @@ -244,11 +285,14 @@ async def _request(self, method, url, *,
except ValueError:
raise InvalidURL(proxy)

if timeout is sentinel:
timeout = self._timeout
else:
if not isinstance(timeout, ClientTimeout):
timeout = ClientTimeout(total=timeout)
# timeout is cumulative for all request operations
# (request, redirects, responses, data consuming)
tm = TimeoutHandle(
self._loop,
timeout if timeout is not sentinel else self._read_timeout)
tm = TimeoutHandle(self._loop, timeout.total)
handle = tm.start()

traces = [
Expand Down Expand Up @@ -309,15 +353,17 @@ async def _request(self, method, url, *,
expect100=expect100, loop=self._loop,
response_class=self._response_class,
proxy=proxy, proxy_auth=proxy_auth, timer=timer,
session=self, auto_decompress=self._auto_decompress,
session=self,
ssl=ssl, proxy_headers=proxy_headers, traces=traces)

# connection timeout
try:
with CeilTimeout(self._conn_timeout, loop=self._loop):
with CeilTimeout(self._timeout.connect,
loop=self._loop):
conn = await self._connector.connect(
req,
traces=traces
traces=traces,
timeout=timeout
)
except asyncio.TimeoutError as exc:
raise ServerTimeoutError(
Expand All @@ -326,11 +372,19 @@ async def _request(self, method, url, *,

tcp_nodelay(conn.transport, True)
tcp_cork(conn.transport, False)

conn.protocol.set_response_params(
timer=timer,
skip_payload=method.upper() == 'HEAD',
read_until_eof=read_until_eof,
auto_decompress=self._auto_decompress,
read_timeout=timeout.sock_read)

try:
try:
resp = await req.send(conn)
try:
await resp.start(conn, read_until_eof)
await resp.start(conn)
except BaseException:
resp.close()
raise
Expand Down
55 changes: 49 additions & 6 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .base_protocol import BaseProtocol
from .client_exceptions import (ClientOSError, ClientPayloadError,
ServerDisconnectedError)
ServerDisconnectedError, ServerTimeoutError)
from .http import HttpResponseParser
from .streams import EMPTY_PAYLOAD, DataQueue

Expand All @@ -16,7 +16,6 @@ def __init__(self, *, loop=None):

self._should_close = False

self._message = None
self._payload = None
self._skip_payload = False
self._payload_parser = None
Expand All @@ -28,6 +27,9 @@ def __init__(self, *, loop=None):
self._upgraded = False
self._parser = None

self._read_timeout = None
self._read_timeout_handle = None

@property
def upgraded(self):
return self._upgraded
Expand Down Expand Up @@ -55,6 +57,8 @@ def is_connected(self):
return self.transport is not None

def connection_lost(self, exc):
self._drop_timeout()

if self._payload_parser is not None:
with suppress(Exception):
self._payload_parser.feed_eof()
Expand All @@ -78,15 +82,15 @@ def connection_lost(self, exc):

self._should_close = True
self._parser = None
self._message = None
self._payload = None
self._payload_parser = None
self._reading_paused = False

super().connection_lost(exc)

def eof_received(self):
pass
# should call parser.feed_eof() most likely
self._drop_timeout()

def pause_reading(self):
if not self._reading_paused:
Expand All @@ -95,6 +99,7 @@ def pause_reading(self):
except (AttributeError, NotImplementedError, RuntimeError):
pass
self._reading_paused = True
self._drop_timeout()

def resume_reading(self):
if self._reading_paused:
Expand All @@ -103,24 +108,33 @@ def resume_reading(self):
except (AttributeError, NotImplementedError, RuntimeError):
pass
self._reading_paused = False
self._reschedule_timeout()

def set_exception(self, exc):
self._should_close = True
self._drop_timeout()
super().set_exception(exc)

def set_parser(self, parser, payload):
self._payload = payload
self._payload_parser = parser

self._drop_timeout()

if self._tail:
data, self._tail = self._tail, b''
self.data_received(data)

def set_response_params(self, *, timer=None,
skip_payload=False,
read_until_eof=False,
auto_decompress=True):
auto_decompress=True,
read_timeout=None):
self._skip_payload = skip_payload

self._read_timeout = read_timeout
self._reschedule_timeout()

self._parser = HttpResponseParser(
self, self._loop, timer=timer,
payload_exception=ClientPayloadError,
Expand All @@ -131,6 +145,26 @@ def set_response_params(self, *, timer=None,
data, self._tail = self._tail, b''
self.data_received(data)

def _drop_timeout(self):
if self._read_timeout_handle is not None:
self._read_timeout_handle.cancel()
self._read_timeout_handle = None

def _reschedule_timeout(self):
timeout = self._read_timeout
if self._read_timeout_handle is not None:
self._read_timeout_handle.cancel()

if timeout:
self._read_timeout_handle = self._loop.call_later(
timeout, self._on_read_timeout)
else:
self._read_timeout_handle = None

def _on_read_timeout(self):
self.set_exception(
ServerTimeoutError("Timeout on reading data from socket"))

def data_received(self, data):
if not data:
return
Expand Down Expand Up @@ -161,17 +195,26 @@ def data_received(self, data):

self._upgraded = upgraded

payload = None
for message, payload in messages:
if message.should_close:
self._should_close = True

self._message = message
self._payload = payload

if self._skip_payload or message.code in (204, 304):
self.feed_data((message, EMPTY_PAYLOAD), 0)
else:
self.feed_data((message, payload), 0)
if payload is not None:
# new message(s) was processed
# register timeout handler unsubscribing
# either on end-of-stream or immediatelly for
# EMPTY_PAYLOAD
if payload is not EMPTY_PAYLOAD:
payload.on_eof(self._drop_timeout)
else:
self._drop_timeout()

if tail:
if upgraded:
Expand Down
17 changes: 4 additions & 13 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def __init__(self, method, url, *,
chunked=None, expect100=False,
loop=None, response_class=None,
proxy=None, proxy_auth=None,
timer=None, session=None, auto_decompress=True,
timer=None, session=None,
ssl=None,
proxy_headers=None,
traces=None):
Expand All @@ -214,7 +214,6 @@ def __init__(self, method, url, *,
self.length = None
self.response_class = response_class or ClientResponse
self._timer = timer if timer is not None else TimerNoop()
self._auto_decompress = auto_decompress
self._ssl = ssl

if loop.get_debug():
Expand Down Expand Up @@ -551,7 +550,6 @@ async def send(self, conn):
self.method, self.original_url,
writer=self._writer, continue100=self._continue, timer=self._timer,
request_info=self.request_info,
auto_decompress=self._auto_decompress,
traces=self._traces,
loop=self.loop,
session=self._session
Expand Down Expand Up @@ -597,7 +595,7 @@ class ClientResponse(HeadersMixin):

def __init__(self, method, url, *,
writer, continue100, timer,
request_info, auto_decompress,
request_info,
traces, loop, session):
assert isinstance(url, URL)

Expand All @@ -614,7 +612,6 @@ def __init__(self, method, url, *,
self._history = ()
self._request_info = request_info
self._timer = timer if timer is not None else TimerNoop()
self._auto_decompress = auto_decompress # True by default
self._cache = {} # required for @reify method decorator
self._traces = traces
self._loop = loop
Expand Down Expand Up @@ -735,23 +732,17 @@ def links(self):

return MultiDictProxy(links)

async def start(self, connection, read_until_eof=False):
async def start(self, connection):
"""Start response processing."""
self._closed = False
self._protocol = connection.protocol
self._connection = connection

connection.protocol.set_response_params(
timer=self._timer,
skip_payload=self.method.lower() == 'head',
read_until_eof=read_until_eof,
auto_decompress=self._auto_decompress)

with self._timer:
while True:
# read response
try:
(message, payload) = await self._protocol.read()
message, payload = await self._protocol.read()
except http.HttpProcessingError as exc:
raise ClientResponseError(
self.request_info, self.history,
Expand Down
Loading

0 comments on commit 7d136fb

Please sign in to comment.