Skip to content

Commit

Permalink
fix: handle database switch correctly (arana-db#618)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjeffcaii authored and linguowei committed Feb 25, 2023
1 parent 7974f48 commit 111fa25
Show file tree
Hide file tree
Showing 19 changed files with 404 additions and 173 deletions.
6 changes: 5 additions & 1 deletion pkg/boot/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,11 @@ func (fp *discovery) ListNodes(ctx context.Context, tenant, cluster, group strin
if !ok {
return nil, nil
}
return append([]string{}, bingo.Nodes...), nil

nodes := make([]string, len(bingo.Nodes))
copy(nodes, bingo.Nodes)

return nodes, nil
}

func (fp *discovery) ListTables(ctx context.Context, tenant, cluster string) ([]string, error) {
Expand Down
3 changes: 1 addition & 2 deletions pkg/constants/mysql/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -589,8 +589,7 @@ const (
// SSNoDatabaseSelected is ER_NO_DB
SSNoDatabaseSelected = "3D000"

// SSSPNotExist is ER_SP_DOES_NOT_EXIST
SSSPNotExist = "42000"
SS42000 = "42000"
)

// Status flags. They are returned by the server in a few cases.
Expand Down
73 changes: 45 additions & 28 deletions pkg/executor/redirect.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package executor
import (
"bytes"
stdErrors "errors"
"fmt"
"strings"
"sync"
"time"
Expand All @@ -32,6 +33,8 @@ import (
pMysql "github.com/arana-db/parser/mysql"

"github.com/pkg/errors"

"golang.org/x/exp/slices"
)

import (
Expand Down Expand Up @@ -77,7 +80,7 @@ func IsErrMissingTx(err error) bool {
}

type RedirectExecutor struct {
localTransactionMap sync.Map // map[uint32]proto.Tx, (ConnectionID,Tx)
localTransactionMap sync.Map // map[uint32]proto.Tx, (connectionID,Tx)
}

func NewRedirectExecutor() *RedirectExecutor {
Expand All @@ -89,30 +92,35 @@ func (executor *RedirectExecutor) ProcessDistributedTransaction() bool {
}

func (executor *RedirectExecutor) InLocalTransaction(ctx *proto.Context) bool {
_, ok := executor.localTransactionMap.Load(ctx.ConnectionID)
_, ok := executor.localTransactionMap.Load(ctx.C.ID())
return ok
}

func (executor *RedirectExecutor) InGlobalTransaction(ctx *proto.Context) bool {
return false
}

func (executor *RedirectExecutor) ExecuteUseDB(ctx *proto.Context) error {
// TODO: check permission, target database should belong to same tenant.
// TODO: process transactions when database switched?
func (executor *RedirectExecutor) ExecuteUseDB(ctx *proto.Context, db string) error {
if ctx.C.Schema() == db {
return nil
}

clusters := security.DefaultTenantManager().GetClusters(ctx.C.Tenant())
if !slices.Contains(clusters, db) {
return mysqlErrors.NewSQLError(mConstants.ERBadDb, mConstants.SS42000, fmt.Sprintf("Unknown database '%s'", db))
}

if hasTx := executor.InLocalTransaction(ctx); hasTx {
// TODO: should commit existing TX when DB switched
log.Debugf("commit tx when db switched: conn=%s", ctx.C)
}

// bind schema
ctx.C.SetSchema(db)

// reset transient variables
ctx.C.SetTransientVariables(make(map[string]proto.Value))

// do nothing.
//resourcePool := resource.GetDataSourceManager().GetMasterResourcePool(executor.dataSources[0].Master.Name)
//r, err := resourcePool.Get(ctx)
//defer func() {
// resourcePool.Put(r)
//}()
//if err != nil {
// return err
//}
//backendConn := r.(*mysql.BackendConnection)
//db := string(ctx.Data[1:])
//return backendConn.WriteComInitDB(db)
return nil
}

Expand All @@ -121,7 +129,7 @@ func (executor *RedirectExecutor) ExecuteFieldList(ctx *proto.Context) ([]proto.
table := string(ctx.Data[1:index])
wildcard := string(ctx.Data[index+1:])

rt, err := runtime.Load(ctx.Schema)
rt, err := runtime.Load(ctx.C.Schema())
if err != nil {
return nil, errors.WithStack(err)
}
Expand All @@ -141,6 +149,15 @@ func (executor *RedirectExecutor) ExecuteFieldList(ctx *proto.Context) ([]proto.
}

func (executor *RedirectExecutor) doExecutorComQuery(ctx *proto.Context, act ast.StmtNode) (proto.Result, uint16, error) {
// switch DB
switch u := act.(type) {
case *ast.UseStmt:
if err := executor.ExecuteUseDB(ctx, u.DBName); err != nil {
return nil, 0, err
}
return resultx.New(), 0, nil
}

var (
start = time.Now()
schemaless bool // true if schema is not specified
Expand All @@ -158,23 +175,23 @@ func (executor *RedirectExecutor) doExecutorComQuery(ctx *proto.Context, act ast
trace.Extract(ctx, hints)
metrics.ParserDuration.Observe(time.Since(start).Seconds())

if len(ctx.Schema) < 1 {
if len(ctx.C.Schema()) < 1 {
// TODO: handle multiple clusters
clusters := security.DefaultTenantManager().GetClusters(ctx.Tenant)
clusters := security.DefaultTenantManager().GetClusters(ctx.C.Tenant())
if len(clusters) != 1 {
// reject if no schema specified
return nil, 0, mysqlErrors.NewSQLError(mConstants.ERNoDb, mConstants.SSNoDatabaseSelected, "No database selected")
}
schemaless = true
ctx.Schema = security.DefaultTenantManager().GetClusters(ctx.Tenant)[0]
ctx.C.SetSchema(security.DefaultTenantManager().GetClusters(ctx.C.Tenant())[0])
}

ctx.Stmt = &proto.Stmt{
Hints: hints,
StmtNode: act,
}

rt, err := runtime.Load(ctx.Schema)
rt, err := runtime.Load(ctx.C.Schema())
if err != nil {
return nil, 0, err
}
Expand Down Expand Up @@ -285,7 +302,7 @@ func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context, h func(re
query := ctx.GetQuery()
log.Debugf("ComQuery: %s", query)

charset, collation := getCharsetCollation(ctx.CharacterSet)
charset, collation := getCharsetCollation(ctx.C.CharacterSet())

switch strings.IndexByte(query, ';') {
case -1: // no ';' exists
Expand Down Expand Up @@ -352,7 +369,7 @@ func (executor *RedirectExecutor) ExecutorComStmtExecute(ctx *proto.Context) (pr
executable = tx
} else {
var rt runtime.Runtime
if rt, err = runtime.Load(ctx.Schema); err != nil {
if rt, err = runtime.Load(ctx.C.Schema()); err != nil {
return nil, 0, err
}
executable = rt
Expand Down Expand Up @@ -380,7 +397,7 @@ func (executor *RedirectExecutor) ConnectionClose(ctx *proto.Context) {
}

//resourcePool := resource.GetDataSourceManager().GetMasterResourcePool(executor.dataSources[0].Master.Name)
//r, ok := executor.localTransactionMap[ctx.ConnectionID]
//r, ok := executor.localTransactionMap[ctx.connectionID]
//if ok {
// defer func() {
// resourcePool.Put(r)
Expand All @@ -394,19 +411,19 @@ func (executor *RedirectExecutor) ConnectionClose(ctx *proto.Context) {
}

func (executor *RedirectExecutor) putTx(ctx *proto.Context, tx proto.Tx) {
executor.localTransactionMap.Store(ctx.ConnectionID, tx)
executor.localTransactionMap.Store(ctx.C.ID(), tx)
}

func (executor *RedirectExecutor) removeTx(ctx *proto.Context) (proto.Tx, bool) {
exist, ok := executor.localTransactionMap.LoadAndDelete(ctx.ConnectionID)
exist, ok := executor.localTransactionMap.LoadAndDelete(ctx.C.ID())
if !ok {
return nil, false
}
return exist.(proto.Tx), true
}

func (executor *RedirectExecutor) getTx(ctx *proto.Context) (proto.Tx, bool) {
exist, ok := executor.localTransactionMap.Load(ctx.ConnectionID)
exist, ok := executor.localTransactionMap.Load(ctx.C.ID())
if !ok {
return nil, false
}
Expand Down
27 changes: 21 additions & 6 deletions pkg/executor/redirect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@ import (
)

import (
"github.com/golang/mock/gomock"

"github.com/pkg/errors"

"github.com/stretchr/testify/assert"
)

import (
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/testdata"
)

func TestIsErrMissingTx(t *testing.T) {
Expand All @@ -42,21 +45,33 @@ func TestProcessDistributedTransaction(t *testing.T) {
}

func TestInGlobalTransaction(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

c := testdata.NewMockFrontConn(ctrl)
c.EXPECT().ID().Return(uint32(0)).AnyTimes()

redirect := NewRedirectExecutor()
assert.False(t, redirect.InGlobalTransaction(createContext()))
assert.False(t, redirect.InGlobalTransaction(createContext(c)))
}

func TestInLocalTransaction(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

c := testdata.NewMockFrontConn(ctrl)
c.EXPECT().ID().Return(uint32(0)).Times(1)

redirect := NewRedirectExecutor()
result := redirect.InLocalTransaction(createContext())
result := redirect.InLocalTransaction(createContext(c))
assert.False(t, result)
}

func createContext() *proto.Context {
func createContext(c proto.FrontConn) *proto.Context {
result := &proto.Context{
ConnectionID: 0,
Data: make([]byte, 0),
Stmt: nil,
C: c,
Data: make([]byte, 0),
Stmt: nil,
}
return result
}
4 changes: 2 additions & 2 deletions pkg/mysql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ func (c *Connector) NewBackendConnection(ctx context.Context) (*BackendConnectio
// \ /
// >>> SYNC <<<
func (conn *BackendConnection) SyncVariables(vars map[string]proto.Value) error {
transient := conn.c.TransientVariables
transient := conn.c.TransientVariables()

if len(vars) < 1 && len(transient) < 1 {
return nil
Expand Down Expand Up @@ -799,7 +799,7 @@ func (conn *BackendConnection) parseInitialHandshakePacket(data []byte) (uint32,
}

// Read the connection id.
conn.c.ConnectionID, pos, ok = readUint32(data, pos)
conn.c.connectionID, pos, ok = readUint32(data, pos)
if !ok {
return 0, nil, "", err2.NewSQLError(mysql.CRMalformedPacket, mysql.SSUnknownSQLState, "parseInitialHandshakePacket: packet has no connection id")
}
Expand Down
62 changes: 48 additions & 14 deletions pkg/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,21 @@ type Conn struct {
// If there are any ongoing reads or writes, they may get interrupted.
conn net.Conn

// Schema is the current database name.
Schema string
// schema is the current database name.
schema string

// Tenant is the current tenant login.
Tenant string
// tenant is the current tenant login.
tenant string

// ConnectionID is set:
// connectionID is set:
// - at Connect() time for clients, with the value returned by
// the server.
// - at accept time for the server.
ConnectionID uint32
connectionID uint32

// TransientVariables represents local transient variables.
// transientVariables represents local transient variables.
// These variables will always keep sync with backend mysql conns.
TransientVariables map[string]proto.Value
transientVariables map[string]proto.Value

// closed is set to true when Close() is called on the connection.
closed *atomic.Bool
Expand Down Expand Up @@ -149,7 +149,9 @@ type Conn struct {
// connection.
// It is set during the initial handshake.
// See the values in constants.go.
CharacterSet uint8
characterSet uint8

serverVersion string
}

// newConn is an internal method to create a Conn. Used by client and server
Expand All @@ -159,10 +161,42 @@ func newConn(conn net.Conn) *Conn {
conn: conn,
closed: atomic.NewBool(false),
bufferedReader: bufio.NewReaderSize(conn, connBufferSize),
TransientVariables: make(map[string]proto.Value),
transientVariables: make(map[string]proto.Value),
}
}

func (c *Conn) ServerVersion() string {
return c.serverVersion
}

func (c *Conn) CharacterSet() uint8 {
return c.characterSet
}

func (c *Conn) Schema() string {
return c.schema
}

func (c *Conn) SetSchema(schema string) {
c.schema = schema
}

func (c *Conn) Tenant() string {
return c.tenant
}

func (c *Conn) SetTenant(t string) {
c.tenant = t
}

func (c *Conn) TransientVariables() map[string]proto.Value {
return c.transientVariables
}

func (c *Conn) SetTransientVariables(v map[string]proto.Value) {
c.transientVariables = v
}

// startWriterBuffering starts using buffered writes. This should
// be terminated by a call to endWriteBuffering.
func (c *Conn) startWriterBuffering() {
Expand Down Expand Up @@ -631,13 +665,13 @@ func (c *Conn) RemoteAddr() net.Addr {
}

// ID returns the MySQL connection ID for this connection.
func (c *Conn) ID() int64 {
return int64(c.ConnectionID)
func (c *Conn) ID() uint32 {
return c.connectionID
}

// Ident returns a useful identification string for error logging
func (c *Conn) String() string {
return fmt.Sprintf("client %v (%s)", c.ConnectionID, c.RemoteAddr().String())
return fmt.Sprintf("client %v (%s)", c.ID(), c.RemoteAddr().String())
}

// Close closes the connection. It can be called from a different go
Expand Down Expand Up @@ -708,7 +742,7 @@ func (c *Conn) fixErrNoSuchTable(errorMessage string) string {
var sb strings.Builder
sb.Grow(len(errorMessage))
sb.WriteString("Table '")
sb.WriteString(c.Schema)
sb.WriteString(c.Schema())
sb.WriteByte('.')
sb.WriteString(matches[2])
sb.WriteString("' doesn't exist")
Expand Down
Loading

0 comments on commit 111fa25

Please sign in to comment.