diff --git a/cns/NetworkContainerContract.go b/cns/NetworkContainerContract.go index a328f95a82..b3617415ac 100644 --- a/cns/NetworkContainerContract.go +++ b/cns/NetworkContainerContract.go @@ -107,6 +107,7 @@ type Route struct { type SetOrchestratorTypeRequest struct { OrchestratorType string DncPartitionKey string + NodeID string } // CreateNetworkContainerResponse specifies response of creating a network container. diff --git a/cns/restserver/restserver.go b/cns/restserver/restserver.go index 578eed6780..4dabf072a3 100644 --- a/cns/restserver/restserver.go +++ b/cns/restserver/restserver.go @@ -70,6 +70,7 @@ type httpRestServiceState struct { Location string NetworkType string OrchestratorType string + NodeID string Initialized bool ContainerIDByOrchestratorContext map[string]string // OrchestratorContext is key and value is NetworkContainerID. ContainerStatus map[string]containerstatus // NetworkContainerID is key. @@ -978,9 +979,12 @@ func (service *HTTPRestService) restoreState() error { func (service *HTTPRestService) setOrchestratorType(w http.ResponseWriter, r *http.Request) { log.Printf("[Azure CNS] setOrchestratorType") - var req cns.SetOrchestratorTypeRequest - returnMessage := "" - returnCode := 0 + var ( + req cns.SetOrchestratorTypeRequest + returnMessage string + returnCode int + nodeID string + ) err := service.Listener.Decode(w, r, &req) if err != nil { @@ -990,24 +994,31 @@ func (service *HTTPRestService) setOrchestratorType(w http.ResponseWriter, r *ht service.lock.Lock() service.dncPartitionKey = req.DncPartitionKey + nodeID = service.state.NodeID - switch req.OrchestratorType { - case cns.ServiceFabric: - fallthrough - case cns.Kubernetes: - fallthrough - case cns.WebApps: - fallthrough - case cns.Batch: - fallthrough - case cns.DBforPostgreSQL: - fallthrough - case cns.AzureFirstParty: - service.state.OrchestratorType = req.OrchestratorType - service.saveState() - default: - returnMessage = fmt.Sprintf("Invalid Orchestrator type %v", req.OrchestratorType) - returnCode = UnsupportedOrchestratorType + if nodeID == "" || nodeID == req.NodeID { + switch req.OrchestratorType { + case cns.ServiceFabric: + fallthrough + case cns.Kubernetes: + fallthrough + case cns.WebApps: + fallthrough + case cns.Batch: + fallthrough + case cns.DBforPostgreSQL: + fallthrough + case cns.AzureFirstParty: + service.state.OrchestratorType = req.OrchestratorType + service.state.NodeID = req.NodeID + service.saveState() + default: + returnMessage = fmt.Sprintf("Invalid Orchestrator type %v", req.OrchestratorType) + returnCode = UnsupportedOrchestratorType + } + } else { + returnMessage = fmt.Sprintf("Invalid request since this node has already been registered as %s", nodeID) + returnCode = InvalidRequest } service.lock.Unlock()