Skip to content

Commit

Permalink
Do not transact shard dbs in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
millerdev committed Jul 2, 2020
1 parent 54e74fa commit eadd947
Showing 1 changed file with 42 additions and 1 deletion.
43 changes: 42 additions & 1 deletion corehq/form_processor/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from couchdbkit import ResourceNotFound
from django.conf import settings
from django.db import transaction
from django.test.utils import override_settings
from nose.plugins.attrib import attr
from nose.tools import nottest
Expand Down Expand Up @@ -182,7 +183,7 @@ def partitioned(cls):
Marks a test to be run with the partitioned database settings in
addition to the non-partitioned database settings.
"""
return attr(sql_backend=True)(cls)
return patch_shard_db_transactions(attr(sql_backend=True)(cls))


def only_run_with_non_partitioned_database(cls):
Expand All @@ -209,6 +210,46 @@ def use_sql_backend(cls):
return partitioned(override_settings(TESTS_SHOULD_USE_SQL_BACKEND=True)(cls))


def patch_shard_db_transactions(cls):
"""Patch shard db transaction management on test class
Do not use a transaction per test on shard databases because proxy
queries cannot see changes in uncommitted transactions in shard dbs.
This means that changes to shard dbs will not be rolled back at the
end of each test; test cleanup must be done manually.
:param cls: A subclass of `django.test.TestCase`
"""
assert hasattr(cls, "_enter_atomics") and hasattr(cls, "_rollback_atomics"), cls
shard_cfg = getattr(settings, "PARTITION_DATABASE_CONFIG", None)
if not shard_cfg:
return cls
shard_dbs = {"proxy"} | set(shard_cfg["shards"])

@classmethod
def _enter_atomics(cls):
atomics = {}
for db_name in cls._databases_names():
if db_name in shard_dbs:
continue
atomics[db_name] = transaction.atomic(using=db_name)
atomics[db_name].__enter__()
return atomics
cls._enter_atomics = _enter_atomics

@classmethod
def _rollback_atomics(cls, atomics):
"""Rollback atomic blocks opened by the previous method."""
for db_name in reversed(cls._databases_names()):
if db_name in shard_dbs:
continue
transaction.set_rollback(True, using=db_name)
atomics[db_name].__exit__(None, None, None)
cls._rollback_atomics = _rollback_atomics

return cls


@nottest
def create_form_for_test(
domain, case_id=None, attachments=None, save=True, state=XFormInstanceSQL.NORMAL,
Expand Down

0 comments on commit eadd947

Please sign in to comment.