diff --git a/src/main/java/com/github/pgasync/Db.java b/src/main/java/com/github/pgasync/Db.java index d27f7c6..42c2759 100644 --- a/src/main/java/com/github/pgasync/Db.java +++ b/src/main/java/com/github/pgasync/Db.java @@ -5,7 +5,7 @@ * * @author Antti Laisi */ -public interface Db extends QueryExecutor, TransactionExecutor, Listenable, AutoCloseable { +public interface Db extends QueryExecutor, Listenable, AutoCloseable { /** * Closes the pool, blocks the calling thread until connections are closed. diff --git a/src/main/java/com/github/pgasync/QueryExecutor.java b/src/main/java/com/github/pgasync/QueryExecutor.java index c2ae195..18f8cdd 100644 --- a/src/main/java/com/github/pgasync/QueryExecutor.java +++ b/src/main/java/com/github/pgasync/QueryExecutor.java @@ -12,6 +12,11 @@ */ public interface QueryExecutor { + /** + * Begins a transaction. + */ + Observable begin(); + /** * Executes an anonymous prepared statement. Uses native PostgreSQL syntax with $arg instead of ? * to mark parameters. Supported parameter types are String, Character, Number, Time, Date, Timestamp @@ -34,6 +39,16 @@ public interface QueryExecutor { */ Observable querySet(String sql, Object... params); + /** + * Begins a transaction. + * + * @param onTransaction Called when transaction is successfully started. + * @param onError Called on exception thrown + */ + default void begin(Consumer onTransaction, Consumer onError) { + begin().subscribe(onTransaction::accept, onError::accept); + } + /** * Executes a simple query. * diff --git a/src/main/java/com/github/pgasync/TransactionExecutor.java b/src/main/java/com/github/pgasync/TransactionExecutor.java deleted file mode 100644 index 1b4ecb6..0000000 --- a/src/main/java/com/github/pgasync/TransactionExecutor.java +++ /dev/null @@ -1,28 +0,0 @@ -package com.github.pgasync; - -import rx.Observable; - -import java.util.function.Consumer; - -/** - * TransactionExecutor begins backend transactions. - * - * @author Antti Laisi - */ -public interface TransactionExecutor { - - /** - * Begins a transaction. - */ - Observable begin(); - - /** - * Begins a transaction. - * - * @param onTransaction Called when transaction is successfully started. - * @param onError Called on exception thrown - */ - default void begin(Consumer onTransaction, Consumer onError) { - begin().subscribe(onTransaction::accept, onError::accept); - } -} diff --git a/src/main/java/com/github/pgasync/impl/PgConnection.java b/src/main/java/com/github/pgasync/impl/PgConnection.java index 29a0f10..efc42e8 100644 --- a/src/main/java/com/github/pgasync/impl/PgConnection.java +++ b/src/main/java/com/github/pgasync/impl/PgConnection.java @@ -14,15 +14,7 @@ package com.github.pgasync.impl; -import com.github.pgasync.Connection; -import com.github.pgasync.ResultSet; -import com.github.pgasync.Row; -import com.github.pgasync.Transaction; -import com.github.pgasync.impl.conversion.DataConverter; -import com.github.pgasync.impl.message.*; -import rx.Observable; -import rx.Subscriber; -import rx.observers.Subscribers; +import static com.github.pgasync.impl.message.RowDescription.ColumnDescription; import java.util.ArrayList; import java.util.HashMap; @@ -32,7 +24,27 @@ import java.util.concurrent.TimeUnit; import java.util.logging.Logger; -import static com.github.pgasync.impl.message.RowDescription.ColumnDescription; +import com.github.pgasync.Connection; +import com.github.pgasync.ResultSet; +import com.github.pgasync.Row; +import com.github.pgasync.Transaction; +import com.github.pgasync.impl.conversion.DataConverter; +import com.github.pgasync.impl.message.Authentication; +import com.github.pgasync.impl.message.Bind; +import com.github.pgasync.impl.message.CommandComplete; +import com.github.pgasync.impl.message.DataRow; +import com.github.pgasync.impl.message.ExtendedQuery; +import com.github.pgasync.impl.message.Message; +import com.github.pgasync.impl.message.Parse; +import com.github.pgasync.impl.message.PasswordMessage; +import com.github.pgasync.impl.message.Query; +import com.github.pgasync.impl.message.ReadyForQuery; +import com.github.pgasync.impl.message.RowDescription; +import com.github.pgasync.impl.message.StartupMessage; + +import rx.Observable; +import rx.Subscriber; +import rx.observers.Subscribers; /** * A connection to PostgreSQL backed. The postmaster forks a backend process for @@ -184,6 +196,10 @@ static Map getColumns(ColumnDescription[] descriptions) { */ class PgConnectionTransaction implements Transaction { + @Override + public Observable begin() { + return querySet("SAVEPOINT sp_1").map(rs -> new PgConnectionNestedTransaction(1)); + } @Override public Observable commit() { return PgConnection.this.querySet("COMMIT") @@ -211,4 +227,30 @@ Observable doRollback(Throwable t) { } } + /** + * Nested Transaction using savepoints. + */ + class PgConnectionNestedTransaction extends PgConnectionTransaction { + + final int depth; + + PgConnectionNestedTransaction(int depth) { + this.depth = depth; + } + @Override + public Observable begin() { + return querySet("SAVEPOINT sp_" + (depth+1)) + .map(rs -> new PgConnectionNestedTransaction(depth+1)); + } + @Override + public Observable commit() { + return PgConnection.this.querySet("RELEASE SAVEPOINT sp_" + depth) + .map(rs -> (Void) null); + } + @Override + public Observable rollback() { + return PgConnection.this.querySet("ROLLBACK TO SAVEPOINT sp_" + depth) + .map(rs -> (Void) null); + } + } } diff --git a/src/main/java/com/github/pgasync/impl/PgConnectionPool.java b/src/main/java/com/github/pgasync/impl/PgConnectionPool.java index 066c6b5..b39d0f0 100644 --- a/src/main/java/com/github/pgasync/impl/PgConnectionPool.java +++ b/src/main/java/com/github/pgasync/impl/PgConnectionPool.java @@ -252,6 +252,12 @@ class ReleasingTransaction implements Transaction { this.transaction = transaction; } + @Override + public Observable begin() { + // Nested transactions should not release things automatically. + return transaction.begin(); + } + @Override public Observable rollback() { return transaction.rollback() diff --git a/src/test/java/com/github/pgasync/impl/TransactionTest.java b/src/test/java/com/github/pgasync/impl/TransactionTest.java index fd8df68..fc42e45 100644 --- a/src/test/java/com/github/pgasync/impl/TransactionTest.java +++ b/src/test/java/com/github/pgasync/impl/TransactionTest.java @@ -142,4 +142,61 @@ public void shouldInvalidateTxConnAfterError() throws Exception { assertEquals(0, dbr.query("SELECT ID FROM TX_TEST WHERE ID = 22").size()); } + @Test + public void shouldSupportNestedTransactions() throws Exception { + CountDownLatch sync = new CountDownLatch(1); + + dbr.db().begin((transaction) -> + transaction.begin((nested) -> + nested.query("INSERT INTO TX_TEST(ID) VALUES(19)", result -> { + assertEquals(1, result.updatedRows()); + nested.commit(() -> transaction.commit(sync::countDown, err), err); + }, err), + err), + err); + + assertTrue(sync.await(5, TimeUnit.SECONDS)); + assertEquals(1L, dbr.query("SELECT ID FROM TX_TEST WHERE ID = 19").size()); + } + + @Test + public void shouldRollbackNestedTransaction() throws Exception { + CountDownLatch sync = new CountDownLatch(1); + + dbr.db().begin((transaction) -> + transaction.query("INSERT INTO TX_TEST(ID) VALUES(24)", result -> { + assertEquals(1, result.updatedRows()); + transaction.begin((nested) -> + nested.query("INSERT INTO TX_TEST(ID) VALUES(23)", res2 -> { + assertEquals(1, res2.updatedRows()); + nested.rollback(() -> transaction.commit(sync::countDown, err), err); + }, err), err); + }, err), + err); + + assertTrue(sync.await(5, TimeUnit.SECONDS)); + assertEquals(1L, dbr.query("SELECT ID FROM TX_TEST WHERE ID = 24").size()); + assertEquals(0L, dbr.query("SELECT ID FROM TX_TEST WHERE ID = 23").size()); + } + + @Test + public void shouldRollbackNestedTransactionOnBackendError() throws Exception { + CountDownLatch sync = new CountDownLatch(1); + + dbr.db().begin((transaction) -> + transaction.query("INSERT INTO TX_TEST(ID) VALUES(25)", result -> { + assertEquals(1, result.updatedRows()); + transaction.begin((nested) -> + nested.query("INSERT INTO TX_TEST(ID) VALUES(26)", res2 -> { + assertEquals(1, res2.updatedRows()); + nested.query("INSERT INTO TD_TEST(ID) VALUES(26)", + fail, t -> transaction.commit(sync::countDown, err)); + }, err), err); + }, err), + err); + + assertTrue(sync.await(5, TimeUnit.SECONDS)); + assertEquals(1L, dbr.query("SELECT ID FROM TX_TEST WHERE ID = 25").size()); + assertEquals(0L, dbr.query("SELECT ID FROM TX_TEST WHERE ID = 26").size()); + } }