Skip to content

Commit

Permalink
refactor: support manipulating the same data in different transaction… (
Browse files Browse the repository at this point in the history
#200)

* refactor: support manipulating the same data in different transaction branches
  • Loading branch information
dk-lockdown committed Jul 14, 2022
1 parent bb408a5 commit 040b4d1
Show file tree
Hide file tree
Showing 16 changed files with 186 additions and 44 deletions.
4 changes: 4 additions & 0 deletions pkg/dt/distributed_transaction_manger.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ func (manager *DistributedTransactionManager) IsLockable(ctx context.Context, re
return manager.storageDriver.IsLockable(ctx, resourceID, lockKey)
}

func (manager *DistributedTransactionManager) IsLockableWithXID(ctx context.Context, resourceID, lockKey, xid string) (bool, error) {
return manager.storageDriver.IsLockableWithXID(ctx, resourceID, lockKey, xid)
}

func (manager *DistributedTransactionManager) branchCommit(bs *api.BranchSession) (api.BranchSession_BranchStatus, error) {
var (
status api.BranchSession_BranchStatus
Expand Down
84 changes: 77 additions & 7 deletions pkg/dt/storage/etcd/etcd.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package etcd
import (
"context"
"fmt"
"strings"
"sync"
"time"

Expand All @@ -38,6 +39,11 @@ const (
// We have set a buffer in order to reduce times of context switches.
incomingBufSize = 100
outgoingBufSize = 100

// LockKeyFormat lk/${XID}/${rowKey}
LockKeyFormat = "lk/%s/%s"
// BranchKeyFormat bs/${XID}/${BranchSessionID}
BranchKeyFormat = "bs/%s/%d"
)

type store struct {
Expand Down Expand Up @@ -124,19 +130,23 @@ func (s *store) AddBranchSession(ctx context.Context, branchSession *api.BranchS
comparisons = append(comparisons, clientv3.Compare(clientv3.ModRevision(branchSession.XID), "=", modRevision))
ops = append(ops, clientv3.OpPut(branchSession.BranchID, string(data)))
// branch transactions of global transaction
globalBranchKey := fmt.Sprintf("bs/%s/%d", branchSession.XID, branchSession.BranchSessionID)
ops = append(ops, clientv3.OpPut(globalBranchKey, branchSession.BranchID))
branchKey := fmt.Sprintf(BranchKeyFormat, branchSession.XID, branchSession.BranchSessionID)
ops = append(ops, clientv3.OpPut(branchKey, branchSession.BranchID))

if branchSession.Type == api.AT && branchSession.LockKey != "" {
rowKeys := misc.CollectRowKeys(branchSession.LockKey, branchSession.ResourceID)
rowKeys, err = s.filterRowKeys(ctx, rowKeys, branchSession.XID)
if err != nil {
return err
}
for _, rowKey := range rowKeys {
comparisons = append(comparisons, notFound(rowKey))
}

for _, rowKey := range rowKeys {
lockKey := fmt.Sprintf("lk/%s/%s", branchSession.XID, rowKey)
ops = append(ops, clientv3.OpPut(lockKey, rowKey))
ops = append(ops, clientv3.OpPut(rowKey, lockKey))
rowKeyValue := fmt.Sprintf(LockKeyFormat, branchSession.XID, rowKey)
ops = append(ops, clientv3.OpPut(rowKeyValue, rowKey))
ops = append(ops, clientv3.OpPut(rowKey, rowKeyValue))
}
}

Expand Down Expand Up @@ -329,7 +339,8 @@ func (s *store) ListBranchSession(ctx context.Context, applicationID string) ([]
return nil, nil
}
var result []*api.BranchSession
for _, kv := range resp.Kvs {
for i := len(resp.Kvs) - 1; i >= 0; i-- {
kv := resp.Kvs[i]
branchSession := &api.BranchSession{}
err = branchSession.Unmarshal(kv.Value)
if err != nil {
Expand Down Expand Up @@ -364,7 +375,8 @@ func (s *store) GetBranchSessionKeys(ctx context.Context, xid string) ([]string,
return nil, err
}
var result []string
for _, kv := range branchKeyResp.Kvs {
for i := len(branchKeyResp.Kvs) - 1; i >= 0; i-- {
kv := branchKeyResp.Kvs[i]
result = append(result, string(kv.Value))
}
return result, nil
Expand Down Expand Up @@ -467,6 +479,27 @@ func (s *store) IsLockable(ctx context.Context, resourceID string, lockKey strin
return true, nil
}

func (s *store) IsLockableWithXID(ctx context.Context, resourceID string, lockKey string, xid string) (bool, error) {
rowKeys := misc.CollectRowKeys(lockKey, resourceID)

for _, rowKey := range rowKeys {
resp, err := s.client.Get(ctx, rowKey, clientv3.WithSerializable())
if err != nil {
return false, err
}
if len(resp.Kvs) == 0 {
continue
}
// rowKeyValue: lk/${XID}/${rowKey}
if strings.Contains(string(resp.Kvs[0].Value), xid) {
continue
} else {
return false, nil
}
}
return true, nil
}

func (s *store) ReleaseLockKeys(ctx context.Context, resourceID string, lockKeys []string) (bool, error) {
var ops []clientv3.Op
for _, lockKey := range lockKeys {
Expand Down Expand Up @@ -495,6 +528,43 @@ func notFound(key string) clientv3.Cmp {
return clientv3.Compare(clientv3.ModRevision(key), "=", 0)
}

func (s *store) filterRowKeys(ctx context.Context, rowKeys []string, xid string) ([]string, error) {
var result []string
rowKeyValues, err := s.getRowKeyValues(ctx, rowKeys)
if err != nil {
return nil, err
}
for _, rowKey := range rowKeys {
if value, ok := rowKeyValues[rowKey]; ok {
if !strings.Contains(value, xid) {
result = append(result, rowKey)
}
} else {
result = append(result, rowKey)
}
}
return result, nil
}

func (s *store) getRowKeyValues(ctx context.Context, rowKeys []string) (map[string]string, error) {
var (
result = make(map[string]string)
resp *clientv3.GetResponse
err error
)
for _, rowKey := range rowKeys {
resp, err = s.client.Get(ctx, rowKey, clientv3.WithSerializable())
if err != nil {
return nil, err
}
if len(resp.Kvs) == 0 {
continue
}
result[rowKey] = string(resp.Kvs[0].Value)
}
return result, nil
}

func (s *store) WatchGlobalSessions(ctx context.Context, applicationID string) storage.Watcher {
prefix := fmt.Sprintf("gs/%s", applicationID)
wc := s.createWatchChan(ctx, prefix, s.initGlobalSessionRevision, true, true)
Expand Down
1 change: 1 addition & 0 deletions pkg/dt/storage/storagedriver.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type Driver interface {
GetBranchSessionKeys(ctx context.Context, xid string) ([]string, error)
BranchReport(ctx context.Context, branchID string, status api.BranchSession_BranchStatus) error
IsLockable(ctx context.Context, resourceID string, lockKey string) (bool, error)
IsLockableWithXID(ctx context.Context, resourceID string, lockKey string, xid string) (bool, error)
ReleaseLockKeys(ctx context.Context, resourceID string, lockKeys []string) (bool, error)
WatchGlobalSessions(ctx context.Context, applicationID string) Watcher
WatchBranchSessions(ctx context.Context, applicationID string) Watcher
Expand Down
8 changes: 7 additions & 1 deletion pkg/filter/dt/exec/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,18 @@ type Executor interface {
GetTableName() string
}

type Executable interface {
type GlobalLockExecutor interface {
Executable(ctx context.Context, lockRetryInterval time.Duration, lockRetryTimes int) (bool, error)
GetTableMeta(ctx context.Context) (schema.TableMeta, error)
GetTableName() string
}

type Executable interface {
Executable(ctx context.Context, xid string, lockRetryInterval time.Duration, lockRetryTimes int) (bool, error)
GetTableMeta(ctx context.Context) (schema.TableMeta, error)
GetTableName() string
}

func BuildUndoItem(
isBinary bool,
sqlType constant.SQLType,
Expand Down
2 changes: 1 addition & 1 deletion pkg/filter/dt/exec/prepare_global_lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func NewPrepareGlobalLockExecutor(
isUpdate bool,
deleteStmt *ast.DeleteStmt,
updateStmt *ast.UpdateStmt,
args map[string]interface{}) Executable {
args map[string]interface{}) GlobalLockExecutor {
return &prepareGlobalLockExecutor{
conn: conn,
isUpdate: isUpdate,
Expand Down
13 changes: 12 additions & 1 deletion pkg/filter/dt/exec/prepare_global_lock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ func TestPrepareGlobalLock(t *testing.T) {
}
ctx = proto.WithPrepareStmt(ctx, protoStmt)

var executor Executable
var executor GlobalLockExecutor
if c.isUpdate {
updateStmt := stmt.(*ast.UpdateStmt)
executor = NewPrepareGlobalLockExecutor(&driver.BackendConnection{}, c.isUpdate, nil, updateStmt, protoStmt.BindVars)
Expand Down Expand Up @@ -258,6 +258,17 @@ func isLockablePatch() *gomonkey.Patches {
})
}

func isLockableWithXIDPatch() *gomonkey.Patches {
var transactionManager *dt.DistributedTransactionManager
return gomonkey.ApplyMethodFunc(transactionManager, "IsLockableWithXID", func(ctx context.Context, resourceID, lockKey, xid string) (bool, error) {
count++
if count < 5 {
return false, err
}
return true, nil
})
}

func beforeImagePatch() *gomonkey.Patches {
var executor *prepareGlobalLockExecutor
return gomonkey.ApplyMethodFunc(executor, "BeforeImage", func(ctx context.Context) (*schema.TableRecords, error) {
Expand Down
6 changes: 3 additions & 3 deletions pkg/filter/dt/exec/prepare_select_for_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func NewPrepareSelectForUpdateExecutor(
}
}

func (executor *prepareSelectForUpdateExecutor) Executable(ctx context.Context, lockRetryInterval time.Duration, lockRetryTimes int) (bool, error) {
func (executor *prepareSelectForUpdateExecutor) Executable(ctx context.Context, xid string, lockRetryInterval time.Duration, lockRetryTimes int) (bool, error) {
tableMeta, err := executor.GetTableMeta(ctx)
if err != nil {
return false, err
Expand All @@ -70,8 +70,8 @@ func (executor *prepareSelectForUpdateExecutor) Executable(ctx context.Context,
err error
)
for i := 0; i < lockRetryTimes; i++ {
lockable, err = dt.GetDistributedTransactionManager().IsLockable(ctx,
executor.conn.DataSourceName(), lockKeys)
lockable, err = dt.GetDistributedTransactionManager().IsLockableWithXID(ctx,
executor.conn.DataSourceName(), lockKeys, xid)
if lockable && err == nil {
break
}
Expand Down
8 changes: 5 additions & 3 deletions pkg/filter/dt/exec/prepare_select_for_update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,16 @@ import (
func TestPrepareSelectForUpdate(t *testing.T) {
testCases := []*struct {
sql string
xid string
lockInterval time.Duration
lockTimes int
expectedTableName string
expectedWhereCondition string
expectedErr error
}{
{
sql: "select /*+ GlobalLock() */ * from T where id = ? for update",
sql: "select /*+ XID('123') */ * from T where id = ? for update",
xid: "123",
lockInterval: 5 * time.Millisecond,
lockTimes: 3,
expectedTableName: "`T`",
Expand All @@ -53,7 +55,7 @@ func TestPrepareSelectForUpdate(t *testing.T) {
},
}

patches1 := isLockablePatch()
patches1 := isLockableWithXIDPatch()
defer patches1.Reset()

patches2 := getPrepareTableMetaPatch()
Expand Down Expand Up @@ -94,7 +96,7 @@ func TestPrepareSelectForUpdate(t *testing.T) {
assert.Equal(t, c.expectedTableName, tableName)
whereCondition := executor.(*prepareSelectForUpdateExecutor).GetWhereCondition()
assert.Equal(t, c.expectedWhereCondition, whereCondition)
_, executeErr := executor.Executable(ctx, c.lockInterval, c.lockTimes)
_, executeErr := executor.Executable(ctx, c.xid, c.lockInterval, c.lockTimes)
assert.Equal(t, c.expectedErr, executeErr)
})
}
Expand Down
30 changes: 29 additions & 1 deletion pkg/filter/dt/exec/query_global_lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func NewQueryGlobalLockExecutor(
conn *driver.BackendConnection,
isUpdate bool,
deleteStmt *ast.DeleteStmt,
updateStmt *ast.UpdateStmt) Executable {
updateStmt *ast.UpdateStmt) GlobalLockExecutor {
return &queryGlobalLockExecutor{
conn: conn,
isUpdate: isUpdate,
Expand Down Expand Up @@ -81,6 +81,34 @@ func (executor *queryGlobalLockExecutor) Executable(ctx context.Context, lockRet
}
}

func (executor *queryGlobalLockExecutor) ExecutableWithXID(ctx context.Context, xid string, lockRetryInterval time.Duration, lockRetryTimes int) (bool, error) {
beforeImage, err := executor.BeforeImage(ctx)
if err != nil {
return false, err
}

lockKeys := schema.BuildLockKey(beforeImage)
if lockKeys == "" {
return true, nil
} else {
var (
err error
lockable bool
)
for i := 0; i < lockRetryTimes; i++ {
lockable, err = dt.GetDistributedTransactionManager().IsLockable(ctx,
executor.conn.DataSourceName(), lockKeys)
if err != nil {
time.Sleep(lockRetryInterval)
}
if lockable {
return true, nil
}
}
return false, err
}
}

func (executor *queryGlobalLockExecutor) GetTableMeta(ctx context.Context) (schema.TableMeta, error) {
dbName := executor.conn.DataSourceName()
db := resource.GetDBManager().GetDB(dbName)
Expand Down
2 changes: 1 addition & 1 deletion pkg/filter/dt/exec/query_global_lock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func TestQueryGlobalLock(t *testing.T) {
}
ctx = proto.WithPrepareStmt(ctx, protoStmt)

var executor Executable
var executor GlobalLockExecutor
if c.isUpdate {
updateStmt := stmt.(*ast.UpdateStmt)
executor = NewQueryGlobalLockExecutor(&driver.BackendConnection{}, c.isUpdate, nil, updateStmt)
Expand Down
6 changes: 3 additions & 3 deletions pkg/filter/dt/exec/query_select_for_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func NewQuerySelectForUpdateExecutor(
}
}

func (executor *querySelectForUpdateExecutor) Executable(ctx context.Context, lockRetryInterval time.Duration, lockRetryTimes int) (bool, error) {
func (executor *querySelectForUpdateExecutor) Executable(ctx context.Context, xid string, lockRetryInterval time.Duration, lockRetryTimes int) (bool, error) {
tableMeta, err := executor.GetTableMeta(ctx)
if err != nil {
return false, err
Expand All @@ -67,8 +67,8 @@ func (executor *querySelectForUpdateExecutor) Executable(ctx context.Context, lo
err error
)
for i := 0; i < lockRetryTimes; i++ {
lockable, err = dt.GetDistributedTransactionManager().IsLockable(ctx,
executor.conn.DataSourceName(), lockKeys)
lockable, err = dt.GetDistributedTransactionManager().IsLockableWithXID(ctx,
executor.conn.DataSourceName(), lockKeys, xid)
if lockable && err == nil {
break
}
Expand Down
8 changes: 5 additions & 3 deletions pkg/filter/dt/exec/query_select_for_update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,16 @@ import (
func TestQuerySelectForUpdate(t *testing.T) {
testCases := []*struct {
sql string
xid string
lockInterval time.Duration
lockTimes int
expectedTableName string
expectedWhereCondition string
expectedErr error
}{
{
sql: "select /*+ GlobalLock() */ * from T where id = 10 for update",
sql: "select /*+ xid('123') */ * from T where id = 10 for update",
xid: "123",
lockInterval: 5 * time.Millisecond,
lockTimes: 3,
expectedTableName: "`T`",
Expand All @@ -53,7 +55,7 @@ func TestQuerySelectForUpdate(t *testing.T) {
},
}

patches1 := isLockablePatch()
patches1 := isLockableWithXIDPatch()
defer patches1.Reset()

patches2 := getQueryTableMetaPatch()
Expand Down Expand Up @@ -92,7 +94,7 @@ func TestQuerySelectForUpdate(t *testing.T) {
executor := NewQuerySelectForUpdateExecutor(&driver.BackendConnection{}, selectForUpdateStmt, &mysql.Result{})
tableName := executor.GetTableName()
assert.Equal(t, c.expectedTableName, tableName)
_, executeErr := executor.Executable(ctx, c.lockInterval, c.lockTimes)
_, executeErr := executor.Executable(ctx, c.xid, c.lockInterval, c.lockTimes)
assert.Equal(t, c.expectedErr, executeErr)
})
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/filter/dt/filter_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func (f *_mysqlFilter) PostHandle(ctx context.Context, result proto.Result, conn
err = f.processAfterQueryUpdate(newCtx, bc, stmtNode)
case *ast.SelectStmt:
if stmtNode.LockInfo != nil && stmtNode.LockInfo.LockType == ast.SelectLockForUpdate {
err = f.processSelectForQueryUpdate(ctx, bc, result, stmtNode)
err = f.processQuerySelectForUpdate(ctx, bc, result, stmtNode)
}
default:
return nil
Expand All @@ -168,7 +168,7 @@ func (f *_mysqlFilter) PostHandle(ctx context.Context, result proto.Result, conn
err = f.processAfterPrepareUpdate(newCtx, bc, stmt, stmtNode)
case *ast.SelectStmt:
if stmtNode.LockInfo != nil && stmtNode.LockInfo.LockType == ast.SelectLockForUpdate {
err = f.processSelectForPrepareUpdate(newCtx, bc, result, stmt, stmtNode)
err = f.processPrepareSelectForUpdate(newCtx, bc, result, stmt, stmtNode)
}
default:
return nil
Expand Down

0 comments on commit 040b4d1

Please sign in to comment.