Skip to content

Commit

Permalink
in sharding mode, support local transaction (#255)
Browse files Browse the repository at this point in the history
* in sharding mode, support local transaction
  • Loading branch information
dk-lockdown committed Sep 5, 2022
1 parent 4d6da29 commit 97c5e18
Show file tree
Hide file tree
Showing 17 changed files with 469 additions and 105 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ require (
go.opentelemetry.io/otel/sdk v1.9.0
go.opentelemetry.io/otel/trace v1.9.0
go.uber.org/multierr v1.7.0 // indirect
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect
golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac // indirect
golang.org/x/tools v0.1.10 // indirect
Expand Down
1 change: 1 addition & 0 deletions pkg/errors/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,6 @@ var (
// See https://github.com/go-sql-driver/mysql/pull/302
ErrBadConnNoWrite = errors.New("bad connection")

ErrTransactionClosed = errors.New("transaction closed")
ErrUnexpectedRead = errors.New("unexpected read from socket")
)
1 change: 1 addition & 0 deletions pkg/executor/read_write_splitting.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ func (executor *ReadWriteSplittingExecutor) ConnectionClose(ctx context.Context)
if _, err := tx.Rollback(ctx, nil); err != nil {
log.Error(err)
}
executor.localTransactionMap.Delete(connectionID)
}

func (executor *ReadWriteSplittingExecutor) doPreFilter(ctx context.Context) error {
Expand Down
87 changes: 72 additions & 15 deletions pkg/executor/sharding.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"encoding/json"
"fmt"
"strings"
"sync"

"github.com/pkg/errors"

Expand All @@ -42,10 +43,11 @@ type ShardingExecutor struct {
PreFilters []proto.DBPreFilter
PostFilters []proto.DBPostFilter

config *config.ShardingConfig
executors []proto.DBGroupExecutor
optimizer proto.Optimizer
localTransactionMap map[uint32]proto.Tx
config *config.ShardingConfig
executors []proto.DBGroupExecutor
optimizer proto.Optimizer
// map[uint32]proto.DBGroupTx
localTransactionMap *sync.Map
}

func NewShardingExecutor(conf *config.Executor) (proto.Executor, error) {
Expand Down Expand Up @@ -94,7 +96,7 @@ func NewShardingExecutor(conf *config.Executor) (proto.Executor, error) {
executors: executorSlice,
optimizer: optimize.NewOptimizer(conf.AppID,
globalTables, executorSlice, executorMap, algorithms, topologies),
localTransactionMap: make(map[uint32]proto.Tx, 0),
localTransactionMap: &sync.Map{},
}

for i := 0; i < len(conf.Filters); i++ {
Expand Down Expand Up @@ -153,7 +155,7 @@ func (executor *ShardingExecutor) ProcessDistributedTransaction() bool {

func (executor *ShardingExecutor) InLocalTransaction(ctx context.Context) bool {
connectionID := proto.ConnectionID(ctx)
_, ok := executor.localTransactionMap[connectionID]
_, ok := executor.localTransactionMap.Load(connectionID)
return ok
}

Expand Down Expand Up @@ -193,40 +195,86 @@ func (executor *ShardingExecutor) ExecutorComQuery(ctx context.Context, sql stri
var plan proto.Plan

log.Debugf("query: %s", sql)
connectionID := proto.ConnectionID(spanCtx)
queryStmt := proto.QueryStmt(spanCtx)
if queryStmt == nil {
return nil, 0, errors.New("query stmt should not be nil")
}

switch stmt := queryStmt.(type) {
case *ast.SetStmt:
for _, db := range executor.executors {
go func(dbGroup proto.DBGroupExecutor) {
if _, _, err := dbGroup.QueryAll(spanCtx, sql); err != nil {
log.Error(err)
}
}(db)
if shouldStartTransaction(stmt) {
tx := group.NewComplexTx(executor.optimizer)
executor.localTransactionMap.Store(connectionID, tx)
} else {
for _, db := range executor.executors {
go func(dbGroup proto.DBGroupExecutor) {
if _, _, err := dbGroup.QueryAll(spanCtx, sql); err != nil {
log.Error(err)
}
}(db)
}
}

return &mysql.Result{
AffectedRows: 0,
InsertId: 0,
}, 0, nil
case *ast.ShowStmt:
return executor.executors[0].Query(spanCtx, sql)
case *ast.BeginStmt:
tx := group.NewComplexTx(executor.optimizer)
executor.localTransactionMap.Store(connectionID, tx)
return &mysql.Result{
AffectedRows: 0,
InsertId: 0,
}, 0, nil
case *ast.CommitStmt:
txi, ok := executor.localTransactionMap.Load(connectionID)
if !ok {
return nil, 0, errors.New("there is no transaction")
}
defer executor.localTransactionMap.Delete(connectionID)
tx := txi.(proto.DBGroupTx)
// TODO add metrics
if result, err = tx.Commit(spanCtx); err != nil {
return nil, 0, err
}
return result, 0, err
case *ast.RollbackStmt:
txi, ok := executor.localTransactionMap.Load(connectionID)
if !ok {
return nil, 0, errors.New("there is no transaction")
}
defer executor.localTransactionMap.Delete(connectionID)
tx := txi.(proto.DBGroupTx)
// TODO add metrics
if result, err = tx.Rollback(spanCtx); err != nil {
return nil, 0, err
}
return result, 0, err
case *ast.SelectStmt:
if stmt.Fields != nil && len(stmt.Fields.Fields) > 0 {
if _, ok := stmt.Fields.Fields[0].Expr.(*ast.VariableExpr); ok {
return executor.executors[0].Query(spanCtx, sql)
}
}
txi, ok := executor.localTransactionMap.Load(connectionID)
if ok {
tx := txi.(proto.DBGroupTx)
return tx.Query(spanCtx, sql)
}
plan, err = executor.optimizer.Optimize(spanCtx, queryStmt)
if err != nil {
return nil, 0, err
}
proto.WithVariable(spanCtx, constant.TransactionTimeout, executor.config.TransactionTimeout)
return plan.Execute(spanCtx)
default:
txi, ok := executor.localTransactionMap.Load(connectionID)
if ok {
tx := txi.(proto.DBGroupTx)
return tx.Query(spanCtx, sql)
}
plan, err = executor.optimizer.Optimize(spanCtx, queryStmt)
if err != nil {
return nil, 0, err
Expand Down Expand Up @@ -265,6 +313,13 @@ func (executor *ShardingExecutor) ExecutorComStmtExecute(
parameterID := fmt.Sprintf("v%d", i+1)
args = append(args, stmt.BindVars[parameterID])
}

txi, ok := executor.localTransactionMap.Load(connectionID)
if ok {
tx := txi.(proto.DBGroupTx)
return tx.Execute(spanCtx, stmt.StmtNode, args...)
}

plan, err = executor.optimizer.Optimize(spanCtx, stmt.StmtNode, args...)
if err != nil {
return nil, 0, err
Expand All @@ -275,14 +330,16 @@ func (executor *ShardingExecutor) ExecutorComStmtExecute(

func (executor *ShardingExecutor) ConnectionClose(ctx context.Context) {
connectionID := proto.ConnectionID(ctx)
tx, ok := executor.localTransactionMap[connectionID]
txi, ok := executor.localTransactionMap.Load(connectionID)
if !ok {
return
}
// TODO add metrics
if _, err := tx.Rollback(ctx, nil); err != nil {
tx := txi.(proto.DBGroupTx)
if _, err := tx.Rollback(ctx); err != nil {
log.Error(err)
}
executor.localTransactionMap.Delete(connectionID)
}

func (executor *ShardingExecutor) doPreFilter(ctx context.Context) error {
Expand Down
170 changes: 170 additions & 0 deletions pkg/group/transaction.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/*
* Copyright 2022 CECTC, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package group

import (
"context"

"github.com/pkg/errors"
"github.com/uber-go/atomic"
"go.opentelemetry.io/otel/attribute"
"golang.org/x/sync/errgroup"

err2 "github.com/cectc/dbpack/pkg/errors"
"github.com/cectc/dbpack/pkg/log"
"github.com/cectc/dbpack/pkg/mysql"
"github.com/cectc/dbpack/pkg/proto"
"github.com/cectc/dbpack/pkg/tracing"
"github.com/cectc/dbpack/third_party/parser/ast"
)

var (
txID = atomic.NewUint32(0)
)

type ComplexTx struct {
closed *atomic.Bool
id uint32
txs map[string]proto.Tx
optimizer proto.Optimizer
}

func NewComplexTx(optimizer proto.Optimizer) proto.DBGroupTx {
txID.Inc()
return &ComplexTx{
closed: atomic.NewBool(false),
id: txID.Load(),
txs: make(map[string]proto.Tx),
optimizer: optimizer,
}
}

func (tx *ComplexTx) Query(ctx context.Context, query string) (proto.Result, uint16, error) {
spanCtx, span := tracing.GetTraceSpan(ctx, tracing.GroupQuery)
defer span.End()

queryStmt := proto.QueryStmt(spanCtx)
if queryStmt == nil {
return nil, 0, errors.New("query stmt should not be nil")
}
plan, err := tx.optimizer.Optimize(spanCtx, queryStmt)
if err != nil {
return nil, 0, err
}
txCtx := proto.WithDBGroupTx(spanCtx, tx)
return plan.Execute(txCtx)
}

func (tx *ComplexTx) Execute(ctx context.Context, stmt ast.StmtNode, args ...interface{}) (proto.Result, uint16, error) {
spanCtx, span := tracing.GetTraceSpan(ctx, tracing.GroupExecute)
defer span.End()

plan, err := tx.optimizer.Optimize(spanCtx, stmt, args...)
if err != nil {
return nil, 0, err
}
txCtx := proto.WithDBGroupTx(spanCtx, tx)
return plan.Execute(txCtx)
}

func (tx *ComplexTx) Begin(ctx context.Context, executor proto.DBGroupExecutor) (proto.Tx, error) {
spanCtx, span := tracing.GetTraceSpan(ctx, tracing.GroupTransactionBegin)
span.SetAttributes(attribute.KeyValue{Key: "group", Value: attribute.StringValue(executor.GroupName())})
defer span.End()

if childTx, ok := tx.txs[executor.GroupName()]; ok {
return childTx, nil
}
masterCtx := proto.WithMaster(spanCtx)
childTx, _, err := executor.Begin(masterCtx)
if err != nil {
return nil, err
}
log.Debugf("DBGroup %s has begun local transaction!", executor.GroupName())
tx.txs[executor.GroupName()] = childTx
return childTx, nil
}

func (tx *ComplexTx) Commit(ctx context.Context) (result proto.Result, err error) {
spanCtx, span := tracing.GetTraceSpan(ctx, tracing.GroupTxCommit)
defer span.End()

if tx.closed.Load() {
return nil, err2.ErrTransactionClosed
}
defer tx.Close()

var g errgroup.Group
for group, child := range tx.txs {
// https://golang.org/doc/faq#closures_and_goroutines
group, child := group, child
g.Go(func() error {
_, err := child.Commit(spanCtx)
if err != nil {
log.Errorf("commit failed, db group: %s, err: %v", group, err)
return err
}
log.Debugf("DBGroup %s has committed local transaction!", group)
return nil
})
}

if err := g.Wait(); err != nil {
return nil, err
}
return &mysql.Result{
AffectedRows: 0,
InsertId: 0,
}, nil
}

func (tx *ComplexTx) Rollback(ctx context.Context) (result proto.Result, err error) {
spanCtx, span := tracing.GetTraceSpan(ctx, tracing.GroupTxCommit)
defer span.End()

if tx.closed.Load() {
return nil, err2.ErrTransactionClosed
}
defer tx.Close()

var g errgroup.Group
for group, child := range tx.txs {
group, child := group, child
g.Go(func() error {
_, err := child.Rollback(spanCtx, nil)
if err != nil {
log.Errorf("rollback failed, db group: %s, err: %v", group, err)
return err
}
log.Debugf("DBGroup %s has rollbacked local transaction!", group)
return nil
})
}

if err := g.Wait(); err != nil {
return nil, err
}
return &mysql.Result{
AffectedRows: 0,
InsertId: 0,
}, nil
}

func (tx *ComplexTx) Close() {
tx.closed.Swap(true)
tx.txs = nil
}

0 comments on commit 97c5e18

Please sign in to comment.