Skip to content

Commit

Permalink
Handle Aiohttp Exceptions as valid responses (#59)
Browse files Browse the repository at this point in the history
* Handle Aiohttp Exceptions as valid responses

This PR fixes a bug that till now handled all exceptions as just 500,
while valid Aiohttp Exceptions have to be handled as valid responses,
they are used to notify none 2XX responses, perhaps a 401 is notified
with a HTTPUnauthorized exception.

* Added CHANGELOG

* Removed invalid files

* Fixed typos

* Keep raising the Aiohttp exceptions to the outer middlewares
  • Loading branch information
pfreixes authored and haotianw465 committed May 16, 2018
1 parent 3c5d0eb commit dc55f7c
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 13 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
CHANGELOG
=========

unreleased
==========
* bugfix: Handle Aiohttp Exceptions as valid responses `PR59 <https://github.com/aws/aws-xray-sdk-python/pull/59>`_.

1.1
===
* feature: Added Sqlalchemy parameterized query capture. `PR34 <https://github.com/aws/aws-xray-sdk-python/pull/34>`_
Expand Down
27 changes: 15 additions & 12 deletions aws_xray_sdk/ext/aiohttp/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import traceback
from aiohttp import web
from aiohttp.web_exceptions import HTTPException

from aws_xray_sdk.core import xray_recorder
from aws_xray_sdk.core.models import http
Expand Down Expand Up @@ -55,25 +56,27 @@ async def middleware(request, handler):
try:
# Call next middleware or request handler
response = await handler(request)
except HTTPException as exc:
# Non 2XX responses are raised as HTTPExceptions
response = exc
raise
except Exception as err:
# Store exception information including the stacktrace to the segment
segment = xray_recorder.current_segment()
response = None
segment.put_http_meta(http.STATUS, 500)
stack = traceback.extract_stack(limit=xray_recorder.max_trace_back)
segment.add_exception(err, stack)
xray_recorder.end_segment()
raise
finally:
if response is not None:
segment.put_http_meta(http.STATUS, response.status)
if 'Content-Length' in response.headers:
length = int(response.headers['Content-Length'])
segment.put_http_meta(http.CONTENT_LENGTH, length)

# Store response metadata into the current segment
segment.put_http_meta(http.STATUS, response.status)
header_str = prepare_response_header(xray_header, segment)
response.headers[http.XRAY_HEADER] = header_str

if 'Content-Length' in response.headers:
length = int(response.headers['Content-Length'])
segment.put_http_meta(http.CONTENT_LENGTH, length)

header_str = prepare_response_header(xray_header, segment)
response.headers[http.XRAY_HEADER] = header_str
xray_recorder.end_segment()

# Close segment so it can be dispatched off to the daemon
xray_recorder.end_segment()
return response
75 changes: 74 additions & 1 deletion tests/ext/aiohttp/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from unittest.mock import patch

from aiohttp import web
from aiohttp.web_exceptions import HTTPUnauthorized
import pytest

from aws_xray_sdk.core.emitters.udp_emitter import UDPEmitter
Expand Down Expand Up @@ -48,14 +49,25 @@ async def handle_ok(self, request: web.Request) -> web.Response:
"""
Handle / request
"""
return web.Response(text="ok")
if "content_length" in request.query:
headers = {'Content-Length': request.query['content_length']}
else:
headers = None

return web.Response(text="ok", headers=headers)

async def handle_error(self, request: web.Request) -> web.Response:
"""
Handle /error which returns a 404
"""
return web.Response(text="not found", status=404)

async def handle_unauthorized(self, request: web.Request) -> web.Response:
"""
Handle /unauthorized which returns a 401
"""
raise HTTPUnauthorized()

async def handle_exception(self, request: web.Request) -> web.Response:
"""
Handle /exception which raises a KeyError
Expand All @@ -74,6 +86,7 @@ def get_app(self) -> web.Application:
app.router.add_get('/', self.handle_ok)
app.router.add_get('/error', self.handle_error)
app.router.add_get('/exception', self.handle_exception)
app.router.add_get('/unauthorized', self.handle_unauthorized)
app.router.add_get('/delay', self.handle_delay)

return app
Expand Down Expand Up @@ -124,6 +137,41 @@ async def test_ok(test_client, loop, recorder):
assert response['status'] == 200


async def test_ok_x_forwarded_for(test_client, loop, recorder):
"""
Test a normal response with x_forwarded_for headers
:param test_client: AioHttp test client fixture
:param loop: Eventloop fixture
:param recorder: X-Ray recorder fixture
"""
client = await test_client(ServerTest.app(loop=loop))

resp = await client.get('/', headers={'X-Forwarded-For': 'foo'})
assert resp.status == 200

segment = recorder.emitter.pop()
assert segment.http['request']['client_ip'] == 'foo'
assert segment.http['request']['x_forwarded_for']


async def test_ok_content_length(test_client, loop, recorder):
"""
Test a normal response with content length as response header
:param test_client: AioHttp test client fixture
:param loop: Eventloop fixture
:param recorder: X-Ray recorder fixture
"""
client = await test_client(ServerTest.app(loop=loop))

resp = await client.get('/?content_length=100')
assert resp.status == 200

segment = recorder.emitter.pop()
assert segment.http['response']['content_length'] == 100


async def test_error(test_client, loop, recorder):
"""
Test a 4XX response
Expand Down Expand Up @@ -176,6 +224,31 @@ async def test_exception(test_client, loop, recorder):
assert exception.type == 'KeyError'


async def test_unhauthorized(test_client, loop, recorder):
"""
Test a 401 response
:param test_client: AioHttp test client fixture
:param loop: Eventloop fixture
:param recorder: X-Ray recorder fixture
"""
client = await test_client(ServerTest.app(loop=loop))

resp = await client.get('/unauthorized')
assert resp.status == 401

segment = recorder.emitter.pop()
assert not segment.in_progress
assert segment.error

request = segment.http['request']
response = segment.http['response']
assert request['method'] == 'GET'
assert request['url'] == 'http://127.0.0.1:{port}/unauthorized'.format(port=client.port)
assert request['client_ip'] == '127.0.0.1'
assert response['status'] == 401


async def test_response_trace_header(test_client, loop, recorder):
client = await test_client(ServerTest.app(loop=loop))
resp = await client.get('/')
Expand Down

0 comments on commit dc55f7c

Please sign in to comment.