Skip to content

Commit

Permalink
Properly handle Sequence & inheritance (Closes #93).
Browse files Browse the repository at this point in the history
There was also a nasty bug: with class FactoryB(FactoryA), FactoryB's sequence
counter started at the value of FactoryA's counter when FactoryB was first called.
  • Loading branch information
rbarrois committed Sep 16, 2013
1 parent a8742c9 commit 7fe9dca
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 16 deletions.
7 changes: 5 additions & 2 deletions docs/changelog.rst
@@ -1,15 +1,18 @@
ChangeLog
=========

.. _v2.1.3:
.. _v2.2.0:

2.1.3 (current)
2.2.0 (current)
---------------

*Bugfix:*

- Removed duplicated :class:`~factory.alchemy.SQLAlchemyModelFactory` lurking in :mod:`factory`
(:issue:`83`)
- Properly handle sequences within object inheritance chains.
If FactoryA inherits from FactoryB, and their associated classes share the same link,
sequence counters will be shared (:issue:`93`)

*New:*

Expand Down
74 changes: 61 additions & 13 deletions factory/base.py
Expand Up @@ -186,7 +186,8 @@ def __new__(mcs, class_name, bases, attrs):
else:
# If inheriting the factory from a parent, keep a link to it.
# This allows to use the sequence counters from the parents.
if associated_class == inherited_associated_class:
if (inherited_associated_class is not None
and issubclass(associated_class, inherited_associated_class)):
attrs['_base_factory'] = base

# The CLASS_ATTRIBUTE_ASSOCIATED_CLASS must *not* be taken into
Expand All @@ -212,6 +213,32 @@ def __str__(cls):

# Factory base classes


class _Counter(object):
"""Simple, naive counter.
Attributes:
for_class (obj): the class this counter related to
seq (int): the next value
"""

def __init__(self, seq, for_class):
self.seq = seq
self.for_class = for_class

def next(self):
value = self.seq
self.seq += 1
return value

def reset(self, next_value=0):
self.seq = next_value

def __repr__(self):
return '<_Counter for %s.%s, next=%d>' % (
self.for_class.__module__, self.for_class.__name__, self.seq)


class BaseFactory(object):
"""Factory base support for sequences, attributes and stubs."""

Expand All @@ -224,10 +251,10 @@ def __new__(cls, *args, **kwargs):
raise FactoryError('You cannot instantiate BaseFactory')

# ID to use for the next 'declarations.Sequence' attribute.
_next_sequence = None
_counter = None

# Base factory, if this class was inherited from another factory. This is
# used for sharing the _next_sequence counter among factories for the same
# used for sharing the sequence _counter among factories for the same
# class.
_base_factory = None

Expand All @@ -245,18 +272,28 @@ def __new__(cls, *args, **kwargs):

@classmethod
def reset_sequence(cls, value=None, force=False):
"""Reset the sequence counter."""
"""Reset the sequence counter.
Args:
value (int or None): the new 'next' sequence value; if None,
recompute the next value from _setup_next_sequence().
force (bool): whether to force-reset parent sequence counters
in a factory inheritance chain.
"""
if cls._base_factory:
if force:
cls._base_factory.reset_sequence(value=value)
else:
raise ValueError(
"Cannot reset the sequence of a factory subclass. "
"Please call reset_sequence() on the root factory, "
"or call reset_sequence(forward=True)."
"or call reset_sequence(force=True)."
)
else:
cls._next_sequence = value
cls._setup_counter()
if value is None:
value = cls._setup_next_sequence()
cls._counter.reset(value)

@classmethod
def _setup_next_sequence(cls):
Expand All @@ -267,6 +304,19 @@ def _setup_next_sequence(cls):
"""
return 0

@classmethod
def _setup_counter(cls):
"""Ensures cls._counter is set for this class.
Due to the way inheritance works in Python, we need to ensure that the
``_counter`` attribute has been initialized for *this* Factory subclass,
not one of its parents.
"""
if cls._counter is None or cls._counter.for_class != cls:
first_seq = cls._setup_next_sequence()
cls._counter = _Counter(for_class=cls, seq=first_seq)
logger.debug("%r: Setting up next sequence (%d)", cls, first_seq)

@classmethod
def _generate_next_sequence(cls):
"""Retrieve a new sequence ID.
Expand All @@ -279,16 +329,14 @@ def _generate_next_sequence(cls):

# Rely upon our parents
if cls._base_factory:
logger.debug("%r: reusing sequence from %r", cls, cls._base_factory)
return cls._base_factory._generate_next_sequence()

# Make sure _next_sequence is initialized
if cls._next_sequence is None:
cls._next_sequence = cls._setup_next_sequence()
# Make sure _counter is initialized
cls._setup_counter()

# Pick current value, then increase class counter for the next call.
next_sequence = cls._next_sequence
cls._next_sequence += 1
return next_sequence
return cls._counter.next()

@classmethod
def attributes(cls, create=False, extra=None):
Expand Down Expand Up @@ -577,7 +625,7 @@ def simple_generate_batch(cls, create, size, **kwargs):
This class has the ability to support multiple ORMs by using custom creation
functions.
""",
})
})


# Backwards compatibility
Expand Down
2 changes: 1 addition & 1 deletion tests/test_alchemy.py
Expand Up @@ -65,7 +65,7 @@ class SQLAlchemyPkSequenceTestCase(unittest.TestCase):

def setUp(self):
super(SQLAlchemyPkSequenceTestCase, self).setUp()
StandardFactory.reset_sequence()
StandardFactory.reset_sequence(1)
NonIntegerPkFactory.FACTORY_SESSION.rollback()

def test_pk_first(self):
Expand Down
72 changes: 72 additions & 0 deletions tests/test_using.py
Expand Up @@ -730,6 +730,78 @@ class TestObjectFactory2(TestObjectFactory):
test_object_alt = TestObjectFactory.build()
self.assertEqual(None, test_object_alt.three)

def test_inheritance_and_sequences(self):
"""Sequence counters should be kept within an inheritance chain."""
class TestObjectFactory(factory.Factory):
FACTORY_FOR = TestObject

one = factory.Sequence(lambda n: n)

class TestObjectFactory2(TestObjectFactory):
FACTORY_FOR = TestObject

to1a = TestObjectFactory()
self.assertEqual(0, to1a.one)
to2a = TestObjectFactory2()
self.assertEqual(1, to2a.one)
to1b = TestObjectFactory()
self.assertEqual(2, to1b.one)
to2b = TestObjectFactory2()
self.assertEqual(3, to2b.one)

def test_inheritance_sequence_inheriting_objects(self):
"""Sequence counters are kept with inheritance, incl. misc objects."""
class TestObject2(TestObject):
pass

class TestObjectFactory(factory.Factory):
FACTORY_FOR = TestObject

one = factory.Sequence(lambda n: n)

class TestObjectFactory2(TestObjectFactory):
FACTORY_FOR = TestObject2

to1a = TestObjectFactory()
self.assertEqual(0, to1a.one)
to2a = TestObjectFactory2()
self.assertEqual(1, to2a.one)
to1b = TestObjectFactory()
self.assertEqual(2, to1b.one)
to2b = TestObjectFactory2()
self.assertEqual(3, to2b.one)

def test_inheritance_sequence_unrelated_objects(self):
"""Sequence counters are kept with inheritance, unrelated objects.
See issue https://github.com/rbarrois/factory_boy/issues/93
Problem: sequence counter is somewhat shared between factories
until the "slave" factory has been called.
"""

class TestObject2(object):
def __init__(self, one):
self.one = one

class TestObjectFactory(factory.Factory):
FACTORY_FOR = TestObject

one = factory.Sequence(lambda n: n)

class TestObjectFactory2(TestObjectFactory):
FACTORY_FOR = TestObject2

to1a = TestObjectFactory()
self.assertEqual(0, to1a.one)
to2a = TestObjectFactory2()
self.assertEqual(0, to2a.one)
to1b = TestObjectFactory()
self.assertEqual(1, to1b.one)
to2b = TestObjectFactory2()
self.assertEqual(1, to2b.one)


def test_inheritance_with_inherited_class(self):
class TestObjectFactory(factory.Factory):
FACTORY_FOR = TestObject
Expand Down

0 comments on commit 7fe9dca

Please sign in to comment.