forked from go-chi/chi
-
Notifications
You must be signed in to change notification settings - Fork 1
/
auth.go
251 lines (214 loc) · 6.78 KB
/
auth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
package middleware
import (
"context"
"errors"
"net/http"
"github.com/lestrrat-go/jwx/v2/jwt"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.philip.id/phi"
"go.philip.id/phi/jwtauth"
)
type Token struct {
ID string `json:"id"`
Subject string `json:"subject"`
}
type TOKEN_TYPE string
// Use this type to get/set the token from the context
const TOKEN_CONTEXT TOKEN_TYPE = "_token"
var (
unauthorizedFunc = phi.Unauthorized
tokenCheckFunc = ImplementAccessCheck
)
// SetUnauthorizedFunc sets the function to be called when a request is unauthorized
//
// default is phi.Unauthorized
func SetUnauthorizedFunc(fn func() *phi.Error) {
unauthorizedFunc = fn
}
// set new tokencheck function, f.e.. check username, password against a database
//
// Example (mongopiet):
//
// func TokenCheck(username, password string) (*Token, error) {
// token, err := database.FindOne("apiTokens", bson.M{"token": username})
// if err != nil {
// return nil, errors.New("not found")
// }
//
// return &phi.Token{
// ID: a.User.Hex(),
// Source: a.Source,
// }
// }
func SetTokenCheckFunc(fn func(username, password string) (*Token, error)) {
tokenCheckFunc = fn
}
// default implementation of token check, wont work in production!
func ImplementAccessCheck(username, password string) (*Token, error) {
return nil, errors.New("not implemented")
}
// GetToken returns the token from the context
func GetToken(r *phi.Request) *Token {
if token, ok := r.Context().Value(TOKEN_CONTEXT).(Token); ok {
return &token
}
return nil
}
// Opinionated helper function to get the user id from the token as primitive.ObjectID
func GetUserID(r *phi.Request) *primitive.ObjectID {
token, ok := r.Context().Value(TOKEN_CONTEXT).(Token)
if !ok {
return nil
}
id, err := primitive.ObjectIDFromHex(token.ID)
if err != nil {
return nil
}
return &id
}
// Checks for bearer token or basic auth and returns unauthorized if not found
//
// unauthorized response can be set via SetTokenCheckFunc
//
// Can be used for endpoints which are gonna be used for a frontend and from an api
// at the same time
//
// Tokens can be extracted like one of the following:
//
// token := r.Context().Value(middleware.TOKEN_CONTEXT).(middleware.Token)
// token := phi.GetToken(r) // only works with *phi.Request
func JWTOrAPIAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if token, err := checkBearer(r); err == nil {
ctx := context.WithValue(r.Context(), TOKEN_CONTEXT, *token)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
if token, err := checkBasic(r); err == nil {
ctx := context.WithValue(r.Context(), TOKEN_CONTEXT, *token)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
phi.ErrorHandler(w, r, unauthorizedFunc())
})
}
// Same as JWTOrAPIAuth but continues without adding the token if unauthorized
//
// Can be used for cases where an authenticated user will receive a different
// response but still has access to the ressource
func JWTOrAPIAuthOptional(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if token, err := checkBearer(r); err == nil {
ctx := context.WithValue(r.Context(), TOKEN_CONTEXT, *token)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
if token, err := checkBasic(r); err == nil {
ctx := context.WithValue(r.Context(), TOKEN_CONTEXT, *token)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
next.ServeHTTP(w, r)
})
}
// Checks for bearer token and returns unauthorized if not found
//
// unauthorized response can be set via SetTokenCheckFunc
//
// Can be used for frontend authentication, middleware expects jwt token at every request,
// needs to be refreshed after expiry
//
// Tokens can be extracted like one of the following:
//
// token := r.Context().Value(middleware.TOKEN_CONTEXT).(middleware.Token)
// token := phi.GetToken(r) // only works with *phi.Request
func JWTAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, err := checkBearer(r)
if err != nil {
phi.ErrorHandler(w, r, unauthorizedFunc())
return
}
ctx := context.WithValue(r.Context(), TOKEN_CONTEXT, *token)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// Same as JWTAuth but continues without adding the token if unauthorized
//
// Can be used for cases where an authenticated user will receive a different
// response but still has access to the ressource
func JWTAuthOptional(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, err := checkBearer(r)
if err == nil {
next.ServeHTTP(w, r)
return
}
ctx := context.WithValue(r.Context(), TOKEN_CONTEXT, *token)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// Checks for basic authentication and returns unauthorized if not found
//
// unauthorized response can be set via SetTokenCheckFunc
//
// Can be used for api authentication, middleware expects basic header at every request,
// token is more likely to be longer available and should not be exposed to the client
//
// Tokens can be extracted like one of the following:
//
// token := r.Context().Value(middleware.TOKEN_CONTEXT).(middleware.Token)
// token := phi.GetToken(r) // only works with *phi.Request
func APIAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, err := checkBasic(r)
if err != nil {
phi.ErrorHandler(w, r, unauthorizedFunc())
return
}
ctx := context.WithValue(r.Context(), TOKEN_CONTEXT, *token)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// Same as APIAuth but continues without adding the token if unauthorized
//
// Can be used for cases where an authenticated user will receive a different
// response but still has access to the ressource
func APIAuthOptional(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, err := checkBasic(r)
if err != nil {
next.ServeHTTP(w, r)
return
}
ctx := context.WithValue(r.Context(), TOKEN_CONTEXT, *token)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// JWT Bearer Token Validation
func checkBearer(r *http.Request) (*Token, error) {
token, _, err := jwtauth.FromContext(r.Context())
if err != nil {
return nil, err
}
if token != nil && jwt.Validate(token) == nil {
t := &Token{}
// Get only ID for now
t.ID = token.JwtID()
t.Subject = token.Subject()
// TODO: add more claim parsing here
return t, nil
}
return nil, errors.New("token invalid")
}
// Basic Auth Validation
func checkBasic(r *http.Request) (*Token, error) {
if username, password, ok := r.BasicAuth(); ok {
t, err := tokenCheckFunc(username, password)
if err != nil {
return nil, err
}
return t, nil
}
return nil, errors.New("no basic auth found")
}