Skip to content

Commit

Permalink
Don't buffer responses when using LoadAndSave()
Browse files Browse the repository at this point in the history
  • Loading branch information
alexedwards committed Mar 5, 2023
1 parent a07530f commit 62e546c
Showing 1 changed file with 43 additions and 42 deletions.
85 changes: 43 additions & 42 deletions session.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package scs

import (
"bufio"
"bytes"
"context"
"log"
"net"
"net/http"
"time"

Expand Down Expand Up @@ -131,6 +128,8 @@ func NewSession() *SessionManager {
// the client in a cookie.
func (s *SessionManager) LoadAndSave(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Vary", "Cookie")

var token string
cookie, err := r.Cookie(s.Cookie.Name)
if err == nil {
Expand All @@ -144,33 +143,36 @@ func (s *SessionManager) LoadAndSave(next http.Handler) http.Handler {
}

sr := r.WithContext(ctx)
bw := &bufferedResponseWriter{ResponseWriter: w}
next.ServeHTTP(bw, sr)

if sr.MultipartForm != nil {
sr.MultipartForm.RemoveAll()
sw := &sessionResponseWriter{
ResponseWriter: w,
request: sr,
sessionManager: s,
}

switch s.Status(ctx) {
case Modified:
token, expiry, err := s.Commit(ctx)
if err != nil {
s.ErrorFunc(w, r, err)
return
}

s.WriteSessionCookie(ctx, w, token, expiry)
case Destroyed:
s.WriteSessionCookie(ctx, w, "", time.Time{})
next.ServeHTTP(sw, sr)

if !sw.written {
s.commitAndWriteSessionCookie(w, sr)
}
})
}

w.Header().Add("Vary", "Cookie")
func (s *SessionManager) commitAndWriteSessionCookie(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

if bw.code != 0 {
w.WriteHeader(bw.code)
switch s.Status(ctx) {
case Modified:
token, expiry, err := s.Commit(ctx)
if err != nil {
s.ErrorFunc(w, r, err)
return
}
w.Write(bw.buf.Bytes())
})

s.WriteSessionCookie(ctx, w, token, expiry)
case Destroyed:
s.WriteSessionCookie(ctx, w, "", time.Time{})
}
}

// WriteSessionCookie writes a cookie to the HTTP response with the provided
Expand Down Expand Up @@ -211,32 +213,31 @@ func defaultErrorFunc(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}

type bufferedResponseWriter struct {
type sessionResponseWriter struct {
http.ResponseWriter
buf bytes.Buffer
code int
wroteHeader bool
request *http.Request
sessionManager *SessionManager
written bool
}

func (bw *bufferedResponseWriter) Write(b []byte) (int, error) {
return bw.buf.Write(b)
func (sw *sessionResponseWriter) Write(b []byte) (int, error) {
if !sw.written {
sw.sessionManager.commitAndWriteSessionCookie(sw.ResponseWriter, sw.request)
sw.written = true
}

return sw.ResponseWriter.Write(b)
}

func (bw *bufferedResponseWriter) WriteHeader(code int) {
if !bw.wroteHeader {
bw.code = code
bw.wroteHeader = true
func (sw *sessionResponseWriter) WriteHeader(code int) {
if !sw.written {
sw.sessionManager.commitAndWriteSessionCookie(sw.ResponseWriter, sw.request)
sw.written = true
}
}

func (bw *bufferedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hj := bw.ResponseWriter.(http.Hijacker)
return hj.Hijack()
sw.ResponseWriter.WriteHeader(code)
}

func (bw *bufferedResponseWriter) Push(target string, opts *http.PushOptions) error {
if pusher, ok := bw.ResponseWriter.(http.Pusher); ok {
return pusher.Push(target, opts)
}
return http.ErrNotSupported
func (sw *sessionResponseWriter) Unwrap() http.ResponseWriter {
return sw.ResponseWriter
}

0 comments on commit 62e546c

Please sign in to comment.