Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions internal/http_handlers/csrf.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package http_handlers

import (
"net/http"
"strings"

"github.com/gin-gonic/gin"
)

// CSRFMiddleware protects against CSRF by requiring state-changing requests
// (POST, PUT, DELETE, PATCH) to include a custom header that browsers will
// not send cross-origin without a CORS preflight.
// OAuth callback POST routes are exempt as they originate from provider redirects.
func (h *httpProvider) CSRFMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
method := c.Request.Method
if method == "GET" || method == "HEAD" || method == "OPTIONS" {
c.Next()
return
}

// Exempt OAuth callback routes (provider POST redirects)
if strings.HasPrefix(c.Request.URL.Path, "/oauth_callback/") {
c.Next()
return
}

// Exempt /oauth/token (client credentials flow, no cookies)
if c.Request.URL.Path == "/oauth/token" || c.Request.URL.Path == "/oauth/revoke" {
c.Next()
return
}

// Require Content-Type to be application/json or the presence of
// X-Requested-With header. Browsers cannot send these cross-origin
// without a CORS preflight check succeeding first.
contentType := c.Request.Header.Get("Content-Type")
xRequestedWith := c.Request.Header.Get("X-Requested-With")

if strings.Contains(contentType, "application/json") || xRequestedWith != "" {
c.Next()
return
}

c.JSON(http.StatusForbidden, gin.H{
"error": "csrf_validation_failed",
"error_description": "State-changing requests must include Content-Type: application/json or X-Requested-With header",
})
c.Abort()
}
}
2 changes: 2 additions & 0 deletions internal/http_handlers/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ type Provider interface {
ContextMiddleware() gin.HandlerFunc
// CORSMiddleware is the middleware that adds the cors headers to the response
CORSMiddleware() gin.HandlerFunc
// CSRFMiddleware protects against CSRF on state-changing requests
CSRFMiddleware() gin.HandlerFunc
// LoggerMiddleware is the middleware that logs the request
LoggerMiddleware() gin.HandlerFunc
}
1 change: 1 addition & 0 deletions internal/server/http_routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ func (s *server) NewRouter() *gin.Engine {
router.Use(s.Dependencies.HTTPProvider.LoggerMiddleware())
router.Use(s.Dependencies.HTTPProvider.ContextMiddleware())
router.Use(s.Dependencies.HTTPProvider.CORSMiddleware())
router.Use(s.Dependencies.HTTPProvider.CSRFMiddleware())
router.Use(s.Dependencies.HTTPProvider.ClientCheckMiddleware())

router.GET("/", s.Dependencies.HTTPProvider.RootHandler())
Expand Down