/
oidc.go
181 lines (153 loc) · 6 KB
/
oidc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
package oidc
import (
"context"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"fmt"
"net/http"
"time"
"github.com/bluele/gcache"
"github.com/rs/xid"
"golang.org/x/oauth2"
"go.aporeto.io/trireme-lib/controller/pkg/usertokens/common"
oidc "github.com/coreos/go-oidc"
)
// TokenVerifier is an OIDC validator.
type TokenVerifier struct {
ProviderURL string
ClientID string
ClientSecret string
RedirectURL string
RedirectOnFail bool
RedirectOnNoToken bool
NonceSize int
CookieDuration time.Duration
Scopes []string
provider *oidc.Provider
clientConfig *oauth2.Config
oauthVerifier *oidc.IDTokenVerifier
cache gcache.Cache
state gcache.Cache
}
// NewClient creates a new validator client
func NewClient(ctx context.Context, v *TokenVerifier) (*TokenVerifier, error) {
// Create a new generic OIDC provider based on the provider URL.
// The library will auto-discover the configuration of the provider.
// If it is not a compliant provider we should report and error here.
provider, err := oidc.NewProvider(ctx, v.ProviderURL)
if err != nil {
return nil, fmt.Errorf("Failed to initialize provider: %s", err)
}
oidConfig := &oidc.Config{
ClientID: v.ClientID,
}
v.oauthVerifier = provider.Verifier(oidConfig)
v.clientConfig = &oauth2.Config{
ClientID: v.ClientID,
ClientSecret: v.ClientSecret,
Endpoint: provider.Endpoint(),
RedirectURL: v.RedirectURL,
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
// We maintain two caches. The first maintains the set of states that
// we issue the redirect requests with. This helps us validate the
// callbacks and verify the state to avoid any cross-origin violations.
// Currently providing 60 seconds for the user to authenticate.
v.state = gcache.New(2048).LRU().Expiration(60 * time.Second).Build()
// The second cache will maintain the validations of the tokens so that
// we don't go to the authorizer for every request.
v.cache = gcache.New(2048).LRU().Expiration(120 * time.Second).Build()
return v, nil
}
// IssueRedirect creates the redirect URL. The URI is created by the provider
// and it includes a state that is random. The state will be remembered
// for the return. There is an assumption here that the LBs in front of
// applications are sticky or the TCP session is re-used. Otherwise, we will
// need a global state that could introduce additional calls to a central
// system.
// TODO: add support for a global state.
func (v *TokenVerifier) IssueRedirect(originURL string) string {
state, err := randomSha1(v.NonceSize)
if err != nil {
state = xid.New().String()
}
if err := v.state.Set(state, originURL); err != nil {
return ""
}
return v.clientConfig.AuthCodeURL(state)
}
// Callback is the function that is called back by the IDP to catch the token
// and perform all other validations. It will return the resulting token,
// the original URL that was called to initiate the protocol, and the
// http status response.
func (v *TokenVerifier) Callback(r *http.Request) (string, string, int, error) {
// We first validate that the callback state matches the original redirect
// state. We clean up the cache once it is validated. During this process
// we recover the original URL that initiated the protocol. This allows
// us to redirect the client to their original request.
receivedState := r.URL.Query().Get("state")
originURL, err := v.state.Get(receivedState)
if err != nil {
return "", "", http.StatusBadRequest, fmt.Errorf("Bad state")
}
v.state.Remove(receivedState)
// We exchange the authorization code with an OAUTH token. This is the main
// step where the OAUTH provider will match the code to the token.
oauth2Token, err := v.clientConfig.Exchange(r.Context(), r.URL.Query().Get("code"))
if err != nil {
return "", "", http.StatusInternalServerError, fmt.Errorf("Bad code: %s", err)
}
// We extract the rawID token.
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
return "", "", http.StatusInternalServerError, fmt.Errorf("Bad ID")
}
return rawIDToken, originURL.(string), http.StatusTemporaryRedirect, nil
}
// Validate checks if the token is valid and returns the claims. The validator
// maintains an internal cache with tokens to accelerate performance. If the
// token is not in the cache, it will validate it with the central authorizer.
func (v *TokenVerifier) Validate(ctx context.Context, token string) ([]string, bool, error) {
if len(token) == 0 && v.RedirectOnNoToken {
return []string{}, v.RedirectOnNoToken, fmt.Errorf("Invalid token presented")
}
if data, err := v.cache.Get(token); err == nil {
return data.([]string), false, nil
}
idToken, err := v.oauthVerifier.Verify(ctx, token)
if err != nil {
return []string{}, v.RedirectOnFail, fmt.Errorf("Token validation failed: %s", err)
}
// Get the claims out of the token. Use the standard data structure for
// this and ignore the other fields. We are only interested on the ID.
resp := struct {
IDTokenClaims map[string]interface{} // ID Token payload is just JSON.
}{map[string]interface{}{}}
if err := idToken.Claims(&resp.IDTokenClaims); err != nil {
return []string{}, v.RedirectOnFail, fmt.Errorf("Unable to process claims: %s", err)
}
// Flatten the claims in a generic format.
attributes := []string{}
for k, v := range resp.IDTokenClaims {
attributes = append(attributes, common.FlattenClaim(k, v)...)
}
// Cache the token and attributes to avoid multiple validations.
if err := v.cache.Set(token, attributes); err != nil {
return []string{}, false, fmt.Errorf("Cannot cache token")
}
return attributes, false, nil
}
// VerifierType returns the type of the TokenVerifier.
func (v *TokenVerifier) VerifierType() common.JWTType {
return common.OIDC
}
func randomSha1(nonceSourceSize int) (string, error) {
nonceSource := make([]byte, nonceSourceSize)
_, err := rand.Read(nonceSource)
if err != nil {
return "", err
}
sha := sha1.Sum(nonceSource)
return base64.StdEncoding.EncodeToString(sha[:]), nil
}