-
Notifications
You must be signed in to change notification settings - Fork 1
/
scs_sessions.go
103 lines (83 loc) · 2.84 KB
/
scs_sessions.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
package scsmiddleware
import (
"time"
"github.com/Nigel2392/router/v3"
"github.com/Nigel2392/router/v3/middleware"
"github.com/Nigel2392/router/v3/request"
"github.com/Nigel2392/router/v3/request/writer"
"github.com/alexedwards/scs/v2"
)
type scsRequestSession struct {
r *request.Request
store *scs.SessionManager
}
func (s *scsRequestSession) Get(key string) interface{} {
return s.store.Get(s.r.Request.Context(), key)
}
func (s *scsRequestSession) Set(key string, value interface{}) {
s.store.Put(s.r.Request.Context(), key, value)
}
func (s *scsRequestSession) Destroy() error {
return s.store.Destroy(s.r.Request.Context())
}
func (s *scsRequestSession) Exists(key string) bool {
return s.store.Exists(s.r.Request.Context(), key)
}
func (s *scsRequestSession) Delete(key string) {
s.store.Remove(s.r.Request.Context(), key)
}
func (s *scsRequestSession) RenewToken() error {
return s.store.RenewToken(s.r.Request.Context())
}
// Customized version of scs's Middleware function
// This is due to the fact that the original Middleware function
// does not support the router.Handler interface
func SessionMiddleware(store *scs.SessionManager) func(next router.Handler) router.Handler {
return func(next router.Handler) router.Handler {
return router.HandleFunc(func(r *request.Request) {
var token string
cookie, err := r.Request.Cookie(store.Cookie.Name)
if err == nil {
token = cookie.Value
}
ctx, err := store.Load(r.Request.Context(), token)
if err != nil {
if middleware.DEFAULT_LOGGER != nil {
middleware.DEFAULT_LOGGER.Error(middleware.FormatMessage(r, "ERROR", "[%s] Error loading session: %v", r.IP(), err))
}
store.ErrorFunc(r.Response, r.Request, err)
return
}
// Store the old response for later
oldWriter := r.Response
bw := writer.NewClearable(r.Response)
sr := r.Request.WithContext(ctx)
// Set the buffered writer as the response writer
r.Response = bw
// Set the new request with the context
r.Request = sr
// Set the session on the request
r.Session = &scsRequestSession{r: r, store: store}
next.ServeHTTP(r)
if sr.MultipartForm != nil {
sr.MultipartForm.RemoveAll()
}
switch store.Status(ctx) {
case scs.Modified:
token, expiry, err := store.Commit(ctx)
if err != nil {
if middleware.DEFAULT_LOGGER != nil {
middleware.DEFAULT_LOGGER.Error(middleware.FormatMessage(r, "ERROR", "[%s] Error committing session: %v", r.IP(), err))
}
store.ErrorFunc(oldWriter, r.Request, err)
return
}
store.WriteSessionCookie(ctx, oldWriter, token, expiry)
case scs.Destroyed:
store.WriteSessionCookie(ctx, oldWriter, "", time.Time{})
}
request.AddHeader(r.Response, "Vary", "Cookie")
bw.Finalize()
})
}
}