Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

tween-based transaction management

  • Loading branch information...
commit aec122143452993115faef9f17fb9c1ef32fcf87 1 parent 37cdf29
@mcdonc mcdonc authored
Showing with 211 additions and 206 deletions.
  1. +73 −74 pyramid_tm/__init__.py
  2. +138 −132 pyramid_tm/tests.py
View
147 pyramid_tm/__init__.py
@@ -1,10 +1,13 @@
-import pyramid.events
+from pyramid.util import DottedNameResolver
+from pyramid.tweens import EXCVIEW
import transaction
+resolver = DottedNameResolver(None)
-def default_commit_veto(environ, status, headers):
- '''When used as a commit veto, the logic in this function will cause the
+def default_commit_veto(request, response):
+ """
+ When used as a commit veto, the logic in this function will cause the
transaction to be committed if:
- An ``X-Tm`` header with the value ``commit`` exists.
@@ -21,82 +24,78 @@ def default_commit_veto(environ, status, headers):
- The status code starts with ``4`` or ``5``.
Otherwise the transaction will be committed by default.
- '''
-
- abort_compat = False
- for header_name, header_value in headers:
- header_name = header_name.lower()
- if header_name == 'x-tm':
- if header_value.lower() == 'commit':
- return False
- return True
- # x-tm honored before x-tm-abort compatability
- elif header_name == 'x-tm-abort':
- abort_compat = True
- if abort_compat:
+ """
+ xtm = response.headers.get('x-tm')
+ if xtm is not None:
+ if xtm == 'commit':
+ return False
return True
+ status = response.status
for bad in ('4', '5'):
if status.startswith(bad):
return True
return False
-
-class TMSubscriber(object):
- '''A NewRequest subscriber that knows about commit_veto.
- '''
-
- transaction = staticmethod(transaction)
-
- def __init__(self, commit_veto):
- self.commit_veto = commit_veto
-
- def __call__(self, event):
- if 'repoze.tm.active' in event.request.environ:
- return
-
- self.begin()
- event.request.add_finished_callback(self.process)
- event.request.add_response_callback(self.process)
-
- def begin(self):
- self.transaction.begin()
-
- def commit(self):
- self.transaction.get().commit()
-
- def abort(self):
- self.transaction.get().abort()
-
- def process(self, request, response=None):
- if getattr(request, '_transaction_committed', False):
- return False
-
- request._transaction_committed = True
- transaction = self.transaction
-
- # ZODB 3.8 + has isDoomed
- if hasattr(transaction, 'isDoomed') and transaction.isDoomed():
- return self.abort()
-
- if request.exception is not None:
- return self.abort()
-
- if response is not None and self.commit_veto is not None:
- environ = request.environ
- status, headers = response.status, response.headerlist
-
- if self.commit_veto(environ, status, headers):
- return self.abort()
-
- return self.commit()
-
+def tm_tween_factory(handler, registry, transaction=transaction):
+ # transaction parameterized for testing purposes
+ commit_veto = registry.settings.get('pyramid_tm.commit_veto')
+ if commit_veto is not None:
+ commit_veto = resolver.resolve(commit_veto)
+
+ def tm_tween(request):
+ if 'repoze.tm.active' in request.environ:
+ return handler(request)
+
+ t = transaction.get()
+ t.begin()
+
+ try:
+ response = handler(request)
+ except:
+ t.abort()
+ raise
+
+ if transaction.isDoomed():
+ t.abort()
+ elif commit_veto is not None:
+ veto = commit_veto(request, response)
+ if veto:
+ t.abort()
+ else:
+ t.commit()
+ else:
+ t.commit()
+
+ return response
+
+ return tm_tween
def includeme(config):
- '''Setup the NewRequest subscriber for bootstrapping transactions.
- '''
-
- commit_veto = config.registry.settings.get('pyramid_tm.commit_veto',
- default_commit_veto)
- commit_veto = config.maybe_dotted(commit_veto)
- subscriber = TMSubscriber(commit_veto)
- config.add_subscriber(subscriber, pyramid.events.NewRequest)
+ """
+ Set up a 'tween' to do transaction management using the ``transaction``
+ package. The tween will be slotted between the main Pyramid app and the
+ Pyramid exception view handler.
+
+ For every request it handles, the tween will begin a transaction by
+ calling ``transaction.begin()``, and will then call the downstream
+ handler (usually the main Pyramid application request handler) to obtain
+ a response. When attempting to call the downstream handler:
+
+ - If an exception is raised by downstream handler while attempting to
+ obtain a response, the transaction will be rolled back
+ (``transaction.abort()`` will be called).
+
+ - If no exception is raised by the downstream handler, but the
+ transaction is doomed (``transaction.doom()`` has been called), the
+ transaction will be rolled back.
+
+ - If the deployment configuration specifies a ``pyramid_tm.commit_veto``
+ setting, and the transaction management tween receives a response from
+ the downstream handler, the commit veto hook will be called. If it
+ returns True, the transaction will be rolled back. If it returns
+ False, the transaction will be committed.
+
+ - If none of the above conditions are True, the transaction will be
+ committed (via ``transaction.commit()``).
+ """
+ config.add_tween(tm_tween_factory, above=EXCVIEW)
View
270 pyramid_tm/tests.py
@@ -1,146 +1,158 @@
import unittest
-
class TestDefaultCommitVeto(unittest.TestCase):
-
- def _callFUT(self, status, headers=()):
+ def _callFUT(self, response, request=None):
from pyramid_tm import default_commit_veto
- return default_commit_veto(None, status, headers)
+ return default_commit_veto(request, response)
- def test_it_true_5XX(self):
- self.failUnless(self._callFUT('500 Server Error'))
- self.failUnless(self._callFUT('503 Service Unavailable'))
+ def test_it_true_500(self):
+ response = DummyResponse('500 Server Error')
+ self.failUnless(self._callFUT(response))
- def test_it_true_4XX(self):
- self.failUnless(self._callFUT('400 Bad Request'))
- self.failUnless(self._callFUT('411 Length Required'))
+ def test_it_true_503(self):
+ response = DummyResponse('503 Service Unavailable')
+ self.failUnless(self._callFUT(response))
- def test_it_false_2XX(self):
- self.failIf(self._callFUT('200 OK'))
- self.failIf(self._callFUT('201 Created'))
+ def test_it_true_400(self):
+ response = DummyResponse('400 Bad Request')
+ self.failUnless(self._callFUT(response))
- def test_it_false_3XX(self):
- self.failIf(self._callFUT('301 Moved Permanently'))
- self.failIf(self._callFUT('302 Found'))
+ def test_it_true_411(self):
+ response = DummyResponse('411 Length Required')
+ self.failUnless(self._callFUT(response))
- def test_it_true_x_tm_abort_specific(self):
- self.failUnless(self._callFUT('200 OK', [('X-Tm-Abort', True)]))
+ def test_it_false_200(self):
+ response = DummyResponse('200 OK')
+ self.failIf(self._callFUT(response))
- def test_it_false_x_tm_commit(self):
- self.failIf(self._callFUT('200 OK', [('X-Tm', 'commit')]))
+ def test_it_false_201(self):
+ response = DummyResponse('201 Created')
+ self.failIf(self._callFUT(response))
- def test_it_true_x_tm_abort(self):
- self.failUnless(self._callFUT('200 OK', [('X-Tm', 'abort')]))
+ def test_it_false_301(self):
+ response = DummyResponse('301 Moved Permanently')
+ self.failIf(self._callFUT(response))
- def test_it_true_x_tm_anythingelse(self):
- self.failUnless(self._callFUT('200 OK', [('X-Tm', '')]))
+ def test_it_false_302(self):
+ response = DummyResponse('302 Found')
+ self.failIf(self._callFUT(response))
- def test_x_tm_generic_precedes_x_tm_abort_specific(self):
- self.failIf(self._callFUT('200 OK', [('X-Tm', 'commit'),
- ('X-Tm-Abort', True)]))
+ def test_it_false_x_tm_commit(self):
+ response = DummyResponse('200 OK', {'x-tm':'commit'})
+ self.failIf(self._callFUT(response))
+ def test_it_true_x_tm_abort(self):
+ response = DummyResponse('200 OK', {'x-tm':'abort'})
+ self.failUnless(self._callFUT(response))
-class TestTMSubscriber(unittest.TestCase):
+ def test_it_true_x_tm_anythingelse(self):
+ response = DummyResponse('200 OK', {'x-tm':''})
+ self.failUnless(self._callFUT(response))
+class Test_tm_tween_factory(unittest.TestCase):
def setUp(self):
- from pyramid_tm import TMSubscriber
- self.subscriber = TMSubscriber(None)
- self.subscriber.transaction = MockTransaction()
-
- def test_basics(self):
- subscriber = self.subscriber
- transaction = subscriber.transaction
-
- subscriber.begin()
- self.assertTrue(transaction.began)
-
- subscriber.commit()
- self.assertTrue(transaction.committed)
-
- subscriber.abort()
- self.assertTrue(transaction.aborted)
-
- def test_calling(self):
- subscriber = self.subscriber
-
- # no callbacks should be registered if it thinks repoze.tm is alive
- m = Mock(request=MockRequest())
- m.request.environ['repoze.tm.active'] = True
- subscriber(m)
- self.assertEqual(len(m.request.finished_callbacks), 0)
-
- # with repoze.tm not alive, we should get regular callbacks
- del m.request.environ['repoze.tm.active']
- subscriber(m)
- self.assertEqual(len(m.request.finished_callbacks), 1)
- self.assertEqual(len(m.request.response_callbacks), 1)
-
- def build_reqres(self):
- response = Mock(status='100', headerlist=[])
- request = Mock(exception=None, environ={})
- return request, response
-
- def test_process_commit(self):
- subscriber = self.subscriber
- subscriber.commit_veto = lambda x, y, z: None
- request, response = self.build_reqres()
- subscriber.process(request, response)
- self.assertTrue(hasattr(request, '_transaction_committed'))
- self.assertTrue(subscriber.transaction.committed)
-
- def test_process_bypass(self):
- subscriber = self.subscriber
- request, response = self.build_reqres()
- subscriber.process(request, response)
- self.assertTrue(hasattr(request, '_transaction_committed'))
- self.assertFalse(subscriber.process(request, response))
-
- def test_process_abort1(self):
- request, response = self.build_reqres()
- subscriber = self.subscriber
- subscriber.transaction.isDoomed = lambda: True
- subscriber.process(request, response)
- self.assertTrue(hasattr(request, '_transaction_committed'))
- self.assertTrue(self.subscriber.transaction.aborted)
-
- def test_process_abort2(self):
- request, response = self.build_reqres()
- subscriber = self.subscriber
- request.exception = Mock()
- subscriber.process(request, response)
- self.assertTrue(hasattr(request, '_transaction_committed'))
- self.assertTrue(self.subscriber.transaction.aborted)
-
- def test_process_abort3(self):
- request, response = self.build_reqres()
- subscriber = self.subscriber
- subscriber.commit_veto = lambda x, y, z: True
- request = Mock(exception=None, environ={})
- subscriber.process(request, response)
- self.assertTrue(hasattr(request, '_transaction_committed'))
- self.assertTrue(subscriber.transaction.aborted)
-
-
-class TestIncludeMe(unittest.TestCase):
-
+ self.txn = DummyTransaction()
+ self.request = DummyRequest()
+ self.response = DummyResponse()
+ self.registry = DummyRegistry()
+
+ def _callFUT(self, handler=None, registry=None, request=None, txn=None):
+ if handler is None:
+ def handler(request):
+ return self.response
+ if registry is None:
+ registry = self.registry
+ if request is None:
+ request = self.request
+ if txn is None:
+ txn = self.txn
+ from pyramid_tm import tm_tween_factory
+ factory = tm_tween_factory(handler, registry, txn)
+ return factory(request)
+
+ def test_repoze_tm_active(self):
+ request = DummyRequest()
+ request.environ['repoze.tm.active'] = True
+ result = self._callFUT(request=request)
+ self.assertEqual(result, self.response)
+ self.assertFalse(self.txn.began)
+
+ def test_handler_exception(self):
+ def handler(request):
+ raise NotImplementedError
+ self.assertRaises(NotImplementedError, self._callFUT, handler=handler)
+ self.assertTrue(self.txn.began)
+ self.assertTrue(self.txn.aborted)
+ self.assertFalse(self.txn.committed)
+
+ def test_handler_isdoomed(self):
+ txn = DummyTransaction(True)
+ self._callFUT(txn=txn)
+ self.assertTrue(txn.began)
+ self.assertTrue(txn.aborted)
+ self.assertFalse(txn.committed)
+
+ def test_commit_veto_true(self):
+ registry = DummyRegistry(
+ {'pyramid_tm.commit_veto':'pyramid_tm.tests.veto_true'})
+ result = self._callFUT(registry=registry)
+ self.assertEqual(result, self.response)
+ self.assertTrue(self.txn.began)
+ self.assertTrue(self.txn.aborted)
+ self.assertFalse(self.txn.committed)
+
+ def test_commit_veto_false(self):
+ registry = DummyRegistry(
+ {'pyramid_tm.commit_veto':'pyramid_tm.tests.veto_false'})
+ result = self._callFUT(registry=registry)
+ self.assertEqual(result, self.response)
+ self.assertTrue(self.txn.began)
+ self.assertFalse(self.txn.aborted)
+ self.assertTrue(self.txn.committed)
+
+ def test_commitonly(self):
+ result = self._callFUT()
+ self.assertEqual(result, self.response)
+ self.assertTrue(self.txn.began)
+ self.assertFalse(self.txn.aborted)
+ self.assertTrue(self.txn.committed)
+
+def veto_true(request, response):
+ return True
+
+def veto_false(request, response):
+ return False
+
+
+class Test_includeme(unittest.TestCase):
def test_it(self):
from pyramid_tm import includeme
-
- m = MockConfig()
- includeme(m)
- self.assertEqual(len(m.subscribers), 1)
+ config = DummyConfig()
+ includeme(config)
+ self.assertEqual(len(config.tweens), 1)
-class Mock(object):
+class Dummy(object):
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
+class DummyRegistry(object):
+ def __init__(self, settings=None):
+ if settings is None:
+ settings = {}
+ self.settings = settings
-class MockTransaction(object):
+class DummyTransaction(object):
began = False
committed = False
aborted = False
+ def __init__(self, doomed=False):
+ self.doomed = doomed
+
+ def isDoomed(self):
+ return self.doomed
+
def begin(self):
self.began = True
@@ -154,27 +166,21 @@ def abort(self):
self.aborted = True
-class MockRequest(object):
-
+class DummyRequest(object):
def __init__(self):
self.environ = {}
- self.finished_callbacks = []
- self.response_callbacks = []
-
- def add_finished_callback(self, cb):
- self.finished_callbacks.append(cb)
- def add_response_callback(self, cb):
- self.response_callbacks.append(cb)
+class DummyResponse(object):
+ def __init__(self, status='200 OK', headers=None):
+ self.status = status
+ if headers is None:
+ headers = {}
+ self.headers = headers
-
-class MockConfig(object):
+class DummyConfig(object):
def __init__(self):
- self.registry = Mock(settings={})
- self.subscribers = []
-
- def maybe_dotted(self, x):
- return x
+ self.registry = Dummy(settings={})
+ self.tweens = []
- def add_subscriber(self, x, y):
- self.subscribers.append((x, y))
+ def add_tween(self, x, above=None, below=None):
+ self.tweens.append((x, above, below))
Please sign in to comment.
Something went wrong with that request. Please try again.