Skip to content

Commit

Permalink
Merge pull request #24 from imbolc/master
Browse files Browse the repository at this point in the history
save session with HTTPException
  • Loading branch information
asvetlov committed Jan 4, 2016
2 parents d67ab43 + 808e801 commit 32def6f
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 4 deletions.
19 changes: 15 additions & 4 deletions aiohttp_session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,25 @@ def get_session(request):
return session


def session_middleware(storage):
def session_middleware(storage, *, max_http_status=400):

assert isinstance(storage, AbstractStorage), storage
assert isinstance(max_http_status, int), max_http_status

@asyncio.coroutine
def factory(app, handler):

@asyncio.coroutine
def middleware(request):
request[STORAGE_KEY] = storage
response = yield from handler(request)
raise_response = False
try:
response = yield from handler(request)
except web.HTTPException as exc:
if exc.status_code > max_http_status:
raise exc
response = exc
raise_response = True
if not isinstance(response, web.StreamResponse):
raise RuntimeError("Expect response, not {!r}", type(response))
if not isinstance(response, web.Response):
Expand All @@ -137,17 +145,20 @@ def middleware(request):
if session is not None:
if session._changed:
yield from storage.save_session(request, response, session)
if raise_response:
raise response
return response

return middleware

return factory


def setup(app, storage):
def setup(app, storage, *, max_http_status=400):
"""Setup the library in aiohttp fashion."""

app.middlewares.append(session_middleware(storage))
app.middlewares.append(
session_middleware(storage, max_http_status=max_http_status))


class AbstractStorage(metaclass=abc.ABCMeta):
Expand Down
90 changes: 90 additions & 0 deletions tests/test_http_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import asyncio
import socket
import unittest

from aiohttp import web, request
from aiohttp_session import (session_middleware,
get_session, SimpleCookieStorage)


class TestHttpException(unittest.TestCase):

def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)
self.srv = None
self.handler = None

def tearDown(self):
self.loop.run_until_complete(self.handler.finish_connections())
self.srv.close()
self.loop.stop()
self.loop.run_forever()
self.loop.close()

def find_unused_port(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('127.0.0.1', 0))
port = s.getsockname()[1]
s.close()
return port

@asyncio.coroutine
def create_server(self, routes, max_http_status):
middleware = session_middleware(SimpleCookieStorage(),
max_http_status=max_http_status)
app = web.Application(middlewares=[middleware], loop=self.loop)
for method, path, handler in routes:
app.router.add_route(method, path, handler)

port = self.find_unused_port()
handler = app.make_handler()
srv = yield from self.loop.create_server(
handler, '127.0.0.1', port)
url = "http://127.0.0.1:{}".format(port)
self.handler = handler
self.srv = srv
return app, srv, url

def test_exceptions(self):

@asyncio.coroutine
def save(request):
session = yield from get_session(request)
session['message'] = 'works'
raise web.HTTPFound('/show')

@asyncio.coroutine
def show(request):
session = yield from get_session(request)
message = session.get('message')
return web.Response(text=str(message))

def get_routes():
return [
['GET', '/save', save],
['GET', '/show', show],
]

@asyncio.coroutine
def go_good_http_status():
_, _, url = yield from self.create_server(get_routes(),
max_http_status=400)
resp = yield from request('GET', url + '/save', loop=self.loop)
self.assertEqual(200, resp.status)
self.assertEqual(resp.url[-5:], '/show')
text = yield from resp.text()
assert text == 'works'

@asyncio.coroutine
def go_bad_http_status():
_, _, url = yield from self.create_server(get_routes(),
max_http_status=200)
resp = yield from request('GET', url + '/save', loop=self.loop)
self.assertEqual(200, resp.status)
self.assertEqual(resp.url[-5:], '/show')
text = yield from resp.text()
assert text == 'None'

self.loop.run_until_complete(go_good_http_status())
self.loop.run_until_complete(go_bad_http_status())

0 comments on commit 32def6f

Please sign in to comment.