Skip to content

Commit

Permalink
Refactor ServerHttpProtocol.start()
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov committed Aug 8, 2016
1 parent 5b4811a commit 98e9c22
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 129 deletions.
6 changes: 1 addition & 5 deletions aiohttp/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,7 @@ def __init__(self, line=''):
self.line = line


class ParserError(Exception):
"""Base parser error."""


class LineLimitExceededParserError(ParserError):
class LineLimitExceededParserError(HttpBadRequest):
"""Line is too long."""

def __init__(self, msg, limit):
Expand Down
128 changes: 51 additions & 77 deletions aiohttp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import traceback
import warnings
from html import escape as html_escape
from math import ceil

import aiohttp
from aiohttp import errors, hdrs, helpers, streams
Expand Down Expand Up @@ -86,7 +85,6 @@ class ServerHttpProtocol(aiohttp.StreamProtocol):
_request_handler = None
_reading_request = False
_keep_alive = False # keep transport open
_keep_alive_handle = None # keep alive timer handle

def __init__(self, *, loop=None,
keepalive_timeout=75, # NGINX default value is 75 secs
Expand Down Expand Up @@ -181,9 +179,6 @@ def connection_lost(self, exc):
if self._request_handler is not None:
self._request_handler.cancel()
self._request_handler = None
if self._keep_alive_handle is not None:
self._keep_alive_handle.cancel()
self._keep_alive_handle = None

def data_received(self, data):
super().data_received(data)
Expand All @@ -192,11 +187,6 @@ def data_received(self, data):
if not self._reading_request:
self._reading_request = True

# stop keep-alive timer
if self._keep_alive_handle is not None:
self._keep_alive_handle.cancel()
self._keep_alive_handle = None

def keep_alive(self, val):
"""Set keep-alive connection mode.
Expand Down Expand Up @@ -228,23 +218,25 @@ def start(self):
"""
reader = self.reader

while not self._closing:
message = None
self._keep_alive = False
self._request_count += 1
self._reading_request = False

payload = None
try:
# read HTTP request method
prefix = reader.set_parser(self._request_prefix)
yield from prefix.read()

# start reading request
self._reading_request = True

# start slow request timer
with Timeout(self._slow_request_timeout, loop=self._loop):
try:
while not self._closing:
message = None
self._keep_alive = False
self._request_count += 1
self._reading_request = False

payload = None
with Timeout(max(self._slow_request_timeout,
self._keepalive_timeout),
loop=self._loop):
# read HTTP request method
prefix = reader.set_parser(self._request_prefix)
yield from prefix.read()

# start reading request
self._reading_request = True

# start slow request timer
# read request headers
httpstream = reader.set_parser(self._request_parser)
message = yield from httpstream.read()
Expand All @@ -254,7 +246,7 @@ def start(self):
content_length = int(
message.headers.get(hdrs.CONTENT_LENGTH, 0))
except ValueError:
content_length = 0
raise errors.InvalidHeader(hdrs.CONTENT_LENGTH) from None

if (content_length > 0 or
message.method == 'CONNECT' or
Expand All @@ -270,59 +262,41 @@ def start(self):

yield from self.handle_request(message, payload)

except asyncio.CancelledError:
return
except errors.ClientDisconnectedError:
self.log_debug(
'Ignored premature client disconnection #1.')
return
except errors.HttpProcessingError as exc:
if self.transport is not None:
yield from self.handle_error(exc.code, message,
None, exc, exc.headers,
exc.message)
except errors.LineLimitExceededParserError as exc:
yield from self.handle_error(400, message, None, exc)
except Exception as exc:
yield from self.handle_error(500, message, None, exc)
finally:
if self.transport is None:
self.log_debug(
'Ignored premature client disconnection #2.')
return
elif self._closing:
self._request_handler = None
self.transport.close()
return

if payload and not payload.is_eof():
self.log_debug('Uncompleted request.')
self._request_handler = None
self.transport.close()
return
self._closing = True
else:
reader.unset_parser()

if self._request_handler:
if self._keep_alive and self._keepalive_timeout:
self.log_debug(
'Start keep-alive timer for %s sec.',
self._keepalive_timeout)
now = self._loop.time()
self._keep_alive_handle = self._loop.call_at(
ceil(now+self._keepalive_timeout),
self.transport.close)
elif self._keep_alive:
# do nothing, rely on kernel or upstream server
pass
else:
self.log_debug('Close client connection.')
self._request_handler = None
self.transport.close()
return
else:
# connection is closed
return
if not self._keep_alive or not self._keepalive_timeout:
self._closing = True

except asyncio.CancelledError:
self.log_debug(
'Request handler cancelled.')
return
except asyncio.TimeoutError:
self.log_debug(
'Request handler timed out.')
return
except errors.ClientDisconnectedError:
self.log_debug(
'Ignored premature client disconnection #1.')
return
except errors.HttpProcessingError as exc:
if self.transport is not None:
yield from self.handle_error(exc.code, message,
None, exc, exc.headers,
exc.message)
except Exception as exc:
if self.transport is not None:
yield from self.handle_error(500, message, None, exc)
finally:
self._request_handler = None
if self.transport is None:
self.log_debug(
'Ignored premature client disconnection #2.')
else:
self.transport.close()

def handle_error(self, status=500, message=None,
payload=None, exc=None, headers=None, reason=None):
Expand Down
75 changes: 28 additions & 47 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,41 +54,30 @@ def test_handle_request(srv):

@pytest.mark.run_loop
def test_closing(srv, loop):
keep_alive_handle = mock.Mock()
srv._keep_alive_handle = keep_alive_handle
timeout_handle = mock.Mock()
srv._timeout_handle = timeout_handle
transport = mock.Mock()
transport.drain.result_value = asyncio.Future(loop=loop)
transport.drain.set_result(None)
transport.drain.side_effect = []
srv.connection_made(transport)
assert transport is srv.transport

yield from asyncio.sleep(0, loop=loop)

srv.reader.feed_data(
b'GET / HTTP/1.1\r\n'
b'Host: example.com\r\n'
b'Content-Length: 0\r\n\r\n')

srv._keep_alive = True
yield from asyncio.sleep(0.01, loop=loop)

request_handler = srv._request_handler

srv.closing()
return

yield from asyncio.sleep(0.01, loop=loop)
assert transport.close.called
assert srv.transport is None

assert srv._keep_alive_handle is not None
assert not keep_alive_handle.cancel.called

assert srv._timeout_handle is not None
assert not timeout_handle.cancel.called

assert srv._request_handler is None
assert request_handler.cancel()
assert request_handler.done()


def test_closing_during_reading(srv):
Expand All @@ -108,10 +97,6 @@ def test_closing_during_reading(srv):
def test_double_closing(srv):
srv._keep_alive = True

keep_alive_handle = mock.Mock()
srv._keep_alive_handle = keep_alive_handle
timeout_handle = mock.Mock()
srv._timeout_handle = timeout_handle
transport = srv.transport = mock.Mock()
srv.writer = mock.Mock()

Expand All @@ -124,12 +109,6 @@ def test_double_closing(srv):
assert not transport.close.called
assert srv.transport is None

assert srv._keep_alive_handle is not None
assert not keep_alive_handle.cancel.called

assert srv._timeout_handle is not None
assert not timeout_handle.cancel.called


def test_connection_made(srv):
assert srv._request_handler is None
Expand Down Expand Up @@ -179,21 +158,15 @@ def test_connection_lost(srv, loop):
srv.connection_made(mock.Mock())
srv.data_received(b'123')

keep_alive_handle = srv._keep_alive_handle = mock.Mock()

handle = srv._request_handler
srv.connection_lost(None)
yield from asyncio.sleep(0, loop=loop)

assert srv._request_handler is None
assert handle.cancelled()

assert srv._keep_alive_handle is None
assert keep_alive_handle.cancel.called

srv.connection_lost(None)
assert srv._request_handler is None
assert srv._keep_alive_handle is None


def test_srv_keep_alive(srv):
Expand All @@ -206,9 +179,9 @@ def test_srv_keep_alive(srv):
assert not srv._keep_alive


def test_srv_slow_request(make_srv, loop):
def test_slow_request(make_srv, loop):
transport = mock.Mock()
srv = make_srv(timeout=0.01)
srv = make_srv(slow_request_timeout=0.01, keepalive_timeout=0)
srv.connection_made(transport)

srv.reader.feed_data(
Expand Down Expand Up @@ -243,6 +216,20 @@ def test_line_too_long(srv, loop):
b'HTTP/1.1 400 Bad Request\r\n')


def test_invalid_content_length(srv, loop):
transport = mock.Mock()
srv.connection_made(transport)

srv.reader.feed_data(
b'GET / HTTP/1.0\r\n'
b'Host: example.com\r\n'
b'Content-Length: sdgg\r\n\r\n')

loop.run_until_complete(srv._request_handler)
assert transport.write.mock_calls[0][1][0].startswith(
b'HTTP/1.1 400 Bad Request\r\n')


def test_handle_error(srv):
transport = mock.Mock()
srv.connection_made(transport)
Expand Down Expand Up @@ -317,9 +304,10 @@ def test_handle_error_debug(srv):
assert b'Traceback (most recent call last):' in content


def test_handle_error_500(make_srv):
def test_handle_error_500(make_srv, loop):
log = mock.Mock()
transport = mock.Mock()
transport.drain.return_value = ()

srv = make_srv(logger=log)
srv.connection_made(transport)
Expand Down Expand Up @@ -405,7 +393,7 @@ def cancel():
srv._request_handler.cancel()

loop.run_until_complete(
asyncio.wait([srv._request_handler, cancel()], loop=loop))
asyncio.gather(srv._request_handler, cancel(), loop=loop))
assert log.debug.called


Expand All @@ -432,32 +420,29 @@ def test_handle_cancelled(make_srv, loop):

def test_handle_400(srv, loop):
transport = mock.Mock()
transport.drain.side_effect = []
srv.connection_made(transport)
srv.handle_error = mock.Mock()
srv.keep_alive(True)
srv.reader.feed_data(b'GET / HT/asd\r\n\r\n')

loop.run_until_complete(srv._request_handler)
assert srv.handle_error.called
assert 400 == srv.handle_error.call_args[0][0]
assert transport.close.called

assert b'400 Bad Request' in srv.transport.write.call_args[0][0]


def test_handle_500(srv, loop):
transport = mock.Mock()
transport.drain.side_effect = []
srv.connection_made(transport)

handle = srv.handle_request = mock.Mock()
handle.side_effect = ValueError
srv.handle_error = mock.Mock()

srv.reader.feed_data(
b'GET / HTTP/1.0\r\n'
b'Host: example.com\r\n\r\n')
loop.run_until_complete(srv._request_handler)

assert srv.handle_error.called
assert 500 == srv.handle_error.call_args[0][0]
assert b'500 Internal Server Error' in srv.transport.write.call_args[0][0]


def test_handle_error_no_handle_task(srv):
Expand Down Expand Up @@ -501,10 +486,8 @@ def test_keep_alive_close_existing(make_srv, loop):
transport = mock.Mock()
srv = make_srv(keep_alive=0)
srv.connection_made(transport)
assert srv._keep_alive_handle is None

srv._keep_alive_period = 15
keep_alive_handle = srv._keep_alive_handle = mock.Mock()
srv.handle_request = mock.Mock()
srv.handle_request.return_value = helpers.create_future(loop)
srv.handle_request.return_value.set_result(1)
Expand All @@ -514,8 +497,6 @@ def test_keep_alive_close_existing(make_srv, loop):
b'HOST: example.com\r\n\r\n')

loop.run_until_complete(srv._request_handler)
assert keep_alive_handle.cancel.called
assert srv._keep_alive_handle is None
assert transport.close.called


Expand Down

0 comments on commit 98e9c22

Please sign in to comment.