Skip to content

Commit

Permalink
allow transform response when path is equal to path of url rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
dencoded committed Dec 8, 2017
1 parent a0772e5 commit 98d9fdc
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 2 deletions.
13 changes: 13 additions & 0 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1460,3 +1460,16 @@ func ctxSetVersionInfo(r *http.Request, v *apidef.VersionInfo) {
}
setCtxValue(r, VersionData, v)
}

func ctxSetUrlRewritePath(r *http.Request, path string) {
setCtxValue(r, UrlRewritePath, path)
}

func ctxGetUrlRewritePath(r *http.Request) string {
if v := r.Context().Value(UrlRewritePath); v != nil {
if strVal, ok := v.(string); ok {
return strVal
}
}
return ""
}
13 changes: 11 additions & 2 deletions api_definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -863,10 +863,19 @@ func (a *APISpec) URLAllowedAndIgnored(r *http.Request, rxPaths []URLSpec, white
func (a *APISpec) CheckSpecMatchesStatus(r *http.Request, rxPaths []URLSpec, mode URLStatus) (bool, interface{}) {
// Check if ignored
for _, v := range rxPaths {
if mode != v.Status {
continue
}
match := v.Spec.MatchString(r.URL.Path)
// only return it it's what we are looking for
if !match || mode != v.Status {
continue
if !match {
// check for special case when using url_rewrites with transform_response
// and specifying the same "path" expression
if mode != TransformedResponse {
continue
} else if v.TransformResponseAction.Path != ctxGetUrlRewritePath(r) {
continue
}
}
switch v.Status {
case Ignored, BlackList, WhiteList, Cached:
Expand Down
2 changes: 2 additions & 0 deletions gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ func getChain(spec *APISpec) http.Handler {
}
proxy := TykNewSingleHostReverseProxy(remote, spec)
proxyHandler := ProxyHandler(proxy, spec)
creeateResponseMiddlewareChain(spec)
baseMid := BaseMiddleware{spec, proxy}
chain := alice.New(mwList(
&IPWhiteListMiddleware{baseMid},
Expand All @@ -316,6 +317,7 @@ func getChain(spec *APISpec) http.Handler {
&AccessRightsCheck{baseMid},
&RateLimitAndQuotaCheck{baseMid},
&TransformHeaders{baseMid},
&URLRewriteMiddleware{baseMid},
)...).Then(proxyHandler)
return chain
}
Expand Down
1 change: 1 addition & 0 deletions handler_success.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ const (
RetainHost
TrackThisEndpoint
DoNotTrackThisEndpoint
UrlRewritePath
)

var SessionCache = cache.New(10*time.Second, 5*time.Second)
Expand Down
3 changes: 3 additions & 0 deletions mw_url_rewrite.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ func urlRewrite(meta *apidef.URLRewriteMeta, r *http.Request) (string, error) {
log.Debug("URL Re-written from: ", path)
log.Debug("URL Re-written to: ", newpath)

// put url_rewrite path to context to be used in ResponseTransformMiddleware
ctxSetUrlRewritePath(r, meta.Path)

// matched?? Set the modified path
// return newpath, nil
}
Expand Down
174 changes: 174 additions & 0 deletions res_handler_transform_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
package main

import (
"encoding/base64"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
)

func TestTransformResponseWithURLRewrite(t *testing.T) {
testTemplateBlob := base64.StdEncoding.EncodeToString([]byte(`{"http_method":"{{.Method}}"}`))

testData := map[string]struct {
apiSpec string
url string
expectedCode int
expectedBody string
}{
"just_transform_response": {
apiSpec: `
{
"api_id": "1",
"auth": {"auth_header_name": "authorization"},
"version_data": {
"not_versioned": true,
"versions": {
"v1": {
"name": "v1",
"use_extended_paths": true,
"extended_paths": {
"transform_response": [
{
"path": "get",
"method": "GET",
"template_data": {
"template_mode": "blob",
"template_source": "` + testTemplateBlob + `"
}
}
]
}
}
}
},
"response_processors":[{"name": "response_body_transform"}],
"proxy": {
"listen_path": "/v1",
"target_url": "` + testHttpAny + `"
}
}
`,
url: "/v1/get",
expectedCode: http.StatusOK,
expectedBody: `{"http_method":"GET"}`,
},
"transform_path_equal_to_rewrite_to": {
apiSpec: `
{
"api_id": "1",
"auth": {"auth_header_name": "authorization"},
"version_data": {
"not_versioned": true,
"versions": {
"v1": {
"name": "v1",
"use_extended_paths": true,
"extended_paths": {
"url_rewrites": [
{
"path": "abc",
"method": "GET",
"match_pattern": "abc",
"rewrite_to": "get"
}
],
"transform_response": [
{
"path": "get",
"method": "GET",
"template_data": {
"template_mode": "blob",
"template_source": "` + testTemplateBlob + `"
}
}
]
}
}
}
},
"response_processors":[{"name": "response_body_transform"}],
"proxy": {
"listen_path": "/v1",
"target_url": "` + testHttpAny + `"
}
}
`,
url: "/v1/abc",
expectedCode: http.StatusOK,
expectedBody: `{"http_method":"GET"}`,
},
"transform_path_equal_to_rewrite_path": {
apiSpec: `
{
"api_id": "1",
"auth": {"auth_header_name": "authorization"},
"version_data": {
"not_versioned": true,
"versions": {
"v1": {
"name": "v1",
"use_extended_paths": true,
"extended_paths": {
"url_rewrites": [
{
"path": "abc",
"method": "GET",
"match_pattern": "abc",
"rewrite_to": "get"
}
],
"transform_response": [
{
"path": "abc",
"method": "GET",
"template_data": {
"template_mode": "blob",
"template_source": "` + testTemplateBlob + `"
}
}
]
}
}
}
},
"response_processors":[{"name": "response_body_transform"}],
"proxy": {
"listen_path": "/v1",
"target_url": "` + testHttpAny + `"
}
}
`,
url: "/v1/abc",
expectedCode: http.StatusOK,
expectedBody: `{"http_method":"GET"}`,
},
}

for testName, test := range testData {
spec := createSpecTest(t, test.apiSpec)
session := createNonThrottledSession()
spec.SessionManager.UpdateSession("1234wer", session, 60)

recorder := httptest.NewRecorder()
req := testReq(t, http.MethodGet, test.url, nil)
req.Header.Set("authorization", "1234wer")
req.RemoteAddr = "127.0.0.1:80"

chain := getChain(spec)
chain.ServeHTTP(recorder, req)

if recorder.Code != 200 {
t.Fatalf("[%s] Invalid response code %d, should be 200\n", testName, recorder.Code)
}

// check that body was transformed
resp := recorder.Result()
bodyData, _ := ioutil.ReadAll(resp.Body)
body := string(bodyData)
if body != test.expectedBody {
t.Fatalf("[%s] Expected response body: '%s' Got response body: %s\n", testName, test.expectedBody, body)
}
}
}

0 comments on commit 98d9fdc

Please sign in to comment.