Skip to content

Commit

Permalink
Prevent all queries in failed tx block
Browse files Browse the repository at this point in the history
  • Loading branch information
akaariai committed Sep 22, 2013
1 parent 9851484 commit 2df5618
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 40 deletions.
8 changes: 8 additions & 0 deletions django/db/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def rollback(self):
"""
self.validate_thread_sharing()
self.validate_no_atomic_block()
self.needs_rollback = False
self._rollback()
self.set_clean()

Expand Down Expand Up @@ -238,6 +239,7 @@ def savepoint_rollback(self, sid):
return

self.validate_thread_sharing()
self.needs_rollback = False
self._savepoint_rollback(sid)

def savepoint_commit(self, sid):
Expand Down Expand Up @@ -361,6 +363,12 @@ def validate_no_atomic_block(self):
raise TransactionManagementError(
"This is forbidden when an 'atomic' block is active.")

def validate_no_broken_transaction(self):
if self.needs_rollback:
raise TransactionManagementError(
"An error occurred in the current transaction. You can't "
"execute queries until the end of the 'atomic' block.")

def abort(self):
"""
Roll back any ongoing transaction and clean the transaction state
Expand Down
47 changes: 32 additions & 15 deletions django/db/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,9 @@ def __init__(self, cursor, db):
self.cursor = cursor
self.db = db

SET_DIRTY_ATTRS = frozenset(['execute', 'executemany', 'callproc'])
WRAP_ERROR_ATTRS = frozenset([
'callproc', 'close', 'execute', 'executemany',
'fetchone', 'fetchmany', 'fetchall', 'nextset'])
WRAP_ERROR_ATTRS = frozenset(['fetchone', 'fetchmany', 'fetchall', 'nextset'])

def __getattr__(self, attr):
if attr in CursorWrapper.SET_DIRTY_ATTRS:
self.db.set_dirty()
cursor_attr = getattr(self.cursor, attr)
if attr in CursorWrapper.WRAP_ERROR_ATTRS:
return self.db.wrap_database_errors(cursor_attr)
Expand All @@ -36,18 +31,42 @@ def __getattr__(self, attr):
def __iter__(self):
return iter(self.cursor)

# The following methods cannot be implemented in __getattr__, because the
# code must run when the method is invoked, not just when it is accessed.

class CursorDebugWrapper(CursorWrapper):
def callproc(self, procname, params=None):
self.db.validate_no_broken_transaction()
self.db.set_dirty()
with self.db.wrap_database_errors:
if params is None:
return self.cursor.callproc(procname)
else:
return self.cursor.callproc(procname, params)

def execute(self, sql, params=None):
self.db.validate_no_broken_transaction()
self.db.set_dirty()
with self.db.wrap_database_errors:
if params is None:
return self.cursor.execute(sql)
else:
return self.cursor.execute(sql, params)

def executemany(self, sql, param_list):
self.db.validate_no_broken_transaction()
self.db.set_dirty()
with self.db.wrap_database_errors:
return self.cursor.executemany(sql, param_list)


class CursorDebugWrapper(CursorWrapper):

# XXX callproc isn't instrumented at this time.

def execute(self, sql, params=None):
start = time()
try:
with self.db.wrap_database_errors:
if params is None:
# params default might be backend specific
return self.cursor.execute(sql)
return self.cursor.execute(sql, params)
return super(CursorDebugWrapper, self).execute(sql, params)
finally:
stop = time()
duration = stop - start
Expand All @@ -61,11 +80,9 @@ def execute(self, sql, params=None):
)

def executemany(self, sql, param_list):
self.db.set_dirty()
start = time()
try:
with self.db.wrap_database_errors:
return self.cursor.executemany(sql, param_list)
return super(CursorDebugWrapper, self).executemany(sql, param_list)
finally:
stop = time()
duration = stop - start
Expand Down
9 changes: 5 additions & 4 deletions django/db/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

from functools import wraps

from django.db import connections, DatabaseError, DEFAULT_DB_ALIAS
from django.db import (
connections, DEFAULT_DB_ALIAS,
DatabaseError, ProgrammingError)
from django.utils.decorators import available_attrs


class TransactionManagementError(Exception):
class TransactionManagementError(ProgrammingError):
"""
This exception is thrown when something bad happens with transaction
management.
This exception is thrown when transaction management is used improperly.
"""
pass

Expand Down
79 changes: 58 additions & 21 deletions tests/transactions/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,6 @@ def test_prevent_rollback(self):
connection.cursor().execute(
"SELECT no_such_col FROM transactions_reporter")
transaction.savepoint_rollback(sid)
# atomic block should rollback, but prevent it, as we just did it.
self.assertTrue(transaction.get_rollback())
transaction.set_rollback(False)
self.assertQuerysetEqual(Reporter.objects.all(), ['<Reporter: Tintin>'])


Expand Down Expand Up @@ -257,27 +254,35 @@ def test_merged_outer_rollback(self):
transaction.atomic(savepoint=False):
Reporter.objects.create(first_name="Tournesol")
raise Exception("Oops, that's his last name")
# It wasn't possible to roll back
self.assertEqual(Reporter.objects.count(), 3)
# It wasn't possible to roll back
self.assertEqual(Reporter.objects.count(), 3)
# No queries as tx marked for rollback
with self.assertRaises(transaction.TransactionManagementError):
Reporter.objects.count()
# Still impossible to query
with self.assertRaises(transaction.TransactionManagementError):
Reporter.objects.count()
# The outer block must roll back
self.assertQuerysetEqual(Reporter.objects.all(), [])

def test_merged_inner_savepoint_rollback(self):
with transaction.atomic():
Reporter.objects.create(first_name="Tintin")
r1 = Reporter.objects.create(first_name="Tintin")
with transaction.atomic():
Reporter.objects.create(first_name="Archibald", last_name="Haddock")
r2 = Reporter.objects.create(first_name="Archibald", last_name="Haddock")
with six.assertRaisesRegex(self, Exception, "Oops"), \
transaction.atomic(savepoint=False):
Reporter.objects.create(first_name="Tournesol")
r3 = Reporter.objects.create(first_name="Tournesol")
raise Exception("Oops, that's his last name")
# It wasn't possible to roll back
# No queries - tx marked for rollback.
with self.assertRaises(transaction.TransactionManagementError):
Reporter.objects.count()
# But, it is possible to continue the tx if explicitly asked
transaction.set_rollback(False)
self.assertEqual(Reporter.objects.count(), 3)
# The first block with a savepoint must roll back
self.assertEqual(Reporter.objects.count(), 1)
self.assertQuerysetEqual(Reporter.objects.all(), ['<Reporter: Tintin>'])
# The block isn't rolled back
self.assertEqual(Reporter.objects.count(), 3)
self.assertQuerysetEqual(
Reporter.objects.order_by('first_name'),
[r2, r1, r3], lambda x: x)

def test_merged_outer_rollback_after_inner_failure_and_inner_success(self):
with transaction.atomic():
Expand All @@ -287,13 +292,9 @@ def test_merged_outer_rollback_after_inner_failure_and_inner_success(self):
transaction.atomic(savepoint=False):
Reporter.objects.create(first_name="Haddock")
raise Exception("Oops, that's his last name")
# It wasn't possible to roll back
self.assertEqual(Reporter.objects.count(), 2)
# Inner block with a savepoint succeeds
with transaction.atomic(savepoint=False):
Reporter.objects.create(first_name="Archibald", last_name="Haddock")
# It still wasn't possible to roll back
self.assertEqual(Reporter.objects.count(), 3)
# It isn't possible to use the connection until rollback
with self.assertRaises(transaction.TransactionManagementError):
Reporter.objects.count()
# The outer block must rollback
self.assertQuerysetEqual(Reporter.objects.all(), [])

Expand Down Expand Up @@ -326,6 +327,42 @@ def test_atomic_prevents_calling_transaction_management_methods(self):
with self.assertRaises(transaction.TransactionManagementError):
transaction.leave_transaction_management()

def test_atomic_prevents_running_more_queries_after_an_error(self):
cursor = connection.cursor()
with transaction.atomic():
with self.assertRaises(DatabaseError):
with transaction.atomic(savepoint=False):
cursor.execute("INSERT INTO transactions_no_such_table (id) VALUES (0)")
with self.assertRaises(transaction.TransactionManagementError):
cursor.execute("INSERT INTO transactions_no_such_table (id) VALUES (0)")

class SaveInAtomicTests(TransactionTestCase):
available_apps = ['transactions']

def test_failed_save_prevents_queries(self):
r1 = Reporter.objects.create(first_name='foo', last_name='bar')
with transaction.atomic():
r2 = Reporter(first_name='foo', last_name='bar2', id=r1.id)
try:
r2.save(force_insert=True)
except IntegrityError:
# queries aren't possible, as r2.save marked the transaction
# as needing rollback
with self.assertRaises(transaction.TransactionManagementError):
r2.save(force_update=True)
self.assertEqual(Reporter.objects.get(pk=r1.pk).last_name, 'bar')

def test_failed_save_set_rollback(self):
r1 = Reporter.objects.create(first_name='foo', last_name='bar')
with transaction.atomic():
r2 = Reporter(first_name='foo', last_name='bar2', id=r1.id)
try:
r2.save(force_insert=True)
except IntegrityError:
# Explicit clear of rollback state
transaction.set_rollback(False)
r2.save(force_update=True)
self.assertEqual(Reporter.objects.get(pk=r1.pk).last_name, 'bar2')

class AtomicMiscTests(TransactionTestCase):

Expand Down

0 comments on commit 2df5618

Please sign in to comment.