/
jwks.go
250 lines (230 loc) · 6.82 KB
/
jwks.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
package auth
// https://github.com/pascaldekloe/jwt
// https://github.com/dgrijalva/jwt-go
// https://github.com/golang-jwt/jwt
// https://github.com/MicahParks/keyfunc
import (
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"math/big"
"net/http"
"strings"
"time"
"github.com/pascaldekloe/jwt"
// jwtgo "github.com/dgrijalva/jwt-go"
// "github.com/MicahParks/keyfunc"
)
// JWKSKeys struct represent structure of JWKS Keys
type Keys struct {
Kid string `json:"kid"`
Kty string `json:"kty"`
Alg string `json:"alg"`
Use string `json:"use"`
N string `json:"n"`
E string `json:"e"`
X5c []string `json:"x5c"`
X5y string `json:"x5y"`
Xt5S256 string `json:"x5t#S256"`
}
// Certs represents structure of JWKS uri
type Certs struct {
Keys []Keys
}
// OpenIDConfiguration holds configuration for OpenID Provider
type OpenIDConfiguration struct {
Issuer string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
IntrospectionEndpoint string `json:"introspection_endpoint"`
UserInfoEndpoint string `json:"userinfo_endpoint"`
EndSessionEndpoint string `json:"end_session_endpoint"`
JWKSUri string `json:"jwks_uri"`
ClaimsSupported []string `json:"claims_supported"`
ScopeSupported []string `json:"scopes_supported"`
RevocationEndpoint string `json:"revocation_endpoint"`
}
type publicKey struct {
key *rsa.PublicKey // RSA public key
kid string // Key Id
}
// Provider holds all information about given provider
type Provider struct {
URL string // provider url
Configuration OpenIDConfiguration // provider OpenID configuration
PublicKeys []publicKey // Public keys of the provider
JWKSBody []byte // jwks body content of the provider
}
// String provides string representation of provider
func (p *Provider) String() string {
data, err := json.MarshalIndent(p, "", " ")
if err != nil {
return fmt.Sprintf("Provider, error=%v", err)
}
return string(data)
}
// Init function initialize provider configuration
func (p *Provider) Init(purl string, verbose int) error {
resp, err := http.Get(fmt.Sprintf("%s/.well-known/openid-configuration", purl))
if err != nil {
log.Println("unable to contact ", purl, " error ", err)
return err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Println("unable to read body of HTTP response ", err)
return err
}
var conf OpenIDConfiguration
err = json.Unmarshal(body, &conf)
if err != nil {
log.Println("unable to unmarshal body of HTTP response ", err)
return err
}
p.URL = purl
p.Configuration = conf
if verbose > 0 {
log.Println("provider configuration", conf)
}
// obtain public key for our OpenID provider, for that we send
// HTTP request to jwks_uri, fetch cert information and decode its public key
resp2, err := http.Get(p.Configuration.JWKSUri)
if err != nil {
log.Println("unable to contact ", p.Configuration.JWKSUri, " error ", err)
return err
}
defer resp2.Body.Close()
body2, err := io.ReadAll(resp2.Body)
if err != nil {
log.Println("unable to read body of HTTP response ", err)
return err
}
var certs Certs
err = json.Unmarshal(body2, &certs)
if err != nil {
log.Println("unable to unmarshal body of HTTP response ", err)
return err
}
p.JWKSBody = body2
for _, key := range certs.Keys {
exp := key.E // exponent
mod := key.N // modulus
kty := key.Kty // kty attribute
if strings.ToLower(kty) != "rsa" {
msg := fmt.Sprintf("not RSA kty key: %s", kty)
log.Println(msg)
return errors.New(msg)
}
pub, err := getPublicKey(exp, mod)
if err != nil {
log.Println("unable to get public key ", err)
return err
}
p.PublicKeys = append(p.PublicKeys, publicKey{pub, key.Kid})
}
if verbose > 0 {
log.Println("\n", p.String())
}
return nil
}
/*
// helper function to check given access token and return its claims
// it is based on github.com/dgrijalva/jwt-go and github.com/MicahParks/keyfunc go packages
func tokenClaims(provider Provider, accessToken string) (map[string]interface{}, error) {
out := make(map[string]interface{})
// Create the JWKS from the resource at the given URL.
jwks, err := keyfunc.New(provider.JWKSBody)
if err != nil {
return out, err
}
// Parse the JWT.
token, err := jwtgo.Parse(accessToken, jwks.KeyFunc)
if err != nil {
return out, err
}
// Check if the token is valid.
if !token.Valid {
msg := "The token is not valid"
return out, errors.New(msg)
}
if claims, ok := token.Claims.(jwtgo.MapClaims); ok {
for k, v := range claims {
out[k] = v
}
}
return out, nil
}
*/
// helper function to get RSA public key from given exponent and modulus
// it is based on implementation of
// https://github.com/MicahParks/keyfunc/blob/master/rsa.go
func getPublicKey(exp, mod string) (*rsa.PublicKey, error) {
// Decode the exponent from Base64.
//
// According to RFC 7518, this is a Base64 URL unsigned integer.
// https://tools.ietf.org/html/rfc7518#section-6.3
var exponent []byte
var err error
if exponent, err = base64.RawURLEncoding.DecodeString(exp); err != nil {
return nil, err
}
// Decode the modulus from Base64.
var modulus []byte
if modulus, err = base64.RawURLEncoding.DecodeString(mod); err != nil {
return nil, err
}
// Create the RSA public key.
publicKey := &rsa.PublicKey{}
// Turn the exponent into an integer.
//
// According to RFC 7517, these numbers are in big-endian format.
// https://tools.ietf.org/html/rfc7517#appendix-A.1
publicKey.E = int(big.NewInt(0).SetBytes(exponent).Uint64())
// Turn the modulus into a *big.Int.
publicKey.N = big.NewInt(0).SetBytes(modulus)
return publicKey, nil
}
// helper function to check access token and return claims map based on
// github.com/pascaldekloe/jwt go package
func tokenClaims(provider Provider, token string) (map[string]interface{}, error) {
out := make(map[string]interface{})
// First parse without checking signature, to get the Kid
claims, err := jwt.ParseWithoutCheck([]byte(token))
log.Println("ParseWithoutCheck returns %v", err)
if err != nil {
return out, err
}
var pub *rsa.PublicKey
for _, pubkey := range provider.PublicKeys {
if claims.KeyID == pubkey.kid {
pub = pubkey.key
break
}
}
if pub == nil {
return out, fmt.Errorf("key id %s not found", claims.KeyID)
}
// verify a JWT
claims, err = jwt.RSACheck([]byte(token), pub)
if err != nil {
return out, err
}
if !claims.Valid(time.Now()) {
msg := "The token is not valid"
return out, errors.New(msg)
}
for k, v := range claims.Set {
out[k] = v
}
t := claims.Registered.Expires.Time()
out["exp"] = t.Unix()
out["sub"] = claims.Subject
out["iss"] = claims.Issuer
out["aud"] = claims.Audiences
return out, nil
}