Skip to content

Commit

Permalink
Prevent multiple Set-Cookie headers when calling RegenerateToken
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Oct 14, 2020
1 parent ec1bc1f commit 3dd5c4d
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 16 deletions.
17 changes: 17 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ type csrfContext struct {
token string
// reason for the failure of CSRF check
reason error
// wasSent is true if `Set-Cookie` was called
// for the `name=csrf_token` already. This prevents
// duplicate `Set-Cookie: csrf_token` headers.
// For more information see:
// https://github.com/justinas/nosurf/pull/61
wasSent bool
}

// Token takes an HTTP request and returns
Expand Down Expand Up @@ -53,6 +59,17 @@ func ctxSetToken(req *http.Request, token []byte) {
ctx.token = b64encode(maskToken(token))
}

func ctxSetSent(req *http.Request) {
ctx := req.Context().Value(nosurfKey).(*csrfContext)
ctx.wasSent = true
}

func ctxWasSent(req *http.Request) bool {
ctx := req.Context().Value(nosurfKey).(*csrfContext)

return ctx.wasSent
}

func ctxSetReason(req *http.Request, reason error) {
ctx := req.Context().Value(nosurfKey).(*csrfContext)
if ctx.token == "" {
Expand Down
32 changes: 32 additions & 0 deletions context_legacy.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ type csrfContext struct {
token string
// reason for the failure of CSRF check
reason error
// wasSent is true if `Set-Cookie` was called
// for the `name=csrf_token` already. This prevents
// duplicate `Set-Cookie: csrf_token` headers.
// For more information see:
// https://github.com/justinas/nosurf/pull/61
wasSent bool
}

var (
Expand Down Expand Up @@ -79,6 +85,32 @@ func ctxSetToken(req *http.Request, token []byte) *http.Request {
return req
}

func ctxSetSent(req *http.Request) {
cmMutex.Lock()
defer cmMutex.Unlock()

ctx, ok := contextMap[req]
if !ok {
ctx = new(csrfContext)
contextMap[req] = ctx
}

ctx.wasSent = true
}

func ctxWasSent(req *http.Request) bool {
cmMutex.RLock()
defer cmMutex.RUnlock()

ctx, ok := contextMap[req]

if !ok {
return false
}

return ctx.wasSent
}

func ctxSetReason(req *http.Request, reason error) *http.Request {
cmMutex.Lock()
defer cmMutex.Unlock()
Expand Down
11 changes: 11 additions & 0 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,16 @@ func (h *CSRFHandler) handleFailure(w http.ResponseWriter, r *http.Request) {

// Generates a new token, sets it on the given request and returns it
func (h *CSRFHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) string {
if ctxWasSent(r) {
// The CSRF Cookie was set already by an earlier call to `RegenerateToken`
// in the same request context. It therefore does not make sense to regenerate
// it again as it will lead to two or more `Set-Cookie` instructions which will in turn
// cause CSRF to fail depending on the resulting order of the `Set-Cookie` instructions.
//
// No warning is necessary as the only caller to `setTokenCookie` is `RegenerateToken`.
return Token(r)
}

token := generateToken()
h.setTokenCookie(w, r, token)

Expand All @@ -210,6 +220,7 @@ func (h *CSRFHandler) setTokenCookie(w http.ResponseWriter, r *http.Request, tok
cookie.Value = b64encode(token)

http.SetCookie(w, &cookie)
ctxSetSent(r)

}

Expand Down
17 changes: 17 additions & 0 deletions handler_go17_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,20 @@ func TestContextIsAccessibleWithContext(t *testing.T) {

hand.ServeHTTP(writer, req)
}

func TestNoDoubleCookie(t *testing.T) {
var n *CSRFHandler
n = New(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n.RegenerateToken(w, r)
}))

r := httptest.NewRequest("GET", "http://dummy.us", nil)
w := httptest.NewRecorder()

n.ServeHTTP(w, r)

count := len(w.Result().Cookies())
if count > 1 {
t.Errorf("Expected one CSRF cookie, got %d", count)
}
}
18 changes: 18 additions & 0 deletions handler_legacy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package nosurf
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)

Expand All @@ -20,3 +21,20 @@ func TestClearsContextAfterTheRequest(t *testing.T) {
t.Errorf("Instead, the context entry remains: %v", contextMap[req])
}
}

func TestNoDoubleCookie(t *testing.T) {
var n *CSRFHandler
n = New(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n.RegenerateToken(w, r)
}))

r := httptest.NewRequest("GET", "http://dummy.us", nil)
w := httptest.NewRecorder()

n.ServeHTTP(w, r)

count := strings.Count(w.HeaderMap.Get("Set-Cookie"), "csrf_token")
if count > 1 {
t.Errorf("Expected one CSRF cookie, got %d", count)
}
}
16 changes: 0 additions & 16 deletions handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,6 @@ import (
"testing"
)

func TestNoDoubleCookie(t *testing.T) {
var n *CSRFHandler
n = New(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n.RegenerateToken(w, r)
}))

r := httptest.NewRequest("GET", "http://dummy.us", nil)
w := httptest.NewRecorder()

n.ServeHTTP(w, r)

if len(w.Result().Cookies()) > 1 {
t.Errorf("Expected one CSRF cookie, got %d", len(w.Result().Cookies()))
}
}

func TestDefaultFailureHandler(t *testing.T) {
writer := httptest.NewRecorder()
req := dummyGet()
Expand Down

0 comments on commit 3dd5c4d

Please sign in to comment.