Skip to content

Commit

Permalink
feat: support savepoint statement (#230)
Browse files Browse the repository at this point in the history
  • Loading branch information
dk-lockdown committed Aug 6, 2022
1 parent 0daddb0 commit 146f75d
Show file tree
Hide file tree
Showing 16 changed files with 11,842 additions and 10,889 deletions.
6 changes: 3 additions & 3 deletions pkg/dt/mysql_undo_log_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func (manager MysqlUndoLogManager) Undo(db proto.DB, xid string) ([]string, erro
tableMeta, err := meta.GetTableMetaCache().GetTableMeta(
proto.WithSchema(context.Background(), sqlUndoLog.SchemaName), db, sqlUndoLog.TableName)
if err != nil {
if _, err := tx.Rollback(context.Background()); err != nil {
if _, err := tx.Rollback(context.Background(), nil); err != nil {
return lockKeys, err
}
return lockKeys, err
Expand All @@ -124,7 +124,7 @@ func (manager MysqlUndoLogManager) Undo(db proto.DB, xid string) ([]string, erro
sqlUndoLog.SetTableMeta(tableMeta)
err = NewMysqlUndoExecutor(sqlUndoLog).Execute(tx)
if err != nil {
if _, err := tx.Rollback(context.Background()); err != nil {
if _, err := tx.Rollback(context.Background(), nil); err != nil {
return lockKeys, err
}
return lockKeys, err
Expand All @@ -135,7 +135,7 @@ func (manager MysqlUndoLogManager) Undo(db proto.DB, xid string) ([]string, erro
if exists {
_, _, err := tx.ExecuteSql(context.Background(), DeleteUndoLogByXIDSql, xid)
if err != nil {
if _, err := tx.Rollback(context.Background()); err != nil {
if _, err := tx.Rollback(context.Background(), nil); err != nil {
return lockKeys, err
}
return lockKeys, err
Expand Down
4 changes: 2 additions & 2 deletions pkg/executor/read_write_splitting.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ func (executor *ReadWriteSplittingExecutor) ExecutorComQuery(
defer executor.localTransactionMap.Delete(connectionID)
tx = txi.(proto.Tx)
// TODO add metrics
if result, err = tx.Rollback(spanCtx); err != nil {
if result, err = tx.Rollback(spanCtx, stmt); err != nil {
return nil, 0, err
}
return result, 0, err
Expand Down Expand Up @@ -352,7 +352,7 @@ func (executor *ReadWriteSplittingExecutor) ConnectionClose(ctx context.Context)
return
}
tx := txi.(proto.Tx)
if _, err := tx.Rollback(ctx); err != nil {
if _, err := tx.Rollback(ctx, nil); err != nil {
log.Error(err)
}
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/read_write_splitting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func TestReadWriteSplittingExecutor(t *testing.T) {
tx.EXPECT().Query(gomock.Any(), gomock.Any()).Return(&mysql.Result{}, uint16(0), nil).MaxTimes(100)
tx.EXPECT().ExecuteStmt(gomock.Any(), gomock.Any()).Return(&mysql.Result{}, uint16(0), nil).MaxTimes(100)
tx.EXPECT().Commit(gomock.Any()).Return(&mysql.Result{}, nil).MaxTimes(100)
tx.EXPECT().Rollback(gomock.Any()).Return(&mysql.Result{}, nil).MaxTimes(100)
tx.EXPECT().Rollback(gomock.Any(), gomock.Any()).Return(&mysql.Result{}, nil).MaxTimes(100)

manager := testdata.NewMockDBManager(ctrl)
manager.EXPECT().GetDB(gomock.Any()).AnyTimes().Return(db)
Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/sharding.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ func (executor *ShardingExecutor) ConnectionClose(ctx context.Context) {
return
}
// TODO add metrics
if _, err := tx.Rollback(ctx); err != nil {
if _, err := tx.Rollback(ctx, nil); err != nil {
log.Error(err)
}
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/executor/single_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ func (executor *SingleDBExecutor) ExecutorComQuery(
defer executor.localTransactionMap.Delete(connectionID)
tx = txi.(proto.Tx)
// TODO add metrics
if result, err = tx.Rollback(spanCtx); err != nil {
if result, err = tx.Rollback(spanCtx, stmt); err != nil {
return nil, 0, err
}
return result, 0, err
Expand Down Expand Up @@ -264,7 +264,7 @@ func (executor *SingleDBExecutor) ConnectionClose(ctx context.Context) {
return
}
tx := txi.(proto.Tx)
if _, err := tx.Rollback(ctx); err != nil {
if _, err := tx.Rollback(ctx, nil); err != nil {
log.Error(err)
}
executor.localTransactionMap.Delete(connectionID)
Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/single_db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func TestSingleDBExecutor(t *testing.T) {
tx.EXPECT().Query(gomock.Any(), gomock.Any()).Return(&mysql.Result{}, uint16(0), nil).MaxTimes(100)
tx.EXPECT().ExecuteStmt(gomock.Any(), gomock.Any()).Return(&mysql.Result{}, uint16(0), nil).MaxTimes(100)
tx.EXPECT().Commit(gomock.Any()).Return(&mysql.Result{}, nil).MaxTimes(100)
tx.EXPECT().Rollback(gomock.Any()).Return(&mysql.Result{}, nil).MaxTimes(100)
tx.EXPECT().Rollback(gomock.Any(), gomock.Any()).Return(&mysql.Result{}, nil).MaxTimes(100)

manager := testdata.NewMockDBManager(ctrl)
manager.EXPECT().GetDB(gomock.Any()).AnyTimes().Return(db)
Expand Down
2 changes: 1 addition & 1 deletion pkg/proto/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ type (
ExecuteStmt(ctx context.Context, stmt *Stmt) (Result, uint16, error)
ExecuteSql(ctx context.Context, sql string, args ...interface{}) (Result, uint16, error)
Commit(ctx context.Context) (Result, error)
Rollback(ctx context.Context) (Result, error)
Rollback(ctx context.Context, stmt *ast.RollbackStmt) (Result, error)
}

// Executor ...
Expand Down
9 changes: 7 additions & 2 deletions pkg/sql/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
err2 "github.com/cectc/dbpack/pkg/errors"
"github.com/cectc/dbpack/pkg/proto"
"github.com/cectc/dbpack/pkg/tracing"
"github.com/cectc/dbpack/third_party/parser/ast"
)

type Tx struct {
Expand Down Expand Up @@ -130,7 +131,7 @@ func (tx *Tx) Commit(ctx context.Context) (result proto.Result, err error) {
return
}

func (tx *Tx) Rollback(ctx context.Context) (result proto.Result, err error) {
func (tx *Tx) Rollback(ctx context.Context, stmt *ast.RollbackStmt) (result proto.Result, err error) {
_, span := tracing.GetTraceSpan(ctx, tracing.TxRollback)
span.SetAttributes(attribute.KeyValue{Key: "db", Value: attribute.StringValue(tx.db.name)})
defer span.End()
Expand All @@ -141,7 +142,11 @@ func (tx *Tx) Rollback(ctx context.Context) (result proto.Result, err error) {
if tx.db == nil || tx.db.IsClosed() {
return nil, err2.ErrInvalidConn
}
result, err = tx.conn.Execute("ROLLBACK", false)
if stmt != nil && stmt.SavepointName != "" {
result, err = tx.conn.Execute(fmt.Sprintf("ROLLBACK TO %s", stmt.SavepointName), false)
} else {
result, err = tx.conn.Execute("ROLLBACK", false)
}
tx.db.pool.Put(tx.conn)
tx.Close()
return
Expand Down
3 changes: 1 addition & 2 deletions testdata/mock_db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions testdata/mock_db_manager.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions testdata/mock_tx.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

56 changes: 54 additions & 2 deletions third_party/parser/ast/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -599,11 +599,17 @@ type RollbackStmt struct {
stmtNode
// CompletionType overwrites system variable `completion_type` within transaction
CompletionType CompletionType
// SavepointName is the savepoint name.
SavepointName string
}

// Restore implements Node interface.
func (n *RollbackStmt) Restore(ctx *format.RestoreCtx) error {
ctx.WriteKeyWord("ROLLBACK")
if n.SavepointName != "" {
ctx.WritePlain(" TO ")
ctx.WritePlain(n.SavepointName)
}
if err := n.CompletionType.Restore(ctx); err != nil {
return errors.Annotate(err, "An error occurred while restore RollbackStmt.CompletionType")
}
Expand Down Expand Up @@ -871,6 +877,48 @@ func (n *KillStmt) Accept(v Visitor) (Node, bool) {
return v.Leave(n)
}

// SavepointStmt is the statement of SAVEPOINT.
type SavepointStmt struct {
stmtNode
// Name is the savepoint name.
Name string
}

// Restore implements Node interface.
func (n *SavepointStmt) Restore(ctx *format.RestoreCtx) error {
ctx.WriteKeyWord("SAVEPOINT ")
ctx.WritePlain(n.Name)
return nil
}

// Accept implements Node Accept interface.
func (n *SavepointStmt) Accept(v Visitor) (Node, bool) {
newNode, _ := v.Enter(n)
n = newNode.(*SavepointStmt)
return v.Leave(n)
}

// ReleaseSavepointStmt is the statement of RELEASE SAVEPOINT.
type ReleaseSavepointStmt struct {
stmtNode
// Name is the savepoint name.
Name string
}

// Restore implements Node interface.
func (n *ReleaseSavepointStmt) Restore(ctx *format.RestoreCtx) error {
ctx.WriteKeyWord("RELEASE SAVEPOINT ")
ctx.WritePlain(n.Name)
return nil
}

// Accept implements Node Accept interface.
func (n *ReleaseSavepointStmt) Accept(v Visitor) (Node, bool) {
newNode, _ := v.Enter(n)
n = newNode.(*ReleaseSavepointStmt)
return v.Leave(n)
}

// SetStmt is the statement to set variables.
type SetStmt struct {
stmtNode
Expand Down Expand Up @@ -1907,6 +1955,12 @@ func (n *ShowSlow) Restore(ctx *format.RestoreCtx) error {
return nil
}

// LimitSimple is the struct for Admin statement limit option.
type LimitSimple struct {
Count uint64
Offset uint64
}

// AdminStmt is the struct for Admin statement.
type AdminStmt struct {
stmtNode
Expand Down Expand Up @@ -3385,8 +3439,6 @@ func (n *TableOptimizerHint) Restore(ctx *format.RestoreCtx) error {
ctx.WriteString(hintData.VarName)
ctx.WritePlain(", ")
ctx.WriteString(hintData.Value)
case "xid":
ctx.WriteString(n.HintData.(model.CIStr).String())
}
ctx.WritePlain(")")
return nil
Expand Down

0 comments on commit 146f75d

Please sign in to comment.