-
Notifications
You must be signed in to change notification settings - Fork 3
/
analyzer.go
343 lines (319 loc) · 8.28 KB
/
analyzer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
/*
Copyright 2017 Google Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sqlparser
// analyzer.go contains utility analysis functions.
import (
"fmt"
"strconv"
"strings"
"unicode"
"github.com/abulo/ratel/elasticsearch/sqlparser/dependency/sqltypes"
"github.com/pkg/errors"
)
// These constants are used to identify the SQL statement type.
const (
StmtSelect = iota
StmtStream
StmtInsert
StmtReplace
StmtUpdate
StmtDelete
StmtDDL
StmtBegin
StmtCommit
StmtRollback
StmtSet
StmtShow
StmtUse
StmtOther
StmtUnknown
StmtComment
)
// Preview analyzes the beginning of the query using a simpler and faster
// textual comparison to identify the statement type.
func Preview(sql string) int {
trimmed := StripLeadingComments(sql)
firstWord := trimmed
if end := strings.IndexFunc(trimmed, unicode.IsSpace); end != -1 {
firstWord = trimmed[:end]
}
firstWord = strings.TrimLeftFunc(firstWord, func(r rune) bool { return !unicode.IsLetter(r) })
// Comparison is done in order of priority.
loweredFirstWord := strings.ToLower(firstWord)
switch loweredFirstWord {
case "select":
return StmtSelect
case "stream":
return StmtStream
case "insert":
return StmtInsert
case "replace":
return StmtReplace
case "update":
return StmtUpdate
case "delete":
return StmtDelete
}
// For the following statements it is not sufficient to rely
// on loweredFirstWord. This is because they are not statements
// in the grammar and we are relying on Preview to parse them.
// For instance, we don't want: "BEGIN JUNK" to be parsed
// as StmtBegin.
trimmedNoComments, _ := SplitMarginComments(trimmed)
switch strings.ToLower(trimmedNoComments) {
case "begin", "start transaction":
return StmtBegin
case "commit":
return StmtCommit
case "rollback":
return StmtRollback
}
switch loweredFirstWord {
case "create", "alter", "rename", "drop", "truncate":
return StmtDDL
case "set":
return StmtSet
case "show":
return StmtShow
case "use":
return StmtUse
case "analyze", "describe", "desc", "explain", "repair", "optimize":
return StmtOther
}
if strings.Index(trimmed, "/*!") == 0 {
return StmtComment
}
return StmtUnknown
}
// StmtType returns the statement type as a string
func StmtType(stmtType int) string {
switch stmtType {
case StmtSelect:
return "SELECT"
case StmtStream:
return "STREAM"
case StmtInsert:
return "INSERT"
case StmtReplace:
return "REPLACE"
case StmtUpdate:
return "UPDATE"
case StmtDelete:
return "DELETE"
case StmtDDL:
return "DDL"
case StmtBegin:
return "BEGIN"
case StmtCommit:
return "COMMIT"
case StmtRollback:
return "ROLLBACK"
case StmtSet:
return "SET"
case StmtShow:
return "SHOW"
case StmtUse:
return "USE"
case StmtOther:
return "OTHER"
default:
return "UNKNOWN"
}
}
// IsDML returns true if the query is an INSERT, UPDATE or DELETE statement.
func IsDML(sql string) bool {
switch Preview(sql) {
case StmtInsert, StmtReplace, StmtUpdate, StmtDelete:
return true
}
return false
}
// GetTableName returns the table name from the SimpleTableExpr
// only if it's a simple expression. Otherwise, it returns "".
func GetTableName(node SimpleTableExpr) TableIdent {
if n, ok := node.(TableName); ok && n.Qualifier.IsEmpty() {
return n.Name
}
// sub-select or '.' expression
return NewTableIdent("")
}
// IsColName returns true if the Expr is a *ColName.
func IsColName(node Expr) bool {
_, ok := node.(*ColName)
return ok
}
// IsValue returns true if the Expr is a string, integral or value arg.
// NULL is not considered to be a value.
func IsValue(node Expr) bool {
switch v := node.(type) {
case *SQLVal:
switch v.Type {
case StrVal, HexVal, IntVal, ValArg:
return true
}
}
return false
}
// IsNull returns true if the Expr is SQL NULL
func IsNull(node Expr) bool {
switch node.(type) {
case *NullVal:
return true
}
return false
}
// IsSimpleTuple returns true if the Expr is a ValTuple that
// contains simple values or if it's a list arg.
func IsSimpleTuple(node Expr) bool {
switch vals := node.(type) {
case ValTuple:
for _, n := range vals {
if !IsValue(n) {
return false
}
}
return true
case ListArg:
return true
}
// It's a subquery
return false
}
// NewPlanValue builds a sqltypes.PlanValue from an Expr.
func NewPlanValue(node Expr) (sqltypes.PlanValue, error) {
switch node := node.(type) {
case *SQLVal:
switch node.Type {
case ValArg:
return sqltypes.PlanValue{Key: string(node.Val[1:])}, nil
case IntVal:
n, err := sqltypes.NewIntegral(string(node.Val))
if err != nil {
return sqltypes.PlanValue{}, fmt.Errorf("%v", err)
}
return sqltypes.PlanValue{Value: n}, nil
case StrVal:
return sqltypes.PlanValue{Value: sqltypes.MakeTrusted(sqltypes.VarBinary, node.Val)}, nil
case HexVal:
v, err := node.HexDecode()
if err != nil {
return sqltypes.PlanValue{}, fmt.Errorf("%v", err)
}
return sqltypes.PlanValue{Value: sqltypes.MakeTrusted(sqltypes.VarBinary, v)}, nil
}
case ListArg:
return sqltypes.PlanValue{ListKey: string(node[2:])}, nil
case ValTuple:
pv := sqltypes.PlanValue{
Values: make([]sqltypes.PlanValue, 0, len(node)),
}
for _, val := range node {
innerpv, err := NewPlanValue(val)
if err != nil {
return sqltypes.PlanValue{}, err
}
if innerpv.ListKey != "" || innerpv.Values != nil {
return sqltypes.PlanValue{}, errors.New("unsupported: nested lists")
}
pv.Values = append(pv.Values, innerpv)
}
return pv, nil
case *NullVal:
return sqltypes.PlanValue{}, nil
}
return sqltypes.PlanValue{}, fmt.Errorf("expression is too complex '%v'", String(node))
}
// StringIn is a convenience function that returns
// true if str matches any of the values.
func StringIn(str string, values ...string) bool {
for _, val := range values {
if str == val {
return true
}
}
return false
}
// SetKey is the extracted key from one SetExpr
type SetKey struct {
Key string
Scope string
}
// ExtractSetValues returns a map of key-value pairs
// if the query is a SET statement. Values can be bool, int64 or string.
// Since set variable names are case insensitive, all keys are returned
// as lower case.
func ExtractSetValues(sql string) (keyValues map[SetKey]interface{}, scope string, err error) {
stmt, err := Parse(sql)
if err != nil {
return nil, "", err
}
setStmt, ok := stmt.(*Set)
if !ok {
return nil, "", fmt.Errorf("ast did not yield *sqlparser.Set: %T", stmt)
}
result := make(map[SetKey]interface{})
for _, expr := range setStmt.Exprs {
scope := SessionStr
key := expr.Name.Lowered()
switch {
case strings.HasPrefix(key, "@@global."):
scope = GlobalStr
key = strings.TrimPrefix(key, "@@global.")
case strings.HasPrefix(key, "@@session."):
key = strings.TrimPrefix(key, "@@session.")
case strings.HasPrefix(key, "@@"):
key = strings.TrimPrefix(key, "@@")
}
if strings.HasPrefix(expr.Name.Lowered(), "@@") {
if setStmt.Scope != "" && scope != "" {
return nil, "", fmt.Errorf("unsupported in set: mixed using of variable scope")
}
_, out := NewStringTokenizer(key).Scan()
key = string(out)
}
setKey := SetKey{
Key: key,
Scope: scope,
}
switch expr := expr.Expr.(type) {
case *SQLVal:
switch expr.Type {
case StrVal:
result[setKey] = strings.ToLower(string(expr.Val))
case IntVal:
num, err := strconv.ParseInt(string(expr.Val), 0, 64)
if err != nil {
return nil, "", err
}
result[setKey] = num
default:
return nil, "", fmt.Errorf("invalid value type: %v", String(expr))
}
case BoolVal:
var val int64
if expr {
val = 1
}
result[setKey] = val
case *ColName:
result[setKey] = expr.Name.String()
case *NullVal:
result[setKey] = nil
case *Default:
result[setKey] = "default"
default:
return nil, "", fmt.Errorf("invalid syntax: %s", String(expr))
}
}
return result, strings.ToLower(setStmt.Scope), nil
}