diff --git a/.gitignore b/.gitignore index 94f7bb3dead0..d1c9a90dc3af 100644 --- a/.gitignore +++ b/.gitignore @@ -42,6 +42,7 @@ petstore.json *.pdf *.mmdb *.cov +*.so !testdata/*.mmdb *.pid coprocess_gen_test.go diff --git a/apidef/api_definitions.go b/apidef/api_definitions.go index 0d194fbc04b8..44bdf90f8bf7 100644 --- a/apidef/api_definitions.go +++ b/apidef/api_definitions.go @@ -38,10 +38,11 @@ const ( RequestXML RequestInputType = "xml" RequestJSON RequestInputType = "json" - OttoDriver MiddlewareDriver = "otto" - PythonDriver MiddlewareDriver = "python" - LuaDriver MiddlewareDriver = "lua" - GrpcDriver MiddlewareDriver = "grpc" + OttoDriver MiddlewareDriver = "otto" + PythonDriver MiddlewareDriver = "python" + LuaDriver MiddlewareDriver = "lua" + GrpcDriver MiddlewareDriver = "grpc" + GoPluginDriver MiddlewareDriver = "goplugin" BodySource IdExtractorSource = "body" HeaderSource IdExtractorSource = "header" @@ -362,6 +363,7 @@ type APIDefinition struct { PinnedPublicKeys map[string]string `bson:"pinned_public_keys" json:"pinned_public_keys"` EnableJWT bool `bson:"enable_jwt" json:"enable_jwt"` UseStandardAuth bool `bson:"use_standard_auth" json:"use_standard_auth"` + UseGoPluginAuth bool `bson:"use_go_plugin_auth" json:"use_go_plugin_auth"` EnableCoProcessAuth bool `bson:"enable_coprocess_auth" json:"enable_coprocess_auth"` JWTSigningMethod string `bson:"jwt_signing_method" json:"jwt_signing_method"` JWTSource string `bson:"jwt_source" json:"jwt_source"` diff --git a/apidef/schema.go b/apidef/schema.go index 71ac10a40cfd..7fa7fd91e189 100644 --- a/apidef/schema.go +++ b/apidef/schema.go @@ -48,9 +48,12 @@ const Schema = `{ "openid_options": { "type": ["object", "null"] }, - "use_standard_auth":{ + "use_standard_auth": { "type": "boolean" }, + "use_go_plugin_auth": { + "type": "boolean" + }, "enable_coprocess_auth": { "type": "boolean" }, diff --git a/bin/ci-test.sh b/bin/ci-test.sh index 42f4366982ea..49bf42a65058 100755 --- a/bin/ci-test.sh +++ b/bin/ci-test.sh @@ -2,8 +2,8 @@ MATRIX=( - "-tags 'coprocess python'" - "-tags 'coprocess grpc'" + "-tags 'coprocess python goplugin'" + "-tags 'coprocess grpc goplugin'" ) TEST_TIMEOUT=2m @@ -40,6 +40,9 @@ i=0 go get -t +# build Go-plugin used in tests +go build -o ./test/goplugins/goplugins.so -buildmode=plugin ./test/goplugins || fatal "building Go-plugin failed" + # need to do per-pkg because go test doesn't support a single coverage # profile for multiple pkgs for pkg in $PKGS; do @@ -55,7 +58,13 @@ if [[ ! $LATEST_GO ]]; then exit 0 fi +# build Go-plugin used in tests but with race support +mv ./test/goplugins/goplugins.so ./test/goplugins/goplugins_old.so +go build -race -o ./test/goplugins/goplugins.so -buildmode=plugin ./test/goplugins \ + || fatal "building Go-plugin with race failed" + go test -race -v -timeout $TEST_TIMEOUT $PKGS || fatal "go test -race failed" +mv ./test/goplugins/goplugins_old.so ./test/goplugins/goplugins.so for opts in "${MATRIX[@]}"; do show go vet $opts $PKGS || fatal "go vet errored" diff --git a/ctx/ctx.go b/ctx/ctx.go new file mode 100644 index 000000000000..5daa1f04dab8 --- /dev/null +++ b/ctx/ctx.go @@ -0,0 +1,79 @@ +package ctx + +import ( + "context" + "net/http" + + "github.com/TykTechnologies/tyk/storage" + "github.com/TykTechnologies/tyk/user" +) + +const ( + SessionData = iota + UpdateSession + AuthToken + HashedAuthToken + VersionData + VersionDefault + OrgSessionContext + ContextData + RetainHost + TrackThisEndpoint + DoNotTrackThisEndpoint + UrlRewritePath + RequestMethod + OrigRequestURL + LoopLevel + LoopLevelLimit + ThrottleLevel + ThrottleLevelLimit + Trace + CheckLoopLimits +) + +func setContext(r *http.Request, ctx context.Context) { + r2 := r.WithContext(ctx) + *r = *r2 +} + +func ctxSetSession(r *http.Request, s *user.SessionState, token string, scheduleUpdate bool) { + if s == nil { + panic("setting a nil context SessionData") + } + + if token == "" { + token = GetAuthToken(r) + } + + if s.KeyHashEmpty() { + s.SetKeyHash(storage.HashKey(token)) + } + + ctx := r.Context() + ctx = context.WithValue(ctx, SessionData, s) + ctx = context.WithValue(ctx, AuthToken, token) + + if scheduleUpdate { + ctx = context.WithValue(ctx, UpdateSession, true) + } + + setContext(r, ctx) +} + +func GetAuthToken(r *http.Request) string { + if v := r.Context().Value(AuthToken); v != nil { + return v.(string) + } + return "" +} + +func GetSession(r *http.Request) *user.SessionState { + if v := r.Context().Value(SessionData); v != nil { + return v.(*user.SessionState) + } + return nil +} + +func SetSession(r *http.Request, s *user.SessionState, token string, scheduleUpdate bool) { + ctxSetSession(r, s, token, scheduleUpdate) +} diff --git a/gateway/api.go b/gateway/api.go index ccd4010247c5..8708f834bdea 100644 --- a/gateway/api.go +++ b/gateway/api.go @@ -48,6 +48,7 @@ import ( "github.com/TykTechnologies/tyk/apidef" "github.com/TykTechnologies/tyk/config" + "github.com/TykTechnologies/tyk/ctx" "github.com/TykTechnologies/tyk/storage" "github.com/TykTechnologies/tyk/user" ) @@ -1884,7 +1885,7 @@ func setCtxValue(r *http.Request, key, val interface{}) { } func ctxGetData(r *http.Request) map[string]interface{} { - if v := r.Context().Value(ContextData); v != nil { + if v := r.Context().Value(ctx.ContextData); v != nil { return v.(map[string]interface{}) } return nil @@ -1894,64 +1895,38 @@ func ctxSetData(r *http.Request, m map[string]interface{}) { if m == nil { panic("setting a nil context ContextData") } - setCtxValue(r, ContextData, m) + setCtxValue(r, ctx.ContextData, m) } func ctxGetSession(r *http.Request) *user.SessionState { - if v := r.Context().Value(SessionData); v != nil { - return v.(*user.SessionState) - } - return nil + return ctx.GetSession(r) } func ctxSetSession(r *http.Request, s *user.SessionState, token string, scheduleUpdate bool) { - if s == nil { - panic("setting a nil context SessionData") - } - - if token == "" { - token = ctxGetAuthToken(r) - } - - if s.KeyHashEmpty() { - s.SetKeyHash(storage.HashKey(token)) - } - - ctx := r.Context() - ctx = context.WithValue(ctx, SessionData, s) - ctx = context.WithValue(ctx, AuthToken, token) - - if scheduleUpdate { - ctx = context.WithValue(ctx, UpdateSession, true) - } - - setContext(r, ctx) + ctx.SetSession(r, s, token, scheduleUpdate) } func ctxScheduleSessionUpdate(r *http.Request) { - setCtxValue(r, UpdateSession, true) + setCtxValue(r, ctx.UpdateSession, true) } func ctxDisableSessionUpdate(r *http.Request) { - setCtxValue(r, UpdateSession, false) + setCtxValue(r, ctx.UpdateSession, false) } func ctxSessionUpdateScheduled(r *http.Request) bool { - if v := r.Context().Value(UpdateSession); v != nil { + if v := r.Context().Value(ctx.UpdateSession); v != nil { return v.(bool) } return false } func ctxGetAuthToken(r *http.Request) string { - if v := r.Context().Value(AuthToken); v != nil { - return v.(string) - } - return "" + return ctx.GetAuthToken(r) } func ctxGetTrackedPath(r *http.Request) string { - if v := r.Context().Value(TrackThisEndpoint); v != nil { + if v := r.Context().Value(ctx.TrackThisEndpoint); v != nil { return v.(string) } return "" @@ -1961,34 +1936,34 @@ func ctxSetTrackedPath(r *http.Request, p string) { if p == "" { panic("setting a nil context TrackThisEndpoint") } - setCtxValue(r, TrackThisEndpoint, p) + setCtxValue(r, ctx.TrackThisEndpoint, p) } func ctxGetDoNotTrack(r *http.Request) bool { - return r.Context().Value(DoNotTrackThisEndpoint) == true + return r.Context().Value(ctx.DoNotTrackThisEndpoint) == true } func ctxSetDoNotTrack(r *http.Request, b bool) { - setCtxValue(r, DoNotTrackThisEndpoint, b) + setCtxValue(r, ctx.DoNotTrackThisEndpoint, b) } func ctxGetVersionInfo(r *http.Request) *apidef.VersionInfo { - if v := r.Context().Value(VersionData); v != nil { + if v := r.Context().Value(ctx.VersionData); v != nil { return v.(*apidef.VersionInfo) } return nil } func ctxSetVersionInfo(r *http.Request, v *apidef.VersionInfo) { - setCtxValue(r, VersionData, v) + setCtxValue(r, ctx.VersionData, v) } func ctxSetOrigRequestURL(r *http.Request, url *url.URL) { - setCtxValue(r, OrigRequestURL, url) + setCtxValue(r, ctx.OrigRequestURL, url) } func ctxGetOrigRequestURL(r *http.Request) *url.URL { - if v := r.Context().Value(OrigRequestURL); v != nil { + if v := r.Context().Value(ctx.OrigRequestURL); v != nil { if urlVal, ok := v.(*url.URL); ok { return urlVal } @@ -1998,11 +1973,11 @@ func ctxGetOrigRequestURL(r *http.Request) *url.URL { } func ctxSetUrlRewritePath(r *http.Request, path string) { - setCtxValue(r, UrlRewritePath, path) + setCtxValue(r, ctx.UrlRewritePath, path) } func ctxGetUrlRewritePath(r *http.Request) string { - if v := r.Context().Value(UrlRewritePath); v != nil { + if v := r.Context().Value(ctx.UrlRewritePath); v != nil { if strVal, ok := v.(string); ok { return strVal } @@ -2011,7 +1986,7 @@ func ctxGetUrlRewritePath(r *http.Request) string { } func ctxSetCheckLoopLimits(r *http.Request, b bool) { - setCtxValue(r, CheckLoopLimits, b) + setCtxValue(r, ctx.CheckLoopLimits, b) } // Should we check Rate limits and Quotas? @@ -2021,7 +1996,7 @@ func ctxCheckLimits(r *http.Request) bool { return true } - if v := r.Context().Value(CheckLoopLimits); v != nil { + if v := r.Context().Value(ctx.CheckLoopLimits); v != nil { return v.(bool) } @@ -2029,11 +2004,11 @@ func ctxCheckLimits(r *http.Request) bool { } func ctxSetRequestMethod(r *http.Request, path string) { - setCtxValue(r, RequestMethod, path) + setCtxValue(r, ctx.RequestMethod, path) } func ctxGetRequestMethod(r *http.Request) string { - if v := r.Context().Value(RequestMethod); v != nil { + if v := r.Context().Value(ctx.RequestMethod); v != nil { if strVal, ok := v.(string); ok { return strVal } @@ -2042,11 +2017,11 @@ func ctxGetRequestMethod(r *http.Request) string { } func ctxGetDefaultVersion(r *http.Request) bool { - return r.Context().Value(VersionDefault) != nil + return r.Context().Value(ctx.VersionDefault) != nil } func ctxSetDefaultVersion(r *http.Request) { - setCtxValue(r, VersionDefault, true) + setCtxValue(r, ctx.VersionDefault, true) } func ctxLoopingEnabled(r *http.Request) bool { @@ -2054,7 +2029,7 @@ func ctxLoopingEnabled(r *http.Request) bool { } func ctxLoopLevel(r *http.Request) int { - if v := r.Context().Value(LoopLevel); v != nil { + if v := r.Context().Value(ctx.LoopLevel); v != nil { if intVal, ok := v.(int); ok { return intVal } @@ -2064,7 +2039,7 @@ func ctxLoopLevel(r *http.Request) int { } func ctxSetLoopLevel(r *http.Request, value int) { - setCtxValue(r, LoopLevel, value) + setCtxValue(r, ctx.LoopLevel, value) } func ctxIncLoopLevel(r *http.Request, loopLimit int) { @@ -2073,7 +2048,7 @@ func ctxIncLoopLevel(r *http.Request, loopLimit int) { } func ctxLoopLevelLimit(r *http.Request) int { - if v := r.Context().Value(LoopLevelLimit); v != nil { + if v := r.Context().Value(ctx.LoopLevelLimit); v != nil { if intVal, ok := v.(int); ok { return intVal } @@ -2085,12 +2060,12 @@ func ctxLoopLevelLimit(r *http.Request) int { func ctxSetLoopLimit(r *http.Request, limit int) { // Can be set only one time per request if ctxLoopLevelLimit(r) == 0 && limit > 0 { - setCtxValue(r, LoopLevelLimit, limit) + setCtxValue(r, ctx.LoopLevelLimit, limit) } } func ctxThrottleLevelLimit(r *http.Request) int { - if v := r.Context().Value(ThrottleLevelLimit); v != nil { + if v := r.Context().Value(ctx.ThrottleLevelLimit); v != nil { if intVal, ok := v.(int); ok { return intVal } @@ -2100,7 +2075,7 @@ func ctxThrottleLevelLimit(r *http.Request) int { } func ctxThrottleLevel(r *http.Request) int { - if v := r.Context().Value(ThrottleLevel); v != nil { + if v := r.Context().Value(ctx.ThrottleLevel); v != nil { if intVal, ok := v.(int); ok { return intVal } @@ -2112,12 +2087,12 @@ func ctxThrottleLevel(r *http.Request) int { func ctxSetThrottleLimit(r *http.Request, limit int) { // Can be set only one time per request if ctxThrottleLevelLimit(r) == 0 && limit > 0 { - setCtxValue(r, ThrottleLevelLimit, limit) + setCtxValue(r, ctx.ThrottleLevelLimit, limit) } } func ctxSetThrottleLevel(r *http.Request, value int) { - setCtxValue(r, ThrottleLevel, value) + setCtxValue(r, ctx.ThrottleLevel, value) } func ctxIncThrottleLevel(r *http.Request, throttleLimit int) { @@ -2126,9 +2101,9 @@ func ctxIncThrottleLevel(r *http.Request, throttleLimit int) { } func ctxTraceEnabled(r *http.Request) bool { - return r.Context().Value(Trace) != nil + return r.Context().Value(ctx.Trace) != nil } func ctxSetTrace(r *http.Request) { - setCtxValue(r, Trace, true) + setCtxValue(r, ctx.Trace, true) } diff --git a/gateway/api_definition.go b/gateway/api_definition.go index 1554556c70b2..8d4b3f1e6f89 100644 --- a/gateway/api_definition.go +++ b/gateway/api_definition.go @@ -102,7 +102,7 @@ const ( StatusURLRewrite RequestStatus = "URL Rewritten" StatusVirtualPath RequestStatus = "Virtual Endpoint" StatusRequestSizeControlled RequestStatus = "Request Size Limited" - StatusRequesTracked RequestStatus = "Request Tracked" + StatusRequestTracked RequestStatus = "Request Tracked" StatusRequestNotTracked RequestStatus = "Request Not Tracked" StatusValidateJSON RequestStatus = "Validate JSON" StatusInternal RequestStatus = "Internal path" @@ -971,7 +971,7 @@ func (a *APISpec) getURLStatus(stat URLStatus) RequestStatus { case MethodTransformed: return StatusMethodTransformed case RequestTracked: - return StatusRequesTracked + return StatusRequestTracked case RequestNotTracked: return StatusRequestNotTracked case ValidateJSONRequest: diff --git a/gateway/api_loader.go b/gateway/api_loader.go index 8656a1fc0de9..42dc98391536 100644 --- a/gateway/api_loader.go +++ b/gateway/api_loader.go @@ -184,10 +184,10 @@ func processSpec(spec *APISpec, apisByListen map[string]int, sessionStore = rpcAuthStore } - // Health checkers are initialised per spec so that each API handler has it's own connection and redis sotorage pool + // Health checkers are initialised per spec so that each API handler has it's own connection and redis storage pool spec.Init(authStore, sessionStore, healthStore, orgStore) - //Set up all the JSVM middleware + // Set up all the JSVM middleware var mwAuthCheckFunc apidef.MiddlewareDefinition mwPreFuncs := []apidef.MiddlewareDefinition{} mwPostFuncs := []apidef.MiddlewareDefinition{} @@ -282,6 +282,15 @@ func processSpec(spec *APISpec, apisByListen map[string]int, coprocessLog.Debug("Registering coprocess middleware, hook name: ", obj.Name, "hook type: Pre", ", driver: ", mwDriver) mwAppendEnabled(&chainArray, &CoProcessMiddleware{baseMid, coprocess.HookType_Pre, obj.Name, mwDriver}) + } else if mwDriver == apidef.GoPluginDriver { + mwAppendEnabled( + &chainArray, + &GoPluginMiddleware{ + BaseMiddleware: baseMid, + Path: obj.Path, + SymbolName: obj.Name, + }, + ) } else { chainArray = append(chainArray, createDynamicMiddleware(obj.Name, true, obj.RequireSession, baseMid)) } @@ -321,6 +330,7 @@ func processSpec(spec *APISpec, apisByListen map[string]int, coprocessAuth := EnableCoProcess && mwDriver != apidef.OttoDriver && spec.EnableCoProcessAuth ottoAuth := !coprocessAuth && mwDriver == apidef.OttoDriver && spec.EnableCoProcessAuth + gopluginAuth := !coprocessAuth && !ottoAuth && mwDriver == apidef.GoPluginDriver && spec.UseGoPluginAuth if coprocessAuth { // TODO: check if mwAuthCheckFunc is available/valid @@ -336,6 +346,17 @@ func processSpec(spec *APISpec, apisByListen map[string]int, authArray = append(authArray, createDynamicMiddleware(mwAuthCheckFunc.Name, true, false, baseMid)) } + if gopluginAuth { + mwAppendEnabled( + &authArray, + &GoPluginMiddleware{ + BaseMiddleware: baseMid, + Path: mwAuthCheckFunc.Path, + SymbolName: mwAuthCheckFunc.Name, + }, + ) + } + if spec.UseStandardAuth || len(authArray) == 0 { logger.Info("Checking security policy: Token") authArray = append(authArray, createMiddleware(&AuthKey{baseMid})) @@ -344,8 +365,19 @@ func processSpec(spec *APISpec, apisByListen map[string]int, chainArray = append(chainArray, authArray...) for _, obj := range mwPostAuthCheckFuncs { - coprocessLog.Debug("Registering coprocess middleware, hook name: ", obj.Name, "hook type: Pre", ", driver: ", mwDriver) - mwAppendEnabled(&chainArray, &CoProcessMiddleware{baseMid, coprocess.HookType_PostKeyAuth, obj.Name, mwDriver}) + if mwDriver == apidef.GoPluginDriver { + mwAppendEnabled( + &chainArray, + &GoPluginMiddleware{ + BaseMiddleware: baseMid, + Path: obj.Path, + SymbolName: obj.Name, + }, + ) + } else { + coprocessLog.Debug("Registering coprocess middleware, hook name: ", obj.Name, "hook type: Pre", ", driver: ", mwDriver) + mwAppendEnabled(&chainArray, &CoProcessMiddleware{baseMid, coprocess.HookType_PostKeyAuth, obj.Name, mwDriver}) + } } mwAppendEnabled(&chainArray, &StripAuth{baseMid}) @@ -369,6 +401,15 @@ func processSpec(spec *APISpec, apisByListen map[string]int, if mwDriver != apidef.OttoDriver { coprocessLog.Debug("Registering coprocess middleware, hook name: ", obj.Name, "hook type: Post", ", driver: ", mwDriver) mwAppendEnabled(&chainArray, &CoProcessMiddleware{baseMid, coprocess.HookType_Post, obj.Name, mwDriver}) + } else if mwDriver == apidef.GoPluginDriver { + mwAppendEnabled( + &chainArray, + &GoPluginMiddleware{ + BaseMiddleware: baseMid, + Path: obj.Path, + SymbolName: obj.Name, + }, + ) } else { chainArray = append(chainArray, createDynamicMiddleware(obj.Name, false, obj.RequireSession, baseMid)) } @@ -439,7 +480,7 @@ func (d *DummyProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if found, err := isLoop(r); found { if err != nil { handler := ErrorHandler{*d.SH.Base()} - handler.HandleError(w, r, err.Error(), http.StatusInternalServerError) + handler.HandleError(w, r, err.Error(), http.StatusInternalServerError, true) return } @@ -458,7 +499,7 @@ func (d *DummyProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { handler = targetAPI.middlewareChain } else { handler := ErrorHandler{*d.SH.Base()} - handler.HandleError(w, r, "Can't detect loop target", http.StatusInternalServerError) + handler.HandleError(w, r, "Can't detect loop target", http.StatusInternalServerError, true) return } } diff --git a/gateway/handler_error.go b/gateway/handler_error.go index 5ceaea7445e5..c8a351d46857 100644 --- a/gateway/handler_error.go +++ b/gateway/handler_error.go @@ -30,45 +30,57 @@ type ErrorHandler struct { } // HandleError is the actual error handler and will store the error details in analytics if analytics processing is enabled. -func (e *ErrorHandler) HandleError(w http.ResponseWriter, r *http.Request, errMsg string, errCode int) { +func (e *ErrorHandler) HandleError(w http.ResponseWriter, r *http.Request, errMsg string, errCode int, writeResponse bool) { defer e.Base().UpdateRequestSession(r) - var templateExtension string - var contentType string + if writeResponse { + var templateExtension string + var contentType string + + switch r.Header.Get("Content-Type") { + case "application/xml": + templateExtension = "xml" + contentType = "application/xml" + default: + templateExtension = "json" + contentType = "application/json" + } - switch r.Header.Get("Content-Type") { - case "application/xml": - templateExtension = "xml" - contentType = "application/xml" - default: - templateExtension = "json" - contentType = "application/json" - } + w.Header().Set("Content-Type", contentType) - w.Header().Set("Content-Type", contentType) + templateName := "error_" + strconv.Itoa(errCode) + "." + templateExtension - templateName := "error_" + strconv.Itoa(errCode) + "." + templateExtension + // Try to use an error template that matches the HTTP error code and the content type: 500.json, 400.xml, etc. + tmpl := templates.Lookup(templateName) - // Try to use an error template that matches the HTTP error code and the content type: 500.json, 400.xml, etc. - tmpl := templates.Lookup(templateName) + // Fallback to a generic error template, but match the content type: error.json, error.xml, etc. + if tmpl == nil { + templateName = defaultTemplateName + "." + templateExtension + tmpl = templates.Lookup(templateName) + } - // Fallback to a generic error template, but match the content type: error.json, error.xml, etc. - if tmpl == nil { - templateName = defaultTemplateName + "." + templateExtension - tmpl = templates.Lookup(templateName) - } + // If no template is available for this content type, fallback to "error.json". + if tmpl == nil { + templateName = defaultTemplateName + "." + defaultTemplateFormat + tmpl = templates.Lookup(templateName) + w.Header().Set("Content-Type", defaultContentType) + } - // If no template is available for this content type, fallback to "error.json". - if tmpl == nil { - templateName = defaultTemplateName + "." + defaultTemplateFormat - tmpl = templates.Lookup(templateName) - w.Header().Set("Content-Type", defaultContentType) - } + //If the config option is not set or is false, add the header + if !e.Spec.GlobalConfig.HideGeneratorHeader { + w.Header().Add("X-Generator", "tyk.io") + } - // Need to return the correct error code! - w.WriteHeader(errCode) - apiError := APIError{errMsg} - tmpl.Execute(w, &apiError) + // Close connections + if e.Spec.GlobalConfig.CloseConnections { + w.Header().Add("Connection", "close") + } + + // Need to return the correct error code! + w.WriteHeader(errCode) + apiError := APIError{errMsg} + tmpl.Execute(w, &apiError) + } if memProfFile != nil { pprof.WriteHeapProfile(memProfFile) @@ -192,16 +204,6 @@ func (e *ErrorHandler) HandleError(w http.ResponseWriter, r *http.Request, errMs // Report in health check reportHealthValue(e.Spec, BlockedRequestLog, "-1") - //If the config option is not set or is false, add the header - if !e.Spec.GlobalConfig.HideGeneratorHeader { - w.Header().Add("X-Generator", "tyk.io") - } - - // Close connections - if e.Spec.GlobalConfig.CloseConnections { - w.Header().Add("Connection", "close") - } - if memProfFile != nil { pprof.WriteHeapProfile(memProfFile) } diff --git a/gateway/handler_success.go b/gateway/handler_success.go index 46c04fb9515e..39a3bf3c69d7 100644 --- a/gateway/handler_success.go +++ b/gateway/handler_success.go @@ -10,35 +10,12 @@ import ( "strings" "time" + cache "github.com/pmylund/go-cache" + "github.com/TykTechnologies/tyk/config" + "github.com/TykTechnologies/tyk/ctx" "github.com/TykTechnologies/tyk/request" "github.com/TykTechnologies/tyk/user" - cache "github.com/pmylund/go-cache" -) - -// Enums for keys to be stored in a session context - this is how gorilla expects -// these to be implemented and is lifted pretty much from docs -const ( - SessionData = iota - UpdateSession - AuthToken - HashedAuthToken - VersionData - VersionDefault - OrgSessionContext - ContextData - RetainHost - TrackThisEndpoint - DoNotTrackThisEndpoint - UrlRewritePath - RequestMethod - OrigRequestURL - LoopLevel - LoopLevelLimit - ThrottleLevel - ThrottleLevelLimit - Trace - CheckLoopLimits ) const ( @@ -281,7 +258,7 @@ func recordDetail(r *http.Request, globalConf config.Config) bool { } // We are, so get session data - ses := r.Context().Value(OrgSessionContext) + ses := r.Context().Value(ctx.OrgSessionContext) if ses == nil { // no session found, use global config return globalConf.AnalyticsConfig.EnableDetailedRecording diff --git a/gateway/middleware.go b/gateway/middleware.go index 6fb8f8047f71..180c06afd640 100644 --- a/gateway/middleware.go +++ b/gateway/middleware.go @@ -94,8 +94,12 @@ func createMiddleware(mw TykMiddleware) func(http.Handler) http.Handler { } err, errCode := mw.ProcessRequest(w, r, mwConf) if err != nil { + // GoPluginMiddleware are expected to send response in case of error + // but we still want to record error + _, isGoPlugin := mw.(*GoPluginMiddleware) + handler := ErrorHandler{*mw.Base()} - handler.HandleError(w, r, err.Error(), errCode) + handler.HandleError(w, r, err.Error(), errCode, !isGoPlugin) meta["error"] = err.Error() diff --git a/gateway/mw_go_plugin.go b/gateway/mw_go_plugin.go new file mode 100644 index 000000000000..855d8b62b8c9 --- /dev/null +++ b/gateway/mw_go_plugin.go @@ -0,0 +1,164 @@ +package gateway + +import ( + "bytes" + "fmt" + "io/ioutil" + "net/http" + "time" + + "github.com/Sirupsen/logrus" + + "github.com/TykTechnologies/tyk/goplugin" +) + +// customResponseWriter is a wrapper around standard http.ResponseWriter +// plus it tracks if response was sent and what status code was sent +type customResponseWriter struct { + http.ResponseWriter + responseSent bool + statusCodeSent int + copyData bool + data []byte + dataLength int64 +} + +func (w *customResponseWriter) Write(b []byte) (int, error) { + w.responseSent = true + if w.statusCodeSent == 0 { + w.statusCodeSent = http.StatusOK // no WriteHeader was called so it will be set to StatusOK in actual ResponseWriter + } + + // send actual data + num, err := w.ResponseWriter.Write(b) + + // copy data sent + if w.copyData { + if w.data == nil { + w.data = make([]byte, num) + copy(w.data, b[:num]) + } else { + w.data = append(w.data, b[:num]...) + } + } + + // count how many bytes we sent + w.dataLength += int64(num) + + return num, err +} + +func (w *customResponseWriter) WriteHeader(statusCode int) { + w.responseSent = true + w.statusCodeSent = statusCode + w.ResponseWriter.WriteHeader(statusCode) +} + +func (w *customResponseWriter) getHttpResponse(r *http.Request) *http.Response { + // craft response on the fly for analytics + httpResponse := &http.Response{ + Status: http.StatusText(w.statusCodeSent), + StatusCode: w.statusCodeSent, + Header: w.ResponseWriter.Header(), // TODO: worth to think about trailer headers + Proto: r.Proto, + ProtoMajor: r.ProtoMajor, + ProtoMinor: r.ProtoMinor, + Request: r, + ContentLength: w.dataLength, + } + if w.copyData { + httpResponse.Body = ioutil.NopCloser(bytes.NewReader(w.data)) + } + + return httpResponse +} + +// GoPluginMiddleware is a generic middleware that will execute Go-plugin code before continuing +type GoPluginMiddleware struct { + BaseMiddleware + Path string // path to .so file + SymbolName string // function symbol to look up + handler http.HandlerFunc + logger *logrus.Entry + successHandler *SuccessHandler // to record analytics +} + +func (m *GoPluginMiddleware) Name() string { + return "GoPluginMiddleware: " + m.Path + ":" + m.SymbolName +} + +func (m *GoPluginMiddleware) EnabledForSpec() bool { + m.logger = log.WithFields(logrus.Fields{ + "mwPath": m.Path, + "mwSymbolName": m.SymbolName, + }) + + if m.handler != nil { + m.logger.Info("Go-plugin middleware is already initialized") + return true + } + + // try to load plugin + var err error + if m.handler, err = goplugin.GetHandler(m.Path, m.SymbolName); err != nil { + m.logger.WithError(err).Error("Could not load Go-plugin") + return false + } + + // to record 2XX hits in analytics + m.successHandler = &SuccessHandler{BaseMiddleware: m.BaseMiddleware} + + return true +} + +func (m *GoPluginMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Request, conf interface{}) (err error, respCode int) { + // make sure tyk recover in case Go-plugin function panics + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf("%v", e) + respCode = http.StatusInternalServerError + m.logger.WithError(err).Error("Recovered from panic while running Go-plugin middleware func") + } + }() + + // prepare data to call Go-plugin function + + // make sure request's body can be re-read again + nopCloseRequestBody(r) + + // wrap ResponseWriter to check if response was sent + rw := &customResponseWriter{ + ResponseWriter: w, + copyData: recordDetail(r, m.Spec.GlobalConfig), + } + + // call Go-plugin function + t1 := time.Now() + m.handler(rw, r) + t2 := time.Now() + + // calculate latency + ms := float64(t2.UnixNano()-t1.UnixNano()) * 0.000001 + m.logger.WithField("ms", ms).Debug("Go-plugin request processing took") + + // check if response was sent + if rw.responseSent { + // check if response code was an error one + if rw.statusCodeSent >= http.StatusBadRequest { + // base middleware will report this error to analytics if needed + respCode = rw.statusCodeSent + err = fmt.Errorf("plugin function sent error response code: %d", rw.statusCodeSent) + m.logger.WithError(err).Error("Failed to process request with Go-plugin middleware func") + } else { + // record 2XX to analytics + m.successHandler.RecordHit(r, int64(ms), rw.statusCodeSent, rw.getHttpResponse(r)) + + // no need to continue passing this request down to reverse proxy + respCode = mwStatusRespond + } + } else { + respCode = http.StatusOK + } + + return +} diff --git a/gateway/mw_go_plugin_test.go b/gateway/mw_go_plugin_test.go new file mode 100644 index 000000000000..9d8e3e944686 --- /dev/null +++ b/gateway/mw_go_plugin_test.go @@ -0,0 +1,89 @@ +// +build goplugin + +package gateway + +import ( + "net/http" + "testing" + "time" + + "github.com/TykTechnologies/tyk/apidef" + "github.com/TykTechnologies/tyk/test" +) + +// TestGoPluginMWs tests all possible Go-plugin MWs ("pre", "auth_check", "post_key_auth" and "post") +// Please see ./test/goplugins/test_goplugins.go for plugin implementation details +func TestGoPluginMWs(t *testing.T) { + ts := StartTest() + defer ts.Close() + + BuildAndLoadAPI(func(spec *APISpec) { + spec.APIID = "plugin_api" + spec.Proxy.ListenPath = "/goplugin" + spec.UseKeylessAccess = false + spec.UseStandardAuth = false + spec.UseGoPluginAuth = true + spec.CustomMiddleware = apidef.MiddlewareSection{ + Driver: apidef.GoPluginDriver, + Pre: []apidef.MiddlewareDefinition{ + { + Name: "MyPluginPre", + Path: "../test/goplugins/goplugins.so", + }, + }, + AuthCheck: apidef.MiddlewareDefinition{ + Name: "MyPluginAuthCheck", + Path: "../test/goplugins/goplugins.so", + }, + PostKeyAuth: []apidef.MiddlewareDefinition{ + { + Name: "MyPluginPostKeyAuth", + Path: "../test/goplugins/goplugins.so", + }, + }, + Post: []apidef.MiddlewareDefinition{ + { + Name: "MyPluginPost", + Path: "../test/goplugins/goplugins.so", + }, + }, + } + }) + + time.Sleep(1 * time.Second) + + t.Run("Run Go-plugin auth failed", func(t *testing.T) { + ts.Run(t, []test.TestCase{ + { + Path: "/goplugin/plugin_hit", + Headers: map[string]string{"Authorization": "invalid_token"}, + HeadersMatch: map[string]string{ + "X-Auth-Result": "failed", + }, + Code: http.StatusForbidden, + }, + }...) + }) + + t.Run("Run Go-plugin all middle-wares", func(t *testing.T) { + ts.Run(t, []test.TestCase{ + { + Path: "/goplugin/plugin_hit", + Headers: map[string]string{"Authorization": "abc"}, + Code: http.StatusOK, + HeadersMatch: map[string]string{ + "X-Initial-URI": "/goplugin/plugin_hit", + "X-Auth-Result": "OK", + "X-Session-Alias": "abc-session", + }, + BodyMatch: `"message":"post message"`, + }, + { + Method: "DELETE", + Path: "/tyk/keys/abc", + AdminAuth: true, + Code: http.StatusOK, + BodyMatch: `"action":"deleted"`}, + }...) + }) +} diff --git a/gateway/mw_js_plugin.go b/gateway/mw_js_plugin.go index 7db6014401ed..3832024aa0da 100644 --- a/gateway/mw_js_plugin.go +++ b/gateway/mw_js_plugin.go @@ -91,12 +91,13 @@ func (d *DynamicMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Reques logger := d.Logger() // Create the proxy object - defer r.Body.Close() originalBody, err := ioutil.ReadAll(r.Body) if err != nil { logger.WithError(err).Error("Failed to read request body") return nil, http.StatusOK } + defer r.Body.Close() + headers := r.Header host := r.Host if host == "" && r.URL != nil { @@ -239,7 +240,7 @@ func (d *DynamicMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Reques r.URL.RawQuery = values.Encode() - // Save the sesison data (if modified) + // Save the session data (if modified) if !d.Pre && d.UseSession { newMeta := mapStrsToIfaces(newRequestData.SessionMeta) if !reflect.DeepEqual(session.MetaData, newMeta) { diff --git a/gateway/mw_organisation_activity.go b/gateway/mw_organisation_activity.go index 25bccdbd3dc8..679a79e2fe45 100644 --- a/gateway/mw_organisation_activity.go +++ b/gateway/mw_organisation_activity.go @@ -1,13 +1,12 @@ package gateway import ( + "errors" "net/http" "sync" - - "errors" - "time" + "github.com/TykTechnologies/tyk/ctx" "github.com/TykTechnologies/tyk/request" "github.com/TykTechnologies/tyk/user" ) @@ -171,7 +170,7 @@ func (k *OrganizationMonitor) ProcessRequestLive(r *http.Request, orgSession use } // Lets keep a reference of the org - setCtxValue(r, OrgSessionContext, orgSession) + setCtxValue(r, ctx.OrgSessionContext, orgSession) // Request is valid, carry on return nil, http.StatusOK @@ -198,7 +197,7 @@ func (k *OrganizationMonitor) ProcessRequestOffThread(r *http.Request, orgSessio // Lets keep a reference of the org // session might be updated by go-routine AllowAccessNext and we loose those changes here // but it is OK as we need it in context for detailed org logging - setCtxValue(r, OrgSessionContext, orgSession) + setCtxValue(r, ctx.OrgSessionContext, orgSession) orgSessionCopy := orgSession go k.AllowAccessNext( diff --git a/gateway/mw_url_rewrite.go b/gateway/mw_url_rewrite.go index 97d57df78dde..06b03deb33ee 100644 --- a/gateway/mw_url_rewrite.go +++ b/gateway/mw_url_rewrite.go @@ -13,6 +13,7 @@ import ( "github.com/Sirupsen/logrus" "github.com/TykTechnologies/tyk/apidef" + "github.com/TykTechnologies/tyk/ctx" "github.com/TykTechnologies/tyk/regexp" "github.com/TykTechnologies/tyk/user" ) @@ -334,7 +335,7 @@ func (m *URLRewriteMiddleware) CheckHostRewrite(oldPath, newTarget string, r *ht newAsURL, _ := url.Parse(newTarget) if newAsURL.Scheme != LoopScheme && oldAsURL.Host != newAsURL.Host { log.Debug("Detected a host rewrite in pattern!") - setCtxValue(r, RetainHost, true) + setCtxValue(r, ctx.RetainHost, true) } } diff --git a/gateway/reverse_proxy.go b/gateway/reverse_proxy.go index c8ea602fe5b3..9ffb8e48fd78 100644 --- a/gateway/reverse_proxy.go +++ b/gateway/reverse_proxy.go @@ -33,6 +33,7 @@ import ( "github.com/TykTechnologies/tyk/apidef" "github.com/TykTechnologies/tyk/config" + "github.com/TykTechnologies/tyk/ctx" "github.com/TykTechnologies/tyk/regexp" "github.com/TykTechnologies/tyk/user" ) @@ -224,7 +225,7 @@ func TykNewSingleHostReverseProxy(target *url.URL, spec *APISpec) *ReverseProxy targetToUse := target - if spec.URLRewriteEnabled && req.Context().Value(RetainHost) == true { + if spec.URLRewriteEnabled && req.Context().Value(ctx.RetainHost) == true { log.Debug("Detected host rewrite, overriding target") tmpTarget, err := url.Parse(req.URL.String()) if err != nil { @@ -563,17 +564,17 @@ func (p *ReverseProxy) WrappedServeHTTP(rw http.ResponseWriter, req *http.Reques } p.TykAPISpec.Unlock() - ctx := req.Context() + reqCtx := req.Context() if cn, ok := rw.(http.CloseNotifier); ok { var cancel context.CancelFunc - ctx, cancel = context.WithCancel(ctx) + reqCtx, cancel = context.WithCancel(reqCtx) defer cancel() notifyChan := cn.CloseNotify() go func() { select { case <-notifyChan: cancel() - case <-ctx.Done(): + case <-reqCtx.Done(): } }() } @@ -593,15 +594,15 @@ func (p *ReverseProxy) WrappedServeHTTP(rw http.ResponseWriter, req *http.Reques log.Debug("UPSTREAM REQUEST URL: ", req.URL) // We need to double set the context for the outbound request to reprocess the target - if p.TykAPISpec.URLRewriteEnabled && req.Context().Value(RetainHost) == true { + if p.TykAPISpec.URLRewriteEnabled && req.Context().Value(ctx.RetainHost) == true { log.Debug("Detected host rewrite, notifying director") - setCtxValue(outreq, RetainHost, true) + setCtxValue(outreq, ctx.RetainHost, true) } if req.ContentLength == 0 { outreq.Body = nil // Issue 16036: nil Body for http.Transport retries } - outreq = outreq.WithContext(ctx) + outreq = outreq.WithContext(reqCtx) outreq.Header = cloneHeader(req.Header) @@ -665,7 +666,7 @@ func (p *ReverseProxy) WrappedServeHTTP(rw http.ResponseWriter, req *http.Reques if breakerEnforced { if !breakerConf.CB.Ready() { log.Debug("ON REQUEST: Circuit Breaker is in OPEN state") - p.ErrorHandler.HandleError(rw, logreq, "Service temporarily unavailable.", 503) + p.ErrorHandler.HandleError(rw, logreq, "Service temporarily unavailable.", 503, true) return nil } log.Debug("ON REQUEST: Circuit Breaker is in CLOSED or HALF-OPEN state") @@ -699,7 +700,7 @@ func (p *ReverseProxy) WrappedServeHTTP(rw http.ResponseWriter, req *http.Reques }).Error("http: proxy error: ", err) if strings.Contains(err.Error(), "timeout awaiting response headers") { - p.ErrorHandler.HandleError(rw, logreq, "Upstream service reached hard timeout.", http.StatusGatewayTimeout) + p.ErrorHandler.HandleError(rw, logreq, "Upstream service reached hard timeout.", http.StatusGatewayTimeout, true) if p.TykAPISpec.Proxy.ServiceDiscovery.UseDiscoveryService { if ServiceCache != nil { @@ -711,16 +712,16 @@ func (p *ReverseProxy) WrappedServeHTTP(rw http.ResponseWriter, req *http.Reques } if strings.Contains(err.Error(), "context canceled") { - p.ErrorHandler.HandleError(rw, logreq, "Client closed request", 499) + p.ErrorHandler.HandleError(rw, logreq, "Client closed request", 499, true) return nil } if strings.Contains(err.Error(), "no such host") { - p.ErrorHandler.HandleError(rw, logreq, "Upstream host lookup failed", http.StatusInternalServerError) + p.ErrorHandler.HandleError(rw, logreq, "Upstream host lookup failed", http.StatusInternalServerError, true) return nil } - p.ErrorHandler.HandleError(rw, logreq, "There was a problem proxying the request", http.StatusInternalServerError) + p.ErrorHandler.HandleError(rw, logreq, "There was a problem proxying the request", http.StatusInternalServerError, true) return nil } diff --git a/gateway/reverse_proxy_test.go b/gateway/reverse_proxy_test.go index 80c8866856f8..e3f2b4e66217 100644 --- a/gateway/reverse_proxy_test.go +++ b/gateway/reverse_proxy_test.go @@ -12,14 +12,12 @@ import ( "text/template" "time" - "github.com/TykTechnologies/tyk/test" - - "github.com/TykTechnologies/tyk/dnscache" - - "github.com/TykTechnologies/tyk/config" - "github.com/TykTechnologies/tyk/apidef" + "github.com/TykTechnologies/tyk/config" + "github.com/TykTechnologies/tyk/ctx" + "github.com/TykTechnologies/tyk/dnscache" "github.com/TykTechnologies/tyk/request" + "github.com/TykTechnologies/tyk/test" ) func TestCopyHeader_NoDuplicateCORSHeaders(t *testing.T) { @@ -94,7 +92,7 @@ func TestReverseProxyRetainHost(t *testing.T) { req := testReq(t, http.MethodGet, tc.inURL, nil) req.URL.Path = tc.inPath if tc.retainHost { - setCtxValue(req, RetainHost, true) + setCtxValue(req, ctx.RetainHost, true) } proxy := TykNewSingleHostReverseProxy(target, spec) diff --git a/gateway/testutil.go b/gateway/testutil.go index b5a74580c0a6..a508d911108b 100644 --- a/gateway/testutil.go +++ b/gateway/testutil.go @@ -656,8 +656,8 @@ func BuildAPI(apiGens ...func(spec *APISpec)) (specs []*APISpec) { panic(err) } - specs = append(specs, spec) gen(spec) + specs = append(specs, spec) } return specs diff --git a/goplugin/goplugin.go b/goplugin/goplugin.go new file mode 100644 index 000000000000..3fd15681d9b3 --- /dev/null +++ b/goplugin/goplugin.go @@ -0,0 +1,31 @@ +// +build goplugin + +package goplugin + +import ( + "errors" + "net/http" + "plugin" +) + +func GetHandler(path string, symbol string) (http.HandlerFunc, error) { + // try to load plugin + loadedPlugin, err := plugin.Open(path) + if err != nil { + return nil, err + } + + // try to lookup function symbol + funcSymbol, err := loadedPlugin.Lookup(symbol) + if err != nil { + return nil, err + } + + // try to cast symbol to real func + pluginHandler, ok := funcSymbol.(func(http.ResponseWriter, *http.Request)) + if !ok { + return nil, errors.New("could not cast function symbol to http.HandlerFunc") + } + + return pluginHandler, nil +} diff --git a/goplugin/no_goplugin.go b/goplugin/no_goplugin.go new file mode 100644 index 000000000000..8a0ffb273e79 --- /dev/null +++ b/goplugin/no_goplugin.go @@ -0,0 +1,12 @@ +// +build !goplugin + +package goplugin + +import ( + "fmt" + "net/http" +) + +func GetHandler(path string, symbol string) (http.HandlerFunc, error) { + return nil, fmt.Errorf("goplugin.GetHandler is disabled, please disable build flag 'nogoplugin'") +} diff --git a/test/goplugins/test_goplugin.go b/test/goplugins/test_goplugin.go new file mode 100644 index 000000000000..2ede4e9ae219 --- /dev/null +++ b/test/goplugins/test_goplugin.go @@ -0,0 +1,75 @@ +package main + +import ( + "encoding/json" + "net/http" + + "github.com/TykTechnologies/tyk/ctx" + "github.com/TykTechnologies/tyk/user" +) + +// MyPluginPre checks if session is NOT present, adds custom header +// with initial URI path and will be used as "pre" custom MW +func MyPluginPre(rw http.ResponseWriter, r *http.Request) { + session := ctx.GetSession(r) + if session != nil { + rw.WriteHeader(http.StatusInternalServerError) + return + } + + rw.Header().Add("X-Initial-URI", r.URL.RequestURI()) +} + +// MyPluginAuthCheck does custom auth and will be used as +// "auth_check" custom MW +func MyPluginAuthCheck(rw http.ResponseWriter, r *http.Request) { + // perform auth (only one token "abc" is allowed) + token := r.Header.Get("Authorization") + if token != "abc" { + rw.Header().Add("X-Auth-Result", "failed") + rw.WriteHeader(http.StatusForbidden) + return + } + + // create session + session := &user.SessionState{ + OrgID: "default", + Alias: "abc-session", + } + ctx.SetSession(r, session, token, true) + + rw.Header().Add("X-Auth-Result", "OK") +} + +// MyPluginPostKeyAuth checks if session is present, adds custom header with session-alias +// and will be used as "post_key_auth" custom MW +func MyPluginPostKeyAuth(rw http.ResponseWriter, r *http.Request) { + session := ctx.GetSession(r) + if session == nil { + rw.Header().Add("X-Session-Alias", "not found") + rw.WriteHeader(http.StatusInternalServerError) + return + } + + rw.Header().Add("X-Session-Alias", session.Alias) +} + +// MyPluginPost prepares and sends reply, will be used as "post" custom MW +func MyPluginPost(rw http.ResponseWriter, r *http.Request) { + + replyData := map[string]interface{}{ + "message": "post message", + } + + jsonData, err := json.Marshal(replyData) + if err != nil { + rw.WriteHeader(http.StatusInternalServerError) + return + } + + rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(http.StatusOK) + rw.Write(jsonData) +} + +func main() {} diff --git a/user/session.go b/user/session.go index 92cdce6014ef..a9e925b336b8 100644 --- a/user/session.go +++ b/user/session.go @@ -1,6 +1,9 @@ package user import ( + "crypto/md5" + "fmt" + "github.com/TykTechnologies/tyk/config" logger "github.com/TykTechnologies/tyk/log" ) @@ -91,6 +94,10 @@ type SessionState struct { keyHash string } +func (s *SessionState) MD5Hash() string { + return fmt.Sprintf("%x", md5.Sum([]byte(fmt.Sprintf("%+v", s)))) +} + func (s *SessionState) KeyHash() string { if s.keyHash == "" { panic("KeyHash cache not found. You should call `SetKeyHash` before.")