/
db.go
256 lines (186 loc) · 5.45 KB
/
db.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
package authdb
import (
"context"
"encoding/json"
"time"
"github.com/bsm/redislock"
"github.com/go-redis/redis"
"github.com/opentracing/opentracing-go"
"github.com/pkg/errors"
)
// UserEntry is a single row in the database
type UserEntry struct {
// ID is some uniquely generated identifier attached to the user
ID string
// Email is the user's email address
Email string
// PasswordHash is the hashed/salted password that will be used
// for comparison/validation
PasswordHashWithSalt string
}
// Db is a persistent database that stores user information
type Db interface {
// Connect actually connects to the database
Connect(ctx context.Context) error
// Ping pings the database to check connectivity
Ping(ctx context.Context) error
// CreateUser creates a user in the database
CreateUser(ctx context.Context, entry UserEntry) error
// GetIDByEmail returns the user's ID or empty string if not found.
// Returns an error if something unexpected goes wrong.
GetIDByEmail(ctx context.Context, email string) (string, error)
// GetUserByID returns the user's entry or nil if not found.
// Returns an error if something unexpected goes wrong.
GetUserByID(ctx context.Context, id string) (*UserEntry, error)
// GetSharedValue returns a stored value; if it does not exist,
// it will atomically store the given value and return that.
GetSharedValue(ctx context.Context, key string, ifNotExist string) (string, error)
// WaitForCreateUser will wait for the user to be created
// in the database before returning, or an error if the user
// was not seen as added before the context cancels
WaitForCreateUser(ctx context.Context, id string) error
}
type db struct {
db *redis.Client
opts ConnectionOptions
}
type ConnectionOptions struct {
Address string
}
func New(opts ConnectionOptions) Db {
return &db{
db: nil,
opts: opts,
}
}
func (d *db) Connect(ctx context.Context) error {
d.db = redis.NewClient(&redis.Options{
Addr: d.opts.Address,
Password: "",
DB: 0,
})
return d.Ping(ctx)
}
func startSpan(ctx context.Context, operationName string) (opentracing.Span, context.Context) {
span, ctx := opentracing.StartSpanFromContext(ctx, operationName)
span.SetTag("db.type", "redis")
span.SetTag("span.kind", "client")
span.SetTag("component", "authdb")
return span, ctx
}
func (d *db) Ping(ctx context.Context) error {
span, ctx := startSpan(ctx, "Redis Ping")
span.SetTag("db.statement", "PING")
defer span.Finish()
if d.db == nil {
return errors.New("db not connected")
}
return d.db.Ping().Err()
}
func (d *db) CreateUser(ctx context.Context, entry UserEntry) error {
span, ctx := startSpan(ctx, "Redis CreateUser")
defer span.Finish()
if entry.Email == "" {
return errors.New("must supply email")
}
val, err := json.Marshal(entry)
if err != nil {
return errors.Wrap(err, "failed to marshal to JSON")
}
// TODO: Transactionify this
err = d.db.Set(keyEmail(entry.Email), entry.ID, 0).Err()
if err != nil {
return errors.Wrap(err, "failed to write email key")
}
err = d.db.Set(keyID(entry.ID), val, 0).Err()
if err != nil {
return errors.Wrap(err, "failed to write ID key")
}
return err
}
func (d *db) WaitForCreateUser(ctx context.Context, id string) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "Wait for user creation")
defer span.Finish()
ps := d.db.PSubscribe("__keyspace@*__:" + keyID(id))
defer ps.Close()
_, err := ps.Receive()
if err != nil {
return err
}
select {
case <-ps.Channel():
return nil
case <-ctx.Done():
return errors.New("context finished before message received")
}
}
func (d *db) GetIDByEmail(ctx context.Context, email string) (string, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "Get ID by email")
defer span.Finish()
id, err := d.db.Get(keyEmail(email)).Result()
if err != nil {
if err == redis.Nil {
return "", nil
}
err = errors.Wrap(err, "failed to get key")
span.SetTag("error", true)
span.SetTag("error.object", err)
return "", err
}
return id, nil
}
func (d *db) GetUserByID(ctx context.Context, id string) (*UserEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "Get user by ID")
defer span.Finish()
entry := &UserEntry{}
raw, err := d.db.Get(keyID(id)).Result()
if err != nil {
if err == redis.Nil {
return nil, nil
}
err = errors.Wrap(err, "failed to get key")
span.SetTag("error", true)
span.SetTag("error.object", err)
return nil, err
}
err = json.Unmarshal([]byte(raw), entry)
if err != nil {
err = errors.Wrap(err, "found key but could not unmarshal json")
span.SetTag("error", true)
span.SetTag("error.object", err)
return nil, err
}
return entry, nil
}
func keyEmail(email string) string {
return "email:" + email
}
func keyID(email string) string {
return "id:" + email
}
func (d *db) GetSharedValue(ctx context.Context, key string, ifNotExist string) (string, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "Get shared value")
span.SetTag("auth.sharedvalue.key", key)
defer span.Finish()
locker := redislock.New(d.db)
lockKey := key + ".lock"
lock, err := locker.Obtain(
lockKey,
50*time.Millisecond,
nil,
)
if err != nil {
span.SetTag("error", true)
span.SetTag("error.object", err)
return "", err
}
defer lock.Release()
d.db.SetNX(key, ifNotExist, 0)
actualID, err := d.db.Get(key).Result()
if err != nil {
span.SetTag("error", true)
span.SetTag("error.object", err)
return "", err
}
return actualID, nil
}