-
Notifications
You must be signed in to change notification settings - Fork 6
/
handler.go
168 lines (147 loc) · 4.67 KB
/
handler.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
/*
Copyright 2021 Adevinta
*/
package saml
import (
"errors"
"net/http"
"net/url"
"time"
)
const (
tokenExpiresAt = 6 * time.Hour
redirectQueryParam = "redirect_to"
)
var (
// ErrSAMLRequest indicates there is an error on SAML callback request.
ErrSAMLRequest = errors.New("malformed SAML callback request")
// ErrUserDataCallback indicates there was an error executing user data callback.
ErrUserDataCallback = errors.New("error on user data callback")
// ErrGeneratingToken indicates there was an error genereting JWT token.
ErrGeneratingToken = errors.New("error generating token")
// ErrRelayStateInvalid indicates the provided "redirect_to" URL is not valid.
ErrRelayStateInvalid = errors.New("invalid RelayState URL")
// ErrUntrustedDomain indicates the redirect domain is not trusted.
ErrUntrustedDomain = errors.New("redirect to an untrusted domain was requested")
)
// UserDataCallback represents the callback to
// execute when user data is obtained from SAML response.
type UserDataCallback func(UserData) error
// TokenGenerator defines the method to generate a new session token.
// Note that is designed thinking in a Bearer token, like OAuth / JWT
type TokenGenerator func(data map[string]interface{}) (string, error)
// CallbackConfig specifies config options
// for the login callback function.
type CallbackConfig struct {
CookieName string
CookieDomain string
CookieSecure bool
UserDataCallback UserDataCallback
TokenGenerator TokenGenerator
}
// Handler represents a SAML
// authentication handler.
type Handler interface {
LoginHandler() http.HandlerFunc
LoginCallbackHandler(CallbackConfig) http.HandlerFunc
}
type handler struct {
p Provider
trustedDomains []string
}
// NewHandler builds a new SAML handler from a SAML provider
// and a list of trusted domains.
func NewHandler(provider Provider, trustedDomains []string) Handler {
return &handler{
p: provider,
trustedDomains: trustedDomains,
}
}
// LoginHandler returns the function to handle login
// requests through a SAML federated identity provider.
// The 'redirect_to' req query param indicates where should
// the user be redirected once the authentication process
// through de IdP is completed.
func (h *handler) LoginHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
redirectPath := r.FormValue(redirectQueryParam)
// Validate URL
redirectURL, err := url.Parse(redirectPath)
if err != nil {
writeResp(w, http.StatusBadRequest, ErrRelayStateInvalid)
return
}
if redirectURL.IsAbs() {
trusted := false
for _, domain := range h.trustedDomains {
if redirectURL.Hostname() == domain {
trusted = true
break
}
}
if !trusted {
writeResp(w, http.StatusBadRequest, ErrUntrustedDomain)
return
}
}
// Build redirect URL
relayState, _ := h.p.BuildAuthURL(url.QueryEscape(redirectPath))
http.Redirect(w, r, relayState, http.StatusFound)
}
}
// LoginCallbackHandler returns the function to handle the SAML callback response
// after authentication has been performed through the identity provider.
func (h *handler) LoginCallbackHandler(cfg CallbackConfig) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
writeResp(w, http.StatusBadRequest, ErrSAMLRequest)
return
}
userData, err := h.p.GetUserData(r.FormValue("SAMLResponse"))
if err != nil {
respStatus := http.StatusBadRequest
if errors.Is(err, ErrNotInAudience) {
respStatus = http.StatusForbidden
}
writeResp(w, respStatus, err)
return
}
if cfg.UserDataCallback != nil {
if err = cfg.UserDataCallback(userData); err != nil {
writeResp(w, http.StatusBadRequest, ErrUserDataCallback)
return
}
}
tokenGenTime := time.Now()
claims := map[string]interface{}{
"first_name": userData.FirstName,
"last_name": userData.LastName,
"email": userData.Email,
"username": userData.UserName,
"iat": tokenGenTime.Unix(),
"exp": tokenGenTime.Add(tokenExpiresAt).Unix(),
"sub": userData.Email,
}
token, err := cfg.TokenGenerator(claims)
if err != nil {
writeResp(w, http.StatusBadRequest, ErrGeneratingToken)
return
}
cookie := &http.Cookie{
Path: "/",
Name: cfg.CookieName,
Value: token,
Expires: tokenGenTime.Add(tokenExpiresAt),
Domain: cfg.CookieDomain,
Secure: cfg.CookieSecure,
}
http.SetCookie(w, cookie)
relayState, _ := url.QueryUnescape(r.FormValue("RelayState"))
http.Redirect(w, r, relayState, http.StatusFound)
}
}
func writeResp(w http.ResponseWriter, code int, err error) {
w.WriteHeader(code)
_, _ = w.Write([]byte(err.Error()))
}