Skip to content

Commit

Permalink
Added transaction block
Browse files Browse the repository at this point in the history
  • Loading branch information
Code-Hex committed May 22, 2018
1 parent 13044bd commit ed37829
Show file tree
Hide file tree
Showing 4 changed files with 279 additions and 0 deletions.
95 changes: 95 additions & 0 deletions eg/tm/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package main

import (
"fmt"
"os"
"time"

sqlx "github.com/Code-Hex/sqlx-transactionmanager"
"github.com/Code-Hex/sqlx-transactionmanager/tm"
_ "github.com/go-sql-driver/mysql"
)

type Person struct {
FirstName string `db:"first_name"`
LastName string `db:"last_name"`
Email string `db:"email"`
AddedAt time.Time `db:"added_at"`
}

func (p *Person) String() string {
return fmt.Sprintf("%s %s: (%s) %s", p.FirstName, p.LastName, p.Email, p.AddedAt.String())
}

func dsn() string {
// You can use environment vatiables from .envrc.
// See https://github.com/direnv/direnv If you want to use .envrc.
return os.Getenv("SQLX_MYSQL_DSN")
}

func loadDefaultFixture(db *sqlx.DB) {
tx := db.MustBeginTxm()
defer tx.MustRollback()
// If you want to know about tx.Rebind, See http://jmoiron.github.io/sqlx/#bindvars
tx.MustExec(tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "Jason", "Moiron", "jmoiron@jmoiron.net")
tx.MustExec(tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "John", "Doe", "johndoeDNE@gmail.net")
tx.Commit()
}

func Connect() *sqlx.DB {
db := sqlx.MustOpen("mysql", dsn())
if err := db.Ping(); err != nil {
panic(err)
}
return db
}

func main() {
Mysql = true // use mysql
db := Connect()
defer db.Close()

// See drivername
fmt.Printf("Using: %s\n", db.DriverName())

RunWithSchema(defaultSchema, db, DoTransaction(db))
}

// DoTransaction is example for transaction
// See transaction_manager_test.go if you want to know detail.
func DoTransaction(db *sqlx.DB) func(*sqlx.DB) {
return func(db *sqlx.DB) {
var p Person
if err := tm.Run(db, func(tx tm.Executor) error {
_, err := tx.Exec("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)", "Al", "Paca", "x00.x7f@gmail.com")
if err != nil {
return err
}
_, err = tx.Exec("UPDATE person SET email = ? WHERE first_name = ? AND last_name = ?", "x@h.com", "Al", "Paca")
if err != nil {
return err
}

return tx.QueryRow("SELECT * FROM person LIMIT 1").Scan(&p.FirstName, &p.LastName, &p.Email, &p.AddedAt)
}); err != nil {
panic(err)
}
println(&p)

if err := tm.Runx(db, func(tx tm.Executorx) error {
tx.MustExec(tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "Code", "Hex", "x00.x7f@gmail.com")
tx.MustExec(tx.Rebind("UPDATE person SET email = ? WHERE first_name = ? AND last_name = ?"), "a@b.com", "Code", "Hex")
if err := tx.Get(&p, "SELECT * FROM person ORDER BY first_name DESC LIMIT 1"); err != nil {
return err
}
return nil
}); err != nil {
panic(err)
}
println(&p)
}
}

func println(str fmt.Stringer) {
fmt.Println(str)
}
Binary file added eg/tm/tm
Binary file not shown.
88 changes: 88 additions & 0 deletions eg/tm/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package main

import (
"fmt"
"strings"

sqlx "github.com/Code-Hex/sqlx-transactionmanager"
osqlx "github.com/jmoiron/sqlx"
)

var (
Postgres bool
Mysql bool
Sqlite bool
)

type Schema struct {
create string
drop string
}

func (s Schema) Postgres() (string, string) {
return s.create, s.drop
}

func (s Schema) MySQL() (string, string) {
return strings.Replace(s.create, `"`, "`", -1), s.drop
}

func (s Schema) Sqlite3() (string, string) {
return strings.Replace(s.create, `now()`, `CURRENT_TIMESTAMP`, -1), s.drop
}

var defaultSchema = Schema{
create: `
CREATE TABLE person (
first_name text,
last_name text,
email text,
added_at timestamp default now()
);
CREATE TABLE place (
country text,
city text NULL,
telcode integer
);
`,
drop: `
drop table person;
drop table place;
`,
}

func MultiExec(e osqlx.Execer, query string) {
stmts := strings.Split(query, ";\n")
if len(strings.Trim(stmts[len(stmts)-1], " \n\t\r")) == 0 {
stmts = stmts[:len(stmts)-1]
}
for _, s := range stmts {
_, err := e.Exec(s)
if err != nil {
fmt.Println(err, s)
}
}
}

func RunWithSchema(schema Schema, db *sqlx.DB, run func(db *sqlx.DB)) {
runner := func(create, drop string) {
defer func() { MultiExec(db, drop) }()
MultiExec(db, create)
run(db)
}

if Postgres {
create, drop := schema.Postgres()
runner(create, drop)
}
if Sqlite {
create, drop := schema.Sqlite3()
runner(create, drop)
}
if Mysql {
create, drop := schema.MySQL()
runner(create, drop)
}
}
96 changes: 96 additions & 0 deletions tm/transaction_block.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package tm

import (
"context"
"database/sql"

"github.com/jmoiron/sqlx"
)

// SQL interface implements for *sql.DB or wrapped it.
type SQL interface{ Begin() (*sql.Tx, error) }

// SQLx interface implements for *sqlx.DB or wrapped it.
type SQLx interface{ Beginx() (*sqlx.Tx, error) }

// Executor interface implements for *sql.Tx or wrapped it.
// It has'nt Commit and Rollback methods.
type Executor interface {
Exec(string, ...interface{}) (sql.Result, error)
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
Prepare(string) (*sql.Stmt, error)
PrepareContext(context.Context, string) (*sql.Stmt, error)
Query(string, ...interface{}) (*sql.Rows, error)
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
QueryRow(string, ...interface{}) *sql.Row
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
Stmt(*sql.Stmt) *sql.Stmt
StmtContext(context.Context, *sql.Stmt) *sql.Stmt
}

// Executorx interface implements for *sqlx.Tx or wrapped it.
// It has'nt Commit and Rollback methods.
type Executorx interface {
Executor

Get(interface{}, string, ...interface{}) error
GetContext(context.Context, interface{}, string, ...interface{}) error
MustExec(string, ...interface{}) sql.Result
MustExecContext(context.Context, string, ...interface{}) sql.Result
NamedExec(string, interface{}) (sql.Result, error)
NamedExecContext(context.Context, string, interface{}) (sql.Result, error)
NamedQuery(string, interface{}) (*sqlx.Rows, error)
NamedStmt(stmt *sqlx.NamedStmt) *sqlx.NamedStmt
NamedStmtContext(context.Context, *sqlx.NamedStmt) *sqlx.NamedStmt
PrepareNamedContext(context.Context, string) (*sqlx.NamedStmt, error)
Preparex(string) (*sqlx.Stmt, error)
PreparexContext(context.Context, string) (*sqlx.Stmt, error)
QueryRowx(string, ...interface{}) *sqlx.Row
QueryRowxContext(context.Context, string, ...interface{}) *sqlx.Row
Queryx(string, ...interface{}) (*sqlx.Rows, error)
QueryxContext(context.Context, string, ...interface{}) (*sqlx.Rows, error)
Rebind(string) string
Select(interface{}, string, ...interface{}) error
SelectContext(context.Context, interface{}, string, ...interface{}) error
Stmtx(interface{}) *sqlx.Stmt
StmtxContext(context.Context, interface{}) *sqlx.Stmt
Unsafe() *sqlx.Tx
}

// TxnFunc implemtnts for func(Executor) error
type TxnFunc func(Executor) error

// TxnxFunc implemtnts for func(Executorx) error
type TxnxFunc func(Executorx) error

// Run begins transaction around TxnFunc.
// It returns error and rollbacks if TxnFunc is failed.
// It commits if TxnFunc is successed.
func Run(db SQL, f TxnFunc) error {
tx, err := db.Begin()
if err != nil {
return err
}
if err := f(tx); err != nil {
tx.Rollback()
return err
}
tx.Commit()
return nil
}

// Runx begins transaction around TxnxFunc.
// It returns error and rollbacks if TxnxFunc is failed.
// It commits if TxnxFunc is successed.
func Runx(db SQLx, f TxnxFunc) error {
tx, err := db.Beginx()
if err != nil {
return err
}
if err := f(tx); err != nil {
tx.Rollback()
return err
}
tx.Commit()
return nil
}

0 comments on commit ed37829

Please sign in to comment.