/
cors.go
57 lines (45 loc) · 1.58 KB
/
cors.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
package middleware
import (
"net/http"
"github.com/authgear/authgear-server/pkg/util/log"
)
type CORSMiddlewareLogger struct{ *log.Logger }
func NewCORSMiddlewareLogger(lf *log.Factory) CORSMiddlewareLogger {
return CORSMiddlewareLogger{lf.New("cors-middleware")}
}
// CORSMiddleware provides CORS headers by matching request origin with the configured allowed origins
// The allowed origins are provided through app config and environment variable
type CORSMiddleware struct {
Matcher *CORSMatcher
Logger CORSMiddlewareLogger
}
func (m *CORSMiddleware) Handle(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
matcher, err := m.Matcher.PrepareOriginMatcher(r)
// nolint: staticcheck
if err != nil {
// err is handled by not writing any CORS headers.
}
w.Header().Add("Vary", "Origin")
origin := r.Header.Get("Origin")
if origin != "" && err == nil && matcher.MatchOrigin(origin) {
corsMethod := r.Header.Get("Access-Control-Request-Method")
corsHeaders := r.Header.Get("Access-Control-Request-Headers")
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Max-Age", "900") // 15 mins
if corsMethod != "" {
w.Header().Set("Access-Control-Allow-Methods", corsMethod)
}
if corsHeaders != "" {
w.Header().Set("Access-Control-Allow-Headers", corsHeaders)
}
}
requestMethod := r.Method
if requestMethod == "OPTIONS" {
w.WriteHeader(http.StatusOK)
} else {
next.ServeHTTP(w, r)
}
})
}