diff --git a/internals/proxy/middlewares/policy.go b/internals/proxy/middlewares/policy.go index 2eb54c96..004f8a44 100644 --- a/internals/proxy/middlewares/policy.go +++ b/internals/proxy/middlewares/policy.go @@ -3,6 +3,7 @@ package middlewares import ( "errors" "net/http" + "reflect" "github.com/codeshelldev/secured-signal-api/internals/config/structure" log "github.com/codeshelldev/secured-signal-api/utils/logger" @@ -78,10 +79,46 @@ func getField(key string, body map[string]any, headers map[string][]string) (any return value, errors.New("field not found") } +func doPoliciesApply(body map[string]any, headers map[string][]string, policies map[string]structure.FieldPolicy) (bool, string) { + for key, policy := range policies { + value, err := getField(key, body, headers) + + if err != nil { + continue + } + + switch asserted := value.(type) { + case string: + policyValue, ok := policy.Value.(string) + + if ok && asserted == policyValue { + return true, key + } + case int: + policyValue, ok := policy.Value.(int); + + if ok && asserted == policyValue { + return true, key + } + case bool: + policyValue, ok := policy.Value.(bool) + + if ok && asserted == policyValue { + return true, key + } + default: + if reflect.DeepEqual(value, policy.Value) { + return true, key + } + } + } + + return false, "" +} + func doBlock(body map[string]any, headers map[string][]string, policies map[string]structure.FieldPolicy) (bool, string) { - if policies == nil { - return false, "" - } else if len(policies) <= 0 { + if len(policies) == 0 { + // default: allow all return false, "" } @@ -89,43 +126,28 @@ func doBlock(body map[string]any, headers map[string][]string, policies map[stri var cause string - var isExplictlyAllowed, isExplicitlyBlocked bool - - for field, policy := range allowed { - value, err := getField(field, body, headers) - - if value == policy.Value && err == nil { - isExplictlyAllowed = true - cause = field - break - } + isExplicitlyAllowed, cause := doPoliciesApply(body, headers, allowed) + isExplicitlyBlocked, cause := doPoliciesApply(body, headers, blocked) + + // explicit allow > block + if isExplicitlyAllowed { + return false, cause } - - for field, policy := range blocked { - value, err := getField(field, body, headers) - - if value == policy.Value && err == nil { - isExplicitlyBlocked = true - cause = field - break - } - } - - // Block all except explicitly Allowed - if len(blocked) == 0 && len(allowed) != 0 { - return !isExplictlyAllowed, cause + + if isExplicitlyBlocked { + return true, cause } - // Allow all except explicitly Blocked - if len(allowed) == 0 && len(blocked) != 0 { - return isExplicitlyBlocked, cause + // only allow policies -> block anything not allowed + if len(allowed) > 0 && len(blocked) == 0 { + return true, cause } - // Excplicitly Blocked except excplictly Allowed - if len(blocked) != 0 && len(allowed) != 0 { - return isExplicitlyBlocked && !isExplictlyAllowed, cause + // only block polcicies -> allow anything not blocked + if len(blocked) > 0 && len(allowed) == 0 { + return false, cause } - // Block all - return true, "" + // no match -> default: block all + return true, cause } diff --git a/utils/logger/logger.go b/utils/logger/logger.go index fe265a72..3d173f16 100644 --- a/utils/logger/logger.go +++ b/utils/logger/logger.go @@ -60,6 +60,12 @@ func Format(data ...any) string { res += value case int: res += strconv.Itoa(value) + case bool: + if value { + res += "true" + } else { + res += "false" + } default: lines := strings.Split(jsonutils.Pretty(value), "\n") diff --git a/utils/request/requestkeys/requestkeys.go b/utils/request/requestkeys/requestkeys.go index 7bb7e81b..33b7f64b 100644 --- a/utils/request/requestkeys/requestkeys.go +++ b/utils/request/requestkeys/requestkeys.go @@ -46,10 +46,16 @@ func PrefixHeaders(headers map[string][]string) map[string][]string { return res } -func GetFromBodyAndHeaders(field Field, body map[string]any, headers map[string][]string) any { +func PrefixBodyAndHeaders(body map[string]any, headers map[string][]string) (map[string]any, map[string][]string) { body = PrefixBody(body) headers = PrefixHeaders(headers) + return body, headers +} + +func GetFromBodyAndHeaders(field Field, body map[string]any, headers map[string][]string) any { + body, headers = PrefixBodyAndHeaders(body, headers) + switch(field.Prefix) { case BodyPrefix: return GetByField(field, body)