/
scope.go
80 lines (70 loc) · 2.19 KB
/
scope.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
package oauth
import (
"encoding/json"
"net/http"
"strings"
"github.com/authgear/authgear-server/pkg/lib/oauth/protocol"
"github.com/authgear/authgear-server/pkg/lib/session"
"github.com/authgear/authgear-server/pkg/lib/session/idpsession"
)
const FullAccessScope = "https://authgear.com/scopes/full-access"
const FullUserInfoScope = "https://authgear.com/scopes/full-userinfo"
func SessionScopes(s session.Session) []string {
switch s := s.(type) {
case *idpsession.IDPSession:
return []string{FullAccessScope}
case *OfflineGrant:
return s.Scopes
default:
panic("oauth: unexpected session type")
}
}
// RequireScope allow request to pass if session contains one of the required scopes.
// If there is no required scopes, only validity of session is checked.
func RequireScope(scopes ...string) func(http.Handler) http.Handler {
requiredScopes := map[string]struct{}{}
for _, s := range scopes {
requiredScopes[s] = struct{}{}
}
scope := strings.Join(scopes, " ")
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
session := session.GetSession(r.Context())
status, errResp := checkAuthz(session, requiredScopes, scope)
if errResp != nil {
h := errResp.ToWWWAuthenticateHeader()
rw.Header().Add("WWW-Authenticate", h)
rw.WriteHeader(status)
encoder := json.NewEncoder(rw)
err := encoder.Encode(errResp)
if err != nil {
http.Error(rw, err.Error(), 500)
}
return
}
next.ServeHTTP(rw, r)
})
}
}
func checkAuthz(session session.Session, requiredScopes map[string]struct{}, scope string) (int, protocol.ErrorResponse) {
if session == nil {
return http.StatusUnauthorized, protocol.NewErrorResponse("invalid_grant", "invalid session")
}
// Check scopes only if there are required scopes.
if len(requiredScopes) > 0 {
sessionScopes := SessionScopes(session)
pass := false
for _, s := range sessionScopes {
if _, ok := requiredScopes[s]; ok {
pass = true
break
}
}
if !pass {
resp := protocol.NewErrorResponse("insufficient_scope", "required scope not granted")
resp["scope"] = scope
return http.StatusForbidden, resp
}
}
return http.StatusOK, nil
}