diff --git a/api.go b/api.go index 2ee9c10d0fd..5f2ff099c71 100644 --- a/api.go +++ b/api.go @@ -22,6 +22,7 @@ import ( "github.com/TykTechnologies/tyk/apidef" "github.com/TykTechnologies/tyk/config" + "github.com/TykTechnologies/tyk/ctx" "github.com/TykTechnologies/tyk/storage" "github.com/TykTechnologies/tyk/user" ) @@ -1784,7 +1785,7 @@ func setCtxValue(r *http.Request, key, val interface{}) { } func ctxGetData(r *http.Request) map[string]interface{} { - if v := r.Context().Value(ContextData); v != nil { + if v := r.Context().Value(ctx.ContextData); v != nil { return v.(map[string]interface{}) } return nil @@ -1794,64 +1795,38 @@ func ctxSetData(r *http.Request, m map[string]interface{}) { if m == nil { panic("setting a nil context ContextData") } - setCtxValue(r, ContextData, m) + setCtxValue(r, ctx.ContextData, m) } func ctxGetSession(r *http.Request) *user.SessionState { - if v := r.Context().Value(SessionData); v != nil { - return v.(*user.SessionState) - } - return nil + return ctx.GetSession(r) } func ctxSetSession(r *http.Request, s *user.SessionState, token string, scheduleUpdate bool) { - if s == nil { - panic("setting a nil context SessionData") - } - - if token == "" { - token = ctxGetAuthToken(r) - } - - if s.KeyHashEmpty() { - s.SetKeyHash(storage.HashKey(token)) - } - - ctx := r.Context() - ctx = context.WithValue(ctx, SessionData, s) - ctx = context.WithValue(ctx, AuthToken, token) - - if scheduleUpdate { - ctx = context.WithValue(ctx, UpdateSession, true) - } - - setContext(r, ctx) + ctx.SetSession(r, s, token, scheduleUpdate) } func ctxScheduleSessionUpdate(r *http.Request) { - setCtxValue(r, UpdateSession, true) + setCtxValue(r, ctx.UpdateSession, true) } func ctxDisableSessionUpdate(r *http.Request) { - setCtxValue(r, UpdateSession, false) + setCtxValue(r, ctx.UpdateSession, false) } func ctxSessionUpdateScheduled(r *http.Request) bool { - if v := r.Context().Value(UpdateSession); v != nil { + if v := r.Context().Value(ctx.UpdateSession); v != nil { return v.(bool) } return false } func ctxGetAuthToken(r *http.Request) string { - if v := r.Context().Value(AuthToken); v != nil { - return v.(string) - } - return "" + return ctx.GetAuthToken(r) } func ctxGetTrackedPath(r *http.Request) string { - if v := r.Context().Value(TrackThisEndpoint); v != nil { + if v := r.Context().Value(ctx.TrackThisEndpoint); v != nil { return v.(string) } return "" @@ -1861,34 +1836,34 @@ func ctxSetTrackedPath(r *http.Request, p string) { if p == "" { panic("setting a nil context TrackThisEndpoint") } - setCtxValue(r, TrackThisEndpoint, p) + setCtxValue(r, ctx.TrackThisEndpoint, p) } func ctxGetDoNotTrack(r *http.Request) bool { - return r.Context().Value(DoNotTrackThisEndpoint) == true + return r.Context().Value(ctx.DoNotTrackThisEndpoint) == true } func ctxSetDoNotTrack(r *http.Request, b bool) { - setCtxValue(r, DoNotTrackThisEndpoint, b) + setCtxValue(r, ctx.DoNotTrackThisEndpoint, b) } func ctxGetVersionInfo(r *http.Request) *apidef.VersionInfo { - if v := r.Context().Value(VersionData); v != nil { + if v := r.Context().Value(ctx.VersionData); v != nil { return v.(*apidef.VersionInfo) } return nil } func ctxSetVersionInfo(r *http.Request, v *apidef.VersionInfo) { - setCtxValue(r, VersionData, v) + setCtxValue(r, ctx.VersionData, v) } func ctxSetOrigRequestURL(r *http.Request, url *url.URL) { - setCtxValue(r, OrigRequestURL, url) + setCtxValue(r, ctx.OrigRequestURL, url) } func ctxGetOrigRequestURL(r *http.Request) *url.URL { - if v := r.Context().Value(OrigRequestURL); v != nil { + if v := r.Context().Value(ctx.OrigRequestURL); v != nil { if urlVal, ok := v.(*url.URL); ok { return urlVal } @@ -1898,11 +1873,11 @@ func ctxGetOrigRequestURL(r *http.Request) *url.URL { } func ctxSetUrlRewritePath(r *http.Request, path string) { - setCtxValue(r, UrlRewritePath, path) + setCtxValue(r, ctx.UrlRewritePath, path) } func ctxGetUrlRewritePath(r *http.Request) string { - if v := r.Context().Value(UrlRewritePath); v != nil { + if v := r.Context().Value(ctx.UrlRewritePath); v != nil { if strVal, ok := v.(string); ok { return strVal } @@ -1911,7 +1886,7 @@ func ctxGetUrlRewritePath(r *http.Request) string { } func ctxSetCheckLoopLimits(r *http.Request, b bool) { - setCtxValue(r, CheckLoopLimits, b) + setCtxValue(r, ctx.CheckLoopLimits, b) } // Should we check Rate limits and Quotas? @@ -1921,7 +1896,7 @@ func ctxCheckLimits(r *http.Request) bool { return true } - if v := r.Context().Value(CheckLoopLimits); v != nil { + if v := r.Context().Value(ctx.CheckLoopLimits); v != nil { return v.(bool) } @@ -1929,11 +1904,11 @@ func ctxCheckLimits(r *http.Request) bool { } func ctxSetRequestMethod(r *http.Request, path string) { - setCtxValue(r, RequestMethod, path) + setCtxValue(r, ctx.RequestMethod, path) } func ctxGetRequestMethod(r *http.Request) string { - if v := r.Context().Value(RequestMethod); v != nil { + if v := r.Context().Value(ctx.RequestMethod); v != nil { if strVal, ok := v.(string); ok { return strVal } @@ -1942,11 +1917,11 @@ func ctxGetRequestMethod(r *http.Request) string { } func ctxGetDefaultVersion(r *http.Request) bool { - return r.Context().Value(VersionDefault) != nil + return r.Context().Value(ctx.VersionDefault) != nil } func ctxSetDefaultVersion(r *http.Request) { - setCtxValue(r, VersionDefault, true) + setCtxValue(r, ctx.VersionDefault, true) } func ctxLoopingEnabled(r *http.Request) bool { @@ -1954,7 +1929,7 @@ func ctxLoopingEnabled(r *http.Request) bool { } func ctxLoopLevel(r *http.Request) int { - if v := r.Context().Value(LoopLevel); v != nil { + if v := r.Context().Value(ctx.LoopLevel); v != nil { if intVal, ok := v.(int); ok { return intVal } @@ -1964,7 +1939,7 @@ func ctxLoopLevel(r *http.Request) int { } func ctxSetLoopLevel(r *http.Request, value int) { - setCtxValue(r, LoopLevel, value) + setCtxValue(r, ctx.LoopLevel, value) } func ctxIncLoopLevel(r *http.Request, loopLimit int) { @@ -1973,7 +1948,7 @@ func ctxIncLoopLevel(r *http.Request, loopLimit int) { } func ctxLoopLevelLimit(r *http.Request) int { - if v := r.Context().Value(LoopLevelLimit); v != nil { + if v := r.Context().Value(ctx.LoopLevelLimit); v != nil { if intVal, ok := v.(int); ok { return intVal } @@ -1985,12 +1960,12 @@ func ctxLoopLevelLimit(r *http.Request) int { func ctxSetLoopLimit(r *http.Request, limit int) { // Can be set only one time per request if ctxLoopLevelLimit(r) == 0 && limit > 0 { - setCtxValue(r, LoopLevelLimit, limit) + setCtxValue(r, ctx.LoopLevelLimit, limit) } } func ctxThrottleLevelLimit(r *http.Request) int { - if v := r.Context().Value(ThrottleLevelLimit); v != nil { + if v := r.Context().Value(ctx.ThrottleLevelLimit); v != nil { if intVal, ok := v.(int); ok { return intVal } @@ -2000,7 +1975,7 @@ func ctxThrottleLevelLimit(r *http.Request) int { } func ctxThrottleLevel(r *http.Request) int { - if v := r.Context().Value(ThrottleLevel); v != nil { + if v := r.Context().Value(ctx.ThrottleLevel); v != nil { if intVal, ok := v.(int); ok { return intVal } @@ -2012,12 +1987,12 @@ func ctxThrottleLevel(r *http.Request) int { func ctxSetThrottleLimit(r *http.Request, limit int) { // Can be set only one time per request if ctxThrottleLevelLimit(r) == 0 && limit > 0 { - setCtxValue(r, ThrottleLevelLimit, limit) + setCtxValue(r, ctx.ThrottleLevelLimit, limit) } } func ctxSetThrottleLevel(r *http.Request, value int) { - setCtxValue(r, ThrottleLevel, value) + setCtxValue(r, ctx.ThrottleLevel, value) } func ctxIncThrottleLevel(r *http.Request, throttleLimit int) { @@ -2026,9 +2001,9 @@ func ctxIncThrottleLevel(r *http.Request, throttleLimit int) { } func ctxTraceEnabled(r *http.Request) bool { - return r.Context().Value(Trace) != nil + return r.Context().Value(ctx.Trace) != nil } func ctxSetTrace(r *http.Request) { - setCtxValue(r, Trace, true) + setCtxValue(r, ctx.Trace, true) } diff --git a/api_loader.go b/api_loader.go index d0a55ee472c..113ae07a801 100644 --- a/api_loader.go +++ b/api_loader.go @@ -288,9 +288,6 @@ func processSpec(spec *APISpec, apisByListen map[string]int, BaseMiddleware: baseMid, Path: obj.Path, SymbolName: obj.Name, - Pre: true, - UseSession: obj.RequireSession, - Auth: false, }, ) } else { @@ -329,9 +326,6 @@ func processSpec(spec *APISpec, apisByListen map[string]int, BaseMiddleware: baseMid, Path: obj.Path, SymbolName: obj.Name, - Pre: false, - UseSession: obj.RequireSession, - Auth: false, }, ) } else { @@ -360,9 +354,6 @@ func processSpec(spec *APISpec, apisByListen map[string]int, BaseMiddleware: baseMid, Path: obj.Path, SymbolName: obj.Name, - Pre: true, - UseSession: obj.RequireSession, - Auth: false, }, ) } else { @@ -428,9 +419,6 @@ func processSpec(spec *APISpec, apisByListen map[string]int, BaseMiddleware: baseMid, Path: mwAuthCheckFunc.Path, SymbolName: mwAuthCheckFunc.Name, - Pre: true, - UseSession: false, - Auth: true, }, ) } @@ -450,9 +438,6 @@ func processSpec(spec *APISpec, apisByListen map[string]int, BaseMiddleware: baseMid, Path: obj.Path, SymbolName: obj.Name, - Pre: false, - UseSession: obj.RequireSession, - Auth: false, }, ) } else { @@ -486,9 +471,6 @@ func processSpec(spec *APISpec, apisByListen map[string]int, BaseMiddleware: baseMid, Path: obj.Path, SymbolName: obj.Name, - Pre: false, - UseSession: obj.RequireSession, - Auth: false, }, ) } else { @@ -581,7 +563,7 @@ func (d *DummyProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { handler = targetAPI.middlewareChain.ThisHandler } else { handler := ErrorHandler{*d.SH.Base()} - handler.HandleError(w, r, "Can't detect loop target", http.StatusInternalServerError) + handler.HandleError(w, r, "Can't detect loop target", http.StatusInternalServerError, true) return } } diff --git a/ctx/ctx.go b/ctx/ctx.go new file mode 100644 index 00000000000..5daa1f04dab --- /dev/null +++ b/ctx/ctx.go @@ -0,0 +1,79 @@ +package ctx + +import ( + "context" + "net/http" + + "github.com/TykTechnologies/tyk/storage" + "github.com/TykTechnologies/tyk/user" +) + +const ( + SessionData = iota + UpdateSession + AuthToken + HashedAuthToken + VersionData + VersionDefault + OrgSessionContext + ContextData + RetainHost + TrackThisEndpoint + DoNotTrackThisEndpoint + UrlRewritePath + RequestMethod + OrigRequestURL + LoopLevel + LoopLevelLimit + ThrottleLevel + ThrottleLevelLimit + Trace + CheckLoopLimits +) + +func setContext(r *http.Request, ctx context.Context) { + r2 := r.WithContext(ctx) + *r = *r2 +} + +func ctxSetSession(r *http.Request, s *user.SessionState, token string, scheduleUpdate bool) { + if s == nil { + panic("setting a nil context SessionData") + } + + if token == "" { + token = GetAuthToken(r) + } + + if s.KeyHashEmpty() { + s.SetKeyHash(storage.HashKey(token)) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, SessionData, s) + ctx = context.WithValue(ctx, AuthToken, token) + + if scheduleUpdate { + ctx = context.WithValue(ctx, UpdateSession, true) + } + + setContext(r, ctx) +} + +func GetAuthToken(r *http.Request) string { + if v := r.Context().Value(AuthToken); v != nil { + return v.(string) + } + return "" +} + +func GetSession(r *http.Request) *user.SessionState { + if v := r.Context().Value(SessionData); v != nil { + return v.(*user.SessionState) + } + return nil +} + +func SetSession(r *http.Request, s *user.SessionState, token string, scheduleUpdate bool) { + ctxSetSession(r, s, token, scheduleUpdate) +} diff --git a/goplugin/goplugin.go b/goplugin/goplugin.go index 51f86b75f4c..cff4b15da00 100644 --- a/goplugin/goplugin.go +++ b/goplugin/goplugin.go @@ -1,49 +1,31 @@ +// +build !nogoplugin + package goplugin import ( + "errors" "net/http" - - "github.com/TykTechnologies/tyk/user" + "plugin" ) -// Logger provides interface to output to Tyk's logging system with log levels INFO, DEBUG, WARN and ERROR -type Logger interface { - Info(args ...interface{}) - Infof(format string, args ...interface{}) - Infoln(args ...interface{}) - - Debug(args ...interface{}) - Debugf(format string, args ...interface{}) - Debugln(args ...interface{}) - - Warning(args ...interface{}) - Warningf(format string, args ...interface{}) - Warningln(args ...interface{}) - - Error(args ...interface{}) - Errorf(format string, args ...interface{}) - Errorln(args ...interface{}) +func GetHandler(path string, symbol string) (http.HandlerFunc, error) { + // try to load plugin + loadedPlugin, err := plugin.Open(path) + if err != nil { + return nil, err + } + + // try to lookup function symbol + funcSymbol, err := loadedPlugin.Lookup(symbol) + if err != nil { + return nil, err + } + + // try to cast symbol to real func + pluginHandler, ok := funcSymbol.(func(http.ResponseWriter, *http.Request)) + if !ok { + return nil, errors.New("could not cast function symbol to http.HandlerFunc") + } + + return pluginHandler, nil } - -type APISpec struct { - OrgID string - APIID string - ConfigData map[string]interface{} -} - -// ProcessFunc type functions are called for "pre", "post", "post_key_auth" custom middleware methods -type ProcessFunc func( - http.ResponseWriter, - *http.Request, - *user.SessionState, - APISpec, - Logger, -) error - -// AuthFunc type function is called for "auth_check" custom middleware method -type AuthFunc func( - http.ResponseWriter, - *http.Request, - APISpec, - Logger, -) (session *user.SessionState, token string, err error) diff --git a/goplugin/no_goplugin.go b/goplugin/no_goplugin.go new file mode 100644 index 00000000000..46606bf7d1d --- /dev/null +++ b/goplugin/no_goplugin.go @@ -0,0 +1,12 @@ +// +build nogoplugin + +package goplugin + +import ( + "fmt" + "net/http" +) + +func GetHandler(path string, symbol string) (http.HandlerFunc, error) { + return nil, fmt.Errorf("goplugin.GetHandler is disabled, please disable build flag 'nogoplugin'") +} diff --git a/handler_success.go b/handler_success.go index 5f79d909e89..802cfe0c70e 100644 --- a/handler_success.go +++ b/handler_success.go @@ -13,35 +13,11 @@ import ( cache "github.com/pmylund/go-cache" "github.com/TykTechnologies/tyk/config" + "github.com/TykTechnologies/tyk/ctx" "github.com/TykTechnologies/tyk/request" "github.com/TykTechnologies/tyk/user" ) -// Enums for keys to be stored in a session context - this is how gorilla expects -// these to be implemented and is lifted pretty much from docs -const ( - SessionData = iota - UpdateSession - AuthToken - HashedAuthToken - VersionData - VersionDefault - OrgSessionContext - ContextData - RetainHost - TrackThisEndpoint - DoNotTrackThisEndpoint - UrlRewritePath - RequestMethod - OrigRequestURL - LoopLevel - LoopLevelLimit - ThrottleLevel - ThrottleLevelLimit - Trace - CheckLoopLimits -) - const ( keyDataDeveloperID = "tyk_developer_id" keyDataDeveloperEmail = "tyk_developer_email" @@ -274,7 +250,7 @@ func recordDetail(r *http.Request, globalConf config.Config) bool { } // We are, so get session data - ses := r.Context().Value(OrgSessionContext) + ses := r.Context().Value(ctx.OrgSessionContext) if ses == nil { // no session found, use global config return globalConf.AnalyticsConfig.EnableDetailedRecording diff --git a/mw_go_plugin.go b/mw_go_plugin.go index 1cbb2a93d46..b69892ff899 100644 --- a/mw_go_plugin.go +++ b/mw_go_plugin.go @@ -3,12 +3,10 @@ package main import ( "fmt" "net/http" - "plugin" "github.com/Sirupsen/logrus" "github.com/TykTechnologies/tyk/goplugin" - "github.com/TykTechnologies/tyk/user" ) // customResponseWriter is a wrapper around standard http.ResponseWriter @@ -36,14 +34,10 @@ func (w *customResponseWriter) WriteHeader(statusCode int) { // GoPluginMiddleware is a generic middleware that will execute Go-plugin code before continuing type GoPluginMiddleware struct { BaseMiddleware - Path string // path to .so file - SymbolName string // function symbol to look up - Pre bool - UseSession bool - Auth bool - mwProcessFunc goplugin.ProcessFunc - mwAuthFunc goplugin.AuthFunc - logger *logrus.Entry + Path string // path to .so file + SymbolName string // function symbol to look up + handler http.HandlerFunc + logger *logrus.Entry } func (m *GoPluginMiddleware) Name() string { @@ -54,48 +48,17 @@ func (m *GoPluginMiddleware) EnabledForSpec() bool { m.logger = log.WithFields(logrus.Fields{ "mwPath": m.Path, "mwSymbolName": m.SymbolName, - "isAuth": m.Auth, }) - if m.mwProcessFunc != nil || m.mwAuthFunc != nil { + if m.handler != nil { m.logger.Info("Go-plugin middleware is already initialized") return true } // try to load plugin - loadedPlugin, err := plugin.Open(m.Path) - if err != nil { - m.logger.WithError(err).Error("Could not load plugin") - return false - } - - // try to lookup function symbol - funcSymbol, err := loadedPlugin.Lookup(m.SymbolName) - if err != nil { - m.logger.WithError(err).Error("Could not look up symbol in loaded plugin") - return false - } - - // try to cast symbol to real func - var ok bool - if m.Auth { - m.mwAuthFunc, ok = funcSymbol.(func( - http.ResponseWriter, - *http.Request, - goplugin.APISpec, - goplugin.Logger, - ) (session *user.SessionState, token string, err error)) - } else { - m.mwProcessFunc, ok = funcSymbol.(func( - http.ResponseWriter, - *http.Request, - *user.SessionState, - goplugin.APISpec, - goplugin.Logger, - ) error) - } - if !ok { - m.logger.Error("Could not cast function symbol") + var err error + if m.handler, err = goplugin.GetHandler(m.Path, m.SymbolName); err != nil { + m.logger.WithError(err).Error("Could not load Go-plugin") return false } @@ -114,19 +77,12 @@ func (m *GoPluginMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Reque // prepare data to call Go-plugin function + // get session hash before Go-plugin function call var prevMD5Hash string - var session *user.SessionState - if m.UseSession && !m.Pre && !m.Auth { // pass session if requested in meta and it is not auth_check or pre-process - session = ctxGetSession(r) + if session := ctxGetSession(r); session != nil { prevMD5Hash = session.MD5Hash() } - apiSpec := goplugin.APISpec{ - OrgID: m.Spec.OrgID, - APIID: m.Spec.APIID, - ConfigData: m.Spec.ConfigData, - } - // make sure request's body can be re-read again nopCloseRequestBody(r) @@ -135,42 +91,27 @@ func (m *GoPluginMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Reque ResponseWriter: w, } - // run Go-plugin function - if m.Auth { - newSession, token, authErr := m.mwAuthFunc(rw, r, apiSpec, m.logger) - if authErr != nil { - err = authErr - } else { - // add to context session and token created my custom middleware - // schedule update so new session will be stored - ctxSetSession(r, newSession, token, true) - } - } else { - err = m.mwProcessFunc(rw, r, session, apiSpec, m.logger) - if err == nil { - // check if session was passed to custom middleware and modified - if session != nil && prevMD5Hash != session.MD5Hash() { - ctxScheduleSessionUpdate(r) - } + // call Go-plugin function + m.handler(rw, r) + + // check if we need to schedule session update in case session was updated by Go-plugin + // but update wasn't scheduled + if prevMD5Hash != "" { + if session := ctxGetSession(r); session != nil && prevMD5Hash != session.MD5Hash() { + ctxScheduleSessionUpdate(r) } } - // process returned error - if err != nil { - if rw.responseSent { + // check if response was sent + if rw.responseSent { + // check if response code was an error one + if rw.statusCodeSent >= http.StatusBadRequest { respCode = rw.statusCodeSent + err = fmt.Errorf("plugin function sent error response code: %d", rw.statusCodeSent) + m.logger.WithError(err).Error("Failed to process request with Go-plugin middleware func") } else { - m.logger.Warning("Go-plugin func returned error but didn't send response. Forcing 500 status") - w.WriteHeader(http.StatusInternalServerError) - respCode = http.StatusInternalServerError + respCode = mwStatusRespond // no need to continue passing this request down to reverse proxy } - m.logger.WithError(err).Error("Failed to run Go-plugin middleware func") - return - } - - // no errors, check if response was sent - if rw.responseSent { - respCode = mwStatusRespond // no need to continue passing this request down to reverse proxy } else { respCode = http.StatusOK } diff --git a/mw_organisation_activity.go b/mw_organisation_activity.go index 22d2bc63049..e219952a0e3 100644 --- a/mw_organisation_activity.go +++ b/mw_organisation_activity.go @@ -1,13 +1,12 @@ package main import ( + "errors" "net/http" "sync" - - "errors" - "time" + "github.com/TykTechnologies/tyk/ctx" "github.com/TykTechnologies/tyk/request" "github.com/TykTechnologies/tyk/user" ) @@ -171,7 +170,7 @@ func (k *OrganizationMonitor) ProcessRequestLive(r *http.Request, orgSession use } // Lets keep a reference of the org - setCtxValue(r, OrgSessionContext, orgSession) + setCtxValue(r, ctx.OrgSessionContext, orgSession) // Request is valid, carry on return nil, http.StatusOK @@ -198,7 +197,7 @@ func (k *OrganizationMonitor) ProcessRequestOffThread(r *http.Request, orgSessio // Lets keep a reference of the org // session might be updated by go-routine AllowAccessNext and we loose those changes here // but it is OK as we need it in context for detailed org logging - setCtxValue(r, OrgSessionContext, orgSession) + setCtxValue(r, ctx.OrgSessionContext, orgSession) orgSessionCopy := orgSession go k.AllowAccessNext( diff --git a/mw_url_rewrite.go b/mw_url_rewrite.go index 597e4b9bce7..ac4ffe876ff 100644 --- a/mw_url_rewrite.go +++ b/mw_url_rewrite.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/TykTechnologies/tyk/apidef" + "github.com/TykTechnologies/tyk/ctx" "github.com/TykTechnologies/tyk/regexp" "github.com/TykTechnologies/tyk/user" ) @@ -327,7 +328,7 @@ func (m *URLRewriteMiddleware) CheckHostRewrite(oldPath, newTarget string, r *ht newAsURL, _ := url.Parse(newTarget) if newAsURL.Scheme != LoopScheme && oldAsURL.Host != newAsURL.Host { log.Debug("Detected a host rewrite in pattern!") - setCtxValue(r, RetainHost, true) + setCtxValue(r, ctx.RetainHost, true) } } diff --git a/reverse_proxy.go b/reverse_proxy.go index 7cd57768c99..bb78202eb3e 100644 --- a/reverse_proxy.go +++ b/reverse_proxy.go @@ -33,6 +33,7 @@ import ( "github.com/TykTechnologies/tyk/apidef" "github.com/TykTechnologies/tyk/config" + "github.com/TykTechnologies/tyk/ctx" "github.com/TykTechnologies/tyk/regexp" "github.com/TykTechnologies/tyk/user" ) @@ -217,7 +218,7 @@ func TykNewSingleHostReverseProxy(target *url.URL, spec *APISpec) *ReverseProxy targetToUse := target - if spec.URLRewriteEnabled && req.Context().Value(RetainHost) == true { + if spec.URLRewriteEnabled && req.Context().Value(ctx.RetainHost) == true { log.Debug("Detected host rewrite, overriding target") tmpTarget, err := url.Parse(req.URL.String()) if err != nil { @@ -552,17 +553,17 @@ func (p *ReverseProxy) WrappedServeHTTP(rw http.ResponseWriter, req *http.Reques } p.TykAPISpec.Unlock() - ctx := req.Context() + reqCtx := req.Context() if cn, ok := rw.(http.CloseNotifier); ok { var cancel context.CancelFunc - ctx, cancel = context.WithCancel(ctx) + reqCtx, cancel = context.WithCancel(reqCtx) defer cancel() notifyChan := cn.CloseNotify() go func() { select { case <-notifyChan: cancel() - case <-ctx.Done(): + case <-reqCtx.Done(): } }() } @@ -582,15 +583,15 @@ func (p *ReverseProxy) WrappedServeHTTP(rw http.ResponseWriter, req *http.Reques log.Debug("UPSTREAM REQUEST URL: ", req.URL) // We need to double set the context for the outbound request to reprocess the target - if p.TykAPISpec.URLRewriteEnabled && req.Context().Value(RetainHost) == true { + if p.TykAPISpec.URLRewriteEnabled && req.Context().Value(ctx.RetainHost) == true { log.Debug("Detected host rewrite, notifying director") - setCtxValue(outreq, RetainHost, true) + setCtxValue(outreq, ctx.RetainHost, true) } if req.ContentLength == 0 { outreq.Body = nil // Issue 16036: nil Body for http.Transport retries } - outreq = outreq.WithContext(ctx) + outreq = outreq.WithContext(reqCtx) outreq.Header = cloneHeader(req.Header)