Skip to content

Commit

Permalink
Merge 7ae076c into df363ef
Browse files Browse the repository at this point in the history
  • Loading branch information
PangXing committed May 20, 2022
2 parents df363ef + 7ae076c commit 6e6e023
Show file tree
Hide file tree
Showing 11 changed files with 866 additions and 18 deletions.
4 changes: 2 additions & 2 deletions pkg/executor/redirect.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context) (proto.Re
} else {
err = errNoDatabaseSelected
}
case *ast.InsertStmt, *ast.UpdateStmt, *ast.DeleteStmt:
case *ast.InsertStmt, *ast.UpdateStmt, *ast.DeleteStmt, *ast.AlterTableStmt:
if schemaless {
err = errNoDatabaseSelected
} else {
Expand Down Expand Up @@ -303,7 +303,7 @@ func (executor *RedirectExecutor) ExecutorComStmtExecute(ctx *proto.Context) (pr
}

switch ctx.Stmt.StmtNode.(type) {
case *ast.SelectStmt, *ast.InsertStmt, *ast.UpdateStmt, *ast.DeleteStmt:
case *ast.SelectStmt, *ast.InsertStmt, *ast.UpdateStmt, *ast.DeleteStmt, *ast.AlterTableStmt:
default:
ctx.Context = rcontext.WithDirect(ctx.Context)
}
Expand Down
183 changes: 183 additions & 0 deletions pkg/runtime/ast/alter_table.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ast

import (
"strings"
)

var (
_ Statement = (*AlterTableStatement)(nil)
_ paramsCounter = (*AlterTableStatement)(nil)
)

type AlterTableType uint8

const (
_ AlterTableType = iota
AlterTableAddColumns
AlterTableDropColumn
AlterTableAddConstraint
AlterTableChangeColumn
AlterTableModifyColumn
AlterTableRenameTable
AlterTableRenameColumn
)

type AlterTableSpecStatement struct {
Tp AlterTableType
OldColumnName ColumnNameExpressionAtom
NewColumnName ColumnNameExpressionAtom
NewColumns []*ColumnDefine
NewTable TableName
Position *ColumnPosition
Constraint *Constraint
}

func (a *AlterTableSpecStatement) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error {
switch a.Tp {
case AlterTableDropColumn:
sb.WriteString("DROP COLUMN ")
if err := a.OldColumnName.Restore(flag, sb, args); err != nil {
return err
}
case AlterTableAddColumns:
sb.WriteString("ADD COLUMN ")
if len(a.NewColumns) == 1 {
if err := a.NewColumns[0].Restore(flag, sb, args); err != nil {
return err
}
if a.Position != nil {
sb.WriteString(" ")
if err := a.Position.Restore(flag, sb, args); err != nil {
return err
}
}
} else {
sb.WriteString("(")
for i, col := range a.NewColumns {
if i != 0 {
sb.WriteString(", ")
}
if err := col.Restore(flag, sb, args); err != nil {
return err
}
}
sb.WriteString(")")
}
case AlterTableAddConstraint:
sb.WriteString("ADD ")
if err := a.Constraint.Restore(flag, sb, args); err != nil {
return err
}
case AlterTableChangeColumn:
sb.WriteString("CHANGE COLUMN ")
if err := a.OldColumnName.Restore(flag, sb, args); err != nil {
return err
}
sb.WriteString(" ")
if err := a.NewColumns[0].Restore(flag, sb, args); err != nil {
return err
}
if a.Position != nil {
sb.WriteString(" ")
if err := a.Position.Restore(flag, sb, args); err != nil {
return err
}
}
case AlterTableModifyColumn:
sb.WriteString("MODIFY COLUMN ")
if err := a.NewColumns[0].Restore(flag, sb, args); err != nil {
return err
}
if a.Position != nil {
sb.WriteString(" ")
if err := a.Position.Restore(flag, sb, args); err != nil {
return err
}
}
case AlterTableRenameTable:
sb.WriteString("RENAME AS ")
if err := a.NewTable.Restore(flag, sb, args); err != nil {
return err
}
case AlterTableRenameColumn:
sb.WriteString("RENAME COLUMN ")
if err := a.OldColumnName.Restore(flag, sb, args); err != nil {
return err
}
sb.WriteString(" TO ")
if err := a.NewColumnName.Restore(flag, sb, args); err != nil {
return err
}
}

return nil
}

func (a *AlterTableSpecStatement) CntParams() int {
return 0
}

// AlterTableStatement represents mysql alter table statement. see https://dev.mysql.com/doc/refman/8.0/en/alter-table.html
type AlterTableStatement struct {
Table TableName
Specs []*AlterTableSpecStatement
}

func (at *AlterTableStatement) ResetTable(table string) *AlterTableStatement {
ret := new(AlterTableStatement)
*ret = *at

tableName := make(TableName, len(ret.Table))
copy(tableName, ret.Table)
tableName[len(tableName)-1] = table

ret.Table = tableName
return ret
}

func (at *AlterTableStatement) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error {
sb.WriteString("ALTER TABLE ")
if err := at.Table.Restore(flag, sb, args); err != nil {
return err
}
for i, spec := range at.Specs {
if i == 0 {
sb.WriteString(" ")
} else {
sb.WriteString(", ")
}
if err := spec.Restore(flag, sb, args); err != nil {
return err
}
}
return nil
}

func (at *AlterTableStatement) Validate() error {
return nil
}

func (at *AlterTableStatement) CntParams() int {
return 0
}

func (at *AlterTableStatement) Mode() SQLType {
return SalterTable
}
170 changes: 170 additions & 0 deletions pkg/runtime/ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,181 @@ func FromStmtNode(node ast.StmtNode) (Statement, error) {
return cc.convTruncateTableStmt(stmt), nil
case *ast.DropTableStmt:
return cc.convDropTableStmt(stmt), nil
case *ast.AlterTableStmt:
return cc.convAlterTableStmt(stmt), nil
default:
return nil, errors.Errorf("unimplement: stmt type %T!", stmt)
}
}

func (cc *convCtx) convAlterTableStmt(stmt *ast.AlterTableStmt) *AlterTableStatement {
var tableName TableName
if db := stmt.Table.Schema.O; len(db) > 0 {
tableName = append(tableName, db)
}
tableName = append(tableName, stmt.Table.Name.O)

var specs []*AlterTableSpecStatement
for _, spec := range stmt.Specs {
switch spec.Tp {
case ast.AlterTableDropColumn:
specs = append(specs, &AlterTableSpecStatement{
Tp: AlterTableDropColumn,
OldColumnName: cc.convColumn(spec.OldColumnName),
})
case ast.AlterTableAddColumns:
specs = append(specs, &AlterTableSpecStatement{
Tp: AlterTableAddColumns,
NewColumns: cc.convColumnDef(spec.NewColumns),
Position: cc.convColumnPostition(spec.Position),
})
case ast.AlterTableAddConstraint:
specs = append(specs, &AlterTableSpecStatement{
Tp: AlterTableAddConstraint,
Constraint: cc.convConstraint(spec.Constraint),
})
case ast.AlterTableChangeColumn:
specs = append(specs, &AlterTableSpecStatement{
Tp: AlterTableChangeColumn,
OldColumnName: cc.convColumn(spec.OldColumnName),
NewColumns: cc.convColumnDef(spec.NewColumns),
Position: cc.convColumnPostition(spec.Position),
})
case ast.AlterTableModifyColumn:
specs = append(specs, &AlterTableSpecStatement{
Tp: AlterTableModifyColumn,
NewColumns: cc.convColumnDef(spec.NewColumns),
Position: cc.convColumnPostition(spec.Position),
})
case ast.AlterTableRenameTable:
var newTable TableName
if db := spec.NewTable.Schema.O; len(db) > 0 {
newTable = append(newTable, db)
}
newTable = append(newTable, spec.NewTable.Name.O)
specs = append(specs, &AlterTableSpecStatement{
Tp: AlterTableRenameTable,
NewTable: newTable,
})
case ast.AlterTableRenameColumn:
specs = append(specs, &AlterTableSpecStatement{
Tp: AlterTableRenameColumn,
OldColumnName: cc.convColumn(spec.OldColumnName),
NewColumnName: cc.convColumn(spec.NewColumnName),
})
}
}
return &AlterTableStatement{
Table: tableName,
Specs: specs,
}
}

func (cc *convCtx) convColumnDef(cds []*ast.ColumnDef) []*ColumnDefine {
var cols []*ColumnDefine
for _, col := range cds {
var opts []*ColumnOption
for _, opt := range col.Options {
switch opt.Tp {
case ast.ColumnOptionPrimaryKey:
opts = append(opts, &ColumnOption{Tp: ColumnOptionPrimaryKey})
case ast.ColumnOptionNotNull:
opts = append(opts, &ColumnOption{Tp: ColumnOptionNotNull})
case ast.ColumnOptionAutoIncrement:
opts = append(opts, &ColumnOption{Tp: ColumnOptionAutoIncrement})
case ast.ColumnOptionDefaultValue:
opts = append(opts, &ColumnOption{
Tp: ColumnOptionDefaultValue,
Expr: toExpressionNode(cc.convExpr(opt.Expr)),
})
case ast.ColumnOptionUniqKey:
opts = append(opts, &ColumnOption{Tp: ColumnOptionUniqKey})
case ast.ColumnOptionNull:
opts = append(opts, &ColumnOption{Tp: ColumnOptionNull})
case ast.ColumnOptionComment:
opts = append(opts, &ColumnOption{
Tp: ColumnOptionComment,
Expr: toExpressionNode(cc.convExpr(opt.Expr)),
})
case ast.ColumnOptionCollate:
opts = append(opts, &ColumnOption{
Tp: ColumnOptionComment,
StrVal: opt.StrValue,
})
case ast.ColumnOptionColumnFormat:
opts = append(opts, &ColumnOption{
Tp: ColumnOptionColumnFormat,
StrVal: opt.StrValue,
})
case ast.ColumnOptionStorage:
opts = append(opts, &ColumnOption{
Tp: ColumnOptionStorage,
StrVal: opt.StrValue,
})
}
}
cols = append(cols, &ColumnDefine{
Column: cc.convColumn(col.Name),
Tp: strings.ToUpper(col.Tp.String()),
Options: opts,
})
}
return cols
}

func (cc *convCtx) convColumnPostition(p *ast.ColumnPosition) *ColumnPosition {
var pos *ColumnPosition
if p != nil {
switch p.Tp {
case ast.ColumnPositionFirst:
pos = &ColumnPosition{
Tp: ColumnPositionFirst,
}
case ast.ColumnPositionAfter:
pos = &ColumnPosition{
Tp: ColumnPositionAfter,
Column: cc.convColumn(p.RelativeColumn),
}
}
}
return pos
}

func (cc *convCtx) convConstraint(c *ast.Constraint) *Constraint {
if c == nil {
return nil
}
keys := make([]*IndexPartSpec, len(c.Keys))
for i, k := range c.Keys {
keys[i] = &IndexPartSpec{
Column: cc.convColumn(k.Column),
Expr: toExpressionNode(cc.convExpr(k.Expr)),
}
}
tp := ConstraintNoConstraint
switch c.Tp {
case ast.ConstraintPrimaryKey:
tp = ConstraintPrimaryKey
case ast.ConstraintKey:
tp = ConstraintKey
case ast.ConstraintIndex:
tp = ConstraintIndex
case ast.ConstraintUniq:
tp = ConstraintUniq
case ast.ConstraintUniqKey:
tp = ConstraintUniqKey
case ast.ConstraintUniqIndex:
tp = ConstraintUniqIndex
case ast.ConstraintFulltext:
tp = ConstraintFulltext
}
return &Constraint{
Tp: tp,
Name: c.Name,
Keys: keys,
}
}

func (cc *convCtx) convDropTableStmt(stmt *ast.DropTableStmt) *DropTableStatement {
var tables = make([]*TableName, len(stmt.Tables))
for i, table := range stmt.Tables {
Expand Down
Loading

0 comments on commit 6e6e023

Please sign in to comment.