Skip to content

Commit

Permalink
feature: support cross-database update transactions (#193)
Browse files Browse the repository at this point in the history
* feature: support cross-database update transactions
  • Loading branch information
dk-lockdown committed Jul 19, 2022
1 parent 3ecb2cc commit b184688
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 10 deletions.
4 changes: 2 additions & 2 deletions pkg/plan/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ func (p *DeletePlan) Execute(ctx context.Context, hints ...*ast.TableOptimizerHi
sql := sb.String()
log.Debugf("delete, db name: %s, sql: %s", p.Database, sql)

pp := parser.New()
stmtNode, err := pp.ParseOneStmt(sql, "", "")
_parser := parser.New()
stmtNode, err := _parser.ParseOneStmt(sql, "", "")
if err != nil {
return nil, 0, errors.WithStack(err)
}
Expand Down
70 changes: 63 additions & 7 deletions pkg/plan/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@ package plan

import (
"context"
"fmt"
"strings"

"github.com/pkg/errors"

"github.com/cectc/dbpack/pkg/constant"
"github.com/cectc/dbpack/pkg/dt"
"github.com/cectc/dbpack/pkg/log"
"github.com/cectc/dbpack/pkg/misc"
"github.com/cectc/dbpack/pkg/mysql"
"github.com/cectc/dbpack/pkg/proto"
"github.com/cectc/dbpack/pkg/visitor"
"github.com/cectc/dbpack/third_party/parser"
"github.com/cectc/dbpack/third_party/parser/ast"
"github.com/cectc/dbpack/third_party/parser/format"
)
Expand Down Expand Up @@ -53,17 +58,27 @@ func (p *UpdatePlan) Execute(ctx context.Context, hints ...*ast.TableOptimizerHi
}
for _, table := range p.Tables {
sb.Reset()
if err = p.generate(&sb, table); err != nil {
if err = p.generate(&sb, table, hints...); err != nil {
return nil, 0, errors.Wrap(err, "failed to generate sql")
}
sql := sb.String()
log.Debugf("update, db name: %s, sql: %s", p.Database, sql)

_parser := parser.New()
stmtNode, err := _parser.ParseOneStmt(sql, "", "")
if err != nil {
return nil, 0, errors.WithStack(err)
}
stmtNode.Accept(&visitor.ParamVisitor{})

commandType := proto.CommandType(ctx)
switch commandType {
case constant.ComQuery:
ctx := proto.WithQueryStmt(ctx, stmtNode)
result, warns, err = tx.Query(ctx, sql)
case constant.ComStmtExecute:
stmt := generateStatement(sql, stmtNode, p.Args)
ctx := proto.WithPrepareStmt(ctx, stmt)
result, warns, err = tx.ExecuteSql(ctx, sql, p.Args...)
default:
continue
Expand All @@ -87,10 +102,24 @@ func (p *UpdatePlan) Execute(ctx context.Context, hints ...*ast.TableOptimizerHi
return mysqlResult, warnings, nil
}

func (p *UpdatePlan) generate(sb *strings.Builder, table string) error {
func (p *UpdatePlan) generate(sb *strings.Builder, table string, hints ...*ast.TableOptimizerHint) error {
ctx := format.NewRestoreCtx(format.DefaultRestoreFlags, sb)
ctx.WriteKeyWord("UPDATE ")
// todo add xid hint for distributed transaction

if len(hints) != 0 {
ctx.WritePlain("/*+ ")
for i, tableHint := range hints {
if i != 0 {
ctx.WritePlain(" ")
}
if err := tableHint.Restore(ctx); err != nil {
return errors.Wrapf(err, "An error occurred while restoring UpdateStmt.TableHints[%d], HintName: %s",
i, tableHint.HintName.String())
}
}
ctx.WritePlain("*/ ")
}

ctx.WritePlain(table)
ctx.WriteKeyWord(" SET ")
for i, assignment := range p.Stmt.List {
Expand Down Expand Up @@ -137,18 +166,45 @@ type MultiUpdatePlan struct {
Plans []*UpdatePlan
}

func (p *MultiUpdatePlan) Execute(ctx context.Context, _ ...*ast.TableOptimizerHint) (proto.Result, uint16, error) {
func (p *MultiUpdatePlan) Execute(ctx context.Context, _ ...*ast.TableOptimizerHint) (result proto.Result, warns uint16, err error) {
var (
affectedRows uint64
warnings uint16
affected uint64
hints []*ast.TableOptimizerHint
)
// todo distributed transaction
if has, _ := misc.HasXIDHint(p.Stmt.TableHints); !has {
tableName := p.Stmt.TableRefs.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.String()
transactionManager := dt.GetDistributedTransactionManager()
timeoutVariable := proto.Variable(ctx, constant.TransactionTimeout)
timeout, ok := timeoutVariable.(int32)
if !ok {
return nil, 0, errors.New("transaction timeout must be of type int32")
}
var xid string
xid, err = transactionManager.Begin(ctx, fmt.Sprintf("UPDATE_%s", tableName), timeout)
if err != nil {
return nil, 0, err
}
hints = append(hints, misc.NewXIDHint(xid))
defer func() {
if err != nil {
if _, rollbackErr := transactionManager.Rollback(ctx, xid); rollbackErr != nil {
log.Error(err)
}
} else {
if _, commitErr := transactionManager.Commit(ctx, xid); commitErr != nil {
log.Error(err)
}
}
}()
}
for _, pl := range p.Plans {
result, warns, err := pl.Execute(ctx)
result, warns, err = pl.Execute(ctx, hints...)
if err != nil {
return nil, 0, err
}
affected, err := result.RowsAffected()
affected, err = result.RowsAffected()
if err != nil {
return nil, 0, errors.WithStack(err)
}
Expand Down
2 changes: 1 addition & 1 deletion test/shd/sharding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ func (suite *_ShardingSuite) TestDeleteDrugResource() {
suite.Assert().Nil(err)
affectedRows, err := result.RowsAffected()
suite.Assert().Nil(err)
suite.Assert().Equal(int64(1), affectedRows)
suite.Assert().Equal(int64(11), affectedRows)
time.Sleep(10 * time.Second)
}

Expand Down

0 comments on commit b184688

Please sign in to comment.