forked from hashicorp/vault
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mysql.go
704 lines (594 loc) · 18.4 KB
/
mysql.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
package mysql
import (
"context"
"crypto/tls"
"crypto/x509"
"database/sql"
"errors"
"fmt"
"io/ioutil"
"math"
"net/url"
"sort"
"strconv"
"strings"
"sync"
"time"
log "github.com/hashicorp/go-hclog"
"github.com/armon/go-metrics"
mysql "github.com/go-sql-driver/mysql"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/physical"
)
// Verify MySQLBackend satisfies the correct interfaces
var _ physical.Backend = (*MySQLBackend)(nil)
var _ physical.HABackend = (*MySQLBackend)(nil)
var _ physical.Lock = (*MySQLHALock)(nil)
// Unreserved tls key
// Reserved values are "true", "false", "skip-verify"
const mysqlTLSKey = "default"
// MySQLBackend is a physical backend that stores data
// within MySQL database.
type MySQLBackend struct {
dbTable string
dbLockTable string
client *sql.DB
statements map[string]*sql.Stmt
logger log.Logger
permitPool *physical.PermitPool
conf map[string]string
redirectHost string
redirectPort int64
haEnabled bool
}
// NewMySQLBackend constructs a MySQL backend using the given API client and
// server address and credential for accessing mysql database.
func NewMySQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
var err error
db, err := NewMySQLClient(conf, logger)
if err != nil {
return nil, err
}
database, ok := conf["database"]
if !ok {
database = "vault"
}
table, ok := conf["table"]
if !ok {
table = "vault"
}
dbTable := "`" + database + "`.`" + table + "`"
maxParStr, ok := conf["max_parallel"]
var maxParInt int
if ok {
maxParInt, err = strconv.Atoi(maxParStr)
if err != nil {
return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err)
}
if logger.IsDebug() {
logger.Debug("max_parallel set", "max_parallel", maxParInt)
}
} else {
maxParInt = physical.DefaultParallelOperations
}
// Check schema exists
var schemaExist bool
schemaRows, err := db.Query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?", database)
if err != nil {
return nil, errwrap.Wrapf("failed to check mysql schema exist: {{err}}", err)
}
defer schemaRows.Close()
schemaExist = schemaRows.Next()
// Check table exists
var tableExist bool
tableRows, err := db.Query("SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME = ? AND TABLE_SCHEMA = ?", table, database)
if err != nil {
return nil, errwrap.Wrapf("failed to check mysql table exist: {{err}}", err)
}
defer tableRows.Close()
tableExist = tableRows.Next()
// Create the required database if it doesn't exists.
if !schemaExist {
if _, err := db.Exec("CREATE DATABASE IF NOT EXISTS `" + database + "`"); err != nil {
return nil, errwrap.Wrapf("failed to create mysql database: {{err}}", err)
}
}
// Create the required table if it doesn't exists.
if !tableExist {
create_query := "CREATE TABLE IF NOT EXISTS " + dbTable +
" (vault_key varbinary(512), vault_value mediumblob, PRIMARY KEY (vault_key))"
if _, err := db.Exec(create_query); err != nil {
return nil, errwrap.Wrapf("failed to create mysql table: {{err}}", err)
}
}
// Default value for ha_enabled
haEnabledStr, ok := conf["ha_enabled"]
if !ok {
haEnabledStr = "false"
}
haEnabled, err := strconv.ParseBool(haEnabledStr)
if err != nil {
return nil, fmt.Errorf("value [%v] of 'ha_enabled' could not be understood", haEnabledStr)
}
locktable, ok := conf["lock_table"]
if !ok {
locktable = table + "_lock"
}
dbLockTable := "`" + database + "`.`" + locktable + "`"
// Only create lock table if ha_enabled is true
if haEnabled {
// Check table exists
var lockTableExist bool
lockTableRows, err := db.Query("SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME = ? AND TABLE_SCHEMA = ?", locktable, database)
if err != nil {
return nil, errwrap.Wrapf("failed to check mysql table exist: {{err}}", err)
}
defer lockTableRows.Close()
lockTableExist = lockTableRows.Next()
// Create the required table if it doesn't exists.
if !lockTableExist {
create_query := "CREATE TABLE IF NOT EXISTS " + dbLockTable +
" (node_job varchar(512), current_leader varchar(512), PRIMARY KEY (node_job))"
if _, err := db.Exec(create_query); err != nil {
return nil, errwrap.Wrapf("failed to create mysql table: {{err}}", err)
}
}
}
// Setup the backend.
m := &MySQLBackend{
dbTable: dbTable,
dbLockTable: dbLockTable,
client: db,
statements: make(map[string]*sql.Stmt),
logger: logger,
permitPool: physical.NewPermitPool(maxParInt),
conf: conf,
haEnabled: haEnabled,
}
// Prepare all the statements required
statements := map[string]string{
"put": "INSERT INTO " + dbTable +
" VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)",
"get": "SELECT vault_value FROM " + dbTable + " WHERE vault_key = ?",
"delete": "DELETE FROM " + dbTable + " WHERE vault_key = ?",
"list": "SELECT vault_key FROM " + dbTable + " WHERE vault_key LIKE ?",
}
// Only prepare ha-related statements if we need them
if haEnabled {
statements["get_lock"] = "SELECT current_leader FROM " + dbLockTable + " WHERE node_job = ?"
statements["used_lock"] = "SELECT IS_USED_LOCK(?)"
}
for name, query := range statements {
if err := m.prepare(name, query); err != nil {
return nil, err
}
}
return m, nil
}
func NewMySQLClient(conf map[string]string, logger log.Logger) (*sql.DB, error) {
var err error
// Get the MySQL credentials to perform read/write operations.
username, ok := conf["username"]
if !ok || username == "" {
return nil, fmt.Errorf("missing username")
}
password, ok := conf["password"]
if !ok || password == "" {
return nil, fmt.Errorf("missing password")
}
// Get or set MySQL server address. Defaults to localhost and default port(3306)
address, ok := conf["address"]
if !ok {
address = "127.0.0.1:3306"
}
maxIdleConnStr, ok := conf["max_idle_connections"]
var maxIdleConnInt int
if ok {
maxIdleConnInt, err = strconv.Atoi(maxIdleConnStr)
if err != nil {
return nil, errwrap.Wrapf("failed parsing max_idle_connections parameter: {{err}}", err)
}
if logger.IsDebug() {
logger.Debug("max_idle_connections set", "max_idle_connections", maxIdleConnInt)
}
}
maxConnLifeStr, ok := conf["max_connection_lifetime"]
var maxConnLifeInt int
if ok {
maxConnLifeInt, err = strconv.Atoi(maxConnLifeStr)
if err != nil {
return nil, errwrap.Wrapf("failed parsing max_connection_lifetime parameter: {{err}}", err)
}
if logger.IsDebug() {
logger.Debug("max_connection_lifetime set", "max_connection_lifetime", maxConnLifeInt)
}
}
maxParStr, ok := conf["max_parallel"]
var maxParInt int
if ok {
maxParInt, err = strconv.Atoi(maxParStr)
if err != nil {
return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err)
}
if logger.IsDebug() {
logger.Debug("max_parallel set", "max_parallel", maxParInt)
}
} else {
maxParInt = physical.DefaultParallelOperations
}
dsnParams := url.Values{}
tlsCaFile, ok := conf["tls_ca_file"]
if ok {
if err := setupMySQLTLSConfig(tlsCaFile); err != nil {
return nil, errwrap.Wrapf("failed register TLS config: {{err}}", err)
}
dsnParams.Add("tls", mysqlTLSKey)
}
// Create MySQL handle for the database.
dsn := username + ":" + password + "@tcp(" + address + ")/?" + dsnParams.Encode()
db, err := sql.Open("mysql", dsn)
if err != nil {
return nil, errwrap.Wrapf("failed to connect to mysql: {{err}}", err)
}
db.SetMaxOpenConns(maxParInt)
if maxIdleConnInt != 0 {
db.SetMaxIdleConns(maxIdleConnInt)
}
if maxConnLifeInt != 0 {
db.SetConnMaxLifetime(time.Duration(maxConnLifeInt) * time.Second)
}
return db, err
}
// prepare is a helper to prepare a query for future execution
func (m *MySQLBackend) prepare(name, query string) error {
stmt, err := m.client.Prepare(query)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf("failed to prepare %q: {{err}}", name), err)
}
m.statements[name] = stmt
return nil
}
// Put is used to insert or update an entry.
func (m *MySQLBackend) Put(ctx context.Context, entry *physical.Entry) error {
defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now())
m.permitPool.Acquire()
defer m.permitPool.Release()
_, err := m.statements["put"].Exec(entry.Key, entry.Value)
if err != nil {
return err
}
return nil
}
// Get is used to fetch an entry.
func (m *MySQLBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now())
m.permitPool.Acquire()
defer m.permitPool.Release()
var result []byte
err := m.statements["get"].QueryRow(key).Scan(&result)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
ent := &physical.Entry{
Key: key,
Value: result,
}
return ent, nil
}
// Delete is used to permanently delete an entry
func (m *MySQLBackend) Delete(ctx context.Context, key string) error {
defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now())
m.permitPool.Acquire()
defer m.permitPool.Release()
_, err := m.statements["delete"].Exec(key)
if err != nil {
return err
}
return nil
}
// List is used to list all the keys under a given
// prefix, up to the next prefix.
func (m *MySQLBackend) List(ctx context.Context, prefix string) ([]string, error) {
defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now())
m.permitPool.Acquire()
defer m.permitPool.Release()
// Add the % wildcard to the prefix to do the prefix search
likePrefix := prefix + "%"
rows, err := m.statements["list"].Query(likePrefix)
if err != nil {
return nil, errwrap.Wrapf("failed to execute statement: {{err}}", err)
}
var keys []string
for rows.Next() {
var key string
err = rows.Scan(&key)
if err != nil {
return nil, errwrap.Wrapf("failed to scan rows: {{err}}", err)
}
key = strings.TrimPrefix(key, prefix)
if i := strings.Index(key, "/"); i == -1 {
// Add objects only from the current 'folder'
keys = append(keys, key)
} else if i != -1 {
// Add truncated 'folder' paths
keys = strutil.AppendIfMissing(keys, string(key[:i+1]))
}
}
sort.Strings(keys)
return keys, nil
}
// LockWith is used for mutual exclusion based on the given key.
func (m *MySQLBackend) LockWith(key, value string) (physical.Lock, error) {
l := &MySQLHALock{
in: m,
key: key,
value: value,
logger: m.logger,
}
return l, nil
}
func (m *MySQLBackend) HAEnabled() bool {
return m.haEnabled
}
// MySQLHALock is a MySQL Lock implementation for the HABackend
type MySQLHALock struct {
in *MySQLBackend
key string
value string
logger log.Logger
held bool
localLock sync.Mutex
leaderCh chan struct{}
stopCh <-chan struct{}
lock *MySQLLock
}
func (i *MySQLHALock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
i.localLock.Lock()
defer i.localLock.Unlock()
if i.held {
return nil, fmt.Errorf("lock already held")
}
// Attempt an async acquisition
didLock := make(chan struct{})
failLock := make(chan error, 1)
releaseCh := make(chan bool, 1)
go i.attemptLock(i.key, i.value, didLock, failLock, releaseCh)
// Wait for lock acquisition, failure, or shutdown
select {
case <-didLock:
releaseCh <- false
case err := <-failLock:
return nil, err
case <-stopCh:
releaseCh <- true
return nil, nil
}
// Create the leader channel
i.held = true
i.leaderCh = make(chan struct{})
go i.monitorLock(i.leaderCh)
i.stopCh = stopCh
return i.leaderCh, nil
}
func (i *MySQLHALock) attemptLock(key, value string, didLock chan struct{}, failLock chan error, releaseCh chan bool) {
lock, err := NewMySQLLock(i.in, i.logger, key, value)
// Set node value
i.lock = lock
if err != nil {
failLock <- err
}
err = lock.Lock()
if err != nil {
failLock <- err
return
}
// Signal that lock is held
close(didLock)
// Handle an early abort
release := <-releaseCh
if release {
lock.Unlock()
}
}
func (i *MySQLHALock) monitorLock(leaderCh chan struct{}) {
for {
// The only way to lose this lock is if someone is
// logging into the DB and altering system tables or you lose a connection in
// which case you will lose the lock anyway.
err := i.hasLock(i.key)
if err != nil {
// Somehow we lost the lock.... likely because the connection holding
// the lock was closed or someone was playing around with the locks in the DB.
close(leaderCh)
return
}
time.Sleep(5 * time.Second)
}
}
func (i *MySQLHALock) Unlock() error {
i.localLock.Lock()
defer i.localLock.Unlock()
if !i.held {
return nil
}
err := i.lock.Unlock()
if err == nil {
i.held = false
return nil
}
return err
}
// hasLock will check if a lock is held by checking the current lock id against our known ID.
func (i *MySQLHALock) hasLock(key string) error {
var result sql.NullInt64
err := i.in.statements["used_lock"].QueryRow(key).Scan(&result)
if err == sql.ErrNoRows || !result.Valid {
// This is not an error to us since it just means the lock isn't held
return nil
}
if err != nil {
return err
}
// IS_USED_LOCK will return the ID of the connection that created the lock.
if result.Int64 != GlobalLockID {
return ErrLockHeld
}
return nil
}
func (i *MySQLHALock) GetLeader() (string, error) {
defer metrics.MeasureSince([]string{"mysql", "lock_get"}, time.Now())
var result string
err := i.in.statements["get_lock"].QueryRow("leader").Scan(&result)
if err == sql.ErrNoRows {
return "", err
}
return result, nil
}
func (i *MySQLHALock) Value() (bool, string, error) {
leaderkey, err := i.GetLeader()
if err != nil {
return false, "", err
}
return true, leaderkey, err
}
// MySQLLock provides an easy way to grab and release mysql
// locks using the built in GET_LOCK function. Note that these
// locks are released when you lose connection to the server.
type MySQLLock struct {
parentConn *MySQLBackend
in *sql.DB
logger log.Logger
statements map[string]*sql.Stmt
key string
value string
}
// Errors specific to trying to grab a lock in MySQL
var (
// This is the GlobalLockID for checking if the lock we got is still the current lock
GlobalLockID int64
// ErrLockHeld is returned when another vault instance already has a lock held for the given key.
ErrLockHeld = errors.New("mysql: lock already held")
// ErrUnlockFailed
ErrUnlockFailed = errors.New("mysql: unable to release lock, already released or not held by this session")
// You were unable to update that you are the new leader in the DB
ErrClaimFailed = errors.New("mysql: unable to update DB with new leader infromation")
// Error to thow if inbetween getting the lock and checking the ID of it we lost it.
ErrSettingGlobalID = errors.New("mysql: getting global lock id failed")
)
// NewMySQLLock helper function
func NewMySQLLock(in *MySQLBackend, l log.Logger, key, value string) (*MySQLLock, error) {
// Create a new MySQL connection so we can close this and have no effect on
// the rest of the MySQL backend and any cleanup that might need to be done.
conn, _ := NewMySQLClient(in.conf, in.logger)
m := &MySQLLock{
parentConn: in,
in: conn,
logger: l,
statements: make(map[string]*sql.Stmt),
key: key,
value: value,
}
statements := map[string]string{
"put": "INSERT INTO " + in.dbLockTable +
" VALUES( ?, ? ) ON DUPLICATE KEY UPDATE current_leader=VALUES(current_leader)",
}
for name, query := range statements {
if err := m.prepare(name, query); err != nil {
return nil, err
}
}
return m, nil
}
// prepare is a helper to prepare a query for future execution
func (m *MySQLLock) prepare(name, query string) error {
stmt, err := m.in.Prepare(query)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf("failed to prepare %q: {{err}}", name), err)
}
m.statements[name] = stmt
return nil
}
// update the current cluster leader in the DB. This is used so
// we can tell the servers in standby who the active leader is.
func (i *MySQLLock) becomeLeader() error {
_, err := i.statements["put"].Exec("leader", i.value)
if err != nil {
return err
}
return nil
}
// Lock will try to get a lock for an indefinite amount of time
// based on the given key that has been requested.
func (i *MySQLLock) Lock() error {
defer metrics.MeasureSince([]string{"mysql", "get_lock"}, time.Now())
// Lock timeout math.MaxInt32 instead of -1 solves compatibility issues with
// different MySQL flavours i.e. MariaDB
rows, err := i.in.Query("SELECT GET_LOCK(?, ?), IS_USED_LOCK(?)", i.key, math.MaxInt32, i.key)
if err != nil {
return err
}
defer rows.Close()
rows.Next()
var lock sql.NullInt64
var connectionID sql.NullInt64
rows.Scan(&lock, &connectionID)
if rows.Err() != nil {
return rows.Err()
}
// 1 is returned from GET_LOCK if it was able to get the lock
// 0 if it failed and NULL if some stange error happened.
// https://dev.mysql.com/doc/refman/8.0/en/miscellaneous-functions.html#function_get-lock
if !lock.Valid || lock.Int64 != 1 {
return ErrLockHeld
}
// Since we have the lock alert the rest of the cluster
// that we are now the active leader.
err = i.becomeLeader()
if err != nil {
return ErrLockHeld
}
// This will return the connection ID of NULL if an error happens
// https://dev.mysql.com/doc/refman/8.0/en/miscellaneous-functions.html#function_is-used-lock
if !connectionID.Valid {
return ErrSettingGlobalID
}
GlobalLockID = connectionID.Int64
return nil
}
// Unlock just closes the connection. This is because closing the MySQL connection
// is a 100% reliable way to close the lock. If you just release the lock you must
// do it from the same mysql connection_id that you originally created it from. This
// is a huge hastle and I actually couldn't find a clean way to do this although one
// likely does exist. Closing the connection however ensures we don't ever get into a
// state where we try to release the lock and it hangs it is also much less code.
func (i *MySQLLock) Unlock() error {
err := i.in.Close()
if err != nil {
return ErrUnlockFailed
}
return nil
}
// Establish a TLS connection with a given CA certificate
// Register a tsl.Config associated with the same key as the dns param from sql.Open
// foo:bar@tcp(127.0.0.1:3306)/dbname?tls=default
func setupMySQLTLSConfig(tlsCaFile string) error {
rootCertPool := x509.NewCertPool()
pem, err := ioutil.ReadFile(tlsCaFile)
if err != nil {
return err
}
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
return err
}
err = mysql.RegisterTLSConfig(mysqlTLSKey, &tls.Config{
RootCAs: rootCertPool,
})
if err != nil {
return err
}
return nil
}