Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TT-11683]: fixed header forwarding #6174

Merged
merged 2 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ spec:
trigger:
type: http
httpRequest:
method: GET
method: POST
url: tyk:8080/test-graphql-tracing-invalid/test-graphql-tracing-invalid
body: "{\n \"query\": \"{\\n country(code: \\\"NG\\\"){\\n name\\n }\\n}\"\n}"
headers:
Expand Down
41 changes: 36 additions & 5 deletions gateway/mw_graphql_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,37 @@ func DetermineGraphQLEngineTransportType(apiSpec *APISpec) GraphQLEngineTranspor
return GraphQLEngineTransportTypeMultiUpstream
}

type contextKey struct{}

var graphqlProxyContextInfo = contextKey{}

type GraphQLProxyOnlyContextValues struct {
forwardedRequest *http.Request
upstreamResponse *http.Response
ignoreForwardedHeaders map[string]bool
}

func SetProxyOnlyContextValue(ctx context.Context, req *http.Request) context.Context {
value := &GraphQLProxyOnlyContextValues{
forwardedRequest: req,
ignoreForwardedHeaders: map[string]bool{
http.CanonicalHeaderKey("date"): true,
http.CanonicalHeaderKey("content-type"): true,
http.CanonicalHeaderKey("content-length"): true,
},
}

return context.WithValue(ctx, graphqlProxyContextInfo, value)
}

func GetProxyOnlyContextValue(ctx context.Context) *GraphQLProxyOnlyContextValues {
val, ok := ctx.Value(graphqlProxyContextInfo).(*GraphQLProxyOnlyContextValues)
if !ok {
return nil
}
return val
}

type GraphQLProxyOnlyContext struct {
context.Context
forwardedRequest *http.Request
Expand Down Expand Up @@ -71,16 +102,16 @@ func NewGraphQLEngineTransport(transportType GraphQLEngineTransportType, origina
func (g *GraphQLEngineTransport) RoundTrip(request *http.Request) (res *http.Response, err error) {
switch g.transportType {
case GraphQLEngineTransportTypeProxyOnly:
proxyOnlyCtx, ok := request.Context().(*GraphQLProxyOnlyContext)
if ok {
return g.handleProxyOnly(proxyOnlyCtx, request)
val := GetProxyOnlyContextValue(request.Context())
if val != nil {
return g.handleProxyOnly(val, request)
}
}

return g.originalTransport.RoundTrip(request)
}

func (g *GraphQLEngineTransport) handleProxyOnly(proxyOnlyCtx *GraphQLProxyOnlyContext, request *http.Request) (*http.Response, error) {
func (g *GraphQLEngineTransport) handleProxyOnly(proxyOnlyCtx *GraphQLProxyOnlyContextValues, request *http.Request) (*http.Response, error) {
request.Method = proxyOnlyCtx.forwardedRequest.Method
g.setProxyOnlyHeaders(proxyOnlyCtx, request)

Expand Down Expand Up @@ -113,7 +144,7 @@ func (g *GraphQLEngineTransport) handleProxyOnly(proxyOnlyCtx *GraphQLProxyOnlyC
return response, err
}

func (g *GraphQLEngineTransport) setProxyOnlyHeaders(proxyOnlyCtx *GraphQLProxyOnlyContext, r *http.Request) {
func (g *GraphQLEngineTransport) setProxyOnlyHeaders(proxyOnlyCtx *GraphQLProxyOnlyContextValues, r *http.Request) {
for forwardedHeaderKey, forwardedHeaderValues := range proxyOnlyCtx.forwardedRequest.Header {
if proxyOnlyCtx.ignoreForwardedHeaders[forwardedHeaderKey] {
continue
Expand Down
2 changes: 1 addition & 1 deletion gateway/mw_graphql_transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func TestGraphQLEngineTransport_RoundTrip(t *testing.T) {

forwardedRequest.Header.Set("X-Custom-Key", "custom-value")
forwardedRequest.Header.Set("X-Other-Value", "other-value")
ctx := NewGraphQLProxyOnlyContext(context.Background(), forwardedRequest)
ctx := SetProxyOnlyContextValue(context.Background(), forwardedRequest)

httpClient := http.Client{
Transport: NewGraphQLEngineTransport(GraphQLEngineTransportTypeProxyOnly, http.DefaultTransport),
Expand Down
39 changes: 36 additions & 3 deletions gateway/reverse_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (
"sync"
"time"

"github.com/TykTechnologies/graphql-go-tools/pkg/engine/datasource/httpclient"

"github.com/buger/jsonparser"

"github.com/gorilla/websocket"
Expand Down Expand Up @@ -1087,7 +1089,7 @@ func (p *ReverseProxy) handleGraphQLEngineWebsocketUpgrade(roundTripper *TykRoun
return nil, true, nil
}

func returnErrorsFromUpstream(proxyOnlyCtx *GraphQLProxyOnlyContext, resultWriter *graphql.EngineResultWriter) error {
func returnErrorsFromUpstream(proxyOnlyCtx *GraphQLProxyOnlyContextValues, resultWriter *graphql.EngineResultWriter) error {
body, ok := proxyOnlyCtx.upstreamResponse.Body.(*nopCloserBuffer)
if !ok {
// Response body already read by graphql-go-tools, and it's not re-readable. Quit silently.
Expand Down Expand Up @@ -1151,7 +1153,7 @@ func (p *ReverseProxy) handoverRequestToGraphQLExecutionEngine(roundTripper *Tyk
span := otel.SpanFromContext(outreq.Context())
reqCtx := otel.ContextWithSpan(context.Background(), span)
if isProxyOnly {
reqCtx = NewGraphQLProxyOnlyContext(reqCtx, outreq)
reqCtx = SetProxyOnlyContextValue(reqCtx, outreq)
}

resultWriter := graphql.NewEngineResultWriter()
Expand Down Expand Up @@ -1179,7 +1181,7 @@ func (p *ReverseProxy) handoverRequestToGraphQLExecutionEngine(roundTripper *Tyk
header.Set("Content-Type", "application/json")

if isProxyOnly {
proxyOnlyCtx := reqCtx.(*GraphQLProxyOnlyContext)
proxyOnlyCtx := GetProxyOnlyContextValue(reqCtx)
// There is a case in the proxy-only mode where the request can be handled
// by the library without calling the upstream.
// This is a valid query for proxy-only mode: query { __typename }
Expand All @@ -1188,6 +1190,9 @@ func (p *ReverseProxy) handoverRequestToGraphQLExecutionEngine(roundTripper *Tyk
if proxyOnlyCtx.upstreamResponse != nil {
header = proxyOnlyCtx.upstreamResponse.Header
httpStatus = proxyOnlyCtx.upstreamResponse.StatusCode
// change the value of the header's content encoding to use the content encoding defined by the accept encoding
contentEncoding := selectContentEncodingToBeUsed(proxyOnlyCtx.forwardedRequest.Header.Get(httpclient.AcceptEncodingHeader))
header.Set(httpclient.ContentEncodingHeader, contentEncoding)
if p.TykAPISpec.GraphQL.Proxy.UseResponseExtensions.OnErrorForwarding && httpStatus >= http.StatusBadRequest {
err = returnErrorsFromUpstream(proxyOnlyCtx, &resultWriter)
if err != nil {
Expand All @@ -1204,6 +1209,34 @@ func (p *ReverseProxy) handoverRequestToGraphQLExecutionEngine(roundTripper *Tyk
return nil, false, errors.New("graphql configuration is invalid")
}

// selectContentEncodingToBeUsed selects the encoding value to be returned based on the IETF standards
// if acceptedEncoding is a list of comma separated strings br,gzip, deflate; then it selects the first supported one
// if it is a single value then it returns that value
// if no supported encoding is found, it returns the last value
func selectContentEncodingToBeUsed(acceptedEncoding string) string {
supportedHeaders := map[string]struct{}{
"gzip": {},
"deflate": {},
"br": {},
}

values := strings.Split(acceptedEncoding, ",")
if len(values) < 2 {
return values[0]
}

for i, e := range values {
enc := strings.TrimSpace(e)
if _, ok := supportedHeaders[enc]; ok {
return enc
}
if i == len(values)-1 {
return enc
}
}
return ""
}

func (p *ReverseProxy) handoverWebSocketConnectionToGraphQLExecutionEngine(roundTripper *TykRoundTripper, conn net.Conn, req *http.Request) {
p.TykAPISpec.GraphQLExecutor.Client.Transport = NewGraphQLEngineTransport(DetermineGraphQLEngineTransportType(p.TykAPISpec), roundTripper)

Expand Down
38 changes: 38 additions & 0 deletions gateway/reverse_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,44 @@ func TestGraphQL_ProxyOnlyHeaders(t *testing.T) {
})
}

func TestGraphQL_ProxyOnlyPassHeadersWithOTel(t *testing.T) {
g := StartTest(func(globalConf *config.Config) {
globalConf.OpenTelemetry.Enabled = true
})
defer g.Close()

spec := BuildAPI(func(spec *APISpec) {
spec.Name = "tyk-api"
spec.APIID = "tyk-api"
spec.GraphQL.Enabled = true
spec.GraphQL.ExecutionMode = apidef.GraphQLExecutionModeProxyOnly
spec.GraphQL.Schema = gqlCountriesSchema
spec.GraphQL.Version = apidef.GraphQLConfigVersion2
spec.Proxy.TargetURL = TestHttpAny + "/dynamic"
spec.Proxy.ListenPath = "/"
})[0]

g.Gw.LoadAPI(spec)
g.AddDynamicHandler("/dynamic", func(writer http.ResponseWriter, r *http.Request) {
if gotten := r.Header.Get("custom-client-header"); gotten != "custom-value" {
t.Errorf("expected upstream to recieve header `custom-client-header` with value of `custom-value`, instead got %s", gotten)
}
})

_, err := g.Run(t, test.TestCase{
Path: "/",
Headers: map[string]string{
"custom-client-header": "custom-value",
},
Method: http.MethodPost,
Data: graphql.Request{
Query: gqlContinentQuery,
},
})

assert.NoError(t, err)
}

func TestGraphQL_InternalDataSource(t *testing.T) {
g := StartTest(nil)
defer g.Close()
Expand Down
Loading