forked from pingcap/tidb
/
connection.go
390 lines (339 loc) · 10.8 KB
/
connection.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
// A Go driver for the MySQL X protocol for Go's database/sql package
// Based heavily on:
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
// Copyright 2016 Simon J Mudd.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql/driver"
"errors"
"fmt"
"log"
"net"
"strings"
"github.com/golang/protobuf/proto"
"github.com/pingcap/tipb/go-mysqlx"
"github.com/pingcap/tipb/go-mysqlx/Datatypes"
"github.com/pingcap/tipb/go-mysqlx/Resultset"
"github.com/pingcap/tipb/go-mysqlx/Sql"
)
type mysqlXConn struct {
buf buffer // raw bytes pulled in from network
pb *netProtobuf // holds a possible protobuf message that still needs processing
netConn net.Conn
affectedRows uint64
insertID uint64
cfg *xconfig
maxPacketAllowed int
maxWriteSize int
// flags clientFlag
// status statusFlag
// sequence uint8
parseTime bool
strict bool
state queryState
capabilities ServerCapabilities
systemVariable []byte
}
func (mc *mysqlXConn) capabilityTestUnknownCapability() error {
name := "randomCapability"
err := mc.setScalarBoolCapability(name, true)
return err
}
// second stage of the open once the driver has been selecteed
func (mc *mysqlXConn) Open2() (driver.Conn, error) {
var err error
// Connect to Server
if dial, ok := dials[mc.cfg.net]; ok {
mc.netConn, err = dial(mc.cfg.addr)
} else {
nd := net.Dialer{Timeout: mc.cfg.timeout}
mc.netConn, err = nd.Dial(mc.cfg.net, mc.cfg.addr)
}
if err != nil {
return nil, err
}
// Enable TCP Keepalives on TCP connections
if tc, ok := mc.netConn.(*net.TCPConn); ok {
if err := tc.SetKeepAlive(true); err != nil {
// Don't send COM_QUIT before handshake.
mc.netConn.Close()
mc.netConn = nil
return nil, err
}
}
mc.buf = newBuffer(mc.netConn)
// could/should be optional for performance? e.g. dsn has get_capabilities=0
if err := mc.getCapabilities(); err != nil {
return nil, fmt.Errorf("mysqlXConn.Open2: getCapabilities() failed: %v", err)
}
// can do some random checks here.
// if err := mc.capabilityTestUnknownCapability(); err != nil {
// return nil, fmt.Errorf("mysqlXConn.Open2: could not set unknown capability")
// }
if !mc.capabilities.Exists("authentication.mechanisms") {
return nil, fmt.Errorf("mysqlXConn.Open2: did not find capability: authentication.mechanisms")
}
// Current known capabilities: as of 5.7.14
// "tls" (scalar string) only visible if TLS is configured
// "authentication.mechanisms" (array string)
// "doc.formats" (scalar string)
// "node_type" (scalar string)
// "plugin.version" (scalar string)
// "client.pwd_expire_ok" (scalar bool)
// Check and use the first one we can. SHOULD prioritise.
values := mc.capabilities.Values("authentication.mechanisms")
found := false
for i := range values {
if values[i].String() == "MYSQL41" {
found = true
if err := mc.AuthenticateMySQL41(); err != nil {
return nil, fmt.Errorf("Authentication failed: %v", err)
}
break
}
}
if !found {
return nil, fmt.Errorf("mysqlXConn.Open2: could not find authentication.mechanism I can deal with. Found: %+v", values)
}
// // Get max allowed packet size
// maxap, err := mc.getSystemVar("mysqlx_max_allowed_packet") // NOT THE SAME AS max_allowed_packet !!
// if err != nil {
// mc.Close()
// return nil, err
// }
// mc.maxPacketAllowed = stringToInt(maxap) - 1
// if mc.maxPacketAllowed < maxPacketSize {
// mc.maxWriteSize = mc.maxPacketAllowed
// }
// Handle DSN Params
err = mc.handleParams()
if err != nil {
mc.Close()
return nil, err
}
return mc, nil
}
// Gets the value of the given MySQL System Variable
// FIXME FIXME FIXME
// - Note this is broken as we need to send a normal SQL statement to get the values.
// - Currently returning a constant to the client which is of course wrong.
// FIXME FIXME FIXME
func (mc *mysqlXConn) getSystemVar(name string) ([]byte, error) {
// Required steps are:
// - Send command "SELECT @@"+name)
// - check the values
// - return them
if name == "mysqlx_max_allowed_packet" {
// hard-coded: should not be FIXME FIXME
response := "1048576"
return []byte(response), nil
}
return nil, fmt.Errorf("mysqlXConn.getSystemVar(%s) not implemented", name)
}
// Handles parameters set in DSN after the connection is established
func (mc *mysqlXConn) handleParams() (err error) {
for param, val := range mc.cfg.params {
switch param {
// Charset
case "charset":
// is this a code bug in upstream go-sql-drivers/mysql ?
// not quite sure why we might want to set more than one charset here.
charsets := strings.Split(val, ",")
for i := range charsets {
// ignore errors here - a charset may not exist
err = mc.exec("SET NAMES " + charsets[i])
if err == nil {
break
}
}
if err != nil {
return
}
// time.Time parsing
case "parseTime":
var isBool bool
mc.parseTime, isBool = readBool(val)
if !isBool {
return errors.New("Invalid Bool value: " + val)
}
// Strict mode
case "strict":
var isBool bool
mc.strict, isBool = readBool(val)
if !isBool {
return errors.New("Invalid Bool value: " + val)
}
// Compression
case "compress":
err = errors.New("Compression not implemented yet")
return
// TLS
case "tls":
// FIXME FIXME FIXME
//
// not sure about order and handling of TLS
// - do we try to go into TLS mode first?
// - do we go into TLS mode at the point that we process the parameter options?
// - do we go into TLS mode at the end of processing parameters, having remembered that we want to use TLS mode?
// - it may make sense to go into TLS mode early but that changes logic etc.
//
// FIXME FIXME FIXME
// check if the server has advertised tls capabilities
if !mc.capabilities.Exists("tls") {
return errors.New("Server does not support TLS")
}
tls := mc.capabilities.Values("tls")
if len(tls) == 1 {
return fmt.Errorf("server tls capability returns unexpected result: len(tls) = %d, expecting 1", len(tls))
}
tlsType := tls[0].Type()
if tlsType != "bool" {
return fmt.Errorf("server tls capability type unexpected: %s, expecting bool", tlsType)
}
tlsValue := tls[0].Bool()
// Setup or check the magic TLS config here
// - should have been done by the app before
// Tell the server we want to go in TLS mode.
if err := mc.setScalarBoolCapability("tls", tlsValue); err != nil {
return fmt.Errorf("Failed to set Capability TLS <fill in here>: %+v", err)
}
// wait for OK back and if we get it then we go into TLS mode
// System Vars
default:
err = mc.exec("SET " + param + "=" + val + "")
if err != nil {
return
}
}
}
return
}
// Internal function to execute commands when we don't expect a resultset
// e.g. for sending commands.
func (mc *mysqlXConn) exec(query string) error {
// Should be able to use normal "query logic" here
rows, err := mc.Query(query, nil)
if err != nil {
return fmt.Errorf("mysqlXConn.exec failed: %+v", err)
}
// close the rows and handle any response packets received
return rows.Close()
}
// methods needed to make things work
func (mc *mysqlXConn) Begin() (driver.Tx, error) {
return nil, fmt.Errorf("DEBUG: mysqlXConn.Begin() not implemented yet")
}
// close the connection
func (mc *mysqlXConn) Close() error {
if mc.netConn == nil {
return nil
}
// we don't handle in mc that we are dealing with a query. If we are we need to drain the input.
// send this message
if err := mc.writeClose(); err != nil {
return fmt.Errorf("mysqlXConn.Close failed: %v", err)
}
// wait for Ok or Error, and ignore others
var err error
var pb *netProtobuf
done := false
for !done {
pb, err = mc.readMsg()
if err != nil {
return fmt.Errorf("mysqlXConn.Close failed waiting for response to Close: %v", err)
}
switch Mysqlx.ServerMessages_Type(pb.msgType) {
case Mysqlx.ServerMessages_OK:
{
// show any message
ok := new(Mysqlx.Ok)
if err := proto.Unmarshal(pb.payload, ok); err != nil {
return fmt.Errorf("mysqlXConn.Close: Failed to read Ok: %v", err)
}
done = true
}
case Mysqlx.ServerMessages_ERROR:
return fmt.Errorf("mysqlXConn.Close received: %+v", err)
case Mysqlx.ServerMessages_NOTICE:
mc.processNotice("mysqlXConn.Close()") // process the notice message
default:
}
}
mc.netConn.Close()
mc.netConn = nil
return nil
}
func (mc *mysqlXConn) Prepare(query string) (driver.Stmt, error) {
return nil, fmt.Errorf("DEBUG: mysqlXConn.Prepare() not implemented yet")
}
func (mc *mysqlXConn) Exec(query string, args []driver.Value) (driver.Result, error) {
return nil, fmt.Errorf("DEBUG: mysqlXConn.Exec() not implemented yet")
}
// Query is the public interface to making a query via database/sql
func (mc *mysqlXConn) Query(query string, args []driver.Value) (driver.Rows, error) {
if mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
nameSpace := "sql"
var stmtArgs []*Mysqlx_Datatypes.Any
if len(args) > 0 {
str, ok := args[0].(string)
if ok {
nameSpace = str
}
for i := 1; i < len(args); i++ {
var any Mysqlx_Datatypes.Any
switch v := args[i].(type) {
case string:
any = setString([]byte(v))
case int:
any = setUint(uint64(v))
default:
continue
}
stmtArgs = append(stmtArgs, &any)
}
}
stmtExecute := &Mysqlx_Sql.StmtExecute{
Namespace: &nameSpace,
Stmt: []byte(query),
Args: stmtArgs,
}
// write a StmtExecute packet with the given query to the network
// - we DO NOT process the result as this will be done later.
if err := mc.writeStmtExecute(stmtExecute); err != nil {
return nil, fmt.Errorf("mysqlXConn.Query(%q,...) failed: %v", err, query)
}
// return the iterator
return &mysqlXRows{
columns: nil, // be explicit about expectations
err: nil, // be explicit about expectations
mc: mc,
state: queryStateWaitingColumnMetaData,
}, nil
}
func printableColumnMetaData(pb *netProtobuf) string {
p := new(Mysqlx_Resultset.ColumnMetaData)
if err := proto.Unmarshal(pb.payload, p); err != nil {
log.Fatalf("error unmarshaling ColumnMetaData p: %v", err)
}
return fmt.Sprintf("Type: %v, Name: %q, OriginalName: %q, Table: %q, Schema: %q, Catalog: %q, Collation: %v, FractionalDigits: %v, Length: %v, Flags: %v, ContentType: %v",
p.GetType(),
string(p.GetName()),
string(p.GetOriginalName()),
string(p.GetTable()),
string(p.GetSchema()),
string(p.GetCatalog()),
p.GetCollation(),
p.GetFractionalDigits(),
p.GetLength(),
p.GetFlags(),
p.GetContentType())
}