/
locking_strategies.go
166 lines (144 loc) · 3.76 KB
/
locking_strategies.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
package orm
import (
"context"
"database/sql"
"fmt"
"sync"
"time"
"github.com/SeerLink/seerlink/core/store/dialects"
"github.com/pkg/errors"
"github.com/SeerLink/seerlink/core/logger"
"github.com/SeerLink/seerlink/core/store/models"
"go.uber.org/multierr"
)
// NewLockingStrategy returns the locking strategy for a particular dialect
// to ensure exlusive access to the orm.
func NewLockingStrategy(ct Connection) (LockingStrategy, error) {
switch ct.name {
case dialects.Postgres, dialects.PostgresWithoutLock, dialects.TransactionWrappedPostgres:
return NewPostgresLockingStrategy(ct)
}
return nil, fmt.Errorf("unable to create locking strategy for dialect %s and path %s", ct.dialect, ct.uri)
}
// LockingStrategy employs the locking and unlocking of an underlying
// resource for exclusive access, usually a file or database.
type LockingStrategy interface {
Lock(timeout models.Duration) error
Unlock(timeout models.Duration) error
}
// PostgresLockingStrategy uses a postgres advisory lock to ensure exclusive
// access.
type PostgresLockingStrategy struct {
db *sql.DB
conn *sql.Conn
m *sync.Mutex
config Connection
}
// NewPostgresLockingStrategy returns a new instance of the PostgresLockingStrategy.
func NewPostgresLockingStrategy(ct Connection) (LockingStrategy, error) {
return &PostgresLockingStrategy{
config: ct,
m: &sync.Mutex{},
}, nil
}
// Lock uses a blocking postgres advisory lock that times out at the passed
// timeout.
func (s *PostgresLockingStrategy) Lock(timeout models.Duration) error {
s.m.Lock()
defer s.m.Unlock()
ctx := context.Background()
if !timeout.IsInstant() {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout.Duration())
defer cancel()
}
if s.conn == nil {
db, err := sql.Open(string(dialects.Postgres), s.config.uri)
if err != nil {
return err
}
s.db = db
// `database/sql`.DB does opaque connection pooling, but PG advisory locks are per-connection
conn, err := db.Conn(ctx)
if err != nil {
return err
}
s.conn = conn
}
if s.config.locking {
err := s.waitForLock(ctx)
if err != nil {
return errors.Wrapf(ErrNoAdvisoryLock, "postgres advisory locking strategy failed on .Lock, timeout set to %v: %v, lock ID: %v", displayTimeout(timeout), err, s.config.advisoryLockID)
}
}
return nil
}
func (s *PostgresLockingStrategy) waitForLock(ctx context.Context) error {
ticker := time.NewTicker(s.config.lockRetryInterval)
defer ticker.Stop()
retryCount := 0
for {
rows, err := s.conn.QueryContext(ctx, "SELECT pg_try_advisory_lock($1)", s.config.advisoryLockID)
if err != nil {
return err
}
var gotLock bool
for rows.Next() {
err := rows.Scan(&gotLock)
if err != nil {
return multierr.Combine(err, rows.Close())
}
}
if err := rows.Close(); err != nil {
return err
}
if gotLock {
return nil
}
select {
case <-ticker.C:
retryCount++
logRetry(retryCount)
continue
case <-ctx.Done():
return errors.Wrap(ctx.Err(), "timeout expired while waiting for lock")
}
}
}
// logRetry logs messages at
// 1
// 2
// 4
// 8
// 16
// 32
/// ... etc, then every 1000
func logRetry(count int) {
if count == 1 {
logger.Infow("Could not get lock, retrying...", "failCount", count)
} else if count%1000 == 0 || count&(count-1) == 0 {
logger.Infow("Still waiting for lock...", "failCount", count)
}
}
// Unlock unlocks the locked postgres advisory lock.
func (s *PostgresLockingStrategy) Unlock(timeout models.Duration) error {
s.m.Lock()
defer s.m.Unlock()
if s.conn == nil {
return nil
}
connErr := s.conn.Close()
if connErr == sql.ErrConnDone {
connErr = nil
}
dbErr := s.db.Close()
if dbErr == sql.ErrConnDone {
dbErr = nil
}
s.db = nil
s.conn = nil
return multierr.Combine(
connErr,
dbErr,
)
}