diff --git a/handlers/allowedparams.go b/handlers/allowedparams.go new file mode 100644 index 0000000..8a5f062 --- /dev/null +++ b/handlers/allowedparams.go @@ -0,0 +1,36 @@ +package handlers + +import ( + "net/http" + + "strings" + + "github.com/go-kit/kit/log" +) + +func NewAllowedParams(l log.Logger, allowedParams []string) func(h http.Handler) http.Handler { + var params = make(map[string]bool, len(allowedParams)) + for _, p := range allowedParams { + if p == "" { + continue + } + + params[p] = true + } + + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestParams := r.URL.Query() + + for p := range requestParams { + if _, exists := params[p]; !exists { + l.Log("error", "parameter is not white-listed", "parameter", p, "allowed", strings.Join(allowedParams, ",")) + http.Error(w, "Unregisterd parameter", http.StatusNotAcceptable) + return + } + } + + h.ServeHTTP(w, r) + }) + } +} diff --git a/main.go b/main.go index 0f350d8..a849103 100644 --- a/main.go +++ b/main.go @@ -19,11 +19,12 @@ import ( ) var ( - allowedHosts argumentList - imaginaryURL string - listenPort int64 - bucketRate float64 - bucketSize int64 + allowedHosts argumentList + allowedImaginaryParams string + imaginaryURL string + listenPort int64 + bucketRate float64 + bucketSize int64 Version = "dev" logger = log.With( @@ -50,6 +51,7 @@ func init() { flag.Int64Var(&listenPort, "listen-port", 8080, "Port to listen on") flag.Float64Var(&bucketRate, "bucket-rate", 20, "Rate limiter bucket fill rate (req/s)") flag.Int64Var(&bucketSize, "bucket-size", 500, "Rate limiter bucket size (burst capacity)") + flag.StringVar(&allowedImaginaryParams, "allowed-params", "", "A comma seperated list of parameters allows to be sent upstream. If empty, everything is allowed.") } @@ -97,10 +99,25 @@ type httpHandler func(h http.Handler) http.Handler func decorateHandler(h http.Handler, b *ratelimit.Bucket) http.Handler { decorators := []httpHandler{ - handlers.NewRateLimitHandler(b, logger), - handlers.NewIgnoreFaviconRequests(), handlers.NewValidateURLParameter(logger, allowedHosts), } + + if allowedImaginaryParams != "" { + decorators = append( + decorators, + handlers.NewAllowedParams( + logger, + strings.Split(allowedImaginaryParams, ","), + )) + } + + // Defining early needed handlers last + decorators = append( + decorators, + handlers.NewIgnoreFaviconRequests(), + handlers.NewRateLimitHandler(b, logger), + ) + var handler http.Handler = h for _, d := range decorators { handler = d(handler)