diff --git a/api_definition.go b/api_definition.go index 6da813638bf7..f96d7feb963c 100644 --- a/api_definition.go +++ b/api_definition.go @@ -342,14 +342,14 @@ func (a APIDefinitionLoader) FromDashboardService(endpoint, secret string) ([]*A } // FromCloud will connect and download ApiDefintions from a Mongo DB instance. -func (a APIDefinitionLoader) FromRPC(orgId string) []*APISpec { +func (a APIDefinitionLoader) FromRPC(orgId string) ([]*APISpec, error) { if rpc.IsEmergencyMode() { return LoadDefinitionsFromRPCBackup() } store := RPCStorageHandler{} if !store.Connect() { - return nil + return nil, errors.New("Can't connect RPC layer") } // enable segments @@ -364,18 +364,19 @@ func (a APIDefinitionLoader) FromRPC(orgId string) []*APISpec { //store.Disconnect() if rpc.LoadCount() > 0 { - saveRPCDefinitionsBackup(apiCollection) + if err := saveRPCDefinitionsBackup(apiCollection); err != nil { + return nil, err + } } return a.processRPCDefinitions(apiCollection) } -func (a APIDefinitionLoader) processRPCDefinitions(apiCollection string) []*APISpec { +func (a APIDefinitionLoader) processRPCDefinitions(apiCollection string) ([]*APISpec, error) { var apiDefs []*apidef.APIDefinition if err := json.Unmarshal([]byte(apiCollection), &apiDefs); err != nil { - log.Error("Failed decode: ", err) - return nil + return nil, err } var specs []*APISpec @@ -396,7 +397,7 @@ func (a APIDefinitionLoader) processRPCDefinitions(apiCollection string) []*APIS specs = append(specs, spec) } - return specs + return specs, nil } func (a APIDefinitionLoader) ParseDefinition(r io.Reader) *apidef.APIDefinition { diff --git a/main.go b/main.go index d04ceb012e01..f74420dccc6c 100644 --- a/main.go +++ b/main.go @@ -230,7 +230,7 @@ func buildConnStr(resource string) string { return config.Global().DBAppConfOptions.ConnectionString + resource } -func syncAPISpecs() int { +func syncAPISpecs() (int, error) { loader := APIDefinitionLoader{} apisMu.Lock() @@ -241,7 +241,7 @@ func syncAPISpecs() int { tmpSpecs, err := loader.FromDashboardService(connStr, config.Global().NodeSecret) if err != nil { log.Error("failed to load API specs: ", err) - return 0 + return 0, err } apiSpecs = tmpSpecs @@ -250,7 +250,11 @@ func syncAPISpecs() int { } else if config.Global().SlaveOptions.UseRPC { mainLog.Debug("Using RPC Configuration") - apiSpecs = loader.FromRPC(config.Global().SlaveOptions.RPCKey) + var err error + apiSpecs, err = loader.FromRPC(config.Global().SlaveOptions.RPCKey) + if err != nil { + return 0, err + } } else { apiSpecs = loader.FromDir(config.Global().AppPath) } @@ -269,10 +273,10 @@ func syncAPISpecs() int { } } - return len(apiSpecs) + return len(apiSpecs), nil } -func syncPolicies() int { +func syncPolicies() (count int, err error) { var pols map[string]user.Policy mainLog.Info("Loading policies") @@ -288,15 +292,14 @@ func syncPolicies() int { mainLog.Info("Using Policies from Dashboard Service") pols = LoadPoliciesFromDashboard(connStr, config.Global().NodeSecret, config.Global().Policies.AllowExplicitPolicyID) - case "rpc": mainLog.Debug("Using Policies from RPC") - pols = LoadPoliciesFromRPC(config.Global().SlaveOptions.RPCKey) + pols, err = LoadPoliciesFromRPC(config.Global().SlaveOptions.RPCKey) default: // this is the only case now where we need a policy record name if config.Global().Policies.PolicyRecordName == "" { mainLog.Debug("No policy record name defined, skipping...") - return 0 + return 0, nil } pols = LoadPoliciesFromFile(config.Global().Policies.PolicyRecordName) } @@ -311,7 +314,7 @@ func syncPolicies() int { policiesByID = pols } - return len(pols) + return len(pols), err } // stripSlashes removes any trailing slashes from the request's URL @@ -605,14 +608,22 @@ func doReload() { } // Load the API Policies - syncPolicies() + if _, err := syncPolicies(); err != nil { + mainLog.Error("Error during syncing policies:", err.Error()) + return + } + // load the specs - count := syncAPISpecs() - // skip re-loading only if dashboard service reported 0 APIs - // and current registry had 0 APIs - if count == 0 && apisByIDLen() == 0 { - mainLog.Warning("No API Definitions found, not reloading") + if count, err := syncAPISpecs(); err != nil { + mainLog.Error("Error during syncing apis:", err.Error()) return + } else { + // skip re-loading only if dashboard service reported 0 APIs + // and current registry had 0 APIs + if count == 0 && apisByIDLen() == 0 { + mainLog.Warning("No API Definitions found, not reloading") + return + } } // We have updated specs, lets load those... diff --git a/mw_jwt_test.go b/mw_jwt_test.go index c5d876b49549..c787b691d685 100644 --- a/mw_jwt_test.go +++ b/mw_jwt_test.go @@ -108,7 +108,7 @@ func prepareGenericJWTSession(testName string, method string, claimName string, var sessionFunc JwtCreator switch method { default: - log.Warningf("Signing method '%s' is not recognised, defaulting to HMAC signature") + log.Warningf("Signing method '%s' is not recognised, defaulting to HMAC signature", method) method = HMACSign fallthrough case HMACSign: diff --git a/policy.go b/policy.go index 94a72b822e2f..cc6b7f1effea 100644 --- a/policy.go +++ b/policy.go @@ -2,6 +2,7 @@ package main import ( "encoding/json" + "errors" "io/ioutil" "net/http" "os" @@ -157,14 +158,14 @@ func parsePoliciesFromRPC(list string) (map[string]user.Policy, error) { return policies, nil } -func LoadPoliciesFromRPC(orgId string) map[string]user.Policy { +func LoadPoliciesFromRPC(orgId string) (map[string]user.Policy, error) { if rpc.IsEmergencyMode() { return LoadPoliciesFromRPCBackup() } store := &RPCStorageHandler{} if !store.Connect() { - return nil + return nil, errors.New("Policies backup: Failed connecting to database") } rpcPolicies := store.GetPolicies(orgId) @@ -175,10 +176,12 @@ func LoadPoliciesFromRPC(orgId string) map[string]user.Policy { log.WithFields(logrus.Fields{ "prefix": "policy", }).Error("Failed decode: ", err, rpcPolicies) - return nil + return nil, err } - saveRPCPoliciesBackup(rpcPolicies) + if err := saveRPCPoliciesBackup(rpcPolicies); err != nil { + return nil, err + } - return policies + return policies, nil } diff --git a/rpc_backup_handlers.go b/rpc_backup_handlers.go index b366f1e4b4e8..36998b520c18 100644 --- a/rpc_backup_handlers.go +++ b/rpc_backup_handlers.go @@ -5,6 +5,8 @@ import ( "crypto/cipher" "crypto/rand" "encoding/base64" + "encoding/json" + "errors" "io" "strings" @@ -28,7 +30,7 @@ func getTagListAsString() string { return tagList } -func LoadDefinitionsFromRPCBackup() []*APISpec { +func LoadDefinitionsFromRPCBackup() ([]*APISpec, error) { tagList := getTagListAsString() checkKey := BackupApiKeyBase + tagList @@ -37,8 +39,7 @@ func LoadDefinitionsFromRPCBackup() []*APISpec { log.Info("[RPC] --> Loading API definitions from backup") if !connected { - log.Error("[RPC] --> RPC Backup recovery failed: redis connection failed") - return nil + return nil, errors.New("[RPC] --> RPC Backup recovery failed: redis connection failed") } secret := rightPad2Len(config.Global().Secret, "=", 32) @@ -46,15 +47,18 @@ func LoadDefinitionsFromRPCBackup() []*APISpec { apiListAsString := decrypt([]byte(secret), cryptoText) if err != nil { - log.Error("[RPC] --> Failed to get node backup (", checkKey, "): ", err) - return nil + return nil, errors.New("[RPC] --> Failed to get node backup (" + checkKey + "): " + err.Error()) } a := APIDefinitionLoader{} return a.processRPCDefinitions(apiListAsString) } -func saveRPCDefinitionsBackup(list string) { +func saveRPCDefinitionsBackup(list string) error { + if !json.Valid([]byte(list)) { + return errors.New("--> RPC Backup save failure: wrong format, skipping.") + } + log.Info("Storing RPC Definitions backup") tagList := getTagListAsString() @@ -66,19 +70,20 @@ func saveRPCDefinitionsBackup(list string) { log.Info("--> Connected to DB") if !connected { - log.Error("--> RPC Backup save failed: redis connection failed") - return + return errors.New("--> RPC Backup save failed: redis connection failed") } secret := rightPad2Len(config.Global().Secret, "=", 32) cryptoText := encrypt([]byte(secret), list) err := store.SetKey(BackupApiKeyBase+tagList, cryptoText, -1) if err != nil { - log.Error("Failed to store node backup: ", err) + return errors.New("Failed to store node backup: " + err.Error()) } + + return nil } -func LoadPoliciesFromRPCBackup() map[string]user.Policy { +func LoadPoliciesFromRPCBackup() (map[string]user.Policy, error) { tagList := getTagListAsString() checkKey := BackupPolicyKeyBase + tagList @@ -88,8 +93,7 @@ func LoadPoliciesFromRPCBackup() map[string]user.Policy { log.Info("[RPC] Loading Policies from backup") if !connected { - log.Error("[RPC] --> RPC Policy Backup recovery failed: redis connection failed") - return nil + return nil, errors.New("[RPC] --> RPC Policy Backup recovery failed: redis connection failed") } secret := rightPad2Len(config.Global().Secret, "=", 32) @@ -97,21 +101,24 @@ func LoadPoliciesFromRPCBackup() map[string]user.Policy { listAsString := decrypt([]byte(secret), cryptoText) if err != nil { - log.Error("[RPC] --> Failed to get node policy backup (", checkKey, "): ", err) - return nil + return nil, errors.New("[RPC] --> Failed to get node policy backup (" + checkKey + "): " + err.Error()) } if policies, err := parsePoliciesFromRPC(listAsString); err != nil { log.WithFields(logrus.Fields{ "prefix": "policy", }).Error("Failed decode: ", err) - return nil + return nil, err } else { - return policies + return policies, nil } } -func saveRPCPoliciesBackup(list string) { +func saveRPCPoliciesBackup(list string) error { + if !json.Valid([]byte(list)) { + return errors.New("--> RPC Backup save failure: wrong format, skipping.") + } + log.Info("Storing RPC policies backup") tagList := getTagListAsString() @@ -123,16 +130,17 @@ func saveRPCPoliciesBackup(list string) { log.Info("--> Connected to DB") if !connected { - log.Error("--> RPC Backup save failed: redis connection failed") - return + return errors.New("--> RPC Backup save failed: redis connection failed") } secret := rightPad2Len(config.Global().Secret, "=", 32) cryptoText := encrypt([]byte(secret), list) err := store.SetKey(BackupPolicyKeyBase+tagList, cryptoText, -1) if err != nil { - log.Error("Failed to store node backup: ", err) + return errors.New("Failed to store node backup: " + err.Error()) } + + return nil } // encrypt string to base64 crypto using AES diff --git a/rpc_test.go b/rpc_test.go index 7c9197e75e7c..1ea3b808000e 100644 --- a/rpc_test.go +++ b/rpc_test.go @@ -3,6 +3,8 @@ package main import ( + "github.com/TykTechnologies/tyk/cli" + "github.com/gorilla/mux" "testing" "time" @@ -59,6 +61,119 @@ func stopRPCMock(server *gorpc.Server) { rpc.Reset() } +const apiDefListTest = `[{ + "api_id": "1", + "definition": { + "location": "header", + "key": "version" + }, + "auth": {"auth_header_name": "authorization"}, + "version_data": { + "versions": { + "v1": {"name": "v1"} + } + }, + "proxy": { + "listen_path": "/v1", + "target_url": "` + testHttpAny + `" + } +}]` + +const apiDefListTest2 = `[{ + "api_id": "1", + "definition": { + "location": "header", + "key": "version" + }, + "auth": {"auth_header_name": "authorization"}, + "version_data": { + "versions": { + "v1": {"name": "v1"} + } + }, + "proxy": { + "listen_path": "/v1", + "target_url": "` + testHttpAny + `" + } +}, +{ + "api_id": "2", + "definition": { + "location": "header", + "key": "version" + }, + "auth": {"auth_header_name": "authorization"}, + "version_data": { + "versions": { + "v2": {"name": "v2"} + } + }, + "proxy": { + "listen_path": "/v2", + "target_url": "` + testHttpAny + `" + } +}]` + +func TestSyncAPISpecsRPCFailure_CheckGlobals(t *testing.T) { + // Mock RPC + callCount := 0 + dispatcher := gorpc.NewDispatcher() + dispatcher.AddFunc("GetApiDefinitions", func(clientAddr string, dr *DefRequest) (string, error) { + if callCount == 0 { + callCount += 1 + return `[]`, nil + } + + if callCount == 1 { + callCount += 1 + return apiDefListTest, nil + } + + if callCount == 2 { + callCount += 1 + return apiDefListTest2, nil + } + + if callCount == 3 { + callCount += 1 + return "malformed json", nil + } + + // clean up + return `[]`, nil + }) + dispatcher.AddFunc("Login", func(clientAddr, userKey string) bool { + return true + }) + dispatcher.AddFunc("GetPolicies", func(orgId string) (string, error) { + return `[]`, nil + }) + + rpc := startRPCMock(dispatcher) + defer stopRPCMock(rpc) + + // Three cases: 1 API, 2 APIs and Malformed data + exp := []int{1, 4, 6, 6, 2} + if *cli.HTTPProfile { + exp = []int{4, 6, 8, 8, 4} + } + + for _, e := range exp { + doReload() + + rtCnt := 0 + mainRouter.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { + rtCnt += 1 + //fmt.Println(route.GetPathTemplate()) + return nil + }) + + if rtCnt != e { + t.Errorf("There should be %v routes, got %v", e, rtCnt) + } + } +} + // Our RPC layer too racy, but not harmul, mostly global variables like RPCIsClientConnected func TestSyncAPISpecsRPCFailure(t *testing.T) { // Mock RPC @@ -73,7 +188,7 @@ func TestSyncAPISpecsRPCFailure(t *testing.T) { rpc := startRPCMock(dispatcher) defer stopRPCMock(rpc) - count := syncAPISpecs() + count, _ := syncAPISpecs() if count != 0 { t.Error("Should return empty value for malformed rpc response", apiSpecs) } @@ -103,12 +218,12 @@ func TestSyncAPISpecsRPCSuccess(t *testing.T) { ts := newTykTestServer() defer ts.Close() - apiBackup := LoadDefinitionsFromRPCBackup() + apiBackup, _ := LoadDefinitionsFromRPCBackup() if len(apiBackup) != 1 { t.Fatal("Should have APIs in backup") } - policyBackup := LoadPoliciesFromRPCBackup() + policyBackup, _ := LoadPoliciesFromRPCBackup() if len(policyBackup) != 1 { t.Fatal("Should have Policies in backup") } @@ -118,7 +233,7 @@ func TestSyncAPISpecsRPCSuccess(t *testing.T) { {Path: "/sample", Headers: authHeaders, Code: 200}, }...) - count := syncAPISpecs() + count, _ := syncAPISpecs() if count != 1 { t.Error("Should return array with one spec", apiSpecs) } @@ -191,11 +306,11 @@ func TestSyncAPISpecsRPCSuccess(t *testing.T) { {Path: "/sample", Headers: notCachedAuth, Code: 200}, }...) - if count := syncAPISpecs(); count != 2 { + if count, _ := syncAPISpecs(); count != 2 { t.Error("Should fetch latest specs", count) } - if count := syncPolicies(); count != 2 { + if count, _ := syncPolicies(); count != 2 { t.Error("Should fetch latest policies", count) } })