/
gosamlserviceprovider.go
428 lines (368 loc) · 14.3 KB
/
gosamlserviceprovider.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
package samlprovider
import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"regexp"
"strings"
securityprotocol "github.com/KvalitetsIT/gosecurityprotocol"
saml2 "github.com/russellhaering/gosaml2"
"github.com/russellhaering/gosaml2/types"
dsig "github.com/russellhaering/goxmldsig"
"go.uber.org/zap"
)
type SamlServiceProviderConfig struct {
ServiceProviderKeystore *tls.Certificate
EntityId string
CookieDomain string
CookiePath string
AudienceRestriction string
SignAuthnRequest bool
IdpMetaDataUrl string
SessionHeaderName string
SessionExpiryHours string
SessiondataHeaderName string
SkipSignatureValidation bool
ExternalUrl string
SamlMetadataPath string
SamlLogoutPath string
SamlSLOPath string
SamlSSOPath string
LogoutLandingPage string
RoleAttributeName string
AllowedRoles []string
Logger *zap.SugaredLogger
}
type SamlServiceProvider struct {
sessionCache securityprotocol.SessionCache
sessionHeaderName string
SessiondataHeaderName string
externalUrl string
SamlServiceProvider *saml2.SAMLServiceProvider
SamlHandler *SamlHandler
Logger *zap.SugaredLogger
Config *SamlServiceProviderConfig
}
func NewSamlServiceProviderFromConfig(config *SamlServiceProviderConfig, sessionCache securityprotocol.SessionCache) (*SamlServiceProvider, error) {
samlServiceProvider, err := createSamlServiceProvider(config)
if err != nil {
return nil, err
}
return newSamlServiceProvider(samlServiceProvider, sessionCache, config), nil
}
func newSamlServiceProvider(samlServiceProvider *saml2.SAMLServiceProvider, sessionCache securityprotocol.SessionCache, config *SamlServiceProviderConfig) *SamlServiceProvider {
s := new(SamlServiceProvider)
s.SamlServiceProvider = samlServiceProvider
s.sessionCache = wrappingSessionCache(sessionCache)
s.sessionHeaderName = config.SessionHeaderName
s.externalUrl = config.ExternalUrl
s.SessiondataHeaderName = config.SessiondataHeaderName
s.SamlHandler = NewSamlHandler(config, s)
s.Logger = config.Logger
// todo: ask Eva if this is okay
s.Config = config
return s
}
type WrappingSessionCache struct {
sessionCache securityprotocol.SessionCache
}
func (w WrappingSessionCache) SaveSessionData(data *securityprotocol.SessionData) error {
// See https://github.com/keycloak/keycloak/issues/14529
// Perform substitution of ascii escape character (see https://github.com/keycloak/keycloak/issues/14529 and https://www.w3.org/Signature/Drafts/WD-xml-c14n-20000907.html#Example-Chars)
assertion, err := base64.StdEncoding.DecodeString(string(data.Authenticationtoken))
if err != nil {
return err
}
assertionStr := string(assertion)
if strings.Contains(assertionStr, " \n") {
assertionStr = strings.ReplaceAll(assertionStr, " \n", "")
data.Authenticationtoken = base64.StdEncoding.EncodeToString([]byte(assertionStr))
}
return w.sessionCache.SaveSessionData(data)
}
func (w WrappingSessionCache) FindSessionDataForSessionId(sessionId string) (*securityprotocol.SessionData, error) {
return w.sessionCache.FindSessionDataForSessionId(sessionId)
}
func (w WrappingSessionCache) DeleteSessionData(sessionId string) error {
return w.sessionCache.DeleteSessionData(sessionId)
}
func wrappingSessionCache(cache securityprotocol.SessionCache) securityprotocol.SessionCache {
w := new(WrappingSessionCache)
w.sessionCache = cache
return w
}
func GetSessionCache(samlServiceProvider *SamlServiceProvider) *securityprotocol.SessionCache {
return &samlServiceProvider.sessionCache
}
func createSamlServiceProvider(config *SamlServiceProviderConfig) (*saml2.SAMLServiceProvider, error) {
// Read and parse the IdP metadata
rawMetadata, err := DownloadIdpMetadata(config)
if err != nil {
config.Logger.Errorf("Error downloading IdP metadata: %s", err.Error())
return nil, err
}
idpMetadata := &types.EntityDescriptor{}
err = xml.Unmarshal(rawMetadata, idpMetadata)
if err != nil {
config.Logger.Errorf("Cannot unmarshal IDP metadata: %s", err.Error())
return nil, err
}
certStore := dsig.MemoryX509CertificateStore{
Roots: []*x509.Certificate{},
}
for _, kd := range idpMetadata.IDPSSODescriptor.KeyDescriptors {
for idx, xcert := range kd.KeyInfo.X509Data.X509Certificates {
if xcert.Data == "" {
return nil, fmt.Errorf("metadata certificate(%d) must not be empty", idx)
}
certData, err := base64.StdEncoding.DecodeString(xcert.Data)
if err != nil {
config.Logger.Errorf("Error decoding certificate: %s", err.Error())
return nil, err
}
idpCert, err := x509.ParseCertificate(certData)
if err != nil {
config.Logger.Errorf("Error parsing certificate: %s", err.Error())
return nil, err
}
certStore.Roots = append(certStore.Roots, idpCert)
}
}
spKeyStore := dsig.TLSCertKeyStore(*config.ServiceProviderKeystore)
sp := &saml2.SAMLServiceProvider{
IdentityProviderSSOURL: idpMetadata.IDPSSODescriptor.SingleSignOnServices[0].Location,
IdentityProviderSLOURL: idpMetadata.IDPSSODescriptor.SingleLogoutServices[0].Location,
IdentityProviderIssuer: idpMetadata.EntityID,
ServiceProviderIssuer: config.EntityId,
AssertionConsumerServiceURL: config.AssertionConsumerServiceUrl(),
ServiceProviderSLOURL: config.SloConsumerServiceUrl(),
SignAuthnRequests: config.SignAuthnRequest,
AudienceURI: config.AudienceRestriction,
IDPCertificateStore: &certStore,
SPKeyStore: spKeyStore,
SkipSignatureValidation: config.SkipSignatureValidation,
}
//signingContext := sp.SigningContext()
//signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList("")
return sp, nil
}
func (config *SamlServiceProviderConfig) AssertionConsumerServiceUrl() string {
return buildUrl(config.ExternalUrl, config.SamlSSOPath)
}
func (config *SamlServiceProviderConfig) SloConsumerServiceUrl() string {
return buildUrl(config.ExternalUrl, config.SamlSLOPath)
}
func buildUrl(baseUrl string, path string) string {
trailingPattern := regexp.MustCompile("/$")
leadingPattern := regexp.MustCompile("^/?(.*)$")
baseUrl = trailingPattern.ReplaceAllString(baseUrl, "")
path = leadingPattern.ReplaceAllString(path, "/${1}")
return baseUrl + path
}
func DownloadIdpMetadata(config *SamlServiceProviderConfig) ([]byte, error) {
//download metadata from idp
config.Logger.Infof("Downloading IDP metadata from: %s", config.IdpMetaDataUrl)
resp, err := http.Get(config.IdpMetaDataUrl)
if err != nil {
config.Logger.Errorf("Cannot download metadata: %s", err.Error())
return nil, err
}
if resp.StatusCode != http.StatusOK {
config.Logger.Errorf("Cannot download metadata: %s", err.Error())
return nil, errors.New("Cannot download metadata")
}
defer resp.Body.Close()
bodyBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
config.Logger.Errorf("Cannot download metadata: %s", err.Error())
return nil, err
}
return EntityDescriptor(bodyBytes)
}
func validateRole(roles []string, attributeName string, sessionData *securityprotocol.SessionData) error {
// initialize role map
containRoles := map[string]bool{}
for _, role := range roles {
containRoles[role] = false
}
// get available roles
presentedRoles, ok := sessionData.UserAttributes[attributeName]
if !ok {
return errors.New(fmt.Sprintf("no field with attribute name %s present", attributeName))
}
for _, role := range presentedRoles {
if _, ok := containRoles[role]; ok {
containRoles[role] = true
}
}
// check all roles are set
for k, v := range containRoles {
if !v {
return errors.New(fmt.Sprintf("role %s not set", k))
}
}
return nil
}
func (a SamlServiceProvider) HandleService(w http.ResponseWriter, r *http.Request, service securityprotocol.HttpHandler) (int, error) {
if a.SamlHandler.isSamlProtocol(r) {
a.Logger.Debugf("Handling request as SAML")
return a.SamlHandler.Handle(w, r)
}
// Get the session id
sessionId := a.SamlHandler.GetSessionId(r)
a.Logger.Debugf("SessionId: %s", sessionId)
// The request identifies a session, check that the session is valid and get it
if sessionId != "" {
sessionData, err := a.sessionCache.FindSessionDataForSessionId(sessionId)
if err != nil {
a.Logger.Errorf("Cannot look up session in cache: %v", err.Error())
return http.StatusInternalServerError, err
}
if sessionData != nil {
// if allowed roles is set, validate if session data contains a valid role
if a.Config != nil && len(a.Config.AllowedRoles) > 0 && a.Config.RoleAttributeName != "" {
// build allowed role list; each item in list means OR and spaces inside item means AND: eg. AllowedRoles=["admin public", "root", "kit test"]
// translates to (admin AND public) OR (root) OR (kit AND test)
roleErr := errors.New("could not find a valid role")
for _, role := range a.Config.AllowedRoles {
role = strings.TrimSpace(role)
andRoles := strings.Fields(role)
// check if session data contains valid roles set
if err := validateRole(andRoles, a.Config.RoleAttributeName, sessionData); err == nil {
// exit out of loop since a valid role is already found
roleErr = nil
break
}
}
if roleErr != nil {
a.Logger.Error(roleErr.Error())
return http.StatusUnauthorized, roleErr
}
}
// Check if the user is requesting sessiondata
handlerFunc := securityprotocol.IsRequestForSessionData(sessionData, a.sessionCache, w, r)
if handlerFunc != nil {
a.Logger.Debugf("Handling session data request")
return handlerFunc()
}
// The session id ok ... pass-through to next handler
r.Header.Add(a.sessionHeaderName, sessionId)
if len(a.SessiondataHeaderName) > 0 {
sessionDataValue, err := getSessionDataValue(sessionData)
if err != nil {
a.Logger.Error(fmt.Sprintf("Error '%s' creating sessiondatavalue for header (sesssionid: %s)", err.Error(), sessionId))
return http.StatusInternalServerError, err
}
r.Header.Set(a.SessiondataHeaderName, sessionDataValue)
}
return service.Handle(w, r)
}
}
authenticateStatusCode, err := a.GenerateAuthenticationRequest(w, r)
return authenticateStatusCode, err
}
func getSessionDataValue(sessionData *securityprotocol.SessionData) (string, error) {
sessionDataBytes, marshalErr := json.Marshal(sessionData)
if marshalErr != nil {
return "", marshalErr
}
encodedData := base64.StdEncoding.EncodeToString(sessionDataBytes)
return encodedData, nil
}
func (a *SamlServiceProvider) CreateLogoutResponse(logoutRequest *saml2.LogoutRequest, w http.ResponseWriter) (int, error) {
status := saml2.StatusCodeSuccess
relayState := ""
responseDocTree, err := a.SamlServiceProvider.BuildLogoutResponseDocument(status, logoutRequest.ID)
if err != nil {
a.Logger.Errorf("Error building logout response: %s", err.Error())
return http.StatusInternalServerError, err
}
responseBytes, err := a.SamlServiceProvider.BuildLogoutResponseBodyPostFromDocument(relayState, responseDocTree)
if err != nil {
a.Logger.Errorf("Error building logout response post from document: %s", err.Error())
return http.StatusInternalServerError, err
}
w.Write(responseBytes)
return http.StatusOK, err
}
func (a *SamlServiceProvider) ParseLogoutPayload(r *http.Request) (*saml2.LogoutRequest, *types.LogoutResponse, error) {
encodedRequest, err := ioutil.ReadAll(r.Body)
if err != nil {
a.Logger.Errorf("Error reading body of logout request: %s", err.Error())
return nil, nil, err
}
encodedRequestString := string(encodedRequest)
a.Logger.Debugf("Considering logout payload: %s", encodedRequestString)
if len(encodedRequest) == 0 {
return nil, nil, nil
}
if strings.HasPrefix(encodedRequestString, "SAMLResponse=") {
urlEncoded := encodedRequestString[13:len(encodedRequestString)]
urlDecoded, err := url.QueryUnescape(urlEncoded)
if err != nil {
// Lets assume it was not urlDecoded
urlDecoded = urlEncoded
}
a.Logger.Debugf("Processing payload: as SAMLResponse %s", urlDecoded)
logoutResponse, err := a.SamlServiceProvider.ValidateEncodedLogoutResponsePOST(urlDecoded)
if err != nil {
a.Logger.Errorf("Error validating encoded logout response (decoded payload: %s) (error: %s)", urlDecoded, err.Error())
return nil, logoutResponse, err
}
if logoutResponse == nil {
a.Logger.Errorf("Could not validate logoutResponse: %s", encodedRequestString)
return nil, nil, errors.New("Could not validate logoutResponse")
}
return nil, logoutResponse, nil
}
if strings.HasPrefix(encodedRequestString, "SAMLRequest=") {
urlEncoded := encodedRequestString[12:len(encodedRequestString)]
urlDecoded, err := url.QueryUnescape(urlEncoded)
if err != nil {
// Lets assume it was not urlDecoded
urlDecoded = urlEncoded
}
a.Logger.Debugf("Processing payload: as SAMLRequest %s", urlDecoded)
logoutRequest, err := a.SamlServiceProvider.ValidateEncodedLogoutRequestPOST(urlDecoded)
if err != nil {
a.Logger.Errorf("Error validating encoded logout request (decoded payload: %s) (error: %s)", urlDecoded, err.Error())
return nil, nil, err
}
if logoutRequest == nil {
a.Logger.Errorf("Could not validate logoutrequest: %s", encodedRequestString)
return nil, nil, errors.New("Could not validate logout request")
}
return logoutRequest, nil, nil
}
a.Logger.Debugf("Could not determine payload: %s", encodedRequestString)
return nil, nil, nil
}
func (a SamlServiceProvider) GenerateAuthenticationRequest(w http.ResponseWriter, r *http.Request) (int, error) {
a.Logger.Debugf("No Session found, redirecting to IDP")
relayState := buildUrl(a.externalUrl, r.RequestURI)
err := a.SamlServiceProvider.AuthRedirect(w, r, relayState)
if err != nil {
a.Logger.Errorf("Error generating authentication request: %s", err.Error())
return http.StatusInternalServerError, err
}
return http.StatusFound, nil
}
func (provider *SamlServiceProvider) Metadata() (*types.EntityDescriptor, error) {
spMetadata, err := provider.SamlServiceProvider.Metadata()
if err != nil {
provider.Logger.Errorf("Error getting metadata from samlprovider: %s", err.Error())
return spMetadata, err
}
spMetadata.SPSSODescriptor.SingleLogoutServices = []types.Endpoint{{
Binding: saml2.BindingHttpPost,
Location: provider.SamlServiceProvider.ServiceProviderSLOURL,
}}
return spMetadata, nil
}