/
auth_entry_point_middleware.go
61 lines (51 loc) · 2.02 KB
/
auth_entry_point_middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
package webapp
import (
"net/http"
"github.com/authgear/authgear-server/pkg/auth/handler/webapp/viewmodels"
"github.com/authgear/authgear-server/pkg/auth/webapp"
"github.com/authgear/authgear-server/pkg/lib/config"
"github.com/authgear/authgear-server/pkg/lib/session"
"github.com/authgear/authgear-server/pkg/util/httputil"
"github.com/authgear/authgear-server/pkg/util/template"
)
var TemplateRequireOAuth = template.RegisterHTML(
"web/require_oauth.html",
Components...,
)
type AuthEntryPointMiddleware struct {
BaseViewModel *viewmodels.BaseViewModeler
Renderer Renderer
AppHostSuffixes config.AppHostSuffixes
TrustProxy config.TrustProxy
OAuthConfig *config.OAuthConfig
UIConfig *config.UIConfig
OAuthClientResolver WebappOAuthClientResolver
}
func (m *AuthEntryPointMiddleware) Handle(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userID := session.GetUserID(r.Context())
webSession := webapp.GetSession(r.Context())
fromAuthzEndpoint := false
if webSession != nil {
// stay in the auth entry point if login is triggered by authz endpoint
fromAuthzEndpoint = webSession.OAuthSessionID != ""
}
host := httputil.GetHost(r, bool(m.TrustProxy))
isDefaultDomain := m.AppHostSuffixes.CheckIsDefaultDomain(host)
if userID != nil && !fromAuthzEndpoint {
defaultRedirectURI := webapp.DerivePostLoginRedirectURIFromRequest(r, m.OAuthClientResolver, m.UIConfig)
redirectURI := webapp.GetRedirectURI(r, bool(m.TrustProxy), defaultRedirectURI)
http.Redirect(w, r, redirectURI, http.StatusFound)
} else if userID == nil && !fromAuthzEndpoint && isDefaultDomain {
m.renderBlocked(w, r)
} else {
next.ServeHTTP(w, r)
}
})
}
func (m *AuthEntryPointMiddleware) renderBlocked(w http.ResponseWriter, r *http.Request) {
data := make(map[string]interface{})
baseViewModel := m.BaseViewModel.ViewModel(r, w)
viewmodels.Embed(data, baseViewModel)
m.Renderer.RenderHTML(w, r, TemplateRequireOAuth, data)
}