Skip to content

Commit

Permalink
feat: support query on global table (#249)
Browse files Browse the repository at this point in the history
  • Loading branch information
dk-lockdown committed Aug 22, 2022
1 parent 5caeb6c commit 1d2d32d
Show file tree
Hide file tree
Showing 11 changed files with 244 additions and 18 deletions.
3 changes: 3 additions & 0 deletions docker/conf/config_shd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ app_config:
data_sources:
- name: world_1
weight: r10w10
global_tables:
- country
- countrylanguage
logic_tables:
- db_name: world
table_name: city
Expand Down
1 change: 1 addition & 0 deletions pkg/config/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ type (

ShardingConfig struct {
DBGroups []*DataSourceRefGroup `yaml:"db_groups" json:"db_groups"`
GlobalTables []string `yaml:"global_tables" json:"global_tables"`
LogicTables []*LogicTable `yaml:"logic_tables" json:"logic_tables"`
TransactionTimeout int32 `yaml:"transaction_timeout" json:"transaction_timeout"`
}
Expand Down
35 changes: 24 additions & 11 deletions pkg/executor/sharding.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"encoding/json"
"fmt"
"strings"

"github.com/pkg/errors"

Expand Down Expand Up @@ -52,6 +53,7 @@ func NewShardingExecutor(conf *config.Executor) (proto.Executor, error) {
err error
content []byte
shardingConfig *config.ShardingConfig
globalTables = make(map[string]bool)
executorSlice []proto.DBGroupExecutor
executorMap = make(map[string]proto.DBGroupExecutor)
algorithms map[string]cond.ShardingAlgorithm
Expand All @@ -67,6 +69,10 @@ func NewShardingExecutor(conf *config.Executor) (proto.Executor, error) {
return nil, err
}

for _, globalTable := range shardingConfig.GlobalTables {
globalTables[strings.ToLower(globalTable)] = true
}

for _, groupConfig := range shardingConfig.DBGroups {
dbGroup, err := group.NewDBGroup(conf.AppID, groupConfig.Name, groupConfig.LBAlgorithm, groupConfig.DataSources)
if err != nil {
Expand All @@ -76,17 +82,18 @@ func NewShardingExecutor(conf *config.Executor) (proto.Executor, error) {
executorMap[dbGroup.GroupName()] = dbGroup
}

algorithms, topologies, err = convertLogicTableConfigsToShardingAlgorithms(shardingConfig.LogicTables)
algorithms, topologies, err = convertShardingAlgorithmsAndTopologies(shardingConfig.LogicTables)
if err != nil {
return nil, errors.WithStack(err)
}

executor := &ShardingExecutor{
PreFilters: make([]proto.DBPreFilter, 0),
PostFilters: make([]proto.DBPostFilter, 0),
config: shardingConfig,
executors: executorSlice,
optimizer: optimize.NewOptimizer(conf.AppID, executorMap, algorithms, topologies),
PreFilters: make([]proto.DBPreFilter, 0),
PostFilters: make([]proto.DBPostFilter, 0),
config: shardingConfig,
executors: executorSlice,
optimizer: optimize.NewOptimizer(conf.AppID,
globalTables, executorSlice, executorMap, algorithms, topologies),
localTransactionMap: make(map[uint32]proto.Tx, 0),
}

Expand All @@ -108,7 +115,7 @@ func NewShardingExecutor(conf *config.Executor) (proto.Executor, error) {
return executor, nil
}

func convertLogicTableConfigsToShardingAlgorithms(logicTables []*config.LogicTable) (
func convertShardingAlgorithmsAndTopologies(logicTables []*config.LogicTable) (
map[string]cond.ShardingAlgorithm,
map[string]*topo.Topology,
error) {
Expand Down Expand Up @@ -231,21 +238,27 @@ func (executor *ShardingExecutor) ExecutorComQuery(ctx context.Context, sql stri

func (executor *ShardingExecutor) ExecutorComStmtExecute(
ctx context.Context, stmt *proto.Stmt) (result proto.Result, warns uint16, err error) {
spanCtx, span := tracing.GetTraceSpan(ctx, tracing.SHDComStmtExecute)
defer span.End()

if err = executor.doPreFilter(ctx); err != nil {
return nil, 0, err
}
defer func() {
err = executor.doPostFilter(ctx, result, err)
if err == nil {
result, err = decodeResult(result)
}
err = executor.doPostFilter(spanCtx, result, err)
if err != nil {
span.RecordError(err)
}
}()

var (
args []interface{}
plan proto.Plan
)

spanCtx, span := tracing.GetTraceSpan(ctx, tracing.SHDComStmtExecute)
defer span.End()

connectionID := proto.ConnectionID(ctx)
log.Debugf("connectionID: %d, prepare: %s", connectionID, stmt.SqlText)
for i := 0; i < len(stmt.BindVars); i++ {
Expand Down
168 changes: 168 additions & 0 deletions pkg/executor/sharding_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Copyright 2022 CECTC, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package executor

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"gopkg.in/yaml.v3"

"github.com/cectc/dbpack/pkg/config"
"github.com/cectc/dbpack/pkg/constant"
"github.com/cectc/dbpack/pkg/mysql"
"github.com/cectc/dbpack/pkg/proto"
"github.com/cectc/dbpack/pkg/visitor"
"github.com/cectc/dbpack/testdata"
"github.com/cectc/dbpack/third_party/parser"
)

var conf = `
name: redirect
mode: shd
config:
transaction_timeout: 60000
db_groups:
- name: world_0
load_balance_algorithm: RandomWeight
data_sources:
- name: world_0
weight: r10w10
- name: world_1
load_balance_algorithm: RandomWeight
data_sources:
- name: world_1
weight: r10w10
global_tables:
- country
- countrylanguage
logic_tables:
- db_name: world
table_name: city
allow_full_scan: true
sharding_rule:
column: id
sharding_algorithm: NumberMod
topology:
"0": 0-4
"1": 5-9`

type _ShardingExecutorTestSuite struct {
suite.Suite
environment *testdata.ShardingTestEnvironment
executor proto.Executor
}

func TestMergeResult(t *testing.T) {
suite.Run(t, new(_ShardingExecutorTestSuite))
}

func (suite *_ShardingExecutorTestSuite) SetupSuite() {
environment := testdata.NewShardingTestEnvironment(suite.T())
environment.RegisterDBResource(suite.T())
suite.environment = environment

executorConfig := suite.unmarshalExecutorConfig()
executorConfig.AppID = "test"
shardingExecutor, err := NewShardingExecutor(executorConfig)
if err != nil {
suite.T().Fatal(err)
}
suite.executor = shardingExecutor
}

func (suite *_ShardingExecutorTestSuite) TestQueryGlobalTable() {
sql := "select code, name, continent, region from country where code = 'CHN'"
p := parser.New()
stmt, err := p.ParseOneStmt(sql, "", "")
assert.Nil(suite.T(), err)
stmt.Accept(&visitor.ParamVisitor{})

ctx := proto.WithVariableMap(context.Background())
ctx = proto.WithConnectionID(ctx, 1)
ctx = proto.WithCommandType(ctx, constant.ComQuery)
ctx = proto.WithQueryStmt(ctx, stmt)

result, warns, err := suite.executor.ExecutorComQuery(ctx, sql)
assert.Nil(suite.T(), err)
assert.Equal(suite.T(), uint16(0), warns)
mysqlResult := result.(*mysql.Result)
for i, row := range mysqlResult.Rows {
suite.T().Logf("---------- row %d ----------", i)
textRow := row.(*mysql.TextRow)
for j, value := range textRow.Values {
switch value.Val.(type) {
case string, []byte:
suite.T().Logf("%d: %s", j, value.Val)
default:
suite.T().Logf("%d: %v", j, value.Val)
}
}
}
}

func (suite *_ShardingExecutorTestSuite) TestPrepareExecuteGlobalTable() {
sql := "select code, name, continent, region from country where code = ?"
p := parser.New()
stmt, err := p.ParseOneStmt(sql, "", "")
assert.Nil(suite.T(), err)
stmt.Accept(&visitor.ParamVisitor{})

protoStmt := &proto.Stmt{
SqlText: sql,
ParamsCount: 1,
BindVars: map[string]interface{}{
"v1": "CHN",
},
StmtNode: stmt,
}
ctx := proto.WithVariableMap(context.Background())
ctx = proto.WithConnectionID(ctx, 1)
ctx = proto.WithCommandType(ctx, constant.ComStmtExecute)
ctx = proto.WithPrepareStmt(ctx, protoStmt)

result, warns, err := suite.executor.ExecutorComStmtExecute(ctx, protoStmt)
assert.Nil(suite.T(), err)
assert.Equal(suite.T(), uint16(0), warns)
mysqlResult := result.(*mysql.Result)
for i, row := range mysqlResult.Rows {
suite.T().Logf("---------- row %d ----------", i)
binaryRow := row.(*mysql.BinaryRow)
for j, value := range binaryRow.Values {
switch value.Val.(type) {
case string, []byte:
suite.T().Logf("%d: %s", j, value.Val)
default:
suite.T().Logf("%d: %v", j, value.Val)
}
}
}
}

func (suite *_ShardingExecutorTestSuite) unmarshalExecutorConfig() *config.Executor {
var executorConfig *config.Executor
if err := yaml.Unmarshal([]byte(conf), &executorConfig); err != nil {
suite.T().Fatal(err)
}
return executorConfig
}

func (suite *_ShardingExecutorTestSuite) TearDownSuite() {
suite.environment.Shutdown(suite.T())
}
8 changes: 8 additions & 0 deletions pkg/optimize/optimize_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package optimize
import (
"context"
"sort"
"strings"

"github.com/pkg/errors"

Expand All @@ -42,6 +43,13 @@ func (o Optimizer) optimizeSelect(ctx context.Context, stmt *ast.SelectStmt, arg
)
tableName := stmt.From.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.String()

if o.globalTables[strings.ToLower(tableName)] {
return &plan.QueryDirectlyPlan{
Stmt: stmt,
Args: args,
Executor: o.executors[0],
}, nil
}
if alg, exists = o.algorithms[tableName]; !exists {
return nil, errors.New("sharding algorithm should not be nil")
}
Expand Down
8 changes: 7 additions & 1 deletion pkg/optimize/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ import (
)

type Optimizer struct {
appid string
appid string
globalTables map[string]bool
executors []proto.DBGroupExecutor
// dbName -> DBGroupExecutor
dbGroupExecutors map[string]proto.DBGroupExecutor
// tableName -> ShardingAlgorithm
Expand All @@ -38,11 +40,15 @@ type Optimizer struct {
}

func NewOptimizer(appid string,
globalTables map[string]bool,
executors []proto.DBGroupExecutor,
dbGroupExecutors map[string]proto.DBGroupExecutor,
algorithms map[string]cond.ShardingAlgorithm,
topologies map[string]*topo.Topology) proto.Optimizer {
return &Optimizer{
appid: appid,
globalTables: globalTables,
executors: executors,
dbGroupExecutors: dbGroupExecutors,
algorithms: algorithms,
topologies: topologies,
Expand Down
2 changes: 1 addition & 1 deletion pkg/plan/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func (p *DeletePlan) Execute(ctx context.Context, hints ...*ast.TableOptimizerHi
}

func (p *DeletePlan) generate(sb *strings.Builder, table string, hints ...*ast.TableOptimizerHint) error {
ctx := format.NewRestoreCtx(format.DefaultRestoreFlags, sb)
ctx := format.NewRestoreCtx(format.DefaultRestoreFlags|format.RestoreStringWithoutDefaultCharset, sb)
ctx.WriteKeyWord("DELETE ")

if len(hints) != 0 {
Expand Down
2 changes: 1 addition & 1 deletion pkg/plan/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (p *InsertPlan) Execute(ctx context.Context, hints ...*ast.TableOptimizerHi
}

func (p *InsertPlan) generate(sb *strings.Builder) (err error) {
ctx := format.NewRestoreCtx(format.DefaultRestoreFlags, sb)
ctx := format.NewRestoreCtx(format.DefaultRestoreFlags|format.RestoreStringWithoutDefaultCharset, sb)

ctx.WriteKeyWord("INSERT ")
ctx.WriteKeyWord("INTO ")
Expand Down
32 changes: 30 additions & 2 deletions pkg/plan/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,34 @@ import (

const FuncColumns = "FuncColumns"

type QueryDirectlyPlan struct {
Stmt *ast.SelectStmt
Args []interface{}
Executor proto.DBGroupExecutor
}

func (p *QueryDirectlyPlan) Execute(ctx context.Context, hints ...*ast.TableOptimizerHint) (proto.Result, uint16, error) {
var (
sb strings.Builder
sql string
err error
)
restoreCtx := format.NewRestoreCtx(format.DefaultRestoreFlags|format.RestoreStringWithoutDefaultCharset, &sb)
if err = p.Stmt.Restore(restoreCtx); err != nil {
return nil, 0, errors.WithStack(err)
}
sql = sb.String()
commandType := proto.CommandType(ctx)
switch commandType {
case constant.ComQuery:
return p.Executor.Query(ctx, sql)
case constant.ComStmtExecute:
return p.Executor.PrepareQuery(ctx, sql, p.Args...)
default:
return nil, 0, nil
}
}

type QueryOnSingleDBPlan struct {
Database string
Tables []string
Expand Down Expand Up @@ -108,7 +136,7 @@ func (p *QueryOnSingleDBPlan) generate(ctx context.Context, sb *strings.Builder,
}
sb.WriteString(") t ")
if p.Stmt.OrderBy != nil {
restoreCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, sb)
restoreCtx := format.NewRestoreCtx(format.DefaultRestoreFlags|format.RestoreStringWithoutDefaultCharset, sb)
if err := p.Stmt.OrderBy.Restore(restoreCtx); err != nil {
return errors.WithStack(err)
}
Expand Down Expand Up @@ -212,7 +240,7 @@ func (p *QueryOnMultiDBPlan) Execute(ctx context.Context, _ ...*ast.TableOptimiz
}

func generateSelect(table string, stmt *ast.SelectStmt, sb *strings.Builder, limit *Limit) error {
ctx := format.NewRestoreCtx(format.DefaultRestoreFlags, sb)
ctx := format.NewRestoreCtx(format.DefaultRestoreFlags|format.RestoreStringWithoutDefaultCharset, sb)
ctx.WriteKeyWord(stmt.Kind.String())
ctx.WritePlain(" ")

Expand Down

0 comments on commit 1d2d32d

Please sign in to comment.