-
Notifications
You must be signed in to change notification settings - Fork 43
/
auth.go
434 lines (358 loc) · 11.4 KB
/
auth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
package jwt
import (
"errors"
"io/ioutil"
"net/http"
"reflect"
"strconv"
"time"
jwtGo "github.com/dgrijalva/jwt-go"
)
// Auth is a middleware that provides jwt based authentication.
type Auth struct {
signKey interface{}
verifyKey interface{}
options Options
// Handlers for when an error occurs
errorHandler http.Handler
unauthorizedHandler http.Handler
// funcs for checking and revoking refresh tokens
revokeRefreshToken TokenRevoker
checkTokenId TokenIdChecker
}
// Options is a struct for specifying configuration options
type Options struct {
SigningMethodString string
PrivateKeyLocation string
PublicKeyLocation string
HMACKey []byte
VerifyOnlyServer bool
BearerTokens bool
RefreshTokenValidTime time.Duration
AuthTokenValidTime time.Duration
AuthTokenName string
RefreshTokenName string
CSRFTokenName string
Debug bool
IsDevEnv bool
}
const (
defaultRefreshTokenValidTime = 72 * time.Hour
defaultAuthTokenValidTime = 15 * time.Minute
defaultBearerAuthTokenName = "X-Auth-Token"
defaultBearerRefreshTokenName = "X-Refresh-Token"
defaultCSRFTokenName = "X-CSRF-Token"
defaultCookieAuthTokenName = "AuthToken"
defaultCookieRefreshTokenName = "RefreshToken"
)
// ClaimsType : holds the claims encoded in the jwt
type ClaimsType struct {
// Standard claims are the standard jwt claims from the ietf standard
// https://tools.ietf.org/html/rfc7519
jwtGo.StandardClaims
Csrf string
CustomClaims map[string]interface{}
}
func defaultTokenRevoker(tokenId string) error {
return nil
}
// TokenRevoker : a type to revoke tokens
type TokenRevoker func(tokenId string) error
func defaultCheckTokenId(tokenId string) bool {
// return true if the token id is valid (has not been revoked). False for otherwise
return true
}
// TokenIdChecker : a type to check tokens
type TokenIdChecker func(tokenId string) bool
func defaultErrorHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Internal Server Error", 500)
return
}
func defaultUnauthorizedHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Unauthorized", 401)
return
}
// New constructs a new Auth instance with supplied options.
func New(auth *Auth, o Options) error {
// check if durations have been provided for auth and refresh token exp
// if not, set them equal to the default
if o.RefreshTokenValidTime <= 0 {
o.RefreshTokenValidTime = defaultRefreshTokenValidTime
}
if o.AuthTokenValidTime <= 0 {
o.AuthTokenValidTime = defaultAuthTokenValidTime
}
if o.BearerTokens {
if o.AuthTokenName == "" {
o.AuthTokenName = defaultBearerAuthTokenName
}
if o.RefreshTokenName == "" {
o.RefreshTokenName = defaultBearerRefreshTokenName
}
} else {
if o.AuthTokenName == "" {
o.AuthTokenName = defaultCookieAuthTokenName
}
if o.RefreshTokenName == "" {
o.RefreshTokenName = defaultCookieRefreshTokenName
}
}
if o.CSRFTokenName == "" {
o.CSRFTokenName = defaultCSRFTokenName
}
// create the sign and verify keys
signKey, verifyKey, err := o.buildSignAndVerifyKeys()
if err != nil {
return err
}
auth.signKey = signKey
auth.verifyKey = verifyKey
auth.options = o
auth.errorHandler = http.HandlerFunc(defaultErrorHandler)
auth.unauthorizedHandler = http.HandlerFunc(defaultUnauthorizedHandler)
auth.revokeRefreshToken = TokenRevoker(defaultTokenRevoker)
auth.checkTokenId = TokenIdChecker(defaultCheckTokenId)
return nil
}
func (o *Options) buildSignAndVerifyKeys() (signKey interface{}, verifyKey interface{}, err error) {
if o.SigningMethodString == "HS256" || o.SigningMethodString == "HS384" || o.SigningMethodString == "HS512" {
return o.buildHMACKeys()
} else if o.SigningMethodString == "RS256" || o.SigningMethodString == "RS384" || o.SigningMethodString == "RS512" {
return o.buildRSAKeys()
} else if o.SigningMethodString == "ES256" || o.SigningMethodString == "ES384" || o.SigningMethodString == "ES512" {
return o.buildESKeys()
}
err = errors.New("Signing method string not recognized!")
return
}
func (o *Options) buildHMACKeys() (signKey interface{}, verifyKey interface{}, err error) {
if len(o.HMACKey) == 0 {
err = errors.New("When using an HMAC-SHA signing method, please provide an HMACKey")
return
}
if !o.VerifyOnlyServer {
signKey = o.HMACKey
}
verifyKey = o.HMACKey
return
}
func (o *Options) buildRSAKeys() (signKey interface{}, verifyKey interface{}, err error) {
var signBytes []byte
var verifyBytes []byte
// check to make sure the provided options are valid
if o.PrivateKeyLocation == "" && !o.VerifyOnlyServer {
err = errors.New("Private key location is required!")
return
}
if o.PublicKeyLocation == "" {
err = errors.New("Public key location is required!")
return
}
// read the key files
if !o.VerifyOnlyServer {
signBytes, err = ioutil.ReadFile(o.PrivateKeyLocation)
if err != nil {
return
}
signKey, err = jwtGo.ParseRSAPrivateKeyFromPEM(signBytes)
if err != nil {
return
}
}
verifyBytes, err = ioutil.ReadFile(o.PublicKeyLocation)
if err != nil {
return
}
verifyKey, err = jwtGo.ParseRSAPublicKeyFromPEM(verifyBytes)
if err != nil {
return
}
return
}
func (o *Options) buildESKeys() (signKey interface{}, verifyKey interface{}, err error) {
var signBytes []byte
var verifyBytes []byte
// check to make sure the provided options are valid
if o.PrivateKeyLocation == "" && !o.VerifyOnlyServer {
err = errors.New("Private key location is required!")
return
}
if o.PublicKeyLocation == "" {
err = errors.New("Public key location is required!")
return
}
// read the key files
if !o.VerifyOnlyServer {
signBytes, err = ioutil.ReadFile(o.PrivateKeyLocation)
if err != nil {
return
}
signKey, err = jwtGo.ParseECPrivateKeyFromPEM(signBytes)
if err != nil {
return
}
}
verifyBytes, err = ioutil.ReadFile(o.PublicKeyLocation)
if err != nil {
return
}
verifyKey, err = jwtGo.ParseECPublicKeyFromPEM(verifyBytes)
if err != nil {
return
}
return
}
// SetErrorHandler : add methods to allow the changing of default functions
func (a *Auth) SetErrorHandler(handler http.Handler) {
a.errorHandler = handler
}
// SetUnauthorizedHandler : set the 401 handler
func (a *Auth) SetUnauthorizedHandler(handler http.Handler) {
a.unauthorizedHandler = handler
}
// SetRevokeTokenFunction : set the function which revokes a token
func (a *Auth) SetRevokeTokenFunction(revoker TokenRevoker) {
a.revokeRefreshToken = revoker
}
// SetCheckTokenIdFunction : set the function which checks token id's
func (a *Auth) SetCheckTokenIdFunction(checker TokenIdChecker) {
a.checkTokenId = checker
}
// Handler implements the http.HandlerFunc for integration with the standard net/http lib.
func (a *Auth) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Process the request. If it returns an error,
// that indicates the request should not continue.
jwtErr := a.Process(w, r)
var j jwtError
// If there was an error, do not continue.
if jwtErr != nil {
a.myLog("Error processing jwts\n" + jwtErr.Error())
_ = a.NullifyTokens(w, r)
if reflect.TypeOf(jwtErr) == reflect.TypeOf(&j) && jwtErr.Type/100 == 4 {
a.unauthorizedHandler.ServeHTTP(w, r)
return
}
a.errorHandler.ServeHTTP(w, r)
return
}
h.ServeHTTP(w, r)
})
}
// HandlerFunc works identically to Handler, but takes a HandlerFunc instead of a Handler.
func (a *Auth) HandlerFunc(fn http.HandlerFunc) http.Handler {
if fn == nil {
return a.Handler(nil)
}
return a.Handler(fn)
}
// HandlerFuncWithNext is a special implementation for Negroni, but could be used elsewhere.
func (a *Auth) HandlerFuncWithNext(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
jwtErr := a.Process(w, r)
var j jwtError
// If there was an error, do not call next.
if jwtErr == nil && next != nil {
next(w, r)
} else {
a.myLog("Error processing jwts\n" + jwtErr.Error())
_ = a.NullifyTokens(w, r)
if reflect.TypeOf(jwtErr) == reflect.TypeOf(&j) && jwtErr.Type/100 == 4 {
a.unauthorizedHandler.ServeHTTP(w, r)
} else {
a.errorHandler.ServeHTTP(w, r)
}
}
}
// Process runs the actual checks and returns an error if the middleware chain should stop.
func (a *Auth) Process(w http.ResponseWriter, r *http.Request) *jwtError {
// cookies aren't included with options, so simply pass through
if r.Method == "OPTIONS" {
a.myLog("Method is OPTIONS")
return nil
}
// grab the credentials from the request
var c credentials
if err := a.buildCredentialsFromRequest(r, &c); err != nil {
return newJwtError(err, 500)
}
// check the credential's validity; updating expiry's if necessary and/or allowed
if err := c.validateAndUpdateCredentials(); err != nil {
return newJwtError(err, 500)
}
a.myLog("Successfully checked / refreshed jwts")
// if we've made it this far, everything is valid!
// And tokens have been refreshed if need-be
if !a.options.VerifyOnlyServer {
if err := a.setCredentialsOnResponseWriter(w, &c); err != nil {
return newJwtError(err, 500)
}
}
return nil
}
// IssueNewTokens : and also modify create refresh and auth token functions!
func (a *Auth) IssueNewTokens(w http.ResponseWriter, claims *ClaimsType) error {
if a.options.VerifyOnlyServer {
a.myLog("Server is not authorized to issue new tokens")
return errors.New("Server is not authorized to issue new tokens")
}
var c credentials
err := a.buildCredentialsFromClaims(&c, claims)
if err != nil {
return errors.New(err.Error())
}
err = a.setCredentialsOnResponseWriter(w, &c)
if err != nil {
return errors.New(err.Error())
}
return nil
}
// NullifyTokens : invalidate tokens
// note @adam-hanna: what if there are no credentials in the request?
func (a *Auth) NullifyTokens(w http.ResponseWriter, r *http.Request) error {
var c credentials
err := a.buildCredentialsFromRequest(r, &c)
if err != nil {
a.myLog("Err building credentials\n" + err.Error())
return errors.New(err.Error())
}
if a.options.BearerTokens {
// tokens are not in cookies
setHeader(w, a.options.AuthTokenName, "")
setHeader(w, a.options.RefreshTokenName, "")
} else {
authCookie := http.Cookie{
Name: a.options.AuthTokenName,
Value: "",
Expires: time.Now().Add(-1000 * time.Hour),
HttpOnly: true,
Secure: !a.options.IsDevEnv,
}
http.SetCookie(w, &authCookie)
refreshCookie := http.Cookie{
Name: a.options.RefreshTokenName,
Value: "",
Expires: time.Now().Add(-1000 * time.Hour),
HttpOnly: true,
Secure: !a.options.IsDevEnv,
}
http.SetCookie(w, &refreshCookie)
}
refreshTokenClaims := c.RefreshToken.Token.Claims.(*ClaimsType)
a.revokeRefreshToken(refreshTokenClaims.StandardClaims.Id)
setHeader(w, a.options.CSRFTokenName, "")
setHeader(w, "Auth-Expiry", strconv.FormatInt(time.Now().Add(-1000*time.Hour).Unix(), 10))
setHeader(w, "Refresh-Expiry", strconv.FormatInt(time.Now().Add(-1000*time.Hour).Unix(), 10))
a.myLog("Successfully nullified tokens and csrf string")
return nil
}
// GrabTokenClaims : extract the claims from the request
// note: we always grab from the authToken
func (a *Auth) GrabTokenClaims(r *http.Request) (ClaimsType, error) {
var c credentials
err := a.buildCredentialsFromRequest(r, &c)
if err != nil {
a.myLog("Err grabbing credentials \n" + err.Error())
return ClaimsType{}, errors.New(err.Error())
}
return *c.AuthToken.Token.Claims.(*ClaimsType), nil
}