From 59285d3e9c0ed585cad07f043506d6d800588613 Mon Sep 17 00:00:00 2001 From: Federico Martinez Date: Mon, 31 Jul 2023 09:54:57 -0300 Subject: [PATCH] adding a commit=True kwarg to save and other methods of active record --- sqlalchemy_mixins/activerecord.py | 38 +++++++------ sqlalchemy_mixins/tests/test_activerecord.py | 60 ++++++++++++++++++++ 2 files changed, 82 insertions(+), 16 deletions(-) diff --git a/sqlalchemy_mixins/activerecord.py b/sqlalchemy_mixins/activerecord.py index 7b082f1..122a60d 100644 --- a/sqlalchemy_mixins/activerecord.py +++ b/sqlalchemy_mixins/activerecord.py @@ -23,50 +23,56 @@ def fill(self, **kwargs): return self - def save(self): + def save(self, commit=True): """Saves the updated model to the current entity db. + :param commit: where to commit the transaction """ - try: - self.session.add(self) - self.session.commit() - return self - except: - self.session.rollback() - raise + self.session.add(self) + if commit: + self._commit_or_fail() + return self @classmethod - def create(cls, **kwargs): + def create(cls, commit=True, **kwargs): """Create and persist a new record for the model + :param commit: where to commit the transaction :param kwargs: attributes for the record :return: the new model instance """ - return cls().fill(**kwargs).save() + return cls().fill(**kwargs).save(commit=commit) - def update(self, **kwargs): + def update(self, commit=True, **kwargs): """Same as :meth:`fill` method but persists changes to database. + :param commit: where to commit the transaction """ - return self.fill(**kwargs).save() + return self.fill(**kwargs).save(commit=commit) - def delete(self): + def delete(self, commit=True): """Removes the model from the current entity session and mark for deletion. + :param commit: where to commit the transaction """ + self.session.delete(self) + if commit: + self._commit_or_fail() + + def _commit_or_fail(self): try: - self.session.delete(self) self.session.commit() except: self.session.rollback() raise @classmethod - def destroy(cls, *ids): + def destroy(cls, *ids, commit=True): """Delete the records with the given ids :type ids: list :param ids: primary key ids of records + :param commit: where to commit the transaction """ for pk in ids: obj = cls.find(pk) if obj: - obj.delete() + obj.delete(commit=commit) cls.session.flush() @classmethod diff --git a/sqlalchemy_mixins/tests/test_activerecord.py b/sqlalchemy_mixins/tests/test_activerecord.py index 8284b2a..1f8b0df 100644 --- a/sqlalchemy_mixins/tests/test_activerecord.py +++ b/sqlalchemy_mixins/tests/test_activerecord.py @@ -1,5 +1,6 @@ import unittest +import sqlalchemy import sqlalchemy as sa from sqlalchemy import create_engine from sqlalchemy.ext.hybrid import hybrid_property @@ -115,6 +116,30 @@ def test_fill_and_save(self): self.assertEqual(p11, sess.query(Post).first()) self.assertEqual(p11.archived, True) + def test_save_commits(self): + with self.assertRaises(sqlalchemy.exc.InvalidRequestError): + with sess.begin(): + u1 = User() + u1.fill(name='Bill u1') + u1.save() + u2 = User() + u2.fill(name='Bill u2') + u2.save() + self.assertEqual([u1, u2], sess.query(User).order_by(User.id.asc()).all()) + # The first user is saved even when the block raises a Exception + self.assertEqual([u1], sess.query(User).order_by(User.id.asc()).all()) + + def test_save_do_not_commit(self): + with sess.begin(): + u1 = User() + u1.fill(name='Bill u1') + u1.save(commit=False) + u2 = User() + u2.fill(name='Bill u2') + u2.save(commit=False) + + self.assertEqual([u1,u2], sess.query(User).order_by(User.id.asc()).all()) + def test_create(self): u1 = User.create(name='Bill u1') self.assertEqual(u1, sess.query(User).first()) @@ -158,6 +183,16 @@ def test_update(self): self.assertEqual(sess.query(Post).get(11).public, True) self.assertEqual(sess.query(Post).get(11).user, u2) + def test_update_no_commit(self): + u1 = User(name='Bill', id=1) + u1.save() + u1.update(name='Joe', commit=False) + self.assertEqual('Joe', sess.query(User).where(User.id==1).first().name) + sess.rollback() + self.assertEqual('Bill', sess.query(User).where(User.id==1).first().name) + + + def test_fill_wrong_attribute(self): u1 = User(name='Bill u1') sess.add(u1) @@ -179,6 +214,15 @@ def test_delete(self): u1.delete() self.assertEqual(sess.query(User).get(1), None) + def test_delete_without_commit(self): + u1 = User() + u1.save() + u1.delete(commit=False) + self.assertIsNone(sess.query(User).one_or_none()) + sess.rollback() + self.assertIsNotNone(sess.query(User).one_or_none()) + + def test_destroy(self): u1, u2, p11, p12, p13 = self._seed() @@ -186,6 +230,16 @@ def test_destroy(self): Post.destroy(11, 12) self.assertEqual(set(sess.query(Post).all()), {p13}) + + def test_destroy_no_commit(self): + u1, u2, p11, p12, p13 = self._seed() + sess.commit() + self.assertEqual(set(sess.query(Post).order_by(Post.id).all()), {p11, p12, p13}) + Post.destroy(11, 12, commit=False) + self.assertEqual(set(sess.query(Post).order_by(Post.id).all()), {p13}) + sess.rollback() + self.assertEqual(set(sess.query(Post).order_by(Post.id).all()), {p11, p12, p13}) + def test_all(self): u1, u2, p11, p12, p13 = self._seed() @@ -231,6 +285,12 @@ def test_create(self): u1 = UserAlternative.create(name='Bill u1') self.assertEqual(u1, sess.query(UserAlternative).first()) + def test_create_no_commit(self): + u1 = UserAlternative.create(name='Bill u1', commit=False) + self.assertEqual(u1, sess.query(UserAlternative).first()) + sess.rollback() + self.assertIsNone(sess.query(UserAlternative).one_or_none()) + if __name__ == '__main__': # pragma: no cover