Skip to content

Commit

Permalink
Plugins: Make proxy endpoints not leak sensitive HTTP headers
Browse files Browse the repository at this point in the history
Fixes CVE-2022-31130

(cherry picked from commit 40b319d3d6a9945c05709ce8d4679407f6ccadf0)
  • Loading branch information
marefr authored and papagian committed Oct 11, 2022
1 parent c4e6e05 commit 9da278c
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 2 deletions.
9 changes: 9 additions & 0 deletions pkg/api/plugin_resource.go
Expand Up @@ -15,6 +15,7 @@ import (

"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins/backendplugin"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/util/proxyutil"
"github.com/grafana/grafana/pkg/web"
Expand Down Expand Up @@ -118,6 +119,14 @@ func (hs *HTTPServer) makePluginResourceRequest(w http.ResponseWriter, req *http
hs.log.Warn("failed to unpack JSONData in datasource instance settings", "err", err)
}
}

list := contexthandler.AuthHTTPHeaderListFromContext(req.Context())
if list != nil {
for _, name := range list.Items {
req.Header.Del(name)
}
}

proxyutil.ClearCookieHeader(req, keepCookieModel.KeepCookies)
proxyutil.PrepareProxyRequest(req)

Expand Down
8 changes: 8 additions & 0 deletions pkg/api/plugins_test.go
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/grafana/grafana/pkg/infra/log/logtest"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/quota/quotatest"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/web/webtest"
Expand Down Expand Up @@ -271,6 +272,12 @@ func TestMakePluginResourceRequest(t *testing.T) {
pluginClient: pluginClient,
}
req := httptest.NewRequest(http.MethodGet, "/", nil)

const customHeader = "X-CUSTOM"
req.Header.Set(customHeader, "val")
ctx := contexthandler.WithAuthHTTPHeader(req.Context(), customHeader)
req = req.WithContext(ctx)

resp := httptest.NewRecorder()
pCtx := backend.PluginContext{}
err := hs.makePluginResourceRequest(resp, req, pCtx)
Expand All @@ -283,6 +290,7 @@ func TestMakePluginResourceRequest(t *testing.T) {
}

require.Equal(t, "sandbox", resp.Header().Get("Content-Security-Policy"))
require.Empty(t, req.Header.Get(customHeader))
}

func callGetPluginAsset(sc *scenarioContext) {
Expand Down
6 changes: 6 additions & 0 deletions pkg/middleware/middleware_basic_auth_test.go
Expand Up @@ -37,6 +37,9 @@ func TestMiddlewareBasicAuth(t *testing.T) {
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, orgID, sc.context.OrgId)
assert.Equal(t, models.ROLE_EDITOR, sc.context.OrgRole)
list := contexthandler.AuthHTTPHeaderListFromContext(sc.context.Req.Context())
require.NotNil(t, list)
require.EqualValues(t, []string{"Authorization"}, list.Items)
}, configure)

middlewareScenario(t, "Handle auth", func(t *testing.T, sc *scenarioContext) {
Expand Down Expand Up @@ -70,6 +73,9 @@ func TestMiddlewareBasicAuth(t *testing.T) {

assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, id, sc.context.UserId)
list := contexthandler.AuthHTTPHeaderListFromContext(sc.context.Req.Context())
require.NotNil(t, list)
require.EqualValues(t, []string{"Authorization"}, list.Items)
}, configure)

middlewareScenario(t, "Should return error if user is not found", func(t *testing.T, sc *scenarioContext) {
Expand Down
4 changes: 4 additions & 0 deletions pkg/middleware/middleware_jwt_auth_test.go
Expand Up @@ -6,6 +6,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/contexthandler"
Expand Down Expand Up @@ -55,6 +56,9 @@ func TestMiddlewareJWTAuth(t *testing.T) {
assert.Equal(t, orgID, sc.context.OrgId)
assert.Equal(t, id, sc.context.UserId)
assert.Equal(t, myUsername, sc.context.Login)
list := contexthandler.AuthHTTPHeaderListFromContext(sc.context.Req.Context())
require.NotNil(t, list)
require.EqualValues(t, []string{sc.cfg.JWTAuthHeaderName}, list.Items)
}, configure, configureUsernameClaim)

middlewareScenario(t, "Valid token with valid email claim", func(t *testing.T, sc *scenarioContext) {
Expand Down
5 changes: 5 additions & 0 deletions pkg/middleware/middleware_test.go
Expand Up @@ -396,6 +396,11 @@ func TestMiddlewareContext(t *testing.T) {
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserId)
assert.Equal(t, orgID, sc.context.OrgId)
list := contexthandler.AuthHTTPHeaderListFromContext(sc.context.Req.Context())
require.NotNil(t, list)
require.Contains(t, list.Items, sc.cfg.AuthProxyHeaderName)
require.Contains(t, list.Items, "X-WEBAUTH-GROUPS")
require.Contains(t, list.Items, "X-WEBAUTH-ROLE")
}, func(cfg *setting.Cfg) {
configure(cfg)
cfg.LDAPEnabled = false
Expand Down
3 changes: 3 additions & 0 deletions pkg/services/contexthandler/auth_jwt.go
Expand Up @@ -99,6 +99,9 @@ func (h *ContextHandler) initContextWithJWT(ctx *models.ReqContext, orgId int64)
return true
}

newCtx := WithAuthHTTPHeader(ctx.Req.Context(), h.Cfg.JWTAuthHeaderName)
*ctx.Req = *ctx.Req.WithContext(newCtx)

ctx.SignedInUser = query.Result
ctx.IsSignedIn = true

Expand Down
54 changes: 52 additions & 2 deletions pkg/services/contexthandler/contexthandler.go
Expand Up @@ -244,6 +244,9 @@ func (h *ContextHandler) initContextWithAPIKey(reqContext *models.ReqContext) bo
_, span := h.tracer.Start(reqContext.Req.Context(), "initContextWithAPIKey")
defer span.End()

ctx := WithAuthHTTPHeader(reqContext.Req.Context(), "Authorization")
*reqContext.Req = *reqContext.Req.WithContext(ctx)

var (
apikey *models.ApiKey
errKey error
Expand Down Expand Up @@ -326,7 +329,7 @@ func (h *ContextHandler) initContextWithBasicAuth(reqContext *models.ReqContext,
return false
}

ctx, span := h.tracer.Start(reqContext.Req.Context(), "initContextWithBasicAuth")
_, span := h.tracer.Start(reqContext.Req.Context(), "initContextWithBasicAuth")
defer span.End()

username, password, err := util.DecodeBasicAuthHeader(header)
Expand All @@ -335,12 +338,15 @@ func (h *ContextHandler) initContextWithBasicAuth(reqContext *models.ReqContext,
return true
}

ctx := WithAuthHTTPHeader(reqContext.Req.Context(), "Authorization")
*reqContext.Req = *reqContext.Req.WithContext(ctx)

authQuery := models.LoginUserQuery{
Username: username,
Password: password,
Cfg: h.Cfg,
}
if err := h.authenticator.AuthenticateUser(reqContext.Req.Context(), &authQuery); err != nil {
if err := h.authenticator.AuthenticateUser(ctx, &authQuery); err != nil {
reqContext.Logger.Debug(
"Failed to authorize the user",
"username", username,
Expand Down Expand Up @@ -571,6 +577,15 @@ func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext,

logger.Debug("Successfully got user info", "userID", user.UserId, "username", user.Login)

ctx := WithAuthHTTPHeader(reqContext.Req.Context(), h.Cfg.AuthProxyHeaderName)
for _, header := range h.Cfg.AuthProxyHeaders {
if header != "" {
ctx = WithAuthHTTPHeader(ctx, header)
}
}

*reqContext.Req = *reqContext.Req.WithContext(ctx)

// Add user info to context
reqContext.SignedInUser = user
reqContext.IsSignedIn = true
Expand All @@ -590,3 +605,38 @@ func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext,

return true
}

type authHTTPHeaderListContextKey struct{}

var authHTTPHeaderListKey = authHTTPHeaderListContextKey{}

// AuthHTTPHeaderList used to record HTTP headers that being when verifying authentication
// of an incoming HTTP request.
type AuthHTTPHeaderList struct {
Items []string
}

// WithAuthHTTPHeader returns a copy of parent in which the named HTTP header will be included
// and later retrievable by AuthHTTPHeaderListFromContext.
func WithAuthHTTPHeader(parent context.Context, name string) context.Context {
list := AuthHTTPHeaderListFromContext(parent)

if list == nil {
list = &AuthHTTPHeaderList{
Items: []string{},
}
}

list.Items = append(list.Items, name)

return context.WithValue(parent, authHTTPHeaderListKey, list)
}

// AuthHTTPHeaderListFromContext returns the AuthHTTPHeaderList in a context.Context, if any,
// and will include any HTTP headers used when verifying authentication of an incoming HTTP request.
func AuthHTTPHeaderListFromContext(c context.Context) *AuthHTTPHeaderList {
if list, ok := c.Value(authHTTPHeaderListKey).(*AuthHTTPHeaderList); ok {
return list
}
return nil
}
8 changes: 8 additions & 0 deletions pkg/util/proxyutil/reverse_proxy.go
Expand Up @@ -10,6 +10,7 @@ import (
"time"

glog "github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/services/contexthandler"
)

// StatusClientClosedRequest A non-standard status code introduced by nginx
Expand Down Expand Up @@ -66,6 +67,13 @@ func NewReverseProxy(logger glog.Logger, director func(*http.Request), opts ...R
// wrapDirector wraps a director and adds additional functionality.
func wrapDirector(d func(*http.Request)) func(req *http.Request) {
return func(req *http.Request) {
list := contexthandler.AuthHTTPHeaderListFromContext(req.Context())
if list != nil {
for _, name := range list.Items {
req.Header.Del(name)
}
}

d(req)
PrepareProxyRequest(req)

Expand Down
7 changes: 7 additions & 0 deletions pkg/util/proxyutil/reverse_proxy_test.go
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/stretchr/testify/require"
)

Expand All @@ -30,6 +31,11 @@ func TestReverseProxy(t *testing.T) {
req.Header.Set("Referer", "https://test.com/api")
req.RemoteAddr = "10.0.0.1"

const customHeader = "X-CUSTOM"
req.Header.Set(customHeader, "val")
ctx := contexthandler.WithAuthHTTPHeader(req.Context(), customHeader)
req = req.WithContext(ctx)

rp := NewReverseProxy(log.New("test"), func(req *http.Request) {
req.Header.Set("X-KEY", "value")
})
Expand All @@ -49,6 +55,7 @@ func TestReverseProxy(t *testing.T) {
require.Empty(t, resp.Cookies())
require.Equal(t, "sandbox", resp.Header.Get("Content-Security-Policy"))
require.NoError(t, resp.Body.Close())
require.Empty(t, actualReq.Header.Get(customHeader))
})

t.Run("When proxying a request using WithModifyResponse should call it before default ModifyResponse func", func(t *testing.T) {
Expand Down

0 comments on commit 9da278c

Please sign in to comment.