-
Notifications
You must be signed in to change notification settings - Fork 26
/
middleware.go
313 lines (278 loc) · 10.2 KB
/
middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
package auth
import (
"context"
"fmt"
"net/http"
"os"
log "github.com/sirupsen/logrus"
"github.com/appbaseio/reactivesearch-api/middleware"
"github.com/appbaseio/reactivesearch-api/middleware/classify"
"github.com/appbaseio/reactivesearch-api/middleware/validate"
"github.com/appbaseio/reactivesearch-api/model/category"
"github.com/appbaseio/reactivesearch-api/model/credential"
"github.com/appbaseio/reactivesearch-api/model/index"
"github.com/appbaseio/reactivesearch-api/model/op"
"github.com/appbaseio/reactivesearch-api/model/permission"
"github.com/appbaseio/reactivesearch-api/model/trackplugin"
"github.com/appbaseio/reactivesearch-api/model/user"
"github.com/appbaseio/reactivesearch-api/plugins/telemetry"
"github.com/dgrijalva/jwt-go"
"github.com/dgrijalva/jwt-go/request"
"github.com/gorilla/mux"
"golang.org/x/crypto/bcrypt"
)
type chain struct {
middleware.Fifo
}
func (c *chain) Wrap(h http.HandlerFunc) http.HandlerFunc {
return c.Adapt(h, list()...)
}
func list() []middleware.Middleware {
return []middleware.Middleware{
classifyCategory,
classifyIndices,
classify.Op(),
BasicAuth(),
validate.Operation(),
validate.Category(),
telemetry.Recorder(),
}
}
func classifyIndices(h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
publicKeyIndex := os.Getenv(envPublicKeyEsIndex)
if publicKeyIndex == "" {
publicKeyIndex = defaultPublicKeyEsIndex
}
ctx := index.NewContext(req.Context(), []string{publicKeyIndex})
req = req.WithContext(ctx)
h(w, req)
}
}
func classifyCategory(h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
permissionCategory := category.Auth
ctx := category.NewContext(req.Context(), &permissionCategory)
req = req.WithContext(ctx)
h(w, req)
}
}
// BasicAuth middleware authenticates each requests against the basic auth credentials.
func BasicAuth() middleware.Middleware {
return Instance().basicAuth
}
func (a *Auth) basicAuth(h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
reqCategory, err := category.FromContext(ctx)
if err != nil {
log.Errorln(logTag, ": *category.Category not found in request context:", err)
telemetry.WriteBackErrorWithTelemetry(req, w, "error occurred while authenticating the request", http.StatusInternalServerError)
return
}
reqOp, err := op.FromContext(ctx)
if err != nil {
log.Errorln(logTag, ": *op.Op not found the request context:", err)
telemetry.WriteBackErrorWithTelemetry(req, w, "error occurred while authenticating the request", http.StatusInternalServerError)
return
}
username, password, hasBasicAuth := req.BasicAuth()
jwtToken, err := request.ParseFromRequest(req, request.AuthorizationHeaderExtractor, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
}
if a.jwtRsaPublicKey == nil {
return nil, fmt.Errorf("No Public Key Registered")
}
return a.jwtRsaPublicKey, nil
})
if !hasBasicAuth && err != nil {
var msg string
if err == request.ErrNoTokenInRequest {
msg = "Basic Auth or JWT is required"
} else {
msg = fmt.Sprintf("Unable to parse JWT: %v", err)
}
w.Header().Set("www-authenticate", "Basic realm=\"Authentication Required\"")
telemetry.WriteBackErrorWithTelemetry(req, w, msg, http.StatusUnauthorized)
return
}
role := ""
if !hasBasicAuth {
if claims, ok := jwtToken.Claims.(jwt.MapClaims); ok && jwtToken.Valid {
if a.jwtRoleKey != "" && claims[a.jwtRoleKey] != nil {
role = claims[a.jwtRoleKey].(string)
} else if u, ok := claims["role"]; ok {
role = u.(string)
} else {
w.Header().Set("www-authenticate", "Basic realm=\"Authentication Required\"")
telemetry.WriteBackErrorWithTelemetry(req, w, fmt.Sprintf("Invalid JWT"), http.StatusUnauthorized)
return
}
} else {
w.Header().Set("www-authenticate", "Basic realm=\"Authentication Required\"")
telemetry.WriteBackErrorWithTelemetry(req, w, fmt.Sprintf("Invalid JWT"), http.StatusUnauthorized)
return
}
}
// we don't know if the credentials provided here are of a 'user' or a 'permission'
var obj credential.AuthCredential
if role != "" {
obj, err = a.es.getRolePermission(ctx, role)
if err != nil || obj == nil {
msg := fmt.Sprintf("No API credentials match with provided role: %s", role)
log.Errorln(logTag, ":", err)
w.Header().Set("www-authenticate", "Basic realm=\"Authentication Required\"")
telemetry.WriteBackErrorWithTelemetry(req, w, msg, http.StatusUnauthorized)
return
}
} else {
obj, err = a.getCredential(ctx, username)
if err != nil || obj == nil {
msg := fmt.Sprintf("No API credentials match with provided username: %s", username)
log.Warnln(logTag, ":", err)
w.Header().Set("www-authenticate", "Basic realm=\"Authentication Required\"")
telemetry.WriteBackErrorWithTelemetry(req, w, msg, http.StatusUnauthorized)
return
}
}
var authenticated bool
var errorMsg = "invalid credentials provided"
// since we are able to fetch a result with the given credentials, we
// do not need to validate the username and password.
switch obj.(type) {
case *user.User:
{
reqUser := obj.(*user.User)
// track `user` middleware
ctx := trackplugin.TrackPlugin(ctx, "au")
req = req.WithContext(ctx)
// No need to validate if already validated before
if hasBasicAuth && !IsPasswordExist(reqUser.Username, password) && bcrypt.CompareHashAndPassword([]byte(reqUser.Password), []byte(password)) != nil {
w.Header().Set("www-authenticate", "Basic realm=\"Authentication Required\"")
telemetry.WriteBackErrorWithTelemetry(req, w, "invalid password", http.StatusUnauthorized)
return
}
// Save validated username to avoid the bcrypt comparison
SavePassword(reqUser.Username, password)
// ignore es auth for root route to fetch the cluster details
if (req.Method == http.MethodGet || req.Method == http.MethodHead) && req.RequestURI == "/" {
authenticated = true
} else if *reqUser.IsAdmin {
authenticated = true
} else if reqCategory.IsFromES() {
// if the request is made to elasticsearch using user credentials,
// then allow the access based on the categories present
if reqUser.HasCategory(*reqCategory) {
authenticated = true
} else {
errorMsg = "user not allowed to access elasticsearch"
}
} else if reqCategory.IsFromRS() {
// if the request is made to reactivesearch api using user credentials,
// then allow the access based on the `reactivesearch` category
if reqUser.HasCategory(category.ReactiveSearch) {
authenticated = true
errorMsg = "user not allowed to access reactivesearch API"
} else {
errorMsg = "user not allowed to access elasticsearch"
}
} else {
authenticated = true
}
// cache the user
if _, ok := GetCachedCredential(username); !ok {
SaveCredentialToCache(username, reqUser)
}
// store request user and credential identifier in the context
ctx = credential.NewContext(ctx, credential.User)
ctx = user.NewContext(ctx, reqUser)
req = req.WithContext(ctx)
}
case *permission.Permission:
{
// track `permission` middleware
ctx := trackplugin.TrackPlugin(ctx, "ap")
req = req.WithContext(ctx)
reqPermission := obj.(*permission.Permission)
if hasBasicAuth && reqPermission.Password != password {
w.Header().Set("www-authenticate", "Basic realm=\"Authentication Required\"")
telemetry.WriteBackErrorWithTelemetry(req, w, "invalid password", http.StatusUnauthorized)
return
}
// ignore es auth for root route to fetch the cluster details
if req.Method == http.MethodGet && req.RequestURI == "/" {
authenticated = true
} else if reqPermission.HasCategory(*reqCategory) {
authenticated = true
} else {
str := (*reqCategory).String()
errorMsg = "credential is not allowed to access" + " " + str
}
// cache the permission
if _, ok := GetCachedCredential(username); !ok {
SaveCredentialToCache(username, reqPermission)
}
// store the request permission and credential identifier in the context
ctx = credential.NewContext(ctx, credential.Permission)
ctx = permission.NewContext(ctx, reqPermission)
req = req.WithContext(ctx)
}
default:
log.Println(logTag, ": unreachable state ...")
}
if !authenticated {
w.Header().Set("www-authenticate", "Basic realm=\"Authentication Required\"")
telemetry.WriteBackErrorWithTelemetry(req, w, errorMsg, http.StatusUnauthorized)
return
}
// remove user/permission from cache on write operation
if *reqOp == op.Write || *reqOp == op.Delete {
username := mux.Vars(req)["username"]
RemoveCredentialFromCache(username)
}
h(w, req)
}
}
func (a *Auth) getCredential(ctx context.Context, username string) (credential.AuthCredential, error) {
c, ok := GetCachedCredential(username)
if ok {
return c, nil
}
return a.es.getCredential(ctx, username)
}
// GetCachedCredential returns the cached credential
func GetCachedCredential(username string) (credential.AuthCredential, bool) {
CredentialCache.mu.Lock()
defer CredentialCache.mu.Unlock()
if c, ok := CredentialCache.cache[username]; ok {
return c, ok
}
return nil, false
}
// GetCachedCredentials returns the cached credentials
func GetCachedCredentials() []credential.AuthCredential {
CredentialCache.mu.Lock()
defer CredentialCache.mu.Unlock()
var credentials []credential.AuthCredential
for _, v := range CredentialCache.cache {
credentials = append(credentials, v)
}
return credentials
}
// RemoveCredentialFromCache removes the credential from the cache
func RemoveCredentialFromCache(username string) {
CredentialCache.mu.Lock()
defer CredentialCache.mu.Unlock()
delete(CredentialCache.cache, username)
}
// SaveCredentialToCache saves the credential to the cache
func SaveCredentialToCache(username string, c credential.AuthCredential) {
if c == nil {
log.Println(logTag, ": cannot cache 'nil' credential, skipping...")
return
}
CredentialCache.mu.Lock()
CredentialCache.cache[username] = c
CredentialCache.mu.Unlock()
}