From 5ce720f004ab6caa75dab40ea58bffa0848d6282 Mon Sep 17 00:00:00 2001 From: Oleg Nykolyn Date: Tue, 9 Oct 2018 19:19:11 +0300 Subject: [PATCH] Fix nested object creation. --- django_object_manager/object_manager.py | 46 +++++++++++++++++-------- tests/functional/test_object_manager.py | 25 ++++++++++---- 2 files changed, 49 insertions(+), 22 deletions(-) diff --git a/django_object_manager/object_manager.py b/django_object_manager/object_manager.py index 3f2bee5..9cecabe 100644 --- a/django_object_manager/object_manager.py +++ b/django_object_manager/object_manager.py @@ -11,6 +11,21 @@ from pytz import utc +class ContextCallable: + """Callable helper used for context passing to ObjectManager.""" + + def __init__(self, object_manager, context): + """Initialize context callable.""" + self.context = context + self.object_manager = object_manager + + def __call__(self, *args, **kwargs): + """Create object(s).""" + return self.object_manager.call_with_context(self.context, + *args, + **kwargs) + + class ObjectManager: """Base class for test objects creation.""" @@ -20,7 +35,6 @@ class ObjectManager: def __init__(self): """Initialize object creator.""" - self.context = None self._instances = defaultdict(dict) self._converters = {ForeignKey: self._create_foreing, ManyToManyRel: self._create_m2m, @@ -41,38 +55,40 @@ def __getattribute__(self, item): return super().__getattribute__(item) name = item.split('_', 1)[1] if name.endswith('s') and name[:-1] in self._registered_models: - self.context = self.Context(name[:-1], True) - return self + return self.with_context(self.Context(name=name[:-1], many=True)) elif name.endswith('ies') and \ f'{name[:-3]}y' in self._registered_models: - self.context = self.Context(f'{name[:-3]}y', True) - return self + return self.with_context(self.Context(name=f'{name[:-3]}y', many=True)) elif name in self._registered_models: - self.context = self.Context(name, False) - return self + return self.with_context(self.Context(name=name, many=False)) else: raise RuntimeError(f'Unknown item: {item}, choices are: ' f'{self._registered_models.keys()}') + def with_context(self, context): + return ContextCallable(self, context) + def __call__(self, *args, **kwargs): """Create object(s).""" - assert self.context is not None, 'x() instead of x.create_model()' - if self.context.many and (args or kwargs): + raise RuntimeError( + 'object_manager() instead of object_manager.get_`model_name`()') + + def call_with_context(self, context, *args, **kwargs): + if context.many and (args or kwargs): raise ValueError('Multiple item creation needs no args') - name = self.context.name - if self.context.many: - return {key: self._get_or_create(name, key, **data) - for key, data in self._data[self.context.name].items()} + if context.many: + return {key: self._get_or_create(context.name, key, **data) + for key, data in self._data[context.name].items()} else: try: (key,) = args - item_data = self._data[self.context.name][key].copy() + item_data = self._data[context.name][key].copy() item_data.update(kwargs) params = item_data except ValueError: key = None params = kwargs - return self._get_or_create(name, key, + return self._get_or_create(context.name, key, _custom=bool(kwargs), **params) diff --git a/tests/functional/test_object_manager.py b/tests/functional/test_object_manager.py index 8809b71..297834f 100644 --- a/tests/functional/test_object_manager.py +++ b/tests/functional/test_object_manager.py @@ -81,12 +81,23 @@ def test_many_to_many_reversed_predefined(self): def test_many_to_many_forward_predefined(self): """Ensure that object with M2M relation can be created.""" - film = self.object_manager.get_film(name='Memento', - year=2000, - uploaded_by='bob', - categories=['crime', 'drama']) + self.object_manager.get_film(name='Memento', + year=2000, + uploaded_by='bob', + categories=['crime', 'drama']) + self.assertEqual(models.FilmCategory.objects.count(), 2) + + def test_inlined_object_creation(self): + """Ensure that nested object creation works.""" + self.object_manager.get_film( + name='Memento', + year=2000, + uploaded_by=self.object_manager.get_user('bob'), + categories=['crime', 'drama']) self.assertEqual(models.FilmCategory.objects.count(), 2) - # TODO - # def test_customized(self): - # """Ensure that customized objects are not cached.""" + def test_customized(self): + """Ensure that customized objects are not cached.""" + user_1 = self.object_manager.get_user('bob', email='bob@bob.com') + user_2 = self.object_manager.get_user('bob', email='bob@bob.com') + self.assertTrue(user_1 is not user_2)