Skip to content

Commit

Permalink
converted Go-plugin handler func to func(http.ResponseWriter, *http.R…
Browse files Browse the repository at this point in the history
…equest), added ctx helpers
  • Loading branch information
dencoded committed Mar 18, 2019
1 parent ff5805c commit 54c3adb
Show file tree
Hide file tree
Showing 10 changed files with 192 additions and 244 deletions.
95 changes: 35 additions & 60 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"github.com/TykTechnologies/tyk/apidef"
"github.com/TykTechnologies/tyk/config"
"github.com/TykTechnologies/tyk/ctx"
"github.com/TykTechnologies/tyk/storage"
"github.com/TykTechnologies/tyk/user"
)
Expand Down Expand Up @@ -1784,7 +1785,7 @@ func setCtxValue(r *http.Request, key, val interface{}) {
}

func ctxGetData(r *http.Request) map[string]interface{} {
if v := r.Context().Value(ContextData); v != nil {
if v := r.Context().Value(ctx.ContextData); v != nil {
return v.(map[string]interface{})
}
return nil
Expand All @@ -1794,64 +1795,38 @@ func ctxSetData(r *http.Request, m map[string]interface{}) {
if m == nil {
panic("setting a nil context ContextData")
}
setCtxValue(r, ContextData, m)
setCtxValue(r, ctx.ContextData, m)
}

func ctxGetSession(r *http.Request) *user.SessionState {
if v := r.Context().Value(SessionData); v != nil {
return v.(*user.SessionState)
}
return nil
return ctx.GetSession(r)
}

func ctxSetSession(r *http.Request, s *user.SessionState, token string, scheduleUpdate bool) {
if s == nil {
panic("setting a nil context SessionData")
}

if token == "" {
token = ctxGetAuthToken(r)
}

if s.KeyHashEmpty() {
s.SetKeyHash(storage.HashKey(token))
}

ctx := r.Context()
ctx = context.WithValue(ctx, SessionData, s)
ctx = context.WithValue(ctx, AuthToken, token)

if scheduleUpdate {
ctx = context.WithValue(ctx, UpdateSession, true)
}

setContext(r, ctx)
ctx.SetSession(r, s, token, scheduleUpdate)
}

func ctxScheduleSessionUpdate(r *http.Request) {
setCtxValue(r, UpdateSession, true)
setCtxValue(r, ctx.UpdateSession, true)
}

func ctxDisableSessionUpdate(r *http.Request) {
setCtxValue(r, UpdateSession, false)
setCtxValue(r, ctx.UpdateSession, false)
}

func ctxSessionUpdateScheduled(r *http.Request) bool {
if v := r.Context().Value(UpdateSession); v != nil {
if v := r.Context().Value(ctx.UpdateSession); v != nil {
return v.(bool)
}
return false
}

func ctxGetAuthToken(r *http.Request) string {
if v := r.Context().Value(AuthToken); v != nil {
return v.(string)
}
return ""
return ctx.GetAuthToken(r)
}

func ctxGetTrackedPath(r *http.Request) string {
if v := r.Context().Value(TrackThisEndpoint); v != nil {
if v := r.Context().Value(ctx.TrackThisEndpoint); v != nil {
return v.(string)
}
return ""
Expand All @@ -1861,34 +1836,34 @@ func ctxSetTrackedPath(r *http.Request, p string) {
if p == "" {
panic("setting a nil context TrackThisEndpoint")
}
setCtxValue(r, TrackThisEndpoint, p)
setCtxValue(r, ctx.TrackThisEndpoint, p)
}

func ctxGetDoNotTrack(r *http.Request) bool {
return r.Context().Value(DoNotTrackThisEndpoint) == true
return r.Context().Value(ctx.DoNotTrackThisEndpoint) == true
}

func ctxSetDoNotTrack(r *http.Request, b bool) {
setCtxValue(r, DoNotTrackThisEndpoint, b)
setCtxValue(r, ctx.DoNotTrackThisEndpoint, b)
}

func ctxGetVersionInfo(r *http.Request) *apidef.VersionInfo {
if v := r.Context().Value(VersionData); v != nil {
if v := r.Context().Value(ctx.VersionData); v != nil {
return v.(*apidef.VersionInfo)
}
return nil
}

func ctxSetVersionInfo(r *http.Request, v *apidef.VersionInfo) {
setCtxValue(r, VersionData, v)
setCtxValue(r, ctx.VersionData, v)
}

func ctxSetOrigRequestURL(r *http.Request, url *url.URL) {
setCtxValue(r, OrigRequestURL, url)
setCtxValue(r, ctx.OrigRequestURL, url)
}

func ctxGetOrigRequestURL(r *http.Request) *url.URL {
if v := r.Context().Value(OrigRequestURL); v != nil {
if v := r.Context().Value(ctx.OrigRequestURL); v != nil {
if urlVal, ok := v.(*url.URL); ok {
return urlVal
}
Expand All @@ -1898,11 +1873,11 @@ func ctxGetOrigRequestURL(r *http.Request) *url.URL {
}

func ctxSetUrlRewritePath(r *http.Request, path string) {
setCtxValue(r, UrlRewritePath, path)
setCtxValue(r, ctx.UrlRewritePath, path)
}

func ctxGetUrlRewritePath(r *http.Request) string {
if v := r.Context().Value(UrlRewritePath); v != nil {
if v := r.Context().Value(ctx.UrlRewritePath); v != nil {
if strVal, ok := v.(string); ok {
return strVal
}
Expand All @@ -1911,7 +1886,7 @@ func ctxGetUrlRewritePath(r *http.Request) string {
}

func ctxSetCheckLoopLimits(r *http.Request, b bool) {
setCtxValue(r, CheckLoopLimits, b)
setCtxValue(r, ctx.CheckLoopLimits, b)
}

// Should we check Rate limits and Quotas?
Expand All @@ -1921,19 +1896,19 @@ func ctxCheckLimits(r *http.Request) bool {
return true
}

if v := r.Context().Value(CheckLoopLimits); v != nil {
if v := r.Context().Value(ctx.CheckLoopLimits); v != nil {
return v.(bool)
}

return false
}

func ctxSetRequestMethod(r *http.Request, path string) {
setCtxValue(r, RequestMethod, path)
setCtxValue(r, ctx.RequestMethod, path)
}

func ctxGetRequestMethod(r *http.Request) string {
if v := r.Context().Value(RequestMethod); v != nil {
if v := r.Context().Value(ctx.RequestMethod); v != nil {
if strVal, ok := v.(string); ok {
return strVal
}
Expand All @@ -1942,19 +1917,19 @@ func ctxGetRequestMethod(r *http.Request) string {
}

func ctxGetDefaultVersion(r *http.Request) bool {
return r.Context().Value(VersionDefault) != nil
return r.Context().Value(ctx.VersionDefault) != nil
}

func ctxSetDefaultVersion(r *http.Request) {
setCtxValue(r, VersionDefault, true)
setCtxValue(r, ctx.VersionDefault, true)
}

func ctxLoopingEnabled(r *http.Request) bool {
return ctxLoopLevel(r) > 0
}

func ctxLoopLevel(r *http.Request) int {
if v := r.Context().Value(LoopLevel); v != nil {
if v := r.Context().Value(ctx.LoopLevel); v != nil {
if intVal, ok := v.(int); ok {
return intVal
}
Expand All @@ -1964,7 +1939,7 @@ func ctxLoopLevel(r *http.Request) int {
}

func ctxSetLoopLevel(r *http.Request, value int) {
setCtxValue(r, LoopLevel, value)
setCtxValue(r, ctx.LoopLevel, value)
}

func ctxIncLoopLevel(r *http.Request, loopLimit int) {
Expand All @@ -1973,7 +1948,7 @@ func ctxIncLoopLevel(r *http.Request, loopLimit int) {
}

func ctxLoopLevelLimit(r *http.Request) int {
if v := r.Context().Value(LoopLevelLimit); v != nil {
if v := r.Context().Value(ctx.LoopLevelLimit); v != nil {
if intVal, ok := v.(int); ok {
return intVal
}
Expand All @@ -1985,12 +1960,12 @@ func ctxLoopLevelLimit(r *http.Request) int {
func ctxSetLoopLimit(r *http.Request, limit int) {
// Can be set only one time per request
if ctxLoopLevelLimit(r) == 0 && limit > 0 {
setCtxValue(r, LoopLevelLimit, limit)
setCtxValue(r, ctx.LoopLevelLimit, limit)
}
}

func ctxThrottleLevelLimit(r *http.Request) int {
if v := r.Context().Value(ThrottleLevelLimit); v != nil {
if v := r.Context().Value(ctx.ThrottleLevelLimit); v != nil {
if intVal, ok := v.(int); ok {
return intVal
}
Expand All @@ -2000,7 +1975,7 @@ func ctxThrottleLevelLimit(r *http.Request) int {
}

func ctxThrottleLevel(r *http.Request) int {
if v := r.Context().Value(ThrottleLevel); v != nil {
if v := r.Context().Value(ctx.ThrottleLevel); v != nil {
if intVal, ok := v.(int); ok {
return intVal
}
Expand All @@ -2012,12 +1987,12 @@ func ctxThrottleLevel(r *http.Request) int {
func ctxSetThrottleLimit(r *http.Request, limit int) {
// Can be set only one time per request
if ctxThrottleLevelLimit(r) == 0 && limit > 0 {
setCtxValue(r, ThrottleLevelLimit, limit)
setCtxValue(r, ctx.ThrottleLevelLimit, limit)
}
}

func ctxSetThrottleLevel(r *http.Request, value int) {
setCtxValue(r, ThrottleLevel, value)
setCtxValue(r, ctx.ThrottleLevel, value)
}

func ctxIncThrottleLevel(r *http.Request, throttleLimit int) {
Expand All @@ -2026,9 +2001,9 @@ func ctxIncThrottleLevel(r *http.Request, throttleLimit int) {
}

func ctxTraceEnabled(r *http.Request) bool {
return r.Context().Value(Trace) != nil
return r.Context().Value(ctx.Trace) != nil
}

func ctxSetTrace(r *http.Request) {
setCtxValue(r, Trace, true)
setCtxValue(r, ctx.Trace, true)
}
20 changes: 1 addition & 19 deletions api_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,6 @@ func processSpec(spec *APISpec, apisByListen map[string]int,
BaseMiddleware: baseMid,
Path: obj.Path,
SymbolName: obj.Name,
Pre: true,
UseSession: obj.RequireSession,
Auth: false,
},
)
} else {
Expand Down Expand Up @@ -329,9 +326,6 @@ func processSpec(spec *APISpec, apisByListen map[string]int,
BaseMiddleware: baseMid,
Path: obj.Path,
SymbolName: obj.Name,
Pre: false,
UseSession: obj.RequireSession,
Auth: false,
},
)
} else {
Expand Down Expand Up @@ -360,9 +354,6 @@ func processSpec(spec *APISpec, apisByListen map[string]int,
BaseMiddleware: baseMid,
Path: obj.Path,
SymbolName: obj.Name,
Pre: true,
UseSession: obj.RequireSession,
Auth: false,
},
)
} else {
Expand Down Expand Up @@ -428,9 +419,6 @@ func processSpec(spec *APISpec, apisByListen map[string]int,
BaseMiddleware: baseMid,
Path: mwAuthCheckFunc.Path,
SymbolName: mwAuthCheckFunc.Name,
Pre: true,
UseSession: false,
Auth: true,
},
)
}
Expand All @@ -450,9 +438,6 @@ func processSpec(spec *APISpec, apisByListen map[string]int,
BaseMiddleware: baseMid,
Path: obj.Path,
SymbolName: obj.Name,
Pre: false,
UseSession: obj.RequireSession,
Auth: false,
},
)
} else {
Expand Down Expand Up @@ -486,9 +471,6 @@ func processSpec(spec *APISpec, apisByListen map[string]int,
BaseMiddleware: baseMid,
Path: obj.Path,
SymbolName: obj.Name,
Pre: false,
UseSession: obj.RequireSession,
Auth: false,
},
)
} else {
Expand Down Expand Up @@ -581,7 +563,7 @@ func (d *DummyProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
handler = targetAPI.middlewareChain.ThisHandler
} else {
handler := ErrorHandler{*d.SH.Base()}
handler.HandleError(w, r, "Can't detect loop target", http.StatusInternalServerError)
handler.HandleError(w, r, "Can't detect loop target", http.StatusInternalServerError, true)
return
}
}
Expand Down
Loading

0 comments on commit 54c3adb

Please sign in to comment.