-
Notifications
You must be signed in to change notification settings - Fork 0
/
authentication.go
81 lines (70 loc) · 2.27 KB
/
authentication.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
package security
import (
"context"
"errors"
"fmt"
"net/http"
"github.com/AleF83/RPGM/services/backend/appConfig"
jwt "github.com/dgrijalva/jwt-go"
"github.com/dgrijalva/jwt-go/request"
"github.com/lestrrat/go-jwx/jwk"
)
// NewAuthenticationMiddleware - create authentication middleware that verifies JWT
func NewAuthenticationMiddleware(providers map[string]appConfig.AuthProvider) func(http.Handler) http.Handler {
keyFunc := func(t *jwt.Token) (interface{}, error) {
claims := t.Claims.(jwt.MapClaims)
if issuer, ok := claims["iss"].(string); ok {
if provider, exists := getProviderByIssuer(providers, issuer); exists {
if keyID, keyExists := t.Header["kid"].(string); keyExists {
key, err := getKeyByProvider(provider, keyID)
if err != nil {
return nil, err
}
return key, nil
}
return nil, errors.New("kid field not found in JWT header")
}
return nil, errors.New("Issuer field not found in JWT claims")
}
return nil, errors.New("Issuer not found in JWT claims")
}
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
token, err := request.ParseFromRequest(r, request.OAuth2Extractor, keyFunc)
if err != nil {
http.Error(w, fmt.Sprintf("Authentication Error: %v", err), http.StatusUnauthorized)
return
}
newRequest := createRequestWithContext(r, token)
next.ServeHTTP(w, newRequest)
}
return http.HandlerFunc(fn)
}
}
func getProviderByIssuer(providers map[string]appConfig.AuthProvider, issuer string) (*appConfig.AuthProvider, bool) {
for _, provider := range providers {
if provider.Issuer == issuer {
return &provider, true
}
}
return nil, false
}
func getKeyByProvider(provider *appConfig.AuthProvider, keyID string) (interface{}, error) {
keySet, err := jwk.FetchHTTP(provider.JWKsURL)
if err != nil {
return nil, err
}
keys := keySet.LookupKeyID(keyID)
if len(keys) == 0 {
return nil, errors.New("No keys found")
}
if len(keys) > 1 {
return nil, errors.New("More then one key found")
}
return keys[0].Materialize()
}
type contextKey string
func createRequestWithContext(r *http.Request, t *jwt.Token) *http.Request {
userInfo := map[string]string{}
return r.WithContext(context.WithValue(r.Context(), contextKey("user"), userInfo))
}