Skip to content

Commit

Permalink
feat: support xa protocol (#261)
Browse files Browse the repository at this point in the history
* feat: support xa protocol

* feat: support xa transaction
  • Loading branch information
dk-lockdown committed Sep 5, 2022
1 parent 97c5e18 commit 6611c17
Show file tree
Hide file tree
Showing 15 changed files with 16,813 additions and 18,806 deletions.
18 changes: 18 additions & 0 deletions pkg/executor/read_write_splitting.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,24 @@ func (executor *ReadWriteSplittingExecutor) ExecutorComQuery(
return nil, 0, err
}
return result, 0, err
case *ast.XAStartStmt:
tx, result, err = executor.dbGroup.XAStart(spanCtx, sqlText)
if err != nil {
return nil, 0, err
}
executor.localTransactionMap.Store(connectionID, tx)
return result, 0, nil
case *ast.XAPrepareStmt:
txi, ok := executor.localTransactionMap.Load(connectionID)
if !ok {
return nil, 0, errors.New("there is no transaction")
}
defer executor.localTransactionMap.Delete(connectionID)
tx = txi.(proto.Tx)
if result, err = tx.XAPrepare(ctx, sqlText); err != nil {
return nil, 0, err
}
return result, 0, err
case *ast.InsertStmt, *ast.DeleteStmt, *ast.UpdateStmt:
txi, ok := executor.localTransactionMap.Load(connectionID)
if ok {
Expand Down
18 changes: 18 additions & 0 deletions pkg/executor/single_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,24 @@ func (executor *SingleDBExecutor) ExecutorComQuery(
return nil, 0, err
}
return result, 0, err
case *ast.XAStartStmt:
tx, result, err = db.XAStart(spanCtx, sqlText)
if err != nil {
return nil, 0, err
}
executor.localTransactionMap.Store(connectionID, tx)
return result, 0, nil
case *ast.XAPrepareStmt:
txi, ok := executor.localTransactionMap.Load(connectionID)
if !ok {
return nil, 0, errors.New("there is no transaction")
}
defer executor.localTransactionMap.Delete(connectionID)
tx = txi.(proto.Tx)
if result, err = tx.XAPrepare(ctx, sqlText); err != nil {
return nil, 0, err
}
return result, 0, err
default:
txi, ok := executor.localTransactionMap.Load(connectionID)
if ok {
Expand Down
5 changes: 5 additions & 0 deletions pkg/group/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ func (group *DBGroup) Begin(ctx context.Context) (proto.Tx, proto.Result, error)
return dbs[0].Begin(ctx)
}

func (group *DBGroup) XAStart(ctx context.Context, sql string) (proto.Tx, proto.Result, error) {
dbs := group.getAvailableMasters()
return dbs[0].XAStart(ctx, sql)
}

func (group *DBGroup) Query(ctx context.Context, query string) (proto.Result, uint16, error) {
db := group.pick(ctx)
return db.Query(ctx, query)
Expand Down
3 changes: 3 additions & 0 deletions pkg/proto/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ type (
ExecuteSql(ctx context.Context, sql string, args ...interface{}) (Result, uint16, error)
ExecuteSqlDirectly(sql string, args ...interface{}) (Result, uint16, error)
Begin(ctx context.Context) (Tx, Result, error)
XAStart(ctx context.Context, sql string) (Tx, Result, error)
}

Tx interface {
Expand All @@ -174,6 +175,7 @@ type (
Commit(ctx context.Context) (Result, error)
Rollback(ctx context.Context, stmt *ast.RollbackStmt) (Result, error)
ReleaseSavepoint(ctx context.Context, savepoint string) (result Result, err error)
XAPrepare(ctx context.Context, sql string) (Result, error)
}

DBManager interface {
Expand All @@ -190,6 +192,7 @@ type (
PrepareQuery(ctx context.Context, query string, args ...interface{}) (Result, uint16, error)
PrepareExecute(ctx context.Context, query string, args ...interface{}) (Result, uint16, error)
PrepareExecuteStmt(ctx context.Context, stmt *Stmt) (Result, uint16, error)
XAStart(ctx context.Context, sql string) (Tx, Result, error)
}

DBGroupTx interface {
Expand Down
30 changes: 30 additions & 0 deletions pkg/sql/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,36 @@ func (db *DB) Begin(ctx context.Context) (proto.Tx, proto.Result, error) {
}, result, nil
}

func (db *DB) XAStart(ctx context.Context, sql string) (proto.Tx, proto.Result, error) {
var (
result proto.Result
conn *driver.BackendConnection
err error
)

spanCtx, span := tracing.GetTraceSpan(ctx, tracing.DBXAStart)
span.SetAttributes(attribute.KeyValue{Key: "db", Value: attribute.StringValue(db.name)})
defer span.End()

r, err := db.pool.Get(spanCtx)
if err != nil {
err = errors.WithStack(err)
return nil, nil, err
}
conn = r.(*driver.BackendConnection)

if result, err = conn.Execute(ctx, sql, false); err != nil {
db.pool.Put(r)
return nil, nil, err
}

return &Tx{
closed: atomic.NewBool(false),
db: db,
conn: conn,
}, result, nil
}

func (db *DB) SetConnectionPreFilters(filters []proto.DBConnectionPreFilter) {
db.connectionPreFilters = filters
}
Expand Down
17 changes: 17 additions & 0 deletions pkg/sql/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,23 @@ func (tx *Tx) Rollback(ctx context.Context, stmt *ast.RollbackStmt) (result prot
return
}

func (tx *Tx) XAPrepare(ctx context.Context, sql string) (result proto.Result, err error) {
_, span := tracing.GetTraceSpan(ctx, tracing.TxXAPrepare)
span.SetAttributes(attribute.KeyValue{Key: "db", Value: attribute.StringValue(tx.db.name)})
defer span.End()

if tx.closed.Load() {
return nil, nil
}
if tx.db == nil || tx.db.IsClosed() {
return nil, err2.ErrInvalidConn
}
result, err = tx.conn.Execute(ctx, sql, false)
tx.db.pool.Put(tx.conn)
tx.Close()
return
}

func (tx *Tx) ReleaseSavepoint(ctx context.Context, savepoint string) (result proto.Result, err error) {
_, span := tracing.GetTraceSpan(ctx, tracing.TxReleaseSavePoint)
span.SetAttributes(attribute.KeyValue{Key: "db", Value: attribute.StringValue(tx.db.name)})
Expand Down
10 changes: 6 additions & 4 deletions pkg/tracing/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,21 @@ const (
DBExecStmt = "db_exec_stmt"
DBExecFieldList = "db_exec_field_list"
DBLocalTransactionBegin = "db_tx_begin"
DBXAStart = "db_xa_start"

// group
GroupQuery = "group_query"
GroupExecute = "group_execute"
GroupTransactionBegin = "group_transaction_begin"
GroupTransactionBegin = "group_tx_begin"

// tx
TxQuery = "tx_query"
TxExecSQL = "tx_exec_sql"
TxExecStmt = "tx_exec_stmt"
TxCommit = "db_local_transaction_commit"
TxRollback = "db_local_transaction_rollback"
TxReleaseSavePoint = "db_local_transaction_release_savepoint"
TxCommit = "db_tx_commit"
TxRollback = "db_tx_rollback"
TxReleaseSavePoint = "db_tx_release_savepoint"
TxXAPrepare = "db_xa_prepare"

// group tx
GroupTxQuery = "group_tx_query"
Expand Down
31 changes: 31 additions & 0 deletions test/rws/read_write_splitting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
package rws

import (
"context"
"database/sql"
"testing"
"time"

_ "github.com/go-sql-driver/mysql" // register mysql
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)

Expand Down Expand Up @@ -79,6 +81,14 @@ func (suite *_ReadWriteSplittingSuite) SetupSuite() {
suite.Equal(int64(1), affected)
}
}

result, err = masterDB.Exec(insertEmployee, 100005, "1992-05-03", "jane", "lewis", "M", "2014-09-01")
if suite.NoErrorf(err, "insert row error: %v", err) {
affected, err := result.RowsAffected()
if suite.NoErrorf(err, "insert row error: %v", err) {
suite.Equal(int64(1), affected)
}
}
}
}

Expand Down Expand Up @@ -230,5 +240,26 @@ func (suite *_ReadWriteSplittingSuite) TestUpdateEncryption() {
}
}

func (suite *_ReadWriteSplittingSuite) TestXATransaction() {
ctx := context.Background()
conn, err := suite.db.Conn(ctx)
assert.Nil(suite.T(), err)
_, err = conn.ExecContext(ctx, "XA START 'abc'")
assert.Nil(suite.T(), err)
result, err := conn.ExecContext(ctx, deleteEmployee, 100005)
if suite.NoErrorf(err, "delete row error: %v", err) {
affected, err := result.RowsAffected()
if suite.NoErrorf(err, "delete row error: %v", err) {
suite.Equal(int64(1), affected)
}
}
_, err = conn.ExecContext(ctx, "XA END 'abc'")
assert.Nil(suite.T(), err)
_, err = conn.ExecContext(ctx, "XA PREPARE 'abc'")
assert.Nil(suite.T(), err)
_, err = conn.ExecContext(ctx, "XA COMMIT 'abc'")
assert.Nil(suite.T(), err)
}

func (suite *_ReadWriteSplittingSuite) TearDownSuite() {
}
30 changes: 30 additions & 0 deletions test/sdb/crud_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
package sdb

import (
"context"
"database/sql"
"testing"
"time"

_ "github.com/go-sql-driver/mysql" // register mysql
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)

Expand Down Expand Up @@ -62,6 +64,13 @@ func (suite *_CRUDSuite) SetupSuite() {
suite.Equal(int64(1), affected)
}
}
result, err = suite.db.Exec(insertEmployee, 100005, "1992-05-03", "jane", "lewis", "M", "2014-09-01")
if suite.NoErrorf(err, "insert row error: %v", err) {
affected, err := result.RowsAffected()
if suite.NoErrorf(err, "insert row error: %v", err) {
suite.Equal(int64(1), affected)
}
}
}

func (suite *_CRUDSuite) TestDelete() {
Expand Down Expand Up @@ -157,6 +166,27 @@ func (suite *_CRUDSuite) TestUpdateEncryption() {
}
}

func (suite *_CRUDSuite) TestXATransaction() {
ctx := context.Background()
conn, err := suite.db.Conn(ctx)
assert.Nil(suite.T(), err)
_, err = conn.ExecContext(ctx, "XA START 'abc'")
assert.Nil(suite.T(), err)
result, err := conn.ExecContext(ctx, deleteEmployee, 100005)
if suite.NoErrorf(err, "delete row error: %v", err) {
affected, err := result.RowsAffected()
if suite.NoErrorf(err, "delete row error: %v", err) {
suite.Equal(int64(1), affected)
}
}
_, err = conn.ExecContext(ctx, "XA END 'abc'")
assert.Nil(suite.T(), err)
_, err = conn.ExecContext(ctx, "XA PREPARE 'abc'")
assert.Nil(suite.T(), err)
_, err = conn.ExecContext(ctx, "XA COMMIT 'abc'")
assert.Nil(suite.T(), err)
}

func (suite *_CRUDSuite) TearDownSuite() {
result, err := suite.db.Exec(deleteEmployee, 100001)
if suite.NoErrorf(err, "delete row error: %v", err) {
Expand Down
16 changes: 16 additions & 0 deletions testdata/mock_db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions testdata/mock_tx.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 6611c17

Please sign in to comment.