From 5fb28a86419ad4e760b65bcbd01d6a15c7cce6f4 Mon Sep 17 00:00:00 2001 From: PangXing Date: Fri, 20 May 2022 06:36:07 +0800 Subject: [PATCH 1/3] support alter table statement --- pkg/executor/redirect.go | 11 +- pkg/runtime/ast/alter_table.go | 183 ++++++++++++++++++++++ pkg/runtime/ast/ast.go | 180 +++++++++++++++++++++- pkg/runtime/ast/ast_test.go | 51 +++++- pkg/runtime/ast/model.go | 205 ++++++++++++++++++++++++- pkg/runtime/ast/proto.go | 32 ++-- pkg/runtime/optimize/optimizer.go | 49 +++++- pkg/runtime/optimize/optimizer_test.go | 75 ++++++++- pkg/runtime/plan/alter_table.go | 118 ++++++++++++++ pkg/runtime/plan/transparent.go | 8 +- test/integration_test.go | 21 ++- 11 files changed, 871 insertions(+), 62 deletions(-) create mode 100644 pkg/runtime/ast/alter_table.go create mode 100644 pkg/runtime/plan/alter_table.go diff --git a/pkg/executor/redirect.go b/pkg/executor/redirect.go index 3fcf6a58..be50e098 100644 --- a/pkg/executor/redirect.go +++ b/pkg/executor/redirect.go @@ -22,22 +22,19 @@ import ( stdErrors "errors" "sync" "time" -) -import ( "github.com/arana-db/parser" "github.com/arana-db/parser/ast" - "github.com/pkg/errors" -) -import ( mConstants "github.com/arana-db/arana/pkg/constants/mysql" "github.com/arana-db/arana/pkg/metrics" "github.com/arana-db/arana/pkg/mysql" + mysqlErrors "github.com/arana-db/arana/pkg/mysql/errors" "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/runtime" + rcontext "github.com/arana-db/arana/pkg/runtime/context" "github.com/arana-db/arana/pkg/security" "github.com/arana-db/arana/pkg/util/log" @@ -218,7 +215,7 @@ func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context) (proto.Re } else { err = errNoDatabaseSelected } - case *ast.InsertStmt, *ast.UpdateStmt, *ast.DeleteStmt: + case *ast.InsertStmt, *ast.UpdateStmt, *ast.DeleteStmt, *ast.AlterTableStmt: if schemaless { err = errNoDatabaseSelected } else { @@ -303,7 +300,7 @@ func (executor *RedirectExecutor) ExecutorComStmtExecute(ctx *proto.Context) (pr } switch ctx.Stmt.StmtNode.(type) { - case *ast.SelectStmt, *ast.InsertStmt, *ast.UpdateStmt, *ast.DeleteStmt: + case *ast.SelectStmt, *ast.InsertStmt, *ast.UpdateStmt, *ast.DeleteStmt, *ast.AlterTableStmt: default: ctx.Context = rcontext.WithDirect(ctx.Context) } diff --git a/pkg/runtime/ast/alter_table.go b/pkg/runtime/ast/alter_table.go new file mode 100644 index 00000000..a60f1596 --- /dev/null +++ b/pkg/runtime/ast/alter_table.go @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ast + +import ( + "strings" +) + +var ( + _ Statement = (*AlterTableStatement)(nil) + _ paramsCounter = (*AlterTableStatement)(nil) +) + +type AlterTableType uint8 + +const ( + _ AlterTableType = iota + AlterTableAddColumns + AlterTableDropColumn + AlterTableAddConstraint + AlterTableChangeColumn + AlterTableModifyColumn + AlterTableRenameTable + AlterTableRenameColumn +) + +type AlterTableSpecStatement struct { + Tp AlterTableType + OldColumnName ColumnNameExpressionAtom + NewColumnName ColumnNameExpressionAtom + NewColumns []*ColumnDefine + NewTable TableName + Position *ColumnPosition + Constraint *Constraint +} + +func (a *AlterTableSpecStatement) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { + switch a.Tp { + case AlterTableDropColumn: + sb.WriteString("DROP COLUMN ") + if err := a.OldColumnName.Restore(flag, sb, args); err != nil { + return err + } + case AlterTableAddColumns: + sb.WriteString("ADD COLUMN ") + if len(a.NewColumns) == 1 { + if err := a.NewColumns[0].Restore(flag, sb, args); err != nil { + return err + } + if a.Position != nil { + sb.WriteString(" ") + if err := a.Position.Restore(flag, sb, args); err != nil { + return err + } + } + } else { + sb.WriteString("(") + for i, col := range a.NewColumns { + if i != 0 { + sb.WriteString(", ") + } + if err := col.Restore(flag, sb, args); err != nil { + return err + } + } + sb.WriteString(")") + } + case AlterTableAddConstraint: + sb.WriteString("ADD ") + if err := a.Constraint.Restore(flag, sb, args); err != nil { + return err + } + case AlterTableChangeColumn: + sb.WriteString("CHANGE COLUMN ") + if err := a.OldColumnName.Restore(flag, sb, args); err != nil { + return err + } + sb.WriteString(" ") + if err := a.NewColumns[0].Restore(flag, sb, args); err != nil { + return err + } + if a.Position != nil { + sb.WriteString(" ") + if err := a.Position.Restore(flag, sb, args); err != nil { + return err + } + } + case AlterTableModifyColumn: + sb.WriteString("MODIFY COLUMN ") + if err := a.NewColumns[0].Restore(flag, sb, args); err != nil { + return err + } + if a.Position != nil { + sb.WriteString(" ") + if err := a.Position.Restore(flag, sb, args); err != nil { + return err + } + } + case AlterTableRenameTable: + sb.WriteString("RENAME AS ") + if err := a.NewTable.Restore(flag, sb, args); err != nil { + return err + } + case AlterTableRenameColumn: + sb.WriteString("RENAME COLUMN ") + if err := a.OldColumnName.Restore(flag, sb, args); err != nil { + return err + } + sb.WriteString(" TO ") + if err := a.NewColumnName.Restore(flag, sb, args); err != nil { + return err + } + } + + return nil +} + +func (a *AlterTableSpecStatement) CntParams() int { + return 0 +} + +// AlterTableStatement represents mysql alter table statement. see https://dev.mysql.com/doc/refman/8.0/en/alter-table.html +type AlterTableStatement struct { + Table TableName + Specs []*AlterTableSpecStatement +} + +func (at *AlterTableStatement) ResetTable(table string) *AlterTableStatement { + ret := new(AlterTableStatement) + *ret = *at + + tableName := make(TableName, len(ret.Table)) + copy(tableName, ret.Table) + tableName[len(tableName)-1] = table + + ret.Table = tableName + return ret +} + +func (at *AlterTableStatement) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { + sb.WriteString("ALTER TABLE ") + if err := at.Table.Restore(flag, sb, args); err != nil { + return err + } + for i, spec := range at.Specs { + if i == 0 { + sb.WriteString(" ") + } else { + sb.WriteString(", ") + } + if err := spec.Restore(flag, sb, args); err != nil { + return err + } + } + return nil +} + +func (at *AlterTableStatement) Validate() error { + return nil +} + +func (at *AlterTableStatement) CntParams() int { + return 0 +} + +func (at *AlterTableStatement) Mode() SQLType { + return SalterTable +} diff --git a/pkg/runtime/ast/ast.go b/pkg/runtime/ast/ast.go index a69c8a6c..9b0642b1 100644 --- a/pkg/runtime/ast/ast.go +++ b/pkg/runtime/ast/ast.go @@ -21,23 +21,17 @@ import ( "fmt" "strconv" "strings" -) -import ( + "github.com/arana-db/arana/pkg/runtime/cmp" + "github.com/arana-db/arana/pkg/runtime/logical" "github.com/arana-db/parser" "github.com/arana-db/parser/ast" "github.com/arana-db/parser/mysql" "github.com/arana-db/parser/opcode" "github.com/arana-db/parser/test_driver" - "github.com/pkg/errors" ) -import ( - "github.com/arana-db/arana/pkg/runtime/cmp" - "github.com/arana-db/arana/pkg/runtime/logical" -) - var ( _opcode2comparison = map[opcode.Op]cmp.Comparison{ opcode.EQ: cmp.Ceq, @@ -106,11 +100,181 @@ func FromStmtNode(node ast.StmtNode) (Statement, error) { return cc.convTruncateTableStmt(stmt), nil case *ast.DropTableStmt: return cc.convDropTableStmt(stmt), nil + case *ast.AlterTableStmt: + return cc.convAlterTableStmt(stmt), nil default: return nil, errors.Errorf("unimplement: stmt type %T!", stmt) } } +func (cc *convCtx) convAlterTableStmt(stmt *ast.AlterTableStmt) *AlterTableStatement { + var tableName TableName + if db := stmt.Table.Schema.O; len(db) > 0 { + tableName = append(tableName, db) + } + tableName = append(tableName, stmt.Table.Name.O) + + var specs []*AlterTableSpecStatement + for _, spec := range stmt.Specs { + switch spec.Tp { + case ast.AlterTableDropColumn: + specs = append(specs, &AlterTableSpecStatement{ + Tp: AlterTableDropColumn, + OldColumnName: cc.convColumn(spec.OldColumnName), + }) + case ast.AlterTableAddColumns: + specs = append(specs, &AlterTableSpecStatement{ + Tp: AlterTableAddColumns, + NewColumns: cc.convColumnDef(spec.NewColumns), + Position: cc.convColumnPostition(spec.Position), + }) + case ast.AlterTableAddConstraint: + specs = append(specs, &AlterTableSpecStatement{ + Tp: AlterTableAddConstraint, + Constraint: cc.convConstraint(spec.Constraint), + }) + case ast.AlterTableChangeColumn: + specs = append(specs, &AlterTableSpecStatement{ + Tp: AlterTableChangeColumn, + OldColumnName: cc.convColumn(spec.OldColumnName), + NewColumns: cc.convColumnDef(spec.NewColumns), + Position: cc.convColumnPostition(spec.Position), + }) + case ast.AlterTableModifyColumn: + specs = append(specs, &AlterTableSpecStatement{ + Tp: AlterTableModifyColumn, + NewColumns: cc.convColumnDef(spec.NewColumns), + Position: cc.convColumnPostition(spec.Position), + }) + case ast.AlterTableRenameTable: + var newTable TableName + if db := spec.NewTable.Schema.O; len(db) > 0 { + newTable = append(newTable, db) + } + newTable = append(newTable, spec.NewTable.Name.O) + specs = append(specs, &AlterTableSpecStatement{ + Tp: AlterTableRenameTable, + NewTable: newTable, + }) + case ast.AlterTableRenameColumn: + specs = append(specs, &AlterTableSpecStatement{ + Tp: AlterTableRenameColumn, + OldColumnName: cc.convColumn(spec.OldColumnName), + NewColumnName: cc.convColumn(spec.NewColumnName), + }) + } + } + return &AlterTableStatement{ + Table: tableName, + Specs: specs, + } +} + +func (cc *convCtx) convColumnDef(cds []*ast.ColumnDef) []*ColumnDefine { + var cols []*ColumnDefine + for _, col := range cds { + var opts []*ColumnOption + for _, opt := range col.Options { + switch opt.Tp { + case ast.ColumnOptionPrimaryKey: + opts = append(opts, &ColumnOption{Tp: ColumnOptionPrimaryKey}) + case ast.ColumnOptionNotNull: + opts = append(opts, &ColumnOption{Tp: ColumnOptionNotNull}) + case ast.ColumnOptionAutoIncrement: + opts = append(opts, &ColumnOption{Tp: ColumnOptionAutoIncrement}) + case ast.ColumnOptionDefaultValue: + opts = append(opts, &ColumnOption{ + Tp: ColumnOptionDefaultValue, + Expr: toExpressionNode(cc.convExpr(opt.Expr)), + }) + case ast.ColumnOptionUniqKey: + opts = append(opts, &ColumnOption{Tp: ColumnOptionUniqKey}) + case ast.ColumnOptionNull: + opts = append(opts, &ColumnOption{Tp: ColumnOptionNull}) + case ast.ColumnOptionComment: + opts = append(opts, &ColumnOption{ + Tp: ColumnOptionComment, + Expr: toExpressionNode(cc.convExpr(opt.Expr)), + }) + case ast.ColumnOptionCollate: + opts = append(opts, &ColumnOption{ + Tp: ColumnOptionComment, + StrVal: opt.StrValue, + }) + case ast.ColumnOptionColumnFormat: + opts = append(opts, &ColumnOption{ + Tp: ColumnOptionColumnFormat, + StrVal: opt.StrValue, + }) + case ast.ColumnOptionStorage: + opts = append(opts, &ColumnOption{ + Tp: ColumnOptionStorage, + StrVal: opt.StrValue, + }) + } + } + cols = append(cols, &ColumnDefine{ + Column: cc.convColumn(col.Name), + Tp: strings.ToUpper(col.Tp.String()), + Options: opts, + }) + } + return cols +} + +func (cc *convCtx) convColumnPostition(p *ast.ColumnPosition) *ColumnPosition { + var pos *ColumnPosition + if p != nil { + switch p.Tp { + case ast.ColumnPositionFirst: + pos = &ColumnPosition{ + Tp: ColumnPositionFirst, + } + case ast.ColumnPositionAfter: + pos = &ColumnPosition{ + Tp: ColumnPositionAfter, + Column: cc.convColumn(p.RelativeColumn), + } + } + } + return pos +} + +func (cc *convCtx) convConstraint(c *ast.Constraint) *Constraint { + if c == nil { + return nil + } + keys := make([]*IndexPartSpec, len(c.Keys)) + for i, k := range c.Keys { + keys[i] = &IndexPartSpec{ + Column: cc.convColumn(k.Column), + Expr: toExpressionNode(cc.convExpr(k.Expr)), + } + } + tp := ConstraintNoConstraint + switch c.Tp { + case ast.ConstraintPrimaryKey: + tp = ConstraintPrimaryKey + case ast.ConstraintKey: + tp = ConstraintKey + case ast.ConstraintIndex: + tp = ConstraintIndex + case ast.ConstraintUniq: + tp = ConstraintUniq + case ast.ConstraintUniqKey: + tp = ConstraintUniqKey + case ast.ConstraintUniqIndex: + tp = ConstraintUniqIndex + case ast.ConstraintFulltext: + tp = ConstraintFulltext + } + return &Constraint{ + Tp: tp, + Name: c.Name, + Keys: keys, + } +} + func (cc *convCtx) convDropTableStmt(stmt *ast.DropTableStmt) *DropTableStatement { var tables = make([]*TableName, len(stmt.Tables)) for i, table := range stmt.Tables { diff --git a/pkg/runtime/ast/ast_test.go b/pkg/runtime/ast/ast_test.go index ddb183a9..c7f9d969 100644 --- a/pkg/runtime/ast/ast_test.go +++ b/pkg/runtime/ast/ast_test.go @@ -20,9 +20,7 @@ package ast import ( "strings" "testing" -) -import ( "github.com/stretchr/testify/assert" ) @@ -365,3 +363,52 @@ func TestQuote(t *testing.T) { t.Log(sb.String()) assert.Equal(t, "SELECT `a``bc`", sb.String()) } + +func TestParse_AlterTableStmt(t *testing.T) { + type tt struct { + input string + expect string + } + + for _, it := range []tt{ + { + "alter table student drop nickname", + "ALTER TABLE `student` DROP COLUMN `nickname`", + }, + { + "alter table student add dept_id int not null default 0 after uid", + "ALTER TABLE `student` ADD COLUMN `dept_id` INT(11) NOT NULL DEFAULT 0 AFTER `uid`", + }, + { + "alter table student add index idx_name (name)", + "ALTER TABLE `student` ADD INDEX idx_name(`name`)", + }, + { + "alter table student change id uid bigint not null", + "ALTER TABLE `student` CHANGE COLUMN `id` `uid` BIGINT(20) NOT NULL", + }, + { + "alter table student modify uid bigint not null default 0", + "ALTER TABLE `student` MODIFY COLUMN `uid` BIGINT(20) NOT NULL DEFAULT 0", + }, + { + "alter table student rename to students", + "ALTER TABLE `student` RENAME AS `students`", + }, + { + "alter table student rename column name to nickname, rename column nickname to name", + "ALTER TABLE `student` RENAME COLUMN `name` TO `nickname`, RENAME COLUMN `nickname` TO `name`", + }, + } { + t.Run(it.input, func(t *testing.T) { + stmt, err := Parse(it.input) + assert.NoError(t, err) + assert.IsTypef(t, (*AlterTableStatement)(nil), stmt, "should be alter table statement") + + actual, err := RestoreToString(RestoreDefault, stmt.(Restorer)) + assert.NoError(t, err, "should restore ok") + assert.Equal(t, it.expect, actual) + }) + } + +} diff --git a/pkg/runtime/ast/model.go b/pkg/runtime/ast/model.go index 5b6d462b..759b6277 100644 --- a/pkg/runtime/ast/model.go +++ b/pkg/runtime/ast/model.go @@ -20,9 +20,7 @@ package ast import ( "fmt" "strings" -) -import ( "github.com/pkg/errors" ) @@ -382,3 +380,206 @@ func (u *UpdateElement) CntParams() int { } return u.Value.CntParams() } + +type ColumnOptionType uint8 + +const ( + _ ColumnOptionType = iota + ColumnOptionPrimaryKey + ColumnOptionNotNull + ColumnOptionAutoIncrement + ColumnOptionDefaultValue + ColumnOptionUniqKey + ColumnOptionNull + ColumnOptionComment + ColumnOptionCollate + ColumnOptionColumnFormat + ColumnOptionStorage +) + +type ColumnOption struct { + Tp ColumnOptionType + Expr ExpressionNode + StrVal string +} + +func (c *ColumnOption) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { + switch c.Tp { + case ColumnOptionPrimaryKey: + sb.WriteString("PRIMARY KEY") + case ColumnOptionNotNull: + sb.WriteString("NOT NULL") + case ColumnOptionAutoIncrement: + sb.WriteString("AUTO_INCREMENT") + case ColumnOptionDefaultValue: + sb.WriteString("DEFAULT ") + if err := c.Expr.Restore(flag, sb, args); err != nil { + return err + } + case ColumnOptionUniqKey: + sb.WriteString("UNIQUE KEY") + case ColumnOptionNull: + sb.WriteString("NULL") + case ColumnOptionComment: + sb.WriteString("COMMENT ") + if err := c.Expr.Restore(flag, sb, args); err != nil { + return err + } + case ColumnOptionCollate: + if len(c.StrVal) == 0 { + return errors.New("Empty ColumnOption COLLATE") + } + sb.WriteString("COLLATE ") + sb.WriteString(c.StrVal) + case ColumnOptionColumnFormat: + sb.WriteString("COLUMN_FORMAT ") + sb.WriteString(c.StrVal) + case ColumnOptionStorage: + sb.WriteString("STORAGE ") + sb.WriteString(c.StrVal) + } + return nil +} + +func (c *ColumnOption) CntParams() int { + return 0 +} + +type ColumnDefine struct { + Column ColumnNameExpressionAtom + Tp string + Options []*ColumnOption +} + +func (c *ColumnDefine) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { + if err := c.Column.Restore(flag, sb, args); err != nil { + return err + } + if len(c.Tp) > 0 { + sb.WriteString(" " + c.Tp) + } + for _, o := range c.Options { + sb.WriteString(" ") + if err := o.Restore(flag, sb, args); err != nil { + return err + } + } + return nil +} + +func (c *ColumnDefine) CntParams() int { + return 0 +} + +type ColumnPositionType uint8 + +const ( + _ ColumnPositionType = iota + ColumnPositionFirst + ColumnPositionAfter +) + +type ColumnPosition struct { + Tp ColumnPositionType + Column ColumnNameExpressionAtom +} + +func (c *ColumnPosition) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { + switch c.Tp { + case ColumnPositionFirst: + sb.WriteString("FIRST") + case ColumnPositionAfter: + sb.WriteString("AFTER ") + if err := c.Column.Restore(flag, sb, args); err != nil { + return err + } + default: + return errors.New("invalid ColumnPositionType") + } + return nil +} + +func (c *ColumnPosition) CntParams() int { + return 0 +} + +type IndexPartSpec struct { + Column ColumnNameExpressionAtom + Expr ExpressionNode +} + +func (i *IndexPartSpec) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { + if i.Expr != nil { + sb.WriteString("(") + if err := i.Expr.Restore(flag, sb, args); err != nil { + return err + } + sb.WriteString(")") + return nil + } + if err := i.Column.Restore(flag, sb, args); err != nil { + return err + } + return nil +} + +func (i *IndexPartSpec) CntParams() int { + return 0 +} + +type ConstraintType uint8 + +const ( + ConstraintNoConstraint ConstraintType = iota + ConstraintPrimaryKey + ConstraintKey + ConstraintIndex + ConstraintUniq + ConstraintUniqKey + ConstraintUniqIndex + ConstraintFulltext +) + +type Constraint struct { + Tp ConstraintType + Name string + Keys []*IndexPartSpec +} + +func (c *Constraint) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { + switch c.Tp { + case ConstraintPrimaryKey: + sb.WriteString("PRIMARY KEY") + case ConstraintKey: + sb.WriteString("KEY") + case ConstraintIndex: + sb.WriteString("INDEX") + case ConstraintUniq: + sb.WriteString("UNIQUE") + case ConstraintUniqKey: + sb.WriteString("UNIQUE KEY") + case ConstraintUniqIndex: + sb.WriteString("UNIQUE INDEX") + case ConstraintFulltext: + sb.WriteString("FULLTEXT") + } + if len(c.Name) > 0 { + sb.WriteString(" ") + sb.WriteString(c.Name) + } + sb.WriteString("(") + for i, k := range c.Keys { + if i != 0 { + sb.WriteString(", ") + } + if err := k.Restore(flag, sb, args); err != nil { + return err + } + } + sb.WriteString(")") + return nil +} + +func (c *Constraint) CntParams() int { + return 0 +} diff --git a/pkg/runtime/ast/proto.go b/pkg/runtime/ast/proto.go index 4fb4dff5..c5109dcd 100644 --- a/pkg/runtime/ast/proto.go +++ b/pkg/runtime/ast/proto.go @@ -22,14 +22,15 @@ import ( ) const ( - _ SQLType = iota - Squery // QUERY - Sdelete // DELETE - Supdate // UPDATE - Sinsert // INSERT - Sreplace // REPLACE - Struncate // TRUNCATE - SdropTable // DROP TABLE + _ SQLType = iota + Squery // QUERY + Sdelete // DELETE + Supdate // UPDATE + Sinsert // INSERT + Sreplace // REPLACE + Struncate // TRUNCATE + SdropTable // DROP TABLE + SalterTable // ALTER TABLE ) type RestoreFlag uint32 @@ -44,13 +45,14 @@ type Restorer interface { } var _sqlTypeNames = [...]string{ - Squery: "QUERY", - Sdelete: "DELETE", - Supdate: "UPDATE", - Sinsert: "INSERT", - Sreplace: "REPLACE", - Struncate: "TRUNCATE", - SdropTable: "DROP TABLE", + Squery: "QUERY", + Sdelete: "DELETE", + Supdate: "UPDATE", + Sinsert: "INSERT", + Sreplace: "REPLACE", + Struncate: "TRUNCATE", + SdropTable: "DROP TABLE", + SalterTable: "ALTER TABLE", } // SQLType represents the type of SQL. diff --git a/pkg/runtime/optimize/optimizer.go b/pkg/runtime/optimize/optimizer.go index c61388a7..ee00cd38 100644 --- a/pkg/runtime/optimize/optimizer.go +++ b/pkg/runtime/optimize/optimizer.go @@ -21,20 +21,16 @@ import ( "context" stdErrors "errors" "strings" -) -import ( - "github.com/arana-db/parser/ast" - - "github.com/pkg/errors" -) - -import ( "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/proto/rule" "github.com/arana-db/arana/pkg/proto/schema_manager" + "github.com/arana-db/parser/ast" + "github.com/pkg/errors" + rast "github.com/arana-db/arana/pkg/runtime/ast" "github.com/arana-db/arana/pkg/runtime/cmp" + rcontext "github.com/arana-db/arana/pkg/runtime/context" "github.com/arana-db/arana/pkg/runtime/plan" "github.com/arana-db/arana/pkg/transformer" @@ -120,6 +116,8 @@ func (o optimizer) doOptimize(ctx context.Context, conn proto.VConn, stmt rast.S return o.optimizeShowVariables(ctx, t, args) case *rast.DescribeStatement: return o.optimizeDescribeStatement(ctx, t, args) + case *rast.AlterTableStatement: + return o.optimizeAlterTable(ctx, t, args) } //TODO implement all statements @@ -131,6 +129,41 @@ const ( _supported ) +func (o optimizer) optimizeAlterTable(ctx context.Context, stmt *rast.AlterTableStatement, args []interface{}) (proto.Plan, error) { + var ( + ret = plan.NewAlterTablePlan(stmt) + ru = rcontext.Rule(ctx) + table = stmt.Table + vt *rule.VTable + ok bool + ) + ret.BindArgs(args) + + // non-sharding update + if vt, ok = ru.VTable(table.Suffix()); !ok { + return ret, nil + } + + //TODO alter table table or column to new name , should update sharding info + + // exit if full-scan is disabled + if !vt.AllowFullScan() { + return nil, errDenyFullScan + } + + // sharding + shards := rule.DatabaseTables{} + topology := vt.Topology() + topology.Each(func(dbIdx, tbIdx int) bool { + if d, t, ok := topology.Render(dbIdx, tbIdx); ok { + shards[d] = append(shards[d], t) + } + return true + }) + ret.Shards = shards + return ret, nil +} + func (o optimizer) optimizeDropTable(ctx context.Context, stmt *rast.DropTableStatement, args []interface{}) (proto.Plan, error) { ru := rcontext.Rule(ctx) //table shard diff --git a/pkg/runtime/optimize/optimizer_test.go b/pkg/runtime/optimize/optimizer_test.go index 8f705f69..cc28a007 100644 --- a/pkg/runtime/optimize/optimizer_test.go +++ b/pkg/runtime/optimize/optimizer_test.go @@ -19,21 +19,17 @@ package optimize import ( "context" + "fmt" "strings" "testing" -) -import ( + "github.com/arana-db/arana/pkg/mysql" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" "github.com/arana-db/parser" - "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" -) -import ( - "github.com/arana-db/arana/pkg/mysql" - "github.com/arana-db/arana/pkg/proto" rcontext "github.com/arana-db/arana/pkg/runtime/context" "github.com/arana-db/arana/testdata" ) @@ -166,3 +162,66 @@ func TestOptimizer_OptimizeInsert(t *testing.T) { }) } + +func TestOptimizer_OptimizeAlterTable(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + conn := testdata.NewMockVConn(ctrl) + loader := testdata.NewMockSchemaLoader(ctrl) + + conn.EXPECT().Exec(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, db string, sql string, args ...interface{}) (proto.Result, error) { + t.Logf("fake exec: db='%s', sql=\"%s\", args=%v\n", db, sql, args) + return &mysql.Result{}, nil + }).AnyTimes() + + var ( + ctx = context.Background() + opt = optimizer{schemaLoader: loader} + ru rule.Rule + tab rule.VTable + topo rule.Topology + ) + + topo.SetRender(func(_ int) string { + return "fake_db" + }, func(i int) string { + return fmt.Sprintf("student_%04d", i) + }) + tables := make([]int, 0, 8) + for i := 0; i < 8; i++ { + tables = append(tables, i) + } + topo.SetTopology(0, tables...) + tab.SetTopology(&topo) + tab.SetAllowFullScan(true) + ru.SetVTable("student", &tab) + + t.Run("sharding", func(t *testing.T) { + sql := "alter table student add dept_id int not null default 0 after uid" + + p := parser.New() + stmt, _ := p.ParseOneStmt(sql, "", "") + + plan, err := opt.Optimize(rcontext.WithRule(ctx, &ru), conn, stmt) + assert.NoError(t, err) + + _, err = plan.ExecIn(ctx, conn) + assert.NoError(t, err) + + }) + + t.Run("non-sharding", func(t *testing.T) { + sql := "alter table employees add index idx_name (first_name)" + + p := parser.New() + stmt, _ := p.ParseOneStmt(sql, "", "") + + plan, err := opt.Optimize(rcontext.WithRule(ctx, &ru), conn, stmt) + assert.NoError(t, err) + + _, err = plan.ExecIn(ctx, conn) + assert.NoError(t, err) + }) +} diff --git a/pkg/runtime/plan/alter_table.go b/pkg/runtime/plan/alter_table.go new file mode 100644 index 00000000..dac10c96 --- /dev/null +++ b/pkg/runtime/plan/alter_table.go @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package plan + +import ( + "context" + "strings" + + "github.com/arana-db/arana/pkg/mysql" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/util/log" + "github.com/pkg/errors" + uatomic "go.uber.org/atomic" + "golang.org/x/sync/errgroup" +) + +var _ proto.Plan = (*AlterTablePlan)(nil) + +type AlterTablePlan struct { + basePlan + stmt *ast.AlterTableStatement + Shards rule.DatabaseTables +} + +func NewAlterTablePlan(stmt *ast.AlterTableStatement) *AlterTablePlan { + return &AlterTablePlan{stmt: stmt} +} + +func (d *AlterTablePlan) Type() proto.PlanType { + return proto.PlanTypeExec +} + +func (at *AlterTablePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + if at.Shards == nil { + // non-sharding alter table + var sb strings.Builder + if err := at.stmt.Restore(ast.RestoreDefault, &sb, nil); err != nil { + return nil, err + } + return conn.Exec(ctx, "", sb.String(), at.args...) + } + var ( + affects = uatomic.NewUint64(0) + cnt = uatomic.NewUint32(0) + ) + + var g errgroup.Group + + // sharding alter table + for k, v := range at.Shards { + // do copy for goroutine-safe + var ( + db = k + tables = v + ) + // execute concurrent for each phy database + g.Go(func() error { + var ( + sb strings.Builder + args []int + res proto.Result + err error + ) + + sb.Grow(256) + + for _, table := range tables { + if err = at.stmt.ResetTable(table).Restore(ast.RestoreDefault, &sb, &args); err != nil { + return errors.WithStack(err) + } + + if res, err = conn.Exec(ctx, db, sb.String(), at.toArgs(args)...); err != nil { + return errors.WithStack(err) + } + + n, _ := res.RowsAffected() + affects.Add(n) + cnt.Inc() + + // cleanup + if len(args) > 0 { + args = args[:0] + } + sb.Reset() + } + + return nil + }) + } + + if err := g.Wait(); err != nil { + return nil, err + } + + log.Debugf("sharding alter table success: batch=%d, affects=%d", cnt.Load(), affects.Load()) + + return &mysql.Result{ + AffectedRows: affects.Load(), + DataChan: make(chan proto.Row, 1), + }, nil +} diff --git a/pkg/runtime/plan/transparent.go b/pkg/runtime/plan/transparent.go index fea0e2cf..40c154d1 100644 --- a/pkg/runtime/plan/transparent.go +++ b/pkg/runtime/plan/transparent.go @@ -20,14 +20,10 @@ package plan import ( "context" "strings" -) -import ( + "github.com/arana-db/arana/pkg/proto" "github.com/pkg/errors" -) -import ( - "github.com/arana-db/arana/pkg/proto" rast "github.com/arana-db/arana/pkg/runtime/ast" ) @@ -45,7 +41,7 @@ type TransparentPlan struct { func Transparent(stmt rast.Statement, args []interface{}) *TransparentPlan { var typ proto.PlanType switch stmt.Mode() { - case rast.Sinsert, rast.Sdelete, rast.Sreplace, rast.Supdate, rast.Struncate, rast.SdropTable: + case rast.Sinsert, rast.Sdelete, rast.Sreplace, rast.Supdate, rast.Struncate, rast.SdropTable, rast.SalterTable: typ = proto.PlanTypeExec default: typ = proto.PlanTypeQuery diff --git a/test/integration_test.go b/test/integration_test.go index a90b6e7a..67b2b752 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -21,17 +21,12 @@ import ( "fmt" "testing" "time" -) -import ( + "github.com/arana-db/arana/pkg/util/rand2" _ "github.com/go-sql-driver/mysql" // register mysql - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" -) -import ( - "github.com/arana-db/arana/pkg/util/rand2" utils "github.com/arana-db/arana/pkg/util/tableprint" ) @@ -431,3 +426,17 @@ func (s *IntegrationSuite) TestShardingAgg() { db.Exec("DELETE FROM student WHERE uid >= 9527") } + +func (s *IntegrationSuite) TestAlterTable() { + var ( + db = s.DB() + t = s.T() + ) + + result, err := db.Exec(`alter table employees add dept_no char(4) not null default "" after emp_no`) + assert.NoErrorf(t, err, "alter table error: %v", err) + affected, err := result.RowsAffected() + assert.NoErrorf(t, err, "alter table error: %v", err) + + assert.Equal(t, int64(1), affected) +} From 6399602354c969c77e6609a34c28be3c93d13b23 Mon Sep 17 00:00:00 2001 From: PangXing Date: Fri, 20 May 2022 07:59:36 +0800 Subject: [PATCH 2/3] fix alter table test --- test/integration_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration_test.go b/test/integration_test.go index 67b2b752..1fed9e6d 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -438,5 +438,5 @@ func (s *IntegrationSuite) TestAlterTable() { affected, err := result.RowsAffected() assert.NoErrorf(t, err, "alter table error: %v", err) - assert.Equal(t, int64(1), affected) + assert.Equal(t, int64(0), affected) } From 7ae076ca994daebf0c1a5d6b6c72da753207e5fe Mon Sep 17 00:00:00 2001 From: PangXing Date: Fri, 20 May 2022 18:14:15 +0800 Subject: [PATCH 3/3] format imports by imports-formatter --- pkg/executor/redirect.go | 7 +++++-- pkg/runtime/ast/ast.go | 10 ++++++++-- pkg/runtime/ast/ast_test.go | 2 ++ pkg/runtime/ast/model.go | 2 ++ pkg/runtime/optimize/optimizer.go | 12 ++++++++---- pkg/runtime/optimize/optimizer_test.go | 12 +++++++++--- pkg/runtime/plan/alter_table.go | 13 ++++++++++--- pkg/runtime/plan/transparent.go | 6 +++++- test/integration_test.go | 7 ++++++- 9 files changed, 55 insertions(+), 16 deletions(-) diff --git a/pkg/executor/redirect.go b/pkg/executor/redirect.go index be50e098..74fc62d3 100644 --- a/pkg/executor/redirect.go +++ b/pkg/executor/redirect.go @@ -22,19 +22,22 @@ import ( stdErrors "errors" "sync" "time" +) +import ( "github.com/arana-db/parser" "github.com/arana-db/parser/ast" + "github.com/pkg/errors" +) +import ( mConstants "github.com/arana-db/arana/pkg/constants/mysql" "github.com/arana-db/arana/pkg/metrics" "github.com/arana-db/arana/pkg/mysql" - mysqlErrors "github.com/arana-db/arana/pkg/mysql/errors" "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/runtime" - rcontext "github.com/arana-db/arana/pkg/runtime/context" "github.com/arana-db/arana/pkg/security" "github.com/arana-db/arana/pkg/util/log" diff --git a/pkg/runtime/ast/ast.go b/pkg/runtime/ast/ast.go index 9b0642b1..97a91300 100644 --- a/pkg/runtime/ast/ast.go +++ b/pkg/runtime/ast/ast.go @@ -21,17 +21,23 @@ import ( "fmt" "strconv" "strings" +) - "github.com/arana-db/arana/pkg/runtime/cmp" - "github.com/arana-db/arana/pkg/runtime/logical" +import ( "github.com/arana-db/parser" "github.com/arana-db/parser/ast" "github.com/arana-db/parser/mysql" "github.com/arana-db/parser/opcode" "github.com/arana-db/parser/test_driver" + "github.com/pkg/errors" ) +import ( + "github.com/arana-db/arana/pkg/runtime/cmp" + "github.com/arana-db/arana/pkg/runtime/logical" +) + var ( _opcode2comparison = map[opcode.Op]cmp.Comparison{ opcode.EQ: cmp.Ceq, diff --git a/pkg/runtime/ast/ast_test.go b/pkg/runtime/ast/ast_test.go index c7f9d969..90b4a565 100644 --- a/pkg/runtime/ast/ast_test.go +++ b/pkg/runtime/ast/ast_test.go @@ -20,7 +20,9 @@ package ast import ( "strings" "testing" +) +import ( "github.com/stretchr/testify/assert" ) diff --git a/pkg/runtime/ast/model.go b/pkg/runtime/ast/model.go index 759b6277..f6bfcb5b 100644 --- a/pkg/runtime/ast/model.go +++ b/pkg/runtime/ast/model.go @@ -20,7 +20,9 @@ package ast import ( "fmt" "strings" +) +import ( "github.com/pkg/errors" ) diff --git a/pkg/runtime/optimize/optimizer.go b/pkg/runtime/optimize/optimizer.go index ee00cd38..a4d20950 100644 --- a/pkg/runtime/optimize/optimizer.go +++ b/pkg/runtime/optimize/optimizer.go @@ -21,16 +21,20 @@ import ( "context" stdErrors "errors" "strings" +) - "github.com/arana-db/arana/pkg/proto" - "github.com/arana-db/arana/pkg/proto/rule" - "github.com/arana-db/arana/pkg/proto/schema_manager" +import ( "github.com/arana-db/parser/ast" + "github.com/pkg/errors" +) +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/proto/schema_manager" rast "github.com/arana-db/arana/pkg/runtime/ast" "github.com/arana-db/arana/pkg/runtime/cmp" - rcontext "github.com/arana-db/arana/pkg/runtime/context" "github.com/arana-db/arana/pkg/runtime/plan" "github.com/arana-db/arana/pkg/transformer" diff --git a/pkg/runtime/optimize/optimizer_test.go b/pkg/runtime/optimize/optimizer_test.go index cc28a007..80569ce6 100644 --- a/pkg/runtime/optimize/optimizer_test.go +++ b/pkg/runtime/optimize/optimizer_test.go @@ -22,14 +22,20 @@ import ( "fmt" "strings" "testing" +) - "github.com/arana-db/arana/pkg/mysql" - "github.com/arana-db/arana/pkg/proto" - "github.com/arana-db/arana/pkg/proto/rule" +import ( "github.com/arana-db/parser" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) +import ( + "github.com/arana-db/arana/pkg/mysql" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" rcontext "github.com/arana-db/arana/pkg/runtime/context" "github.com/arana-db/arana/testdata" ) diff --git a/pkg/runtime/plan/alter_table.go b/pkg/runtime/plan/alter_table.go index dac10c96..c04cd560 100644 --- a/pkg/runtime/plan/alter_table.go +++ b/pkg/runtime/plan/alter_table.go @@ -20,15 +20,22 @@ package plan import ( "context" "strings" +) + +import ( + "github.com/pkg/errors" + uatomic "go.uber.org/atomic" + + "golang.org/x/sync/errgroup" +) + +import ( "github.com/arana-db/arana/pkg/mysql" "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/proto/rule" "github.com/arana-db/arana/pkg/runtime/ast" "github.com/arana-db/arana/pkg/util/log" - "github.com/pkg/errors" - uatomic "go.uber.org/atomic" - "golang.org/x/sync/errgroup" ) var _ proto.Plan = (*AlterTablePlan)(nil) diff --git a/pkg/runtime/plan/transparent.go b/pkg/runtime/plan/transparent.go index 40c154d1..a4b67a36 100644 --- a/pkg/runtime/plan/transparent.go +++ b/pkg/runtime/plan/transparent.go @@ -20,10 +20,14 @@ package plan import ( "context" "strings" +) - "github.com/arana-db/arana/pkg/proto" +import ( "github.com/pkg/errors" +) +import ( + "github.com/arana-db/arana/pkg/proto" rast "github.com/arana-db/arana/pkg/runtime/ast" ) diff --git a/test/integration_test.go b/test/integration_test.go index 1fed9e6d..85411e00 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -21,12 +21,17 @@ import ( "fmt" "testing" "time" +) - "github.com/arana-db/arana/pkg/util/rand2" +import ( _ "github.com/go-sql-driver/mysql" // register mysql + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" +) +import ( + "github.com/arana-db/arana/pkg/util/rand2" utils "github.com/arana-db/arana/pkg/util/tableprint" )