Skip to content

Commit

Permalink
feat(server): recover panics
Browse files Browse the repository at this point in the history
This adds a general panic recovery middleware to the server so any request can theoretically not cause unreoovered panics.

Signed-off-by: James Elliott <james-d-elliott@users.noreply.github.com>
  • Loading branch information
james-d-elliott committed Mar 4, 2024
1 parent f897782 commit 2c6a8e1
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 20 deletions.
18 changes: 4 additions & 14 deletions internal/middlewares/authelia_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ import (
)

// NewRequestLogger create a new request logger for the given request.
func NewRequestLogger(ctx *AutheliaCtx) *logrus.Entry {
func NewRequestLogger(ctx *fasthttp.RequestCtx) *logrus.Entry {
return logging.Logger().WithFields(logrus.Fields{
logging.FieldMethod: string(ctx.Method()),
logging.FieldPath: string(ctx.Path()),
logging.FieldRemoteIP: ctx.RemoteIP().String(),
logging.FieldRemoteIP: RequestCtxRemoteIP(ctx).String(),
})
}

Expand All @@ -37,7 +37,7 @@ func NewAutheliaCtx(requestCTX *fasthttp.RequestCtx, configuration schema.Config
ctx.RequestCtx = requestCTX
ctx.Providers = providers
ctx.Configuration = configuration
ctx.Logger = NewRequestLogger(ctx)
ctx.Logger = NewRequestLogger(ctx.RequestCtx)
ctx.Clock = clock.New()

return ctx
Expand Down Expand Up @@ -482,17 +482,7 @@ func (ctx *AutheliaCtx) SetJSONBody(value any) error {

// RemoteIP return the remote IP taking X-Forwarded-For header into account if provided.
func (ctx *AutheliaCtx) RemoteIP() net.IP {
if header := ctx.Request.Header.PeekBytes(headerXForwardedFor); len(header) != 0 {
ips := strings.SplitN(string(header), ",", 2)

if len(ips) != 0 {
if ip := net.ParseIP(strings.Trim(ips[0], " ")); ip != nil {
return ip
}
}
}

return ctx.RequestCtx.RemoteIP()
return RequestCtxRemoteIP(ctx.RequestCtx)
}

// GetXForwardedURL returns the parsed X-Forwarded-Proto, X-Forwarded-Host, and X-Forwarded-URI request header as a
Expand Down
38 changes: 37 additions & 1 deletion internal/middlewares/errors.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package middlewares

import "errors"
import (
"errors"
"fmt"

"github.com/valyala/fasthttp"
)

var (
// ErrMissingXForwardedProto is returned on methods which require an X-Forwarded-Proto header.
Expand All @@ -15,3 +20,34 @@ var (
// ErrMissingXOriginalURL is returned on methods which require an X-Original-URL header.
ErrMissingXOriginalURL = errors.New("missing required X-Original-URL header")
)

// RecoverPanic recovers from panics and logs the error.
func RecoverPanic(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
defer func() {
if r := recover(); r != nil {
NewRequestLogger(ctx).WithError(recoverErr(r)).Error("Panic (recovered) occurred while handling requests, please report this error")

ctx.Response.Reset()
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
ctx.SetContentTypeBytes(contentTypeTextPlain)
ctx.SetBodyString(fmt.Sprintf("%d %s", fasthttp.StatusInternalServerError, fasthttp.StatusMessage(fasthttp.StatusInternalServerError)))
}
}()

next(ctx)
}
}

func recoverErr(i any) error {
switch v := i.(type) {
case nil:
return nil
case string:
return fmt.Errorf("recovered panic: %s", v)
case error:
return fmt.Errorf("recovered panic: %w", v)
default:
return fmt.Errorf("recovered panic with unknown type: %v", v)
}
}
9 changes: 5 additions & 4 deletions internal/middlewares/log_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ import (
// LogRequest provides trace logging for all requests.
func LogRequest(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
autheliaCtx := &AutheliaCtx{RequestCtx: ctx}
logger := NewRequestLogger(autheliaCtx)
log := NewRequestLogger(ctx)

log.Trace("Request hit")

logger.Trace("Request hit")
next(ctx)
logger.Tracef("Replied (status=%d)", ctx.Response.StatusCode())

log.Tracef("Replied (status=%d)", ctx.Response.StatusCode())
}
}
30 changes: 30 additions & 0 deletions internal/middlewares/wrap.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package middlewares

import (
"net"
"strings"

"github.com/valyala/fasthttp"
)

Expand All @@ -12,3 +15,30 @@ func Wrap(middleware Basic, next fasthttp.RequestHandler) (handler fasthttp.Requ

return middleware(next)
}

// MultiWrap allows wrapping a handler with additional middlewares if they are not nil.
func MultiWrap(next fasthttp.RequestHandler, middlewares ...Basic) (handler fasthttp.RequestHandler) {
for i := len(middlewares) - 1; i >= 0; i-- {
if middlewares[i] == nil {
continue
}

next = middlewares[i](next)
}

return next
}

func RequestCtxRemoteIP(ctx *fasthttp.RequestCtx) net.IP {
if header := ctx.Request.Header.PeekBytes(headerXForwardedFor); len(header) != 0 {
ips := strings.SplitN(string(header), ",", 2)

if len(ips) != 0 {
if ip := net.ParseIP(strings.Trim(ips[0], " ")); ip != nil {
return ip
}
}
}

return ctx.RemoteIP()
}
2 changes: 1 addition & 1 deletion internal/server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ func handleRouter(config *schema.Configuration, providers middlewares.Providers)
handler = middlewares.StripPath(config.Server.Address.RouterPath())(handler)
}

handler = middlewares.Wrap(middlewares.NewMetricsRequest(providers.Metrics), handler)
handler = middlewares.MultiWrap(handler, middlewares.RecoverPanic, middlewares.NewMetricsRequest(providers.Metrics))

return handler
}
Expand Down

0 comments on commit 2c6a8e1

Please sign in to comment.