diff --git a/pyramid_crud/views.py b/pyramid_crud/views.py index 543cd4e..291b8fd 100644 --- a/pyramid_crud/views.py +++ b/pyramid_crud/views.py @@ -804,7 +804,7 @@ def edit(self): else: is_new = True form = self.Form(self.request.POST, csrf_context=self.request) - form.session = self.request.dbsession + form.session = self.dbsession # Prepare return values retparams = {'form': form, 'is_new': is_new} diff --git a/tests/test_views.py b/tests/test_views.py index 36c3548..5e7d0c1 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -32,9 +32,11 @@ def csrf_token(session, pyramid_request): class TestCRUDView(object): - @pytest.fixture(autouse=True) - def _prepare_view(self, pyramid_request, DBSession, form_factory, + @pytest.fixture(autouse=True, params=[True, False]) + def _prepare_view(self, pyramid_request, DBSession, form_factory, request, model_factory): + + request_dbsession = request.param self.request = pyramid_request self.request.POST = MultiDict(self.request.POST) self.Model = model_factory([Column('test_text', String), @@ -44,13 +46,20 @@ def _prepare_view(self, pyramid_request, DBSession, form_factory, self.Model.id.info["label"] = "ID" self.Model.__str__ = lambda self: 'ModelStr' self.Form = form_factory(model=self.Model, base=forms.CSRFModelForm) - self.request.dbsession = DBSession + self.session = DBSession - class MyView(CRUDView): - Form = self.Form - url_path = '/test' - self.View = MyView + # Create view + view_attrs = { + 'Form': self.Form, + 'url_path': '/test', + } + if request_dbsession: + self.request.dbsession = DBSession + else: + view_attrs["dbsession"] = DBSession + + self.View = type('MyView', (CRUDView,), view_attrs) self.View.routes = { 'list': 'tests.test_views.MyView.list', 'delete': 'tests.test_views.MyView.delete',