Skip to content

Commit

Permalink
feat(server): handle head method (#5003)
Browse files Browse the repository at this point in the history
This implements some HEAD method handlers for various static resources and the /api/health endpoint.
  • Loading branch information
james-d-elliott committed Feb 28, 2023
1 parent f68e5cf commit a345490
Show file tree
Hide file tree
Showing 8 changed files with 251 additions and 65 deletions.
8 changes: 8 additions & 0 deletions api/openapi.yml
Expand Up @@ -76,6 +76,14 @@ paths:
schema:
$ref: '#/components/schemas/handlers.configuration.PasswordPolicyConfigurationBody'
/api/health:
head:
tags:
- State
summary: Application Health
description: The health check endpoint provides information about the health of Authelia.
responses:
"200":
description: Successful Operation
get:
tags:
- State
Expand Down
32 changes: 6 additions & 26 deletions internal/handlers/handler_oidc_wellknown.go
@@ -1,8 +1,6 @@
package handlers

import (
"net/url"

"github.com/valyala/fasthttp"

"github.com/authelia/authelia/v4/internal/middlewares"
Expand All @@ -11,20 +9,11 @@ import (
// OpenIDConnectConfigurationWellKnownGET handles requests to a .well-known endpoint (RFC5785) which returns the
// OpenID Connect Discovery 1.0 metadata.
//
// https://datatracker.ietf.org/doc/html/rfc5785
// RFC5785: Defining Well-Known URIs (https://datatracker.ietf.org/doc/html/rfc5785)
//
// https://openid.net/specs/openid-connect-discovery-1_0.html
// OpenID Connect Discovery 1.0 (https://openid.net/specs/openid-connect-discovery-1_0.html)
func OpenIDConnectConfigurationWellKnownGET(ctx *middlewares.AutheliaCtx) {
var (
issuer *url.URL
err error
)

issuer = ctx.RootURL()

wellKnown := ctx.Providers.OpenIDConnect.GetOpenIDConnectWellKnownConfiguration(issuer.String())

if err = ctx.ReplyJSON(wellKnown, fasthttp.StatusOK); err != nil {
if err := ctx.ReplyJSON(ctx.Providers.OpenIDConnect.GetOpenIDConnectWellKnownConfiguration(ctx.RootURL().String()), fasthttp.StatusOK); err != nil {
ctx.Logger.Errorf("Error occurred in JSON encode: %+v", err)

// TODO: Determine if this is the appropriate error code here.
Expand All @@ -37,20 +26,11 @@ func OpenIDConnectConfigurationWellKnownGET(ctx *middlewares.AutheliaCtx) {
// OAuthAuthorizationServerWellKnownGET handles requests to a .well-known endpoint (RFC5785) which returns the
// OAuth 2.0 Authorization Server Metadata (RFC8414).
//
// https://datatracker.ietf.org/doc/html/rfc5785
// RFC5785: Defining Well-Known URIs (https://datatracker.ietf.org/doc/html/rfc5785)
//
// https://datatracker.ietf.org/doc/html/rfc8414
// RFC8414: OAuth 2.0 Authorization Server Metadata (https://datatracker.ietf.org/doc/html/rfc8414)
func OAuthAuthorizationServerWellKnownGET(ctx *middlewares.AutheliaCtx) {
var (
issuer *url.URL
err error
)

issuer = ctx.RootURL()

wellKnown := ctx.Providers.OpenIDConnect.GetOAuth2WellKnownConfiguration(issuer.String())

if err = ctx.ReplyJSON(wellKnown, fasthttp.StatusOK); err != nil {
if err := ctx.ReplyJSON(ctx.Providers.OpenIDConnect.GetOAuth2WellKnownConfiguration(ctx.RootURL().String()), fasthttp.StatusOK); err != nil {
ctx.Logger.Errorf("Error occurred in JSON encode: %+v", err)

// TODO: Determine if this is the appropriate error code here.
Expand Down
20 changes: 18 additions & 2 deletions internal/server/asset.go
Expand Up @@ -11,6 +11,7 @@ import (
"net/http"
"path"
"path/filepath"
"strconv"
"strings"

"github.com/valyala/fasthttp"
Expand Down Expand Up @@ -64,7 +65,15 @@ func newPublicHTMLEmbeddedHandler() fasthttp.RequestHandler {
}

ctx.SetContentType(contentType)
ctx.SetBody(data)

switch {
case ctx.IsHead():
ctx.Response.ResetBody()
ctx.Response.SkipBody = true
ctx.Response.Header.Set(fasthttp.HeaderContentLength, strconv.Itoa(len(data)))
default:
ctx.SetBody(data)
}
}
}

Expand Down Expand Up @@ -182,7 +191,14 @@ func newLocalesEmbeddedHandler() (handler fasthttp.RequestHandler) {

middlewares.SetContentTypeApplicationJSON(ctx)

ctx.SetBody(data)
switch {
case ctx.IsHead():
ctx.Response.ResetBody()
ctx.Response.SkipBody = true
ctx.Response.Header.Set(fasthttp.HeaderContentLength, strconv.Itoa(len(data)))
default:
ctx.SetBody(data)
}
}
}

Expand Down
32 changes: 29 additions & 3 deletions internal/server/handlers.go
@@ -1,6 +1,7 @@
package server

import (
"fmt"
"net"
"os"
"path"
Expand Down Expand Up @@ -77,10 +78,10 @@ func handleError() func(ctx *fasthttp.RequestCtx, err error) {

func handleNotFound(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
path := strings.ToLower(string(ctx.Path()))
uri := strings.ToLower(string(ctx.Path()))

for i := 0; i < len(dirsHTTPServer); i++ {
if path == dirsHTTPServer[i].name || strings.HasPrefix(path, dirsHTTPServer[i].prefix) {
if uri == dirsHTTPServer[i].name || strings.HasPrefix(uri, dirsHTTPServer[i].prefix) {
handlers.SetStatusCodeResponse(ctx, fasthttp.StatusNotFound)

return
Expand All @@ -91,6 +92,13 @@ func handleNotFound(next fasthttp.RequestHandler) fasthttp.RequestHandler {
}
}

func handleMethodNotAllowed(ctx *fasthttp.RequestCtx) {
middlewares.SetContentTypeTextPlain(ctx)

ctx.SetStatusCode(fasthttp.StatusMethodNotAllowed)
ctx.SetBodyString(fmt.Sprintf("%d %s", fasthttp.StatusMethodNotAllowed, fasthttp.StatusMessage(fasthttp.StatusMethodNotAllowed)))
}

//nolint:gocyclo
func handleRouter(config *schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler {
log := logging.Logger()
Expand All @@ -115,29 +123,45 @@ func handleRouter(config *schema.Configuration, providers middlewares.Providers)
r := router.New()

// Static Assets.
r.HEAD("/", bridge(serveIndexHandler))
r.GET("/", bridge(serveIndexHandler))

for _, f := range filesRoot {
r.HEAD("/"+f, handlerPublicHTML)
r.GET("/"+f, handlerPublicHTML)
}

r.HEAD("/favicon.ico", middlewares.AssetOverride(config.Server.AssetPath, 0, handlerPublicHTML))
r.GET("/favicon.ico", middlewares.AssetOverride(config.Server.AssetPath, 0, handlerPublicHTML))

r.HEAD("/static/media/logo.png", middlewares.AssetOverride(config.Server.AssetPath, 2, handlerPublicHTML))
r.GET("/static/media/logo.png", middlewares.AssetOverride(config.Server.AssetPath, 2, handlerPublicHTML))

r.HEAD("/static/{filepath:*}", handlerPublicHTML)
r.GET("/static/{filepath:*}", handlerPublicHTML)

// Locales.
r.HEAD("/locales/{language:[a-z]{1,3}}-{variant:[a-zA-Z0-9-]+}/{namespace:[a-z]+}.json", middlewares.AssetOverride(config.Server.AssetPath, 0, handlerLocales))
r.GET("/locales/{language:[a-z]{1,3}}-{variant:[a-zA-Z0-9-]+}/{namespace:[a-z]+}.json", middlewares.AssetOverride(config.Server.AssetPath, 0, handlerLocales))

r.HEAD("/locales/{language:[a-z]{1,3}}/{namespace:[a-z]+}.json", middlewares.AssetOverride(config.Server.AssetPath, 0, handlerLocales))
r.GET("/locales/{language:[a-z]{1,3}}/{namespace:[a-z]+}.json", middlewares.AssetOverride(config.Server.AssetPath, 0, handlerLocales))

// Swagger.
r.HEAD("/api/", bridge(serveOpenAPIHandler))
r.GET("/api/", bridge(serveOpenAPIHandler))
r.OPTIONS("/api/", policyCORSPublicGET.HandleOPTIONS)

r.HEAD("/api/index.html", bridge(serveOpenAPIHandler))
r.GET("/api/index.html", bridge(serveOpenAPIHandler))
r.OPTIONS("/api/index.html", policyCORSPublicGET.HandleOPTIONS)

r.HEAD("/api/openapi.yml", policyCORSPublicGET.Middleware(bridge(serveOpenAPISpecHandler)))
r.GET("/api/openapi.yml", policyCORSPublicGET.Middleware(bridge(serveOpenAPISpecHandler)))
r.OPTIONS("/api/openapi.yml", policyCORSPublicGET.HandleOPTIONS)

for _, file := range filesSwagger {
r.HEAD("/api/"+file, handlerPublicHTML)
r.GET("/api/"+file, handlerPublicHTML)
}

Expand All @@ -150,7 +174,9 @@ func handleRouter(config *schema.Configuration, providers middlewares.Providers)
WithPostMiddlewares(middlewares.Require1FA).
Build()

r.HEAD("/api/health", middlewareAPI(handlers.HealthGET))
r.GET("/api/health", middlewareAPI(handlers.HealthGET))

r.GET("/api/state", middlewareAPI(handlers.StateGET))

r.GET("/api/configuration", middleware1FA(handlers.ConfigurationGET))
Expand Down Expand Up @@ -356,7 +382,7 @@ func handleRouter(config *schema.Configuration, providers middlewares.Providers)
}

r.HandleMethodNotAllowed = true
r.MethodNotAllowed = handlers.Status(fasthttp.StatusMethodNotAllowed)
r.MethodNotAllowed = handleMethodNotAllowed
r.NotFound = handleNotFound(bridge(serveIndexHandler))

handler := middlewares.LogRequest(r.Handler)
Expand Down
50 changes: 44 additions & 6 deletions internal/server/template.go
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/hex"
"fmt"
"os"
"path"
"path/filepath"
"strconv"
"strings"
Expand All @@ -25,7 +26,7 @@ import (
// and generate a nonce to support a restrictive CSP while using material-ui.
func ServeTemplatedFile(t templates.Template, opts *TemplatedFileOptions) middlewares.RequestHandler {
isDevEnvironment := os.Getenv(environment) == dev
ext := filepath.Ext(t.Name())
ext := path.Ext(t.Name())

return func(ctx *middlewares.AutheliaCtx) {
var err error
Expand Down Expand Up @@ -67,18 +68,34 @@ func ServeTemplatedFile(t templates.Template, opts *TemplatedFileOptions) middle
rememberMe = strconv.FormatBool(!provider.Config.DisableRememberMe)
}

if err = t.Execute(ctx.Response.BodyWriter(), opts.CommonData(ctx.BasePath(), ctx.RootURLSlash().String(), nonce, logoOverride, rememberMe)); err != nil {
ctx.RequestCtx.Error("an error occurred", 503)
data := &bytes.Buffer{}

if err = t.Execute(data, opts.CommonData(ctx.BasePath(), ctx.RootURLSlash().String(), nonce, logoOverride, rememberMe)); err != nil {
ctx.RequestCtx.Error("an error occurred", fasthttp.StatusServiceUnavailable)
ctx.Logger.WithError(err).Errorf("Error occcurred rendering template")

return
}

switch {
case ctx.IsHead():
ctx.Response.ResetBody()
ctx.Response.SkipBody = true
ctx.Response.Header.Set(fasthttp.HeaderContentLength, strconv.Itoa(data.Len()))
default:
if _, err = data.WriteTo(ctx.Response.BodyWriter()); err != nil {
ctx.RequestCtx.Error("an error occurred", fasthttp.StatusServiceUnavailable)
ctx.Logger.WithError(err).Errorf("Error occcurred writing body")

return
}
}
}
}

// ServeTemplatedOpenAPI serves templated OpenAPI related files.
func ServeTemplatedOpenAPI(t templates.Template, opts *TemplatedFileOptions) middlewares.RequestHandler {
ext := filepath.Ext(t.Name())
ext := path.Ext(t.Name())

spec := ext == extYML

Expand All @@ -103,12 +120,28 @@ func ServeTemplatedOpenAPI(t templates.Template, opts *TemplatedFileOptions) mid

var err error

if err = t.Execute(ctx.Response.BodyWriter(), opts.OpenAPIData(ctx.BasePath(), ctx.RootURLSlash().String(), nonce)); err != nil {
ctx.RequestCtx.Error("an error occurred", 503)
data := &bytes.Buffer{}

if err = t.Execute(data, opts.OpenAPIData(ctx.BasePath(), ctx.RootURLSlash().String(), nonce)); err != nil {
ctx.RequestCtx.Error("an error occurred", fasthttp.StatusServiceUnavailable)
ctx.Logger.WithError(err).Errorf("Error occcurred rendering template")

return
}

switch {
case ctx.IsHead():
ctx.Response.ResetBody()
ctx.Response.SkipBody = true
ctx.Response.Header.Set(fasthttp.HeaderContentLength, strconv.Itoa(data.Len()))
default:
if _, err = data.WriteTo(ctx.Response.BodyWriter()); err != nil {
ctx.RequestCtx.Error("an error occurred", fasthttp.StatusServiceUnavailable)
ctx.Logger.WithError(err).Errorf("Error occcurred writing body")

return
}
}
}
}

Expand Down Expand Up @@ -139,6 +172,11 @@ func ETagRootURL(next middlewares.RequestHandler) middlewares.RequestHandler {

next(ctx)

if ctx.Response.SkipBody || ctx.Response.StatusCode() != fasthttp.StatusOK {
// Skip generating the ETag as the response body should be empty.
return
}

mu.Lock()

h.Write(ctx.Response.Body())
Expand Down

0 comments on commit a345490

Please sign in to comment.