Skip to content

Commit

Permalink
Factorize out SQLAlchemy transaction management
Browse files Browse the repository at this point in the history
  • Loading branch information
francoisfreitag committed Jun 1, 2022
1 parent 7541994 commit 0999e78
Showing 1 changed file with 14 additions and 18 deletions.
32 changes: 14 additions & 18 deletions tests/test_alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,15 @@ class Meta:
text = factory.Sequence(lambda n: "text%s" % n)


class SQLAlchemyPkSequenceTestCase(unittest.TestCase):
class TransactionTestCase(unittest.TestCase):
def tearDown(self):
models.session.rollback()


class SQLAlchemyPkSequenceTestCase(TransactionTestCase):
def setUp(self):
super().setUp()
StandardFactory.reset_sequence(1)
NonIntegerPkFactory._meta.sqlalchemy_session.rollback()

def test_pk_first(self):
std = StandardFactory.build()
Expand Down Expand Up @@ -109,10 +113,7 @@ def test_pk_force_value(self):
self.assertEqual(0, std2.id)


class SQLAlchemyGetOrCreateTests(unittest.TestCase):
def setUp(self):
models.session.rollback()

class SQLAlchemyGetOrCreateTests(TransactionTestCase):
def test_simple_call(self):
obj1 = WithGetOrCreateFieldFactory(foo='foo1')
obj2 = WithGetOrCreateFieldFactory(foo='foo1')
Expand Down Expand Up @@ -144,10 +145,7 @@ def test_multicall(self):
)


class MultipleGetOrCreateFieldsTest(unittest.TestCase):
def setUp(self):
models.session.rollback()

class MultipleGetOrCreateFieldsTest(TransactionTestCase):
def test_one_defined(self):
obj1 = WithMultipleGetOrCreateFieldsFactory()
obj2 = WithMultipleGetOrCreateFieldsFactory(slug=obj1.slug)
Expand Down Expand Up @@ -216,11 +214,10 @@ class Meta:
model = models.StandardModel


class SQLAlchemyNonIntegerPkTestCase(unittest.TestCase):
def setUp(self):
super().setUp()
class SQLAlchemyNonIntegerPkTestCase(TransactionTestCase):
def tearDown(self):
super().tearDown()
NonIntegerPkFactory.reset_sequence()
NonIntegerPkFactory._meta.sqlalchemy_session.rollback()

def test_first(self):
nonint = NonIntegerPkFactory.build()
Expand Down Expand Up @@ -250,7 +247,7 @@ def test_force_pk(self):
self.assertEqual('foo0', nonint2.id)


class SQLAlchemyNoSessionTestCase(unittest.TestCase):
class SQLAlchemyNoSessionTestCase(TransactionTestCase):

def test_create_raises_exception_when_no_session_was_set(self):
with self.assertRaises(RuntimeError):
Expand All @@ -264,8 +261,7 @@ def test_build_does_not_raises_exception_when_no_session_was_set(self):
self.assertEqual(inst1.id, 1)


class SQLAlchemySessionFactoryTestCase(unittest.TestCase):

class SQLAlchemySessionFactoryTestCase(TransactionTestCase):
def test_create_get_session_from_sqlalchemy_session_factory(self):
class SessionGetterFactory(SQLAlchemyModelFactory):
class Meta:
Expand All @@ -292,7 +288,7 @@ class Meta:
id = factory.Sequence(lambda n: n)


class NameConflictTests(unittest.TestCase):
class NameConflictTests(TransactionTestCase):
"""Regression test for `TypeError: _save() got multiple values for argument 'session'`
See #775.
Expand Down

0 comments on commit 0999e78

Please sign in to comment.