Skip to content

Commit

Permalink
Merge pull request #3 from Code-Hex/change/rollback
Browse files Browse the repository at this point in the history
Changes MustRollback roles
  • Loading branch information
Code-Hex authored Jun 25, 2018
2 parents 1718484 + 724a728 commit 99ac97e
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 107 deletions.
62 changes: 58 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ Transaction handling for database it extends https://github.com/jmoiron/sqlx

## Synopsis

### Standard
<details>
<summary>Standard</summary>

```go
db := sqlx.MustOpen("mysql", dsn())
Expand All @@ -33,18 +34,70 @@ tx.MustExec("UPDATE person SET email = ? WHERE first_name = ? AND last_name = ?"

var p Person
if err := tx.Get(&p, "SELECT * FROM person LIMIT 1"); err != nil {
panic(err)
return err
}

// transaction commits
if err := tx.Commit(); err != nil {
panic(err)
return err
}

fmt.Println(p)
```
</details>

<details>
<summary>Nested Transaction</summary>

```go
db := sqlx.MustOpen("mysql", dsn())

func() {
// We should prepare to recover from panic.
defer func() {
if r := recover(); r != nil {
// Do something recover process
}
}()
// Start nested transaction.
// To be simple, we will cause panic if something sql process if failed.
func() {
// starts transaction statements
tx, err := db.BeginTxm()
if err != nil {
panic(err)
}
// Do rollbacks if fail something in nested transaction.
defer tx.Rollback()
func() {
// You don't need error handle in already began transaction.
tx2, _ := db.BeginTxm()
defer tx2.Rollback()
tx2.MustExec("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)", "Code", "Hex", "x00.x7f@gmail.com")
// Do something processing.
// You should cause panic() if something failed.
if err := tx2.Commit(); err != nil {
panic(err)
}
}()
tx.MustExec("UPDATE person SET email = ? WHERE first_name = ? AND last_name = ?", "a@b.com", "Code", "Hex")
if err := tx.Commit(); err != nil {
panic(err)
}
}()
}()

var p Person
if err := tx.Get(&p, "SELECT * FROM person LIMIT 1"); err != nil {
return err
}

fmt.Println(p)
```
</details>

### Transaction block
<details>
<summary>Transaction block</summary>

```go
var p Person
Expand Down Expand Up @@ -76,6 +129,7 @@ if err := tm.Runx(db, func(tx tm.Executorx) error {
}
println(&p)
```
</details>

## Description

Expand Down
2 changes: 1 addition & 1 deletion eg/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func dsn() string {

func loadDefaultFixture(db *sqlx.DB) {
tx := db.MustBeginTxm()
defer tx.MustRollback()
defer tx.Rollback()
// If you want to know about tx.Rebind, See http://jmoiron.github.io/sqlx/#bindvars
tx.MustExec(tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "Jason", "Moiron", "jmoiron@jmoiron.net")
tx.MustExec(tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "John", "Doe", "johndoeDNE@gmail.net")
Expand Down
2 changes: 1 addition & 1 deletion eg/tm/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func dsn() string {

func loadDefaultFixture(db *sqlx.DB) {
tx := db.MustBeginTxm()
defer tx.MustRollback()
defer tx.Rollback()
// If you want to know about tx.Rebind, See http://jmoiron.github.io/sqlx/#bindvars
tx.MustExec(tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "Jason", "Moiron", "jmoiron@jmoiron.net")
tx.MustExec(tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "John", "Doe", "johndoeDNE@gmail.net")
Expand Down
5 changes: 0 additions & 5 deletions errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,6 @@ import (
)

func TestErrors(t *testing.T) {
txer := new(NestedBeginTxErr)
if txer.Error() != beginTxErrMsg {
t.Fatal("Something error")
}

cterr := new(NestedCommitErr)
if cterr.Error() != commitErrMsg {
t.Fatal("Something error")
Expand Down
8 changes: 0 additions & 8 deletions transaction_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,3 @@ type NestedCommitErr struct{}
func (n *NestedCommitErr) Error() string {
return commitErrMsg
}

// NestedBeginTxErr is an error type to notice that
// restart transaction in already begun transaction.
type NestedBeginTxErr struct{}

func (n *NestedBeginTxErr) Error() string {
return beginTxErrMsg
}
36 changes: 14 additions & 22 deletions transaction_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ func Open(driverName, dataSourceName string) (*DB, error) {
if err != nil {
return nil, err
}
return &DB{DB: db, activeTx: &activeTx{}}, err
return &DB{
DB: db,
activeTx: &activeTx{},
rollbacked: &rollbacked{},
}, nil
}

// MustOpen returns only pointer of DB struct to manage transaction.
Expand Down Expand Up @@ -67,7 +71,7 @@ func (db *DB) setTx(tx *sqlxx.Tx) {
db.tx = &Txm{
Tx: tx,
activeTx: db.activeTx,
rollbacked: &rollbacked{},
rollbacked: db.rollbacked,
}
}

Expand All @@ -90,7 +94,7 @@ func (db *DB) BeginTxm() (*Txm, error) {
db.setTx(tx)
return db.getTxm(), nil
}
return db.getTxm(), new(NestedBeginTxErr)
return db.getTxm(), nil
}

// MustBeginTxm is like BeginTxm but panics
Expand Down Expand Up @@ -118,7 +122,7 @@ func (db *DB) BeginTxmx(ctx context.Context, opts *sql.TxOptions) (*Txm, error)
db.setTx(tx)
return db.getTxm(), nil
}
return db.getTxm(), new(NestedBeginTxErr)
return db.getTxm(), nil
}

// MustBeginTxmx is like BeginTxmx but panics
Expand All @@ -134,23 +138,19 @@ func (db *DB) MustBeginTxmx(ctx context.Context, opts *sql.TxOptions) (*Txm, err
// Commit commits the transaction.
func (t *Txm) Commit() error {
if t.rollbacked.already() {
return new(NestedCommitErr)
panic(new(NestedCommitErr))
}
t.activeTx.decrement()
if !t.activeTx.has() {
return t.Tx.Commit()
if err := t.Tx.Commit(); err != nil {
return err
}
t.reset()
return nil
}
return nil
}

// MustCommit is like Commit but panics if Commit is failed.
func (t *Txm) MustCommit() {
defer t.reset()
if err := t.Tx.Commit(); err != nil {
panic(err)
}
}

// Rollback rollbacks the transaction.
func (t *Txm) Rollback() error {
if !t.activeTx.has() {
Expand All @@ -164,14 +164,6 @@ func (t *Txm) Rollback() error {
return t.Tx.Rollback()
}

// MustRollback is like Rollback but panics if Rollback is failed.
func (t *Txm) MustRollback() {
defer t.reset()
if err := t.Tx.Rollback(); err != nil {
panic(err)
}
}

// In expands slice values in args, returning the modified query string
// and a new arg list that can be executed by a database. The `query` should
// use the `?` bindVar. The return value uses the `?` bindVar.
Expand Down
16 changes: 11 additions & 5 deletions transaction_manager_atomic_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sqlx

import (
"fmt"
"sync"
"testing"
)
Expand All @@ -18,21 +19,25 @@ func TestAtomicCount(t *testing.T) {
go func(d *DB) {
defer wg.Done()
_, err := db.BeginTxm()
if e, ok := err.(*NestedBeginTxErr); !ok {
panic(e)
if err != nil {
panic(err)
}
}(db)
}
wg.Wait()

if uint64(times) != db.activeTx.get() {
t.Fatalf("Failed to atomic count in db activeTx: %d, expected %d", db.activeTx.get(), times)
panic(
fmt.Sprintf("Failed to atomic count in db activeTx: %d, expected %d", db.activeTx.get(), times),
)
}
if uint64(times) != tx.activeTx.get() {
t.Fatalf("Failed to atomic count in tx activeTx: %d, expected %d", tx.activeTx.get(), times)
panic(
fmt.Sprintf("Failed to atomic count in tx activeTx: %d, expected %d", tx.activeTx.get(), times),
)
}

for i := 1; i < times; i++ {
for i := 1; i <= times; i++ {
wg.Add(1)
go func(txm *Txm) {
defer wg.Done()
Expand All @@ -47,5 +52,6 @@ func TestAtomicCount(t *testing.T) {
if err := tx.Rollback(); tx.activeTx.has() || err != nil {
t.Fatalf("Failed to many rollback: error(%s), activeTx(%d)", err, tx.activeTx.get())
}
tx.reset()
})
}
93 changes: 32 additions & 61 deletions transaction_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,7 @@ func TestNestedCommit(t *testing.T) {
nested := func(db *DB) {
tx, err := db.BeginTxm()
if err != nil {
if _, ok := err.(*NestedBeginTxErr); !ok {
t.Fatal(err)
}
t.Fatal(err)
}
if tx == nil {
t.Fatal("Failed to return tx")
Expand All @@ -221,9 +219,7 @@ func TestNestedCommit(t *testing.T) {
nestedmore := func(db *DB) {
tx, err := db.BeginTxm()
if err != nil {
if _, ok := err.(*NestedBeginTxErr); !ok {
t.Fatal(err)
}
t.Fatal(err)
}
nested(db)
if tx == nil {
Expand Down Expand Up @@ -260,76 +256,51 @@ func TestNestedRollback(t *testing.T) {
nested := func(db *DB) {
tx, err := db.BeginTxm()
if err != nil {
if _, ok := err.(*NestedBeginTxErr); !ok {
t.Fatal(err)
}
}
if tx == nil {
t.Fatal("Failed to return tx")
t.Fatal(err)
}
defer tx.Rollback()
tx.MustExec(tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "Code", "Hex", "x00.x7f@gmail.com")
if !tx.activeTx.has() {
t.Fatal("Failed having active transaction in nested BEGIN")
}
panic("Something failed")
// Maybe we will `tx.Commit()` at last
}
RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) {
nestedmore := func(db *DB) {
tx, err := db.BeginTxm()
if err != nil {
t.Fatal(err)
}
tx.MustExec(tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "Code", "Hex", "x00.x7f@gmail.com")
tx.MustExec(tx.Rebind("UPDATE person SET email = ? WHERE first_name = ? AND last_name = ?"), "a@b.com", "Code", "Hex")

// I will try begin 4 times
nested(db)
defer tx.Rollback()
nested(db)
nestedmore := func(db *DB) {
tx, err := db.BeginTxm()
if err != nil {
if _, ok := err.(*NestedBeginTxErr); !ok {
tx.MustExec(tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "Code", "Hex", "x00.x7f@gmail.com")
tx.Commit() // maybe will not be reach
}
RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) {
func() {
// Panic handler
defer func() {
r := recover()
if s, ok := r.(string); !ok || s != "Something failed" {
t.Fatalf("Failed to cause panic: %s", s)
}
}()
func() {
tx, err := db.BeginTxm()
if err != nil {
t.Fatal(err)
}
}
nested(db)
if tx == nil {
t.Fatal("Failed to return tx")
}
if !tx.activeTx.has() {
t.Fatal("Failed having active transaction in nested BEGIN")
}
}
nestedmore(db)

tx.Rollback() // count rollbacked +1, It will not rollback

if tx.rollbacked.times() != 1 {
t.Fatalf("Failed to count rollbacked: %d, expected 1", tx.rollbacked.times())
}

if e, ok := tx.Commit().(*NestedCommitErr); e == nil || !ok {
t.Fatal("Failed to get nested commit err")
}

tx.rollbacked.reset()

// 4 times of nested begin
// We should stop when count is 1
// because rollback can be done per transaction
for i := 1; i < 4; i++ {
if err := tx.Commit(); err != nil {
t.Fatal(err)
}
}

if tx.activeTx.get() != 1 {
t.Fatalf("Failed to decrease count: %d", tx.activeTx.get())
}

tx.Rollback() // activeTx count is 1, So it will rollback
defer tx.Rollback()
tx.MustExec(tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "Code", "Hex", "x00.x7f@gmail.com")
nestedmore(db)
tx.Commit() // maybe will not be reach
}()
}()

var author Person
if err := db.Get(&author, "SELECT * FROM person LIMIT 1"); err != sql.ErrNoRows {
if err := db.Get(&author, "SELECT * FROM person WHERE first_name = 'Code' AND last_name = 'Hex'"); err != sql.ErrNoRows {
t.Fatal(
errors.Wrapf(err, "rollback test is failed\n %s\n %s\n",
errors.Errorf("rollback test is failed\n %s\n %s\n",
fmt.Sprintf("rollbacked in nested transaction: %d", db.rollbacked.times()),
fmt.Sprintf("active tx counter: %d", db.activeTx.get()),
),
Expand Down

0 comments on commit 99ac97e

Please sign in to comment.