diff --git a/aiohttp/errors.py b/aiohttp/errors.py index 2162e469168..35f848e0bb1 100644 --- a/aiohttp/errors.py +++ b/aiohttp/errors.py @@ -159,11 +159,7 @@ def __init__(self, line=''): self.line = line -class ParserError(Exception): - """Base parser error.""" - - -class LineLimitExceededParserError(ParserError): +class LineLimitExceededParserError(HttpBadRequest): """Line is too long.""" def __init__(self, msg, limit): diff --git a/aiohttp/server.py b/aiohttp/server.py index c24d29f4058..e28095badc3 100644 --- a/aiohttp/server.py +++ b/aiohttp/server.py @@ -6,11 +6,10 @@ import traceback import warnings from html import escape as html_escape -from math import ceil import aiohttp from aiohttp import errors, hdrs, helpers, streams -from aiohttp.helpers import _get_kwarg, ensure_future +from aiohttp.helpers import Timeout, _get_kwarg, ensure_future from aiohttp.log import access_logger, server_logger __all__ = ('ServerHttpProtocol',) @@ -53,15 +52,11 @@ class ServerHttpProtocol(aiohttp.StreamProtocol): :param keepalive_timeout: number of seconds before closing keep-alive connection - :type keepalive: int or None + :type keepalive_timeout: int or None :param bool tcp_keepalive: TCP keep-alive is on, default is on - :param int timeout: slow request timeout - - :param allowed_methods: (optional) List of allowed request methods. - Set to empty list to allow all methods. - :type allowed_methods: tuple + :param int slow_request_timeout: slow request timeout :param bool debug: enable debug mode @@ -86,8 +81,6 @@ class ServerHttpProtocol(aiohttp.StreamProtocol): _request_handler = None _reading_request = False _keep_alive = False # keep transport open - _keep_alive_handle = None # keep alive timer handle - _slow_request_timeout_handle = None # slow request timer handle def __init__(self, *, loop=None, keepalive_timeout=75, # NGINX default value is 75 secs @@ -138,6 +131,7 @@ def __init__(self, *, loop=None, access_log_format) else: self.access_logger = None + self._closing = False @property def keep_alive_timeout(self): @@ -157,6 +151,7 @@ def closing(self, timeout=15.0): self._keep_alive = False self._tcp_keep_alive = False self._keepalive_timeout = None + self._closing = True if (not self._reading_request and self.transport is not None): if self._request_handler: @@ -165,27 +160,12 @@ def closing(self, timeout=15.0): self.transport.close() self.transport = None - elif self.transport is not None and timeout: - if self._slow_request_timeout_handle is not None: - self._slow_request_timeout_handle.cancel() - - # use slow request timeout for closing - # connection_lost cleans timeout handler - now = self._loop.time() - self._slow_request_timeout_handle = self._loop.call_at( - ceil(now+timeout), self.cancel_slow_request) def connection_made(self, transport): super().connection_made(transport) self._request_handler = ensure_future(self.start(), loop=self._loop) - # start slow request timer - if self._slow_request_timeout: - now = self._loop.time() - self._slow_request_timeout_handle = self._loop.call_at( - ceil(now+self._slow_request_timeout), self.cancel_slow_request) - if self._tcp_keepalive: tcp_keepalive(self, transport) @@ -195,12 +175,6 @@ def connection_lost(self, exc): if self._request_handler is not None: self._request_handler.cancel() self._request_handler = None - if self._keep_alive_handle is not None: - self._keep_alive_handle.cancel() - self._keep_alive_handle = None - if self._slow_request_timeout_handle is not None: - self._slow_request_timeout_handle.cancel() - self._slow_request_timeout_handle = None def data_received(self, data): super().data_received(data) @@ -209,11 +183,6 @@ def data_received(self, data): if not self._reading_request: self._reading_request = True - # stop keep-alive timer - if self._keep_alive_handle is not None: - self._keep_alive_handle.cancel() - self._keep_alive_handle = None - def keep_alive(self, val): """Set keep-alive connection mode. @@ -233,16 +202,6 @@ def log_debug(self, *args, **kw): def log_exception(self, *args, **kw): self.logger.exception(*args, **kw) - def cancel_slow_request(self): - if self._request_handler is not None: - self._request_handler.cancel() - self._request_handler = None - - if self.transport is not None: - self.transport.close() - - self.log_debug('Close slow request.') - @asyncio.coroutine def start(self): """Start processing of incoming requests. @@ -255,44 +214,35 @@ def start(self): """ reader = self.reader - while True: - message = None - self._keep_alive = False - self._request_count += 1 - self._reading_request = False - - payload = None - try: - # read HTTP request method - prefix = reader.set_parser(self._request_prefix) - yield from prefix.read() - - # start reading request - self._reading_request = True - - # start slow request timer - if (self._slow_request_timeout and - self._slow_request_timeout_handle is None): - now = self._loop.time() - self._slow_request_timeout_handle = self._loop.call_at( - ceil(now+self._slow_request_timeout), - self.cancel_slow_request) - - # read request headers - httpstream = reader.set_parser(self._request_parser) - message = yield from httpstream.read() - - # cancel slow request timer - if self._slow_request_timeout_handle is not None: - self._slow_request_timeout_handle.cancel() - self._slow_request_timeout_handle = None + try: + while not self._closing: + message = None + self._keep_alive = False + self._request_count += 1 + self._reading_request = False + + payload = None + with Timeout(max(self._slow_request_timeout, + self._keepalive_timeout), + loop=self._loop): + # read HTTP request method + prefix = reader.set_parser(self._request_prefix) + yield from prefix.read() + + # start reading request + self._reading_request = True + + # start slow request timer + # read request headers + httpstream = reader.set_parser(self._request_parser) + message = yield from httpstream.read() # request may not have payload try: content_length = int( message.headers.get(hdrs.CONTENT_LENGTH, 0)) except ValueError: - content_length = 0 + raise errors.InvalidHeader(hdrs.CONTENT_LENGTH) from None if (content_length > 0 or message.method == 'CONNECT' or @@ -308,55 +258,39 @@ def start(self): yield from self.handle_request(message, payload) - except asyncio.CancelledError: - return - except errors.ClientDisconnectedError: - self.log_debug( - 'Ignored premature client disconnection #1.') - return - except errors.HttpProcessingError as exc: - if self.transport is not None: - yield from self.handle_error(exc.code, message, - None, exc, exc.headers, - exc.message) - except errors.LineLimitExceededParserError as exc: - yield from self.handle_error(400, message, None, exc) - except Exception as exc: - yield from self.handle_error(500, message, None, exc) - finally: - if self.transport is None: - self.log_debug( - 'Ignored premature client disconnection #2.') - return - if payload and not payload.is_eof(): self.log_debug('Uncompleted request.') - self._request_handler = None - self.transport.close() - return + self._closing = True else: reader.unset_parser() - - if self._request_handler: - if self._keep_alive and self._keepalive_timeout: - self.log_debug( - 'Start keep-alive timer for %s sec.', - self._keepalive_timeout) - now = self._loop.time() - self._keep_alive_handle = self._loop.call_at( - ceil(now+self._keepalive_timeout), - self.transport.close) - elif self._keep_alive: - # do nothing, rely on kernel or upstream server - pass - else: - self.log_debug('Close client connection.') - self._request_handler = None - self.transport.close() - return - else: - # connection is closed - return + if not self._keep_alive or not self._keepalive_timeout: + self._closing = True + + except asyncio.CancelledError: + self.log_debug( + 'Request handler cancelled.') + return + except asyncio.TimeoutError: + self.log_debug( + 'Request handler timed out.') + return + except errors.ClientDisconnectedError: + self.log_debug( + 'Ignored premature client disconnection #1.') + return + except errors.HttpProcessingError as exc: + yield from self.handle_error(exc.code, message, + None, exc, exc.headers, + exc.message) + except Exception as exc: + yield from self.handle_error(500, message, None, exc) + finally: + self._request_handler = None + if self.transport is None: + self.log_debug( + 'Ignored premature client disconnection #2.') + else: + self.transport.close() def handle_error(self, status=500, message=None, payload=None, exc=None, headers=None, reason=None): @@ -366,7 +300,7 @@ def handle_error(self, status=500, message=None, information. It always closes current connection.""" now = self._loop.time() try: - if self._request_handler is None: + if self.transport is None: # client has been disconnected during writing. return () diff --git a/tests/test_server.py b/tests/test_server.py index 8479e1f883b..75dc6845b56 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -52,55 +52,51 @@ def test_handle_request(srv): assert content.startswith(b'HTTP/1.1 404 Not Found\r\n') -def test_closing(srv): +@pytest.mark.run_loop +def test_closing(srv, loop): + transport = mock.Mock() + transport.drain.side_effect = [] + srv.connection_made(transport) + assert transport is srv.transport + + yield from asyncio.sleep(0, loop=loop) + + srv.reader.feed_data( + b'GET / HTTP/1.1\r\n' + b'Host: example.com\r\n' + b'Content-Length: 0\r\n\r\n') + srv._keep_alive = True - keep_alive_handle = mock.Mock() - srv._keep_alive_handle = keep_alive_handle - timeout_handle = mock.Mock() - srv._timeout_handle = timeout_handle - transport = srv.transport = mock.Mock() - request_handler = srv._request_handler = mock.Mock() - srv.writer = mock.Mock() + request_handler = srv._request_handler srv.closing() + + yield from asyncio.sleep(0.01, loop=loop) assert transport.close.called assert srv.transport is None - assert srv._keep_alive_handle is not None - assert not keep_alive_handle.cancel.called - - assert srv._timeout_handle is not None - assert not timeout_handle.cancel.called - assert srv._request_handler is None - assert request_handler.cancel.called + assert request_handler.done() def test_closing_during_reading(srv): srv._keep_alive = True srv._keep_alive_on = True srv._reading_request = True - srv._slow_request_timeout_handle = timeout_handle = mock.Mock() transport = srv.transport = mock.Mock() + assert not srv._closing + srv.closing() assert not transport.close.called assert srv.transport is not None - - # cancel existing slow request handler - assert srv._slow_request_timeout_handle is not None - assert timeout_handle.cancel.called - assert timeout_handle is not srv._slow_request_timeout_handle + assert srv._closing def test_double_closing(srv): srv._keep_alive = True - keep_alive_handle = mock.Mock() - srv._keep_alive_handle = keep_alive_handle - timeout_handle = mock.Mock() - srv._timeout_handle = timeout_handle transport = srv.transport = mock.Mock() srv.writer = mock.Mock() @@ -113,24 +109,13 @@ def test_double_closing(srv): assert not transport.close.called assert srv.transport is None - assert srv._keep_alive_handle is not None - assert not keep_alive_handle.cancel.called - - assert srv._timeout_handle is not None - assert not timeout_handle.cancel.called - def test_connection_made(srv): assert srv._request_handler is None srv.connection_made(mock.Mock()) assert srv._request_handler is not None - assert srv._slow_request_timeout_handle is None - - -def test_connection_made_without_timeout(srv): - srv.connection_made(mock.Mock()) - assert srv._slow_request_timeout_handle is None + assert not srv._closing def test_connection_made_with_keepaplive(srv): @@ -173,9 +158,6 @@ def test_connection_lost(srv, loop): srv.connection_made(mock.Mock()) srv.data_received(b'123') - timeout_handle = srv._slow_request_timeout_handle = mock.Mock() - keep_alive_handle = srv._keep_alive_handle = mock.Mock() - handle = srv._request_handler srv.connection_lost(None) yield from asyncio.sleep(0, loop=loop) @@ -183,15 +165,8 @@ def test_connection_lost(srv, loop): assert srv._request_handler is None assert handle.cancelled() - assert srv._keep_alive_handle is None - assert keep_alive_handle.cancel.called - - assert srv._slow_request_timeout_handle is None - assert timeout_handle.cancel.called - srv.connection_lost(None) assert srv._request_handler is None - assert srv._keep_alive_handle is None def test_srv_keep_alive(srv): @@ -204,9 +179,9 @@ def test_srv_keep_alive(srv): assert not srv._keep_alive -def test_srv_slow_request(make_srv, loop): +def test_slow_request(make_srv, loop): transport = mock.Mock() - srv = make_srv(timeout=0.01) + srv = make_srv(slow_request_timeout=0.01, keepalive_timeout=0) srv.connection_made(transport) srv.reader.feed_data( @@ -215,8 +190,6 @@ def test_srv_slow_request(make_srv, loop): loop.run_until_complete(srv._request_handler) assert transport.close.called - srv.connection_lost(None) - assert srv._slow_request_timeout_handle is None def test_bad_method(srv, loop): @@ -243,6 +216,20 @@ def test_line_too_long(srv, loop): b'HTTP/1.1 400 Bad Request\r\n') +def test_invalid_content_length(srv, loop): + transport = mock.Mock() + srv.connection_made(transport) + + srv.reader.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n' + b'Content-Length: sdgg\r\n\r\n') + + loop.run_until_complete(srv._request_handler) + assert transport.write.mock_calls[0][1][0].startswith( + b'HTTP/1.1 400 Bad Request\r\n') + + def test_handle_error(srv): transport = mock.Mock() srv.connection_made(transport) @@ -317,9 +304,10 @@ def test_handle_error_debug(srv): assert b'Traceback (most recent call last):' in content -def test_handle_error_500(make_srv): +def test_handle_error_500(make_srv, loop): log = mock.Mock() transport = mock.Mock() + transport.drain.return_value = () srv = make_srv(logger=log) srv.connection_made(transport) @@ -405,7 +393,7 @@ def cancel(): srv._request_handler.cancel() loop.run_until_complete( - asyncio.wait([srv._request_handler, cancel()], loop=loop)) + asyncio.gather(srv._request_handler, cancel(), loop=loop)) assert log.debug.called @@ -432,32 +420,29 @@ def test_handle_cancelled(make_srv, loop): def test_handle_400(srv, loop): transport = mock.Mock() + transport.drain.side_effect = [] srv.connection_made(transport) - srv.handle_error = mock.Mock() - srv.keep_alive(True) srv.reader.feed_data(b'GET / HT/asd\r\n\r\n') loop.run_until_complete(srv._request_handler) - assert srv.handle_error.called - assert 400 == srv.handle_error.call_args[0][0] - assert transport.close.called + + assert b'400 Bad Request' in srv.transport.write.call_args[0][0] def test_handle_500(srv, loop): transport = mock.Mock() + transport.drain.side_effect = [] srv.connection_made(transport) handle = srv.handle_request = mock.Mock() handle.side_effect = ValueError - srv.handle_error = mock.Mock() srv.reader.feed_data( b'GET / HTTP/1.0\r\n' b'Host: example.com\r\n\r\n') loop.run_until_complete(srv._request_handler) - assert srv.handle_error.called - assert 500 == srv.handle_error.call_args[0][0] + assert b'500 Internal Server Error' in srv.transport.write.call_args[0][0] def test_handle_error_no_handle_task(srv): @@ -501,10 +486,8 @@ def test_keep_alive_close_existing(make_srv, loop): transport = mock.Mock() srv = make_srv(keep_alive=0) srv.connection_made(transport) - assert srv._keep_alive_handle is None srv._keep_alive_period = 15 - keep_alive_handle = srv._keep_alive_handle = mock.Mock() srv.handle_request = mock.Mock() srv.handle_request.return_value = helpers.create_future(loop) srv.handle_request.return_value.set_result(1) @@ -514,20 +497,13 @@ def test_keep_alive_close_existing(make_srv, loop): b'HOST: example.com\r\n\r\n') loop.run_until_complete(srv._request_handler) - assert keep_alive_handle.cancel.called - assert srv._keep_alive_handle is None assert transport.close.called -def test_cancel_not_connected_handler(srv): - srv.cancel_slow_request() - - def test_srv_process_request_without_timeout(make_srv, loop): transport = mock.Mock() srv = make_srv(timeout=0) srv.connection_made(transport) - assert srv._slow_request_timeout_handle is None srv.reader.feed_data( b'GET / HTTP/1.0\r\n' @@ -535,7 +511,6 @@ def test_srv_process_request_without_timeout(make_srv, loop): loop.run_until_complete(srv._request_handler) assert transport.close.called - assert srv._slow_request_timeout_handle is None def test_keep_alive_timeout_default(srv):